|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
#include <ATen/cuda/tunable/GemmCommon.h>
|
|
|
#ifdef USE_ROCM
|
|
|
#include <ATen/cuda/tunable/GemmHipblaslt.h>
|
|
|
#include <ATen/cuda/tunable/GemmRocblas.h>
|
|
|
#endif
|
|
|
#include <ATen/cuda/tunable/TunableOp.h>
|
|
|
#include <c10/cuda/CUDACachingAllocator.h>
|
|
|
#include <c10/util/Float8_e4m3fn.h>
|
|
|
#include <c10/util/Float8_e4m3fnuz.h>
|
|
|
#include <c10/util/Float8_e5m2.h>
|
|
|
#include <c10/util/Float8_e5m2fnuz.h>
|
|
|
#include <c10/util/Float8_e8m0fnu.h>
|
|
|
#include <c10/util/StringUtil.h>
|
|
|
#include <fmt/printf.h>
|
|
|
|
|
|
namespace at::cuda::tunable {
|
|
|
|
|
|
template <typename T>
|
|
|
class DefaultGemmOp : public Callable<GemmParams<T>> {
|
|
|
public:
|
|
|
TuningStatus Call(const GemmParams<T>* params) override {
|
|
|
at::cuda::blas::gemm_internal<T>(
|
|
|
params->transa, params->transb,
|
|
|
params->m, params->n, params->k,
|
|
|
params->alpha,
|
|
|
params->a, params->lda,
|
|
|
params->b, params->ldb,
|
|
|
params->beta,
|
|
|
params->c, params->ldc);
|
|
|
return OK;
|
|
|
}
|
|
|
};
|
|
|
|
|
|
static bool _transposeBoolFromChar(char op) {
|
|
|
return op == 't' || op == 'T';
|
|
|
}
|
|
|
|
|
|
template <typename T>
|
|
|
class DefaultGemmAndBiasOp : public Callable<GemmAndBiasParams<T>> {
|
|
|
public:
|
|
|
TuningStatus Call(const GemmAndBiasParams<T>* params) override {
|
|
|
at::cuda::blas::gemm_and_bias<T>(
|
|
|
_transposeBoolFromChar(params->transa),
|
|
|
_transposeBoolFromChar(params->transb),
|
|
|
params->m, params->n, params->k,
|
|
|
params->alpha,
|
|
|
params->a, params->lda,
|
|
|
params->b, params->ldb,
|
|
|
params->bias,
|
|
|
params->c, params->ldc,
|
|
|
params->activation);
|
|
|
return OK;
|
|
|
}
|
|
|
};
|
|
|
|
|
|
template <typename T>
|
|
|
class DefaultGemmStridedBatchedOp : public Callable<GemmStridedBatchedParams<T>> {
|
|
|
public:
|
|
|
TuningStatus Call(const GemmStridedBatchedParams<T>* params) override {
|
|
|
at::cuda::blas::bgemm_internal<T>(
|
|
|
params->transa, params->transb,
|
|
|
params->m, params->n, params->k,
|
|
|
params->alpha,
|
|
|
params->a, params->lda, params->stride_a,
|
|
|
params->b, params->ldb, params->stride_b,
|
|
|
params->beta,
|
|
|
params->c, params->ldc, params->stride_c,
|
|
|
params->batch);
|
|
|
return OK;
|
|
|
}
|
|
|
};
|
|
|
|
|
|
template <typename T>
|
|
|
class DefaultScaledGemmOp : public Callable<ScaledGemmParams<T>> {
|
|
|
public:
|
|
|
TuningStatus Call(const ScaledGemmParams<T>* params) override {
|
|
|
at::cuda::blas::scaled_gemm(
|
|
|
params->transa,
|
|
|
params->transb,
|
|
|
params->m,
|
|
|
params->n,
|
|
|
params->k,
|
|
|
params->a,
|
|
|
params->a_scale_ptr,
|
|
|
params->lda,
|
|
|
params->a_dtype,
|
|
|
params->a_scale_dtype,
|
|
|
params->b,
|
|
|
params->b_scale_ptr,
|
|
|
params->ldb,
|
|
|
params->b_dtype,
|
|
|
params->b_scale_dtype,
|
|
|
params->bias_ptr,
|
|
|
params->bias_dtype,
|
|
|
params->c,
|
|
|
params->c_scale_ptr,
|
|
|
params->ldc,
|
|
|
params->c_dtype,
|
|
|
params->use_fast_accum,
|
|
|
params->use_rowwise);
|
|
|
return OK;
|
|
|
}
|
|
|
};
|
|
|
|
|
|
template <typename T>
|
|
|
inline bool IsZero(T v) {
|
|
|
return v == 0.0f;
|
|
|
}
|
|
|
|
|
|
template <>
|
|
|
inline bool IsZero(BFloat16 v) {
|
|
|
return v.x == 0;
|
|
|
}
|
|
|
|
|
|
template <>
|
|
|
inline bool IsZero(Half v) {
|
|
|
return float(v) == 0.0f;
|
|
|
}
|
|
|
|
|
|
template <>
|
|
|
inline bool IsZero(c10::complex<double> v) {
|
|
|
return v == 0.0;
|
|
|
}
|
|
|
|
|
|
template <>
|
|
|
inline bool IsZero(c10::complex<float> v) {
|
|
|
return v == 0.0f;
|
|
|
}
|
|
|
|
|
|
template <typename T>
|
|
|
inline const char* TypeName(T v) {
|
|
|
return "unknown";
|
|
|
}
|
|
|
|
|
|
template <>
|
|
|
inline const char* TypeName(float v) {
|
|
|
if (at::globalContext().allowTF32CuBLAS()) {
|
|
|
return "tf32";
|
|
|
} else {
|
|
|
return "float";
|
|
|
}
|
|
|
}
|
|
|
|
|
|
template <>
|
|
|
inline const char* TypeName(double v) {
|
|
|
return "double";
|
|
|
}
|
|
|
|
|
|
template <>
|
|
|
inline const char* TypeName(BFloat16 v) {
|
|
|
return "BFloat16";
|
|
|
}
|
|
|
|
|
|
template <>
|
|
|
inline const char* TypeName(Half v) {
|
|
|
return "Half";
|
|
|
}
|
|
|
|
|
|
template <>
|
|
|
inline const char* TypeName(Float8_e4m3fn v) {
|
|
|
return "Float8_e4m3fn";
|
|
|
}
|
|
|
|
|
|
template <>
|
|
|
inline const char* TypeName(Float8_e5m2 v) {
|
|
|
return "Float8_e5m2";
|
|
|
}
|
|
|
|
|
|
template <>
|
|
|
inline const char* TypeName(Float8_e4m3fnuz v) {
|
|
|
return "Float8_e4m3fnuz";
|
|
|
}
|
|
|
|
|
|
template <>
|
|
|
inline const char* TypeName(Float8_e5m2fnuz v) {
|
|
|
return "Float8_e5m2fnuz";
|
|
|
}
|
|
|
|
|
|
template <>
|
|
|
inline const char* TypeName(Float8_e8m0fnu v) {
|
|
|
return "Float8_e8m0fnu";
|
|
|
}
|
|
|
|
|
|
template <>
|
|
|
inline const char* TypeName(c10::complex<double> v) {
|
|
|
return "c10::complex<double>";
|
|
|
}
|
|
|
|
|
|
template <>
|
|
|
inline const char* TypeName(c10::complex<float> v) {
|
|
|
return "c10::complex<float>";
|
|
|
}
|
|
|
|
|
|
template <typename T, BlasOp ALayout, BlasOp BLayout>
|
|
|
class GemmTunableOp : public TunableOp<GemmParams<T>> {
|
|
|
public:
|
|
|
GemmTunableOp() {
|
|
|
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmOp<T>>());
|
|
|
|
|
|
#ifdef USE_ROCM
|
|
|
static const auto env_rocblas = c10::utils::check_env("PYTORCH_TUNABLEOP_ROCBLAS_ENABLED");
|
|
|
if (!env_rocblas.has_value() || env_rocblas.value()) {
|
|
|
for (auto&& [name, op] : GetRocBlasGemmTypeStringAndOps<T>()) {
|
|
|
this->RegisterOp(std::move(name), std::move(op));
|
|
|
}
|
|
|
}
|
|
|
|
|
|
static const auto env_hipblaslt = c10::utils::check_env("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
|
|
|
if (!env_hipblaslt.has_value() || env_hipblaslt.value()) {
|
|
|
|
|
|
if constexpr (
|
|
|
!std::is_same_v<T, c10::complex<float>> &&
|
|
|
!std::is_same_v<T, c10::complex<double>>) {
|
|
|
for (auto&& [name, op] : GetHipBlasLtGemmTypeStringAndOps<T, ALayout, BLayout>()) {
|
|
|
this->RegisterOp(std::move(name), std::move(op));
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
#endif
|
|
|
|
|
|
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmOp<T>>());
|
|
|
}
|
|
|
|
|
|
std::string Signature() override {
|
|
|
return fmt::sprintf("GemmTunableOp_%s_%c%c", TypeName<T>(T{}), BlasOpToString(ALayout), BlasOpToString(BLayout));
|
|
|
}
|
|
|
};
|
|
|
|
|
|
template <typename T, BlasOp ALayout, BlasOp BLayout>
|
|
|
class GemmAndBiasTunableOp : public TunableOp<GemmAndBiasParams<T>> {
|
|
|
public:
|
|
|
GemmAndBiasTunableOp() {
|
|
|
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmAndBiasOp<T>>());
|
|
|
|
|
|
#ifdef USE_ROCM
|
|
|
static const auto env_hipblaslt = c10::utils::check_env("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
|
|
|
if (!env_hipblaslt.has_value() || env_hipblaslt.value()) {
|
|
|
|
|
|
if constexpr (
|
|
|
!std::is_same_v<T, c10::complex<float>> &&
|
|
|
!std::is_same_v<T, c10::complex<double>>) {
|
|
|
for (auto&& [name, op] : GetHipBlasLtGemmAndBiasTypeStringAndOps<T, ALayout, BLayout>()) {
|
|
|
this->RegisterOp(std::move(name), std::move(op));
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
#endif
|
|
|
|
|
|
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmAndBiasOp<T>>());
|
|
|
}
|
|
|
|
|
|
std::string Signature() override {
|
|
|
return fmt::sprintf("GemmAndBiasTunableOp_%s_%c%c", TypeName<T>(T{}), BlasOpToString(ALayout), BlasOpToString(BLayout));
|
|
|
}
|
|
|
};
|
|
|
|
|
|
template <typename T, BlasOp ALayout, BlasOp BLayout>
|
|
|
class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>> {
|
|
|
public:
|
|
|
GemmStridedBatchedTunableOp() {
|
|
|
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmStridedBatchedOp<T>>());
|
|
|
|
|
|
#ifdef USE_ROCM
|
|
|
static const auto env_rocblas = c10::utils::check_env("PYTORCH_TUNABLEOP_ROCBLAS_ENABLED");
|
|
|
if (!env_rocblas.has_value() || env_rocblas.value()) {
|
|
|
for (auto&& [name, op] : GetRocBlasGemmStridedBatchedTypeStringAndOps<T>()) {
|
|
|
this->RegisterOp(std::move(name), std::move(op));
|
|
|
}
|
|
|
}
|
|
|
|
|
|
static const auto env_hipblaslt = c10::utils::check_env("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
|
|
|
if (!env_hipblaslt.has_value() || env_hipblaslt.value()) {
|
|
|
|
|
|
if constexpr (
|
|
|
!std::is_same_v<T, c10::complex<float>> &&
|
|
|
!std::is_same_v<T, c10::complex<double>>) {
|
|
|
for (auto&& [name, op] : GetHipBlasLtGemmStridedBatchedTypeStringAndOps<T, ALayout, BLayout>()) {
|
|
|
this->RegisterOp(std::move(name), std::move(op));
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
#endif
|
|
|
|
|
|
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmStridedBatchedOp<T>>());
|
|
|
}
|
|
|
|
|
|
std::string Signature() override {
|
|
|
return fmt::sprintf("GemmStridedBatchedTunableOp_%s_%c%c", TypeName<T>(T{}), BlasOpToString(ALayout), BlasOpToString(BLayout));
|
|
|
}
|
|
|
};
|
|
|
|
|
|
template <typename AT, typename BT, typename CT, BlasOp ALayout, BlasOp BLayout>
|
|
|
class ScaledGemmTunableOp : public TunableOp<ScaledGemmParams<CT>> {
|
|
|
public:
|
|
|
ScaledGemmTunableOp() {
|
|
|
this->RegisterOp(std::string("Default"), std::make_unique<DefaultScaledGemmOp<CT>>());
|
|
|
|
|
|
#ifdef USE_ROCM
|
|
|
for (auto&& [name, op] : GetHipBlasLtScaledGemmTypeStringAndOps<AT, BT, CT, ALayout, BLayout>()) {
|
|
|
this->RegisterOp(std::move(name), std::move(op));
|
|
|
}
|
|
|
#endif
|
|
|
|
|
|
this->RegisterOp(std::string("Default"), std::make_unique<DefaultScaledGemmOp<CT>>());
|
|
|
}
|
|
|
|
|
|
std::string Signature() override {
|
|
|
return fmt::sprintf("ScaledGemmTunableOp_%s_%s_%s_%c%c",
|
|
|
TypeName<AT>(AT{}),
|
|
|
TypeName<BT>(BT{}),
|
|
|
TypeName<CT>(CT{}),
|
|
|
BlasOpToString(ALayout), BlasOpToString(BLayout));
|
|
|
}
|
|
|
};
|
|
|
|
|
|
}
|
|
|
|