|
|
#pragma once |
|
|
|
|
|
#include <unordered_map> |
|
|
#include <cuda_fp16.h> |
|
|
#include <cuda_bf16.h> |
|
|
|
|
|
#ifdef OLD_GENERATOR_PATH |
|
|
#include <ATen/CUDAGeneratorImpl.h> |
|
|
#else |
|
|
#include <ATen/cuda/CUDAGeneratorImpl.h> |
|
|
#endif |
|
|
|
|
|
namespace layer_norm { |
|
|
|
|
|
|
|
|
|
|
|
template<typename Params> |
|
|
struct LaunchParams{ |
|
|
|
|
|
size_t elts_per_thread; |
|
|
size_t workspace_bytes; |
|
|
size_t barrier_size; |
|
|
|
|
|
cudaDeviceProp * props; |
|
|
|
|
|
cudaStream_t stream; |
|
|
|
|
|
Params params; |
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
struct ParamsBase { |
|
|
ParamsBase() |
|
|
: ctas_per_col(0) |
|
|
, rows(0) |
|
|
, cols(0) |
|
|
, x(nullptr) |
|
|
, mu(nullptr) |
|
|
, rs(nullptr) |
|
|
, gamma(nullptr) |
|
|
, gamma1(nullptr) |
|
|
, rowscale(nullptr) |
|
|
, colscale(nullptr) |
|
|
, dropout_keep_p(1.f) |
|
|
, dropout_scale(1.f) |
|
|
, is_rms_norm(false) |
|
|
, workspace(nullptr) |
|
|
, barrier(nullptr) |
|
|
{ |
|
|
} |
|
|
|
|
|
|
|
|
int ctas_per_col; |
|
|
|
|
|
|
|
|
int rows; |
|
|
int cols; |
|
|
|
|
|
|
|
|
void *x0; |
|
|
void *x1; |
|
|
void *residual; |
|
|
void *x; |
|
|
void *dmask; |
|
|
void *dmask1; |
|
|
void *mu; |
|
|
void *rs; |
|
|
void *gamma; |
|
|
void *gamma1; |
|
|
void *rowscale; |
|
|
void *colscale; |
|
|
void *x0_subset; |
|
|
void *z_subset; |
|
|
|
|
|
float inverse_cols; |
|
|
|
|
|
float dropout_keep_p; |
|
|
float dropout_scale; |
|
|
float rowscale_const; |
|
|
|
|
|
bool is_rms_norm; |
|
|
|
|
|
|
|
|
void *workspace; |
|
|
|
|
|
|
|
|
int *barrier; |
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
struct FwdParams : public ParamsBase { |
|
|
FwdParams() |
|
|
: ParamsBase() |
|
|
, z(nullptr) |
|
|
, z1(nullptr) |
|
|
, beta(nullptr) |
|
|
, beta1(nullptr) |
|
|
, epsilon(0.f) |
|
|
{ |
|
|
} |
|
|
|
|
|
|
|
|
void *z; |
|
|
void *z1; |
|
|
void *beta; |
|
|
void *beta1; |
|
|
float epsilon; |
|
|
|
|
|
|
|
|
at::PhiloxCudaState philox_args; |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
struct BwdParams : public ParamsBase { |
|
|
BwdParams() |
|
|
: ParamsBase() |
|
|
, dz(nullptr) |
|
|
, dz1(nullptr) |
|
|
, dx(nullptr) |
|
|
, dbeta_part(nullptr) |
|
|
, dgamma_part(nullptr) |
|
|
, dbeta1_part(nullptr) |
|
|
, dgamma1_part(nullptr) |
|
|
, dcolscale_part(nullptr) |
|
|
, dx0(nullptr) |
|
|
, dx1(nullptr) |
|
|
, dresidual(nullptr) |
|
|
, dbeta(nullptr) |
|
|
, dgamma(nullptr) |
|
|
, dbeta1(nullptr) |
|
|
, dgamma1(nullptr) |
|
|
, dcolscale(nullptr) |
|
|
{ |
|
|
} |
|
|
|
|
|
|
|
|
void *dz; |
|
|
void *dz1; |
|
|
|
|
|
void *dx; |
|
|
|
|
|
|
|
|
void *dbeta_part; |
|
|
void *dgamma_part; |
|
|
void *dbeta1_part; |
|
|
void *dgamma1_part; |
|
|
void *dcolscale_part; |
|
|
|
|
|
|
|
|
void *dx0; |
|
|
void *dx1; |
|
|
void *dresidual; |
|
|
|
|
|
void *dbeta; |
|
|
void *dgamma; |
|
|
void *dbeta1; |
|
|
void *dgamma1; |
|
|
void *dcolscale; |
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
using FwdFunction = std::function<void(LaunchParams<FwdParams>&, const bool)>; |
|
|
using BwdFunction = std::function<void(LaunchParams<BwdParams>&, const bool)>; |
|
|
using FunctionKey = uint64_t; |
|
|
using FwdRegistry = std::unordered_map<FunctionKey, FwdFunction>; |
|
|
using BwdRegistry = std::unordered_map<FunctionKey, BwdFunction>; |
|
|
|
|
|
extern FwdRegistry FWD_FUNCS, PARALLEL_FWD_FUNCS; |
|
|
extern BwdRegistry BWD_FUNCS, PARALLEL_BWD_FUNCS; |
|
|
|
|
|
|
|
|
|
|
|
using fp32 = float; |
|
|
using fp16 = half; |
|
|
using bf16 = nv_bfloat16; |
|
|
|
|
|
|
|
|
|
|
|
template<typename T> |
|
|
struct TypeId{}; |
|
|
|
|
|
template<> |
|
|
struct TypeId<fp16>{ |
|
|
constexpr static uint32_t Value = 0; |
|
|
}; |
|
|
|
|
|
template<> |
|
|
struct TypeId<bf16>{ |
|
|
constexpr static uint32_t Value = 1; |
|
|
}; |
|
|
|
|
|
template<> |
|
|
struct TypeId<fp32>{ |
|
|
constexpr static uint32_t Value = 2; |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
template<typename T, int S> |
|
|
struct Type2Key{ |
|
|
constexpr static uint32_t Value = TypeId<T>::Value << S; |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
template<typename T> |
|
|
struct WeightType2Key : public Type2Key<T, 0>{}; |
|
|
|
|
|
template<typename T> |
|
|
struct InputType2Key : public Type2Key<T, 2>{}; |
|
|
|
|
|
template<typename T> |
|
|
struct ResidualType2Key : public Type2Key<T, 4>{}; |
|
|
|
|
|
template<typename T> |
|
|
struct OutputType2Key : public Type2Key<T, 6>{}; |
|
|
|
|
|
template<typename T> |
|
|
struct ComputeType2Key : public Type2Key<T, 8>{}; |
|
|
|
|
|
|
|
|
|
|
|
template<typename W, typename I, typename R, typename O, typename C> |
|
|
struct Types2Key{ |
|
|
constexpr static uint32_t Value = WeightType2Key<W>::Value | InputType2Key<I>::Value | ResidualType2Key<R>::Value | OutputType2Key<O>::Value | ComputeType2Key<C>::Value; |
|
|
constexpr static inline uint64_t get(const uint64_t hidden_size){ |
|
|
constexpr uint64_t type_key = Value; |
|
|
return (type_key << 32) | hidden_size; |
|
|
} |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
template<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE> |
|
|
struct FwdRegistrar{ |
|
|
FwdRegistrar(FwdFunction f){ |
|
|
uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE); |
|
|
FWD_FUNCS.insert({ key, f }); |
|
|
} |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
template<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE> |
|
|
struct BwdRegistrar{ |
|
|
BwdRegistrar(BwdFunction f){ |
|
|
uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE); |
|
|
BWD_FUNCS.insert({ key, f }); |
|
|
} |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
template<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE> |
|
|
struct FwdParallelRegistrar{ |
|
|
FwdParallelRegistrar(FwdFunction f){ |
|
|
uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE); |
|
|
PARALLEL_FWD_FUNCS.insert({ key, f }); |
|
|
} |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
template<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE> |
|
|
struct BwdParallelRegistrar{ |
|
|
BwdParallelRegistrar(BwdFunction f){ |
|
|
uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE); |
|
|
PARALLEL_BWD_FUNCS.insert({ key, f }); |
|
|
} |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|