| | #pragma once |
| |
|
| | #include <span> |
| |
|
| | #include <torch/torch.h> |
| | #include <ATen/cuda/CUDAContext.h> |
| | #include <c10/cuda/CUDAGuard.h> |
| | #include <kerutils/supplemental/torch_tensors.h> |
| |
|
| | #include <cutlass/bfloat16.h> |
| |
|
| | static constexpr float LOG_2_E = 1.44269504f; |
| |
|
| | |
| | template<> |
| | inline cutlass::bfloat16_t* at::TensorBase::data_ptr<cutlass::bfloat16_t>() const { |
| | return reinterpret_cast<cutlass::bfloat16_t*>(this->data_ptr()); |
| | } |
| |
|
| | |
| | struct Arch { |
| | int major; |
| | int minor; |
| | int num_sms; |
| | cudaDeviceProp* device_prop; |
| |
|
| | Arch() { |
| | device_prop = at::cuda::getCurrentDeviceProperties(); |
| | major = device_prop->major; |
| | minor = device_prop->minor; |
| | num_sms = device_prop->multiProcessorCount; |
| | } |
| |
|
| | bool is_sm90a() const { |
| | return major == 9 && minor == 0; |
| | } |
| |
|
| | bool is_sm100f() const { |
| | return major == 10; |
| | } |
| | }; |
| |
|
| | |
| | inline int int64_stride_to_int(int64_t orig_stride) { |
| | if (orig_stride > std::numeric_limits<int>::max()) { |
| | TORCH_CHECK(false, "[FlashMLA] Stride exceeds int32 limit: ", orig_stride); |
| | } |
| | return static_cast<int>(orig_stride); |
| | } |
| |
|
| | #define DISPATCH_NUM_HEADS(NUM_HEADS, CONSTEXPR_NAME, ...) \ |
| | [&] () { \ |
| | if (NUM_HEADS == 128) { \ |
| | static constexpr int CONSTEXPR_NAME = 128; \ |
| | return __VA_ARGS__(); \ |
| | } else if (NUM_HEADS == 64) { \ |
| | static constexpr int CONSTEXPR_NAME = 64; \ |
| | return __VA_ARGS__(); \ |
| | } else { \ |
| | TORCH_CHECK(false, "Unsupported num_heads_q: ", NUM_HEADS); \ |
| | } \ |
| | } (); |
| |
|
| | #define DISPATCH_HEAD_DIM(HEAD_DIM, CONSTEXPR_NAME, ...) \ |
| | [&] () { \ |
| | if (HEAD_DIM == 576) { \ |
| | static constexpr int CONSTEXPR_NAME = 576; \ |
| | return __VA_ARGS__(); \ |
| | } else if (HEAD_DIM == 512) { \ |
| | static constexpr int CONSTEXPR_NAME = 512; \ |
| | return __VA_ARGS__(); \ |
| | } else { \ |
| | TORCH_CHECK(false, "Unsupported head_dim_qk: ", HEAD_DIM); \ |
| | } \ |
| | } (); |
| |
|
| | #define DISPATCH_BOOLEAN_FLAG(FLAG, CONSTEXPR_NAME, ...) \ |
| | [&] () { \ |
| | if (FLAG) { \ |
| | static constexpr bool CONSTEXPR_NAME = true; \ |
| | return __VA_ARGS__(); \ |
| | } else { \ |
| | static constexpr bool CONSTEXPR_NAME = false; \ |
| | return __VA_ARGS__(); \ |
| | } \ |
| | } (); |
| |
|
| | #define DISPATCH_MODEL_TYPE(MODEL_TYPE, CONSTEXPR_NAME, ...) \ |
| | [&] () { \ |
| | if (MODEL_TYPE == ModelType::V32) { \ |
| | static constexpr ModelType CONSTEXPR_NAME = ModelType::V32; \ |
| | return __VA_ARGS__(); \ |
| | } else if (MODEL_TYPE == ModelType::MODEL1) { \ |
| | static constexpr ModelType CONSTEXPR_NAME = ModelType::MODEL1; \ |
| | return __VA_ARGS__(); \ |
| | } else { \ |
| | TORCH_CHECK(false, "Unsupported model type: ", (int)MODEL_TYPE); \ |
| | } \ |
| | } (); |
| |
|
| | |
| | template<auto value> |
| | constexpr auto get_static_enum_name(){ |
| | std::string_view name; |
| | #if __GNUC__ || __clang__ |
| | name = __PRETTY_FUNCTION__; |
| | std::size_t start = name.find('=') + 2; |
| | std::size_t end = name.size() - 1; |
| | name = std::string_view{ name.data() + start, end - start }; |
| | start = name.find("::"); |
| | #elif _MSC_VER |
| | name = __FUNCSIG__; |
| | std::size_t start = name.find('<') + 1; |
| | std::size_t end = name.rfind(">("); |
| | name = std::string_view{ name.data() + start, end - start }; |
| | start = name.rfind("::"); |
| | #endif |
| | return start == std::string_view::npos ? name : std::string_view { |
| | name.data() + start + 2, name.size() - start - 2 |
| | }; |
| | } |
| |
|
| | template<typename T, std::size_t N = 0> |
| | static constexpr std::size_t get_enum_max(){ |
| | constexpr T value = static_cast<T>(N); |
| | if constexpr (get_static_enum_name<value>().find(")") == std::string_view::npos) |
| | return get_enum_max<T, N + 1>(); |
| | else |
| | return N; |
| | } |
| |
|
| | template<typename T> requires std::is_enum_v<T> |
| | static constexpr std::string get_dynamic_enum_name(T value){ |
| | constexpr std::size_t num = get_enum_max<T>(); |
| | constexpr auto names = []<std::size_t... Is>(std::index_sequence<Is...>){ |
| | return std::array<std::string_view, num>{ |
| | get_static_enum_name<static_cast<T>(Is)>()... |
| | }; |
| | }(std::make_index_sequence<num>{}); |
| | return (std::string)names[static_cast<std::size_t>(value)]; |
| | } |
| |
|
| | |
| | #define DECLARE_SUPPORTED_FEATURES(...) \ |
| | protected: \ |
| | static constexpr FeatureT features[] = { __VA_ARGS__ }; \ |
| | constexpr inline std::span<const FeatureT> get_supported_features() const override { \ |
| | return features; \ |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | template< |
| | typename RunArgT_, |
| | typename FeatureT_ |
| | > |
| | class ImplBase { |
| | protected: |
| | using RunArgT = RunArgT_; |
| | using FeatureT = FeatureT_; |
| |
|
| | virtual inline void run_(const RunArgT ¶ms, const std::vector<FeatureT> &required_features) = 0; |
| |
|
| | constexpr virtual inline std::span<const FeatureT> get_supported_features() const = 0; |
| |
|
| | virtual ~ImplBase() = default; |
| |
|
| | public: |
| | inline bool check_if_all_features_are_supported(const std::vector<FeatureT> &required_features) { |
| | for (const auto &required_feature : required_features) { |
| | bool is_supported = false; |
| | for (const auto &supported_feature : get_supported_features()) { |
| | if (required_feature == supported_feature) { |
| | is_supported = true; |
| | break; |
| | } |
| | } |
| | if (!is_supported) { |
| | return false; |
| | } |
| | } |
| | return true; |
| | } |
| |
|
| | inline void check_if_all_features_are_supported_and_abort(const std::vector<FeatureT> &required_features) { |
| | if (!check_if_all_features_are_supported(required_features)) { |
| | fprintf(stderr, "[FlashMLA] Error: The chosen implementation does not support all required features.\n"); |
| | fprintf(stderr, "Required features:\n"); |
| | for (const auto &f : required_features) { |
| | fprintf(stderr, " - %3d: %s\n", static_cast<int>(f), get_dynamic_enum_name(f).c_str()); |
| | } |
| | fprintf(stderr, "\n"); |
| | fprintf(stderr, "Supported features:\n"); |
| | for (const auto &supported_feature : get_supported_features()) { |
| | fprintf(stderr, " - %3d: %s\n", static_cast<int>(supported_feature), get_dynamic_enum_name(supported_feature).c_str()); |
| | } |
| | fprintf(stderr, "\n"); |
| | fprintf(stderr, "Features that are required but not supported:\n"); |
| | for (const auto &required_feature : required_features) { |
| | bool is_supported = false; |
| | for (const auto &supported_feature : get_supported_features()) { |
| | if (required_feature == supported_feature) { |
| | is_supported = true; |
| | break; |
| | } |
| | } |
| | if (!is_supported) { |
| | fprintf(stderr, " - %3d: %s\n", static_cast<int>(required_feature), get_dynamic_enum_name(required_feature).c_str()); |
| | } |
| | } |
| | fprintf(stderr, "\n"); |
| | Arch cur_gpu_arch = Arch(); |
| | fprintf(stderr, "Current GPU: %s, SM %d.%d with %d SMs\n", cur_gpu_arch.device_prop->name, cur_gpu_arch.major, cur_gpu_arch.minor, cur_gpu_arch.num_sms); |
| | fprintf(stderr, "This means that the dispatcher has chosen an implementation that does not support all required features. Maybe there is a bug in the dispatcher, or you have requested an invalid combination of features.\n"); |
| | TORCH_CHECK(false, "The chosen implementation does not support all required features. See message above for details."); |
| | } |
| | } |
| |
|
| | inline void run(const RunArgT ¶ms, const std::vector<FeatureT> &required_features) { |
| | check_if_all_features_are_supported_and_abort(required_features); |
| | run_(params, required_features); |
| | } |
| | }; |
| |
|
| |
|