|
|
#pragma once
|
|
|
|
|
|
#include <c10/util/irange.h>
|
|
|
#include <memory>
|
|
|
#include <mutex>
|
|
|
|
|
|
namespace at::native {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename Params>
|
|
|
struct ParamsHash {
|
|
|
|
|
|
|
|
|
static_assert(std::is_standard_layout_v<Params>, "Params is not POD");
|
|
|
|
|
|
size_t operator()(const Params& params) const {
|
|
|
auto ptr = reinterpret_cast<const uint8_t*>(¶ms);
|
|
|
uint32_t value = 0x811C9DC5;
|
|
|
for (const auto i : c10::irange(sizeof(Params))) {
|
|
|
value ^= ptr[i];
|
|
|
value *= 0x01000193;
|
|
|
}
|
|
|
return (size_t)value;
|
|
|
}
|
|
|
};
|
|
|
|
|
|
template <typename Params>
|
|
|
struct ParamsEqual {
|
|
|
|
|
|
|
|
|
static_assert(std::is_standard_layout_v<Params>, "Params is not POD");
|
|
|
|
|
|
bool operator()(const Params& a, const Params& b) const {
|
|
|
auto ptr1 = reinterpret_cast<const uint8_t*>(&a);
|
|
|
auto ptr2 = reinterpret_cast<const uint8_t*>(&b);
|
|
|
return memcmp(ptr1, ptr2, sizeof(Params)) == 0;
|
|
|
}
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
struct ParamsWrapper {
|
|
|
T pod;
|
|
|
static_assert(
|
|
|
std::is_standard_layout_v<T>,
|
|
|
"ParamsWrapper cannot wrap non-POD data");
|
|
|
|
|
|
ParamsWrapper() {
|
|
|
memset(&(this->pod), 0, sizeof(this->pod));
|
|
|
}
|
|
|
|
|
|
ParamsWrapper(const ParamsWrapper& other) {
|
|
|
memcpy(&(this->pod), &(other.pod), sizeof(this->pod));
|
|
|
}
|
|
|
|
|
|
ParamsWrapper(ParamsWrapper&& other) noexcept {
|
|
|
memcpy(&(this->pod), &(other.pod), sizeof(this->pod));
|
|
|
}
|
|
|
|
|
|
ParamsWrapper& operator=(const ParamsWrapper& other) {
|
|
|
memcpy(&(this->pod), &(other.pod), sizeof(this->pod));
|
|
|
return *this;
|
|
|
}
|
|
|
|
|
|
ParamsWrapper& operator=(ParamsWrapper&& other) noexcept {
|
|
|
memcpy(&(this->pod), &(other.pod), sizeof(this->pod));
|
|
|
return *this;
|
|
|
}
|
|
|
|
|
|
inline friend bool operator==(
|
|
|
const ParamsWrapper& lhs,
|
|
|
const ParamsWrapper& rhs) noexcept {
|
|
|
auto ptr1 = reinterpret_cast<const uint8_t*>(&(lhs.pod));
|
|
|
auto ptr2 = reinterpret_cast<const uint8_t*>(&(rhs.pod));
|
|
|
return memcmp(ptr1, ptr2, sizeof(lhs.pod)) == 0;
|
|
|
}
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
template <typename ParamsWrapper>
|
|
|
struct ParamsWrapperHash {
|
|
|
|
|
|
|
|
|
static_assert(
|
|
|
std::is_standard_layout_v<decltype(ParamsWrapper::pod)>,
|
|
|
"ParamsWrapper cannot wrap non-POD data");
|
|
|
|
|
|
size_t operator()(const ParamsWrapper& params_wrapper) const {
|
|
|
auto ptr = reinterpret_cast<const uint8_t*>(&(params_wrapper.pod));
|
|
|
uint32_t value = 0x811C9DC5;
|
|
|
for (const auto i : c10::irange(sizeof(params_wrapper.pod))) {
|
|
|
value ^= ptr[i];
|
|
|
value *= 0x01000193;
|
|
|
}
|
|
|
return (size_t)value;
|
|
|
}
|
|
|
};
|
|
|
|
|
|
}
|
|
|
|