File size: 3,228 Bytes
c1af2fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#pragma once

#include <c10/util/irange.h>
#include <memory>
#include <mutex>

namespace at::native {

// Hashing machinery for Params
// Fowler–Noll–Vo hash function
// see
// https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function
template <typename Params>
struct ParamsHash {
  // Params must be a POD because we read out its memory
  // contents as char* when hashing
  static_assert(std::is_standard_layout_v<Params>, "Params is not POD");

  size_t operator()(const Params& params) const {
    auto ptr = reinterpret_cast<const uint8_t*>(&params);
    uint32_t value = 0x811C9DC5;
    for (const auto i : c10::irange(sizeof(Params))) {
      value ^= ptr[i];
      value *= 0x01000193;
    }
    return (size_t)value;
  }
};

template <typename Params>
struct ParamsEqual {
  // Params must be a POD because we read out its memory
  // contents as char* when comparing
  static_assert(std::is_standard_layout_v<Params>, "Params is not POD");

  bool operator()(const Params& a, const Params& b) const {
    auto ptr1 = reinterpret_cast<const uint8_t*>(&a);
    auto ptr2 = reinterpret_cast<const uint8_t*>(&b);
    return memcmp(ptr1, ptr2, sizeof(Params)) == 0;
  }
};

// Provide explicit byte-for-byte constructors to avoid uwittingly leaving
// padding bytes unitialized (e.g., when passing Params by value)
template <typename T>
struct ParamsWrapper {
  T pod;
  static_assert(
      std::is_standard_layout_v<T>,
      "ParamsWrapper cannot wrap non-POD data");

  ParamsWrapper() {
    memset(&(this->pod), 0, sizeof(this->pod));
  }

  ParamsWrapper(const ParamsWrapper& other) {
    memcpy(&(this->pod), &(other.pod), sizeof(this->pod));
  }

  ParamsWrapper(ParamsWrapper&& other) noexcept {
    memcpy(&(this->pod), &(other.pod), sizeof(this->pod));
  }

  ParamsWrapper& operator=(const ParamsWrapper& other) {
    memcpy(&(this->pod), &(other.pod), sizeof(this->pod));
    return *this;
  }

  ParamsWrapper& operator=(ParamsWrapper&& other) noexcept {
    memcpy(&(this->pod), &(other.pod), sizeof(this->pod));
    return *this;
  }

  inline friend bool operator==(
      const ParamsWrapper& lhs,
      const ParamsWrapper& rhs) noexcept {
    auto ptr1 = reinterpret_cast<const uint8_t*>(&(lhs.pod));
    auto ptr2 = reinterpret_cast<const uint8_t*>(&(rhs.pod));
    return memcmp(ptr1, ptr2, sizeof(lhs.pod)) == 0;
  }
};

// Wrapped version: this allows the outer struct to have custom copy and move
// constructors for additional safety
template <typename ParamsWrapper>
struct ParamsWrapperHash {
  // Params must be a POD because we read out its memory
  // contents as char* when hashing
  static_assert(
      std::is_standard_layout_v<decltype(ParamsWrapper::pod)>,
      "ParamsWrapper cannot wrap non-POD data");

  size_t operator()(const ParamsWrapper& params_wrapper) const {
    auto ptr = reinterpret_cast<const uint8_t*>(&(params_wrapper.pod));
    uint32_t value = 0x811C9DC5;
    for (const auto i : c10::irange(sizeof(params_wrapper.pod))) {
      value ^= ptr[i];
      value *= 0x01000193;
    }
    return (size_t)value;
  }
};

} // namespace at::native