cranky-coder08's picture
Add files using upload-large-folder tool
c1af2fa verified
// Original TunableOp is from onnxruntime.
// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
//
// Adapting TunableOp into PyTorch
// Copyright (c) Advanced Micro Devices, Inc.
//
#pragma once
#include <ATen/cuda/tunable/GemmCommon.h>
#ifdef USE_ROCM
#include <ATen/cuda/tunable/GemmHipblaslt.h>
#include <ATen/cuda/tunable/GemmRocblas.h>
#endif
#include <ATen/cuda/tunable/TunableOp.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e4m3fnuz.h>
#include <c10/util/Float8_e5m2.h>
#include <c10/util/Float8_e5m2fnuz.h>
#include <c10/util/Float8_e8m0fnu.h>
#include <c10/util/StringUtil.h>
#include <fmt/printf.h>
namespace at::cuda::tunable {
template <typename T>
class DefaultGemmOp : public Callable<GemmParams<T>> {
public:
TuningStatus Call(const GemmParams<T>* params) override {
at::cuda::blas::gemm_internal<T>(
params->transa, params->transb,
params->m, params->n, params->k,
params->alpha,
params->a, params->lda,
params->b, params->ldb,
params->beta,
params->c, params->ldc);
return OK;
}
};
static bool _transposeBoolFromChar(char op) {
return op == 't' || op == 'T';
}
template <typename T>
class DefaultGemmAndBiasOp : public Callable<GemmAndBiasParams<T>> {
public:
TuningStatus Call(const GemmAndBiasParams<T>* params) override {
at::cuda::blas::gemm_and_bias<T>(
_transposeBoolFromChar(params->transa),
_transposeBoolFromChar(params->transb),
params->m, params->n, params->k,
params->alpha,
params->a, params->lda,
params->b, params->ldb,
params->bias,
params->c, params->ldc,
params->activation);
return OK;
}
};
template <typename T>
class DefaultGemmStridedBatchedOp : public Callable<GemmStridedBatchedParams<T>> {
public:
TuningStatus Call(const GemmStridedBatchedParams<T>* params) override {
at::cuda::blas::bgemm_internal<T>(
params->transa, params->transb,
params->m, params->n, params->k,
params->alpha,
params->a, params->lda, params->stride_a,
params->b, params->ldb, params->stride_b,
params->beta,
params->c, params->ldc, params->stride_c,
params->batch);
return OK;
}
};
template <typename T>
class DefaultScaledGemmOp : public Callable<ScaledGemmParams<T>> {
public:
TuningStatus Call(const ScaledGemmParams<T>* params) override {
at::cuda::blas::scaled_gemm(
params->transa,
params->transb,
params->m,
params->n,
params->k,
params->a,
params->a_scale_ptr,
params->lda,
params->a_dtype,
params->a_scale_dtype,
params->b,
params->b_scale_ptr,
params->ldb,
params->b_dtype,
params->b_scale_dtype,
params->bias_ptr,
params->bias_dtype,
params->c,
params->c_scale_ptr,
params->ldc,
params->c_dtype,
params->use_fast_accum,
params->use_rowwise);
return OK;
}
};
template <typename T>
inline bool IsZero(T v) {
return v == 0.0f;
}
template <>
inline bool IsZero(BFloat16 v) {
return v.x == 0;
}
template <>
inline bool IsZero(Half v) {
return float(v) == 0.0f;
}
template <>
inline bool IsZero(c10::complex<double> v) {
return v == 0.0;
}
template <>
inline bool IsZero(c10::complex<float> v) {
return v == 0.0f;
}
template <typename T>
inline const char* TypeName(T v) {
return "unknown";
}
template <>
inline const char* TypeName(float v) {
if (at::globalContext().allowTF32CuBLAS()) {
return "tf32";
} else {
return "float";
}
}
template <>
inline const char* TypeName(double v) {
return "double";
}
template <>
inline const char* TypeName(BFloat16 v) {
return "BFloat16";
}
template <>
inline const char* TypeName(Half v) {
return "Half";
}
template <>
inline const char* TypeName(Float8_e4m3fn v) {
return "Float8_e4m3fn";
}
template <>
inline const char* TypeName(Float8_e5m2 v) {
return "Float8_e5m2";
}
template <>
inline const char* TypeName(Float8_e4m3fnuz v) {
return "Float8_e4m3fnuz";
}
template <>
inline const char* TypeName(Float8_e5m2fnuz v) {
return "Float8_e5m2fnuz";
}
template <>
inline const char* TypeName(Float8_e8m0fnu v) {
return "Float8_e8m0fnu";
}
template <>
inline const char* TypeName(c10::complex<double> v) {
return "c10::complex<double>";
}
template <>
inline const char* TypeName(c10::complex<float> v) {
return "c10::complex<float>";
}
template <typename T, BlasOp ALayout, BlasOp BLayout>
class GemmTunableOp : public TunableOp<GemmParams<T>> {
public:
GemmTunableOp() {
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmOp<T>>());
#ifdef USE_ROCM
static const auto env_rocblas = c10::utils::check_env("PYTORCH_TUNABLEOP_ROCBLAS_ENABLED");
if (!env_rocblas.has_value() || env_rocblas.value()) {
for (auto&& [name, op] : GetRocBlasGemmTypeStringAndOps<T>()) {
this->RegisterOp(std::move(name), std::move(op));
}
}
static const auto env_hipblaslt = c10::utils::check_env("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
if (!env_hipblaslt.has_value() || env_hipblaslt.value()) {
// disallow tuning of hipblaslt with c10::complex
if constexpr (
!std::is_same_v<T, c10::complex<float>> &&
!std::is_same_v<T, c10::complex<double>>) {
for (auto&& [name, op] : GetHipBlasLtGemmTypeStringAndOps<T, ALayout, BLayout>()) {
this->RegisterOp(std::move(name), std::move(op));
}
}
}
#endif
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmOp<T>>());
}
std::string Signature() override {
return fmt::sprintf("GemmTunableOp_%s_%c%c", TypeName<T>(T{}), BlasOpToString(ALayout), BlasOpToString(BLayout));
}
};
template <typename T, BlasOp ALayout, BlasOp BLayout>
class GemmAndBiasTunableOp : public TunableOp<GemmAndBiasParams<T>> {
public:
GemmAndBiasTunableOp() {
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmAndBiasOp<T>>());
#ifdef USE_ROCM
static const auto env_hipblaslt = c10::utils::check_env("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
if (!env_hipblaslt.has_value() || env_hipblaslt.value()) {
// disallow tuning of hipblaslt with c10::complex
if constexpr (
!std::is_same_v<T, c10::complex<float>> &&
!std::is_same_v<T, c10::complex<double>>) {
for (auto&& [name, op] : GetHipBlasLtGemmAndBiasTypeStringAndOps<T, ALayout, BLayout>()) {
this->RegisterOp(std::move(name), std::move(op));
}
}
}
#endif
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmAndBiasOp<T>>());
}
std::string Signature() override {
return fmt::sprintf("GemmAndBiasTunableOp_%s_%c%c", TypeName<T>(T{}), BlasOpToString(ALayout), BlasOpToString(BLayout));
}
};
template <typename T, BlasOp ALayout, BlasOp BLayout>
class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>> {
public:
GemmStridedBatchedTunableOp() {
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmStridedBatchedOp<T>>());
#ifdef USE_ROCM
static const auto env_rocblas = c10::utils::check_env("PYTORCH_TUNABLEOP_ROCBLAS_ENABLED");
if (!env_rocblas.has_value() || env_rocblas.value()) {
for (auto&& [name, op] : GetRocBlasGemmStridedBatchedTypeStringAndOps<T>()) {
this->RegisterOp(std::move(name), std::move(op));
}
}
static const auto env_hipblaslt = c10::utils::check_env("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
if (!env_hipblaslt.has_value() || env_hipblaslt.value()) {
// disallow tuning of hipblaslt with c10::complex
if constexpr (
!std::is_same_v<T, c10::complex<float>> &&
!std::is_same_v<T, c10::complex<double>>) {
for (auto&& [name, op] : GetHipBlasLtGemmStridedBatchedTypeStringAndOps<T, ALayout, BLayout>()) {
this->RegisterOp(std::move(name), std::move(op));
}
}
}
#endif
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmStridedBatchedOp<T>>());
}
std::string Signature() override {
return fmt::sprintf("GemmStridedBatchedTunableOp_%s_%c%c", TypeName<T>(T{}), BlasOpToString(ALayout), BlasOpToString(BLayout));
}
};
template <typename AT, typename BT, typename CT, BlasOp ALayout, BlasOp BLayout>
class ScaledGemmTunableOp : public TunableOp<ScaledGemmParams<CT>> {
public:
ScaledGemmTunableOp() {
this->RegisterOp(std::string("Default"), std::make_unique<DefaultScaledGemmOp<CT>>());
#ifdef USE_ROCM
for (auto&& [name, op] : GetHipBlasLtScaledGemmTypeStringAndOps<AT, BT, CT, ALayout, BLayout>()) {
this->RegisterOp(std::move(name), std::move(op));
}
#endif
this->RegisterOp(std::string("Default"), std::make_unique<DefaultScaledGemmOp<CT>>());
}
std::string Signature() override {
return fmt::sprintf("ScaledGemmTunableOp_%s_%s_%s_%c%c",
TypeName<AT>(AT{}),
TypeName<BT>(BT{}),
TypeName<CT>(CT{}),
BlasOpToString(ALayout), BlasOpToString(BLayout));
}
};
} // namespace at::cuda::tunable