File size: 3,228 Bytes
c1af2fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
#pragma once
#include <c10/util/irange.h>
#include <memory>
#include <mutex>
namespace at::native {
// Hashing machinery for Params
// Fowler–Noll–Vo hash function
// see
// https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function
template <typename Params>
struct ParamsHash {
// Params must be a POD because we read out its memory
// contents as char* when hashing
static_assert(std::is_standard_layout_v<Params>, "Params is not POD");
size_t operator()(const Params& params) const {
auto ptr = reinterpret_cast<const uint8_t*>(¶ms);
uint32_t value = 0x811C9DC5;
for (const auto i : c10::irange(sizeof(Params))) {
value ^= ptr[i];
value *= 0x01000193;
}
return (size_t)value;
}
};
template <typename Params>
struct ParamsEqual {
// Params must be a POD because we read out its memory
// contents as char* when comparing
static_assert(std::is_standard_layout_v<Params>, "Params is not POD");
bool operator()(const Params& a, const Params& b) const {
auto ptr1 = reinterpret_cast<const uint8_t*>(&a);
auto ptr2 = reinterpret_cast<const uint8_t*>(&b);
return memcmp(ptr1, ptr2, sizeof(Params)) == 0;
}
};
// Provide explicit byte-for-byte constructors to avoid uwittingly leaving
// padding bytes unitialized (e.g., when passing Params by value)
template <typename T>
struct ParamsWrapper {
T pod;
static_assert(
std::is_standard_layout_v<T>,
"ParamsWrapper cannot wrap non-POD data");
ParamsWrapper() {
memset(&(this->pod), 0, sizeof(this->pod));
}
ParamsWrapper(const ParamsWrapper& other) {
memcpy(&(this->pod), &(other.pod), sizeof(this->pod));
}
ParamsWrapper(ParamsWrapper&& other) noexcept {
memcpy(&(this->pod), &(other.pod), sizeof(this->pod));
}
ParamsWrapper& operator=(const ParamsWrapper& other) {
memcpy(&(this->pod), &(other.pod), sizeof(this->pod));
return *this;
}
ParamsWrapper& operator=(ParamsWrapper&& other) noexcept {
memcpy(&(this->pod), &(other.pod), sizeof(this->pod));
return *this;
}
inline friend bool operator==(
const ParamsWrapper& lhs,
const ParamsWrapper& rhs) noexcept {
auto ptr1 = reinterpret_cast<const uint8_t*>(&(lhs.pod));
auto ptr2 = reinterpret_cast<const uint8_t*>(&(rhs.pod));
return memcmp(ptr1, ptr2, sizeof(lhs.pod)) == 0;
}
};
// Wrapped version: this allows the outer struct to have custom copy and move
// constructors for additional safety
template <typename ParamsWrapper>
struct ParamsWrapperHash {
// Params must be a POD because we read out its memory
// contents as char* when hashing
static_assert(
std::is_standard_layout_v<decltype(ParamsWrapper::pod)>,
"ParamsWrapper cannot wrap non-POD data");
size_t operator()(const ParamsWrapper& params_wrapper) const {
auto ptr = reinterpret_cast<const uint8_t*>(&(params_wrapper.pod));
uint32_t value = 0x811C9DC5;
for (const auto i : c10::irange(sizeof(params_wrapper.pod))) {
value ^= ptr[i];
value *= 0x01000193;
}
return (size_t)value;
}
};
} // namespace at::native
|