| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| #pragma once |
|
|
| #include "cutlass_extensions/include/cutlass_extensions/ft_gemm_configs.h" |
| #include "utils/activation_types.h" |
| #include <cuda_runtime_api.h> |
|
|
| namespace fastertransformer { |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| template<typename T, typename WeightType> |
| class CutlassFpAIntBGemmRunner { |
| public: |
| CutlassFpAIntBGemmRunner(); |
| ~CutlassFpAIntBGemmRunner(); |
|
|
| void gemm(const T* A, |
| const WeightType* B, |
| const T* weight_scales, |
| T* C, |
| int m, |
| int n, |
| int k, |
| char* workspace_ptr, |
| const size_t workspace_bytes, |
| cudaStream_t stream); |
|
|
| void gemm_bias_act(const T* A, |
| const WeightType* B, |
| const T* weight_scales, |
| const T* biases, |
| T* C, |
| int m, |
| int n, |
| int k, |
| int bias_stride, |
| ActivationType activation_type, |
| char* workspace_ptr, |
| const size_t workspace_bytes, |
| cudaStream_t stream); |
|
|
| void gemm_bias_act_residual(const T *A, const WeightType *B, |
| const T *weight_scales, const T *biases, |
| const T *residual, T *C, int m, int n, int k, |
| const std::string& activation, const std::string& binary_op, |
| const std::string& unary_op, |
| char *workspace_ptr, |
| const size_t workspace_bytes, |
| cudaStream_t stream); |
|
|
| |
| int getWorkspaceSize(const int m, const int n, const int k); |
|
|
| private: |
| template<typename EpilogueTag> |
| void dispatch_to_arch(const T* A, |
| const WeightType* B, |
| const T* weight_scales, |
| const T* biases, |
| T* C, |
| int m, |
| int n, |
| int k, |
| int bias_stride, |
| CutlassGemmConfig gemm_config, |
| char* workspace_ptr, |
| const size_t workspace_bytes, |
| cudaStream_t stream, |
| int* occupancy = nullptr); |
|
|
| template<typename EpilogueTag> |
| void run_gemm(const T* A, |
| const WeightType* B, |
| const T* weight_scales, |
| const T* biases, |
| T* C, |
| int m, |
| int n, |
| int k, |
| int bias_stride, |
| char* workspace_ptr, |
| const size_t workspace_bytes, |
| cudaStream_t stream); |
|
|
| private: |
| static constexpr int split_k_limit = 7; |
|
|
| int sm_; |
| int multi_processor_count_; |
| }; |
|
|
| } |
|
|