File size: 10,671 Bytes
c1af2fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/tunable/TunableOp.h>
#include <ATen/cuda/tunable/GemmCommon.h>
#include <c10/util/StringUtil.h>
#include <fmt/printf.h>

#define ROCBLAS_BETA_FEATURES_API
#include <rocblas/rocblas.h>

#define TORCH_ROCBLAS_CHECK(EXPR)                 \
  do {                                            \
    rocblas_status __err = EXPR;                  \
    TORCH_CHECK(__err == rocblas_status_success,  \
                "rocblas error: ",                \
                rocblas_status_to_string(__err),  \
                " when calling `" #EXPR "`");     \
  } while (0)

namespace at::cuda::tunable {

template <typename T>
constexpr rocblas_datatype RocBlasDataTypeFor();

template <>
constexpr rocblas_datatype RocBlasDataTypeFor<float>() {
  return rocblas_datatype_f32_r;
}

template <>
constexpr rocblas_datatype RocBlasDataTypeFor<double>() {
  return rocblas_datatype_f64_r;
}

template <>
constexpr rocblas_datatype RocBlasDataTypeFor<Half>() {
  return rocblas_datatype_f16_r;
}

template <>
constexpr rocblas_datatype RocBlasDataTypeFor<BFloat16>() {
  return rocblas_datatype_bf16_r;
}

template <>
constexpr rocblas_datatype RocBlasDataTypeFor<c10::complex<float>>() {
  return rocblas_datatype_f32_c;
}

template <>
constexpr rocblas_datatype RocBlasDataTypeFor<c10::complex<double>>() {
  return rocblas_datatype_f64_c;
}

template <typename T>
constexpr rocblas_datatype RocBlasComputeTypeFor();

template <>
constexpr rocblas_datatype RocBlasComputeTypeFor<float>() {
  return rocblas_datatype_f32_r;
}

template <>
constexpr rocblas_datatype RocBlasComputeTypeFor<double>() {
  return rocblas_datatype_f64_r;
}

template <>
constexpr rocblas_datatype RocBlasComputeTypeFor<Half>() {
  // Note that we're returning the _compute_ type for a given datatype.
  // As of 12/2022, using compute type FP16 for 16-bit floats was much
  // slower than using compute type FP32. So we use FP32 compute even for
  // FP16 datatypes. This is how GEMM is implemented even in the function
  // rocblasGemmHelper (see fpgeneric.h)
  return rocblas_datatype_f32_r;
}

template <>
constexpr rocblas_datatype RocBlasComputeTypeFor<BFloat16>() {
  // Note that we're returning the _compute_ type for a given datatype.
  // As of 12/2022, using compute type FP16 for 16-bit floats was much
  // slower than using compute type FP32. So we use FP32 compute even for
  // BF16 datatypes. This is how GEMM is implemented even in the function
  // rocblasGemmHelper (see fpgeneric.h)
  return rocblas_datatype_f32_r;
}

template <>
constexpr rocblas_datatype RocBlasComputeTypeFor<c10::complex<float>>() {
  return rocblas_datatype_f32_c;
}

template <>
constexpr rocblas_datatype RocBlasComputeTypeFor<c10::complex<double>>() {
  return rocblas_datatype_f64_c;
}

template <typename T>
auto DoCastForHalfOrBfloat16(const T fp) {
  return fp;
}

template <>
inline auto DoCastForHalfOrBfloat16<Half>(const Half fp) {
  // alpha and beta should be the same as compute_type, in Half case it is float.
  float h = fp;
  return h;
}

template <>
inline auto DoCastForHalfOrBfloat16<BFloat16>(const BFloat16 fp) {
  // alpha and beta should be the same as compute_type, in bfloat16 case it is float.
  float h = fp;
  return h;
}

static rocblas_operation _rocblasOpFromChar(char op) {
  switch (op) {
    case 'n':
    case 'N':
      return rocblas_operation_none;
    case 't':
    case 'T':
      return rocblas_operation_transpose;
    case 'c':
    case 'C':
      return rocblas_operation_conjugate_transpose;
  }
  TORCH_CHECK(false,
      "_rocblasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`");
}

template <typename T>
class RocblasGemmOp : public Callable<GemmParams<T>> {
  public:
    RocblasGemmOp(int solution) : solution_{solution} {}

    TuningStatus Call(const GemmParams<T>* params) override {
      auto input_output_type = RocBlasDataTypeFor<T>();
      if (at::globalContext().allowTF32CuBLAS() && input_output_type == rocblas_datatype_f32_r)
        return FAIL;  // no support for TF32 in rocBLAS
      auto compute_type = RocBlasComputeTypeFor<T>();
      auto h_a = DoCastForHalfOrBfloat16(params->alpha);
      auto h_b = DoCastForHalfOrBfloat16(params->beta);
      auto status = rocblas_gemm_ex(
          (rocblas_handle)at::cuda::getCurrentCUDABlasHandle(),
          _rocblasOpFromChar(params->transa),
          _rocblasOpFromChar(params->transb),
          params->m, params->n, params->k,
          &h_a,
          params->a, input_output_type, params->lda,
          params->b, input_output_type, params->ldb,
          &h_b,
          params->c, input_output_type, params->ldc,
          params->c, input_output_type, params->ldc,
          compute_type,
          rocblas_gemm_algo_solution_index,
          solution_,
          rocblas_gemm_flags_none);
      if (status != rocblas_status_success) {
        return FAIL;
      }
      return OK;
    }

  private:
    int solution_;
};

template <typename T>
auto GetRocBlasGemmTypeStringAndOps() {
  rocblas_handle handle = (rocblas_handle)at::cuda::getCurrentCUDABlasHandle();
  int solution_size;
  auto input_output_type = RocBlasDataTypeFor<T>();
  auto compute_type = RocBlasComputeTypeFor<T>();
  // Get the number of available solutions
  TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
                                                            input_output_type,
                                                            input_output_type,
                                                            compute_type,
                                                            rocblas_gemm_flags_none,
                                                            nullptr,
                                                            &solution_size));
  std::vector<int> solutions(solution_size);
  // Get the list of available solutions
  TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
                                                            input_output_type,
                                                            input_output_type,
                                                            compute_type,
                                                            rocblas_gemm_flags_none,
                                                            solutions.data(),
                                                            &solution_size));
  std::vector<std::pair<std::string, std::unique_ptr<Callable<GemmParams<T>>>>> ret;
  for (size_t i = 0; i < solutions.size(); ++i) {
    auto callable = std::make_unique<RocblasGemmOp<T>>(solutions[i]);
    ret.emplace_back(std::make_pair(fmt::sprintf("Gemm_Rocblas_%d", solutions[i]), std::move(callable)));
  }
  return ret;
}

template <typename T>
class RocblasGemmStridedBatchedOp : public Callable<GemmStridedBatchedParams<T>> {
  public:
    RocblasGemmStridedBatchedOp(int solution) : solution_{solution} {}

    TuningStatus Call(const GemmStridedBatchedParams<T>* params) override {
      auto input_output_type = RocBlasDataTypeFor<T>();
      if (at::globalContext().allowTF32CuBLAS() && input_output_type == rocblas_datatype_f32_r)
        return FAIL;  // no support for TF32 in rocBLAS
      auto compute_type = RocBlasComputeTypeFor<T>();
      auto h_a = DoCastForHalfOrBfloat16(params->alpha);
      auto h_b = DoCastForHalfOrBfloat16(params->beta);
      auto status = rocblas_gemm_strided_batched_ex(
          (rocblas_handle)at::cuda::getCurrentCUDABlasHandle(),
          _rocblasOpFromChar(params->transa),
          _rocblasOpFromChar(params->transb),
          params->m, params->n, params->k,
          &h_a,
          params->a, input_output_type, params->lda, params->stride_a,
          params->b, input_output_type, params->ldb, params->stride_b,
          &h_b,
          params->c, input_output_type, params->ldc, params->stride_c,
          params->c, input_output_type, params->ldc, params->stride_c,
          params->batch,
          compute_type,
          rocblas_gemm_algo_solution_index,
          solution_,
          rocblas_gemm_flags_none);
      if (status != rocblas_status_success) {
        return FAIL;
      }
      return OK;
    }

  private:
    int solution_;
};

template <typename T>
auto GetRocBlasGemmStridedBatchedTypeStringAndOps() {
  rocblas_handle handle = (rocblas_handle)at::cuda::getCurrentCUDABlasHandle();
  int solution_size;
  auto input_output_type = RocBlasDataTypeFor<T>();
  auto compute_type = RocBlasComputeTypeFor<T>();
  // Get the number of available solutions
  TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
                                                            input_output_type,
                                                            input_output_type,
                                                            compute_type,
                                                            rocblas_gemm_flags_none,
                                                            nullptr,
                                                            &solution_size));
  std::vector<int> solutions(solution_size);
  // Get the list of available solutions
  TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
                                                            input_output_type,
                                                            input_output_type,
                                                            compute_type,
                                                            rocblas_gemm_flags_none,
                                                            solutions.data(),
                                                            &solution_size));
  // Sort the solutions in ascending order to make the solution vector deterministic across runs
  std::sort(solutions.begin(), solutions.end());

  std::vector<std::pair<std::string, std::unique_ptr<Callable<GemmStridedBatchedParams<T>>>>> ret;
  for (size_t i = 0; i < solutions.size(); ++i) {
    auto callable = std::make_unique<RocblasGemmStridedBatchedOp<T>>(solutions[i]);
    ret.emplace_back(std::make_pair(c10::str("Gemm_Rocblas_", solutions[i]), std::move(callable)));
  }
  return ret;
}

}  // namespace at::cuda::tunable