File size: 7,120 Bytes
9dd3461
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
#pragma once
#include <ATen/jit_macros.h>

#if AT_USE_JITERATOR()

#include <c10/util/variant.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/cuda/detail/OffsetCalculator.cuh>
#include <ATen/native/cuda/jit_utils.h>
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <ATen/native/cuda/JitLoops.cuh>

#include <string>
#include <vector>

namespace at {
namespace native {


#define AT_FOR_8_CASES(_)  \
  _(1)                      \
  _(2)                      \
  _(3)                      \
  _(4)                      \
  _(5)                      \
  _(6)                      \
  _(7)                      \
  _(8)

#define AT_FOR_8_CASES_WITH_COMMA(_)  \
  _(1)     ,                           \
  _(2)     ,                           \
  _(3)     ,                           \
  _(4)     ,                           \
  _(5)     ,                           \
  _(6)     ,                           \
  _(7)     ,                           \
  _(8)

c10::SmallVector<std::string> get_extra_args_typenames(const c10::SmallVector<at::Scalar>& extra_args) {
  c10::SmallVector<std::string> args_typenames(extra_args.size());
  for (auto i = 0; i < extra_args.size(); ++i) {
    args_typenames[i] = at::cuda::jit::typeName(extra_args[i].type());
  }
  return args_typenames;
}

int can_vectorize_up_to(at::ScalarType type, char* pointer) {
  switch(type) {
#define DEFINE_CASE(ctype, scalartype)                                   \
    case ScalarType::scalartype : return memory::can_vectorize_up_to<ctype>(pointer);

    AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE)
#undef DEFINE_CASE

    default: TORCH_INTERNAL_ASSERT(false, "Unrecognized ScalarType: ", type);
  }
}

// jitted version of the above
// See Note [Jiterator], this relies on the assumptions enumerated there
int jitted_can_vectorize_up_to(const TensorIteratorBase& iter) {
  const at::ScalarType common_dtype = iter.common_dtype();
  const at::ScalarType result_dtype = common_dtype;

  // Deals with output
  int result = can_vectorize_up_to(result_dtype, static_cast<char*>(iter.data_ptr(0)));

  // Incorporates input(s)
  for (auto i = 1; i < iter.ntensors(); ++i) {
    result = std::min<int>(result, can_vectorize_up_to(common_dtype, static_cast<char*>(iter.data_ptr(i))));
  }

  return result;
}

template<bool IS_INPUT, int N>
static std::unique_ptr<OffsetCalculator<N>> make_unique_offset_calculator(
          const TensorIteratorBase& iter) {
  // array size can not be 0, this happens when N == 0
  constexpr int array_size = std::max<int>(N, 1);
  TORCH_INTERNAL_ASSERT(N == (IS_INPUT ? iter.ninputs() : iter.noutputs()));

  std::array<const int64_t*, array_size> strides;
  int64_t element_sizes[array_size];
  for (int i = 0; i < N; i++) {
    int index = IS_INPUT ? i + iter.noutputs() : i;
    strides[i] = iter.strides(index).data();
    element_sizes[i] = iter.element_size(index);
  }
  return std::make_unique<OffsetCalculator<N>>(iter.ndim(), iter.shape().data(), strides.data(), element_sizes);
}

template <bool IS_INPUT>
struct OffsetCalculatorVariant {
#define DEFINE_CASE(index) std::unique_ptr<OffsetCalculator<index>>
  using OffsetCalculatorTypes = c10::variant<
    AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE)
  >;
#undef DEFINE_CASE

  OffsetCalculatorVariant(const TensorIteratorBase& iter) {
    int num = IS_INPUT ? iter.ninputs() : iter.noutputs();

    switch(num) {
#define DEFINE_CASE(index)        \
      case index : v = make_unique_offset_calculator<IS_INPUT, index>(iter); break;

      AT_FOR_8_CASES(DEFINE_CASE)
#undef DEFINE_CASE
      default:
        TORCH_CHECK(false, "OffsetCalculatorVariant is not implemented for num_tensor = ", num);
    }
  }

  void* data_ptr() {
    return c10::visit([](auto & v){ return static_cast<void*>(v.get()); }, v);
  }

 private:
  OffsetCalculatorTypes v;
};

struct ArrayVariant {
// works for up to 8 input + 8 outputs
#define DEFINE_CASE(index) at::detail::Array<char*, index>, at::detail::Array<char*, index+8>
  using ArrayTypes = c10::variant<
    AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE)
  >;
#undef DEFINE_CASE

  ArrayVariant(const TensorIteratorBase& iter) {
    int ntensors = iter.ntensors();
    switch(ntensors) {
#define DEFINE_CASE(index)                                            \
      case index: array = at::detail::Array<char*, index>{}; break;   \
      case index+8: array = at::detail::Array<char*, index+8>{}; break;

      AT_FOR_8_CASES(DEFINE_CASE)
#undef DEFINE_CASE

      default:
        TORCH_CHECK(false, "ArrayVariant is not implemented for ntensors = ", ntensors);
    }

    c10::visit([&](auto& a) {
      for (auto i = 0; i < ntensors; ++i) {
        a[i] = (char*)iter.data_ptr(i);
      }
    }, array);
  }

  void* data_ptr() {
    return c10::visit([](auto & a){ return static_cast<void*>(&a); }, array);
  }

private:
  ArrayTypes array;
};

struct TrivialOffsetCalculatorVariant {
#define DEFINE_CASE(index) TrivialOffsetCalculator<index>
  using TrivialOffsetCalculatorTypes = c10::variant<
    AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE)
  >;
#undef DEFINE_CASE

  TrivialOffsetCalculatorVariant(int num) {
    switch(num) {
#define DEFINE_CASE(index)      \
      case index: v = TrivialOffsetCalculator<index>(); break;

      AT_FOR_8_CASES(DEFINE_CASE)
#undef DEFINE_CASE

      default:
        TORCH_CHECK(false, "TrivialOffsetCalculatorVariant is not implemented for num_tensors = ", num);
    }
  }

  void* data_ptr() {
    return c10::visit([](auto & v){ return static_cast<void*>(&v); }, v);
  }

private:
  TrivialOffsetCalculatorTypes v;
};

struct LoadWithCastVariant {
#define DEFINE_CASE(index) std::unique_ptr<memory::LoadWithCast<index>>
  using LoadWithCastPtr = c10::variant<
    AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE)
  >;
#undef DEFINE_CASE

  LoadWithCastVariant(const TensorIteratorBase& iter) {
    int arity = iter.ninputs();
    switch(arity) {
#define DEFINE_CASE(index)      \
      case index: v = std::make_unique<memory::LoadWithCast<index>>(iter); break;

      AT_FOR_8_CASES(DEFINE_CASE)
#undef DEFINE_CASE

      default:
        TORCH_CHECK(false, "LoadWithCastVariant is not implemented for ninputs = ", arity);
    }
  }

  void* data_ptr() {
    return c10::visit([](auto & v){ return static_cast<void*>(v.get()); }, v);
  }

private:
  LoadWithCastPtr v;
};

struct StoreWithCastVariant {
#define DEFINE_CASE(index) std::unique_ptr<memory::StoreWithCast<index>>
  using StoreWithCastPtr = c10::variant<
    AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE)
  >;
#undef DEFINE_CASE

  StoreWithCastVariant(const TensorIteratorBase& iter) {
    int num = iter.noutputs();
    switch(num) {
#define DEFINE_CASE(index)      \
      case index: v = std::make_unique<memory::StoreWithCast<index>>(iter); break;

      AT_FOR_8_CASES(DEFINE_CASE)
#undef DEFINE_CASE

      default:
        TORCH_CHECK(false, "StoreWithCastVariant is not implemented for noutputs = ", num);
    }
  }

  void* data_ptr() {
    return c10::visit([](auto & v){ return static_cast<void*>(v.get()); }, v);
  }

private:
  StoreWithCastPtr v;
};

}} // namespace at::native


#endif // AT_USE_JITERATOR()