| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_ |
| #define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_ |
|
|
| #include <math.h> |
| #include <stdint.h> |
|
|
| #include <string> |
|
|
| #include "compression/compress.h" |
| #include "gemma/configs.h" |
| #include "hwy/aligned_allocator.h" |
| #include "hwy/base.h" |
|
|
| namespace gcpp { |
|
|
| using ByteStorageT = hwy::AlignedFreeUniquePtr<uint8_t[]>; |
|
|
| template <typename T> |
| ByteStorageT AllocateSizeof() { |
| return hwy::AllocateAligned<uint8_t>(sizeof(T)); |
| } |
|
|
| |
| enum class Model { |
| GEMMA_2B, |
| GEMMA_7B, |
| GEMMA_9B, |
| GEMMA_27B, |
| GRIFFIN_2B, |
| GEMMA_TINY, |
| }; |
|
|
| |
| enum class ModelTraining { GEMMA_IT, GEMMA_PT }; |
|
|
| |
| enum class Type { kF32, kBF16, kSFP }; |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| template <typename TWeight, template <typename TConfig> class FuncT, |
| typename... TArgs> |
| decltype(auto) CallForModel(Model model, TArgs&&... args) { |
| switch (model) { |
| case Model::GEMMA_TINY: |
| return FuncT<ConfigGemmaTiny<TWeight>>()(std::forward<TArgs>(args)...); |
| case Model::GEMMA_2B: |
| return FuncT<ConfigGemma2B<TWeight>>()(std::forward<TArgs>(args)...); |
| case Model::GEMMA_7B: |
| return FuncT<ConfigGemma7B<TWeight>>()(std::forward<TArgs>(args)...); |
| case Model::GEMMA_9B: |
| return FuncT<ConfigGemma9B<TWeight>>()(std::forward<TArgs>(args)...); |
| case Model::GEMMA_27B: |
| return FuncT<ConfigGemma27B<TWeight>>()(std::forward<TArgs>(args)...); |
| case Model::GRIFFIN_2B: |
| return FuncT<ConfigGriffin2B<TWeight>>()(std::forward<TArgs>(args)...); |
| default: |
| HWY_ABORT("Model type %d unknown.", static_cast<int>(model)); |
| } |
| } |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| template <template <typename TConfig> class FuncT, typename... TArgs> |
| decltype(auto) CallForModelAndWeight(Model model, Type weight, |
| TArgs&&... args) { |
| switch (weight) { |
| case Type::kF32: |
| return CallForModel<float, FuncT, TArgs...>( |
| model, std::forward<TArgs>(args)...); |
| case Type::kBF16: |
| return CallForModel<hwy::bfloat16_t, FuncT, TArgs...>( |
| model, std::forward<TArgs>(args)...); |
| case Type::kSFP: |
| return CallForModel<SfpStream, FuncT, TArgs...>( |
| model, std::forward<TArgs>(args)...); |
| default: |
| HWY_ABORT("Weight type %d unknown.", static_cast<int>(weight)); |
| } |
| } |
|
|
| |
| |
| #define GEMMA_DISPATCH_MODEL(MODEL, TWEIGHT, FUNC, ARGS) \ |
| switch (MODEL) { \ |
| case Model::GEMMA_TINY: { \ |
| HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemmaTiny<TWEIGHT>>) \ |
| ARGS; \ |
| break; \ |
| } \ |
| case Model::GEMMA_2B: { \ |
| HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemma2B<TWEIGHT>>) \ |
| ARGS; \ |
| break; \ |
| } \ |
| case Model::GEMMA_7B: { \ |
| HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemma7B<TWEIGHT>>) \ |
| ARGS; \ |
| break; \ |
| } \ |
| case Model::GEMMA_9B: { \ |
| HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemma9B<TWEIGHT>>) \ |
| ARGS; \ |
| break; \ |
| } \ |
| case Model::GEMMA_27B: { \ |
| HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemma27B<TWEIGHT>>) \ |
| ARGS; \ |
| break; \ |
| } \ |
| case Model::GRIFFIN_2B: { \ |
| HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGriffin2B<TWEIGHT>>) \ |
| ARGS; \ |
| break; \ |
| } \ |
| default: \ |
| HWY_ABORT("Model type %d unknown.", static_cast<int>(MODEL)); \ |
| } |
|
|
| |
| |
| |
| #define GEMMA_EXPORT_AND_DISPATCH(MODEL, WEIGHT, FUNC, ARGS) \ |
| switch (WEIGHT) { \ |
| case Type::kF32: \ |
| GEMMA_DISPATCH_MODEL(MODEL, float, FUNC, ARGS); \ |
| break; \ |
| case Type::kBF16: \ |
| GEMMA_DISPATCH_MODEL(MODEL, hwy::bfloat16_t, FUNC, ARGS); \ |
| break; \ |
| case Type::kSFP: \ |
| GEMMA_DISPATCH_MODEL(MODEL, SfpStream, FUNC, ARGS); \ |
| break; \ |
| default: \ |
| HWY_ABORT("Weight type %d unknown.", static_cast<int>(WEIGHT)); \ |
| } |
|
|
| |
| |
| const char* ParseModelTypeAndTraining(const std::string& model_flag, |
| Model& model, ModelTraining& training); |
| const char* ParseType(const std::string& type_string, Type& type); |
|
|
| |
| const char* ModelString(Model model, ModelTraining training); |
| const char* StringFromType(Type type); |
|
|
| |
| |
|
|
| |
| #if HWY_COMPILER_GCC_ACTUAL |
| #define GEMMA_CONSTEXPR_SQRT constexpr |
| static GEMMA_CONSTEXPR_SQRT HWY_INLINE float Sqrt(float x) { |
| return __builtin_sqrt(x); |
| } |
| #else |
| #define GEMMA_CONSTEXPR_SQRT |
| static GEMMA_CONSTEXPR_SQRT HWY_INLINE float Sqrt(float x) { return sqrtf(x); } |
| #endif |
|
|
| |
| |
| #if HWY_COMPILER_GCC_ACTUAL |
| #define GEMMA_CONSTEXPR_EMBSCALING HWY_BF16_CONSTEXPR |
| #else |
| #define GEMMA_CONSTEXPR_EMBSCALING |
| #endif |
|
|
| template <typename TConfig> |
| GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling() { |
| |
| return hwy::ConvertScalarTo<float>(hwy::ConvertScalarTo<hwy::bfloat16_t>( |
| Sqrt(static_cast<float>(TConfig::kModelDim)))); |
| } |
|
|
| static HWY_INLINE GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling( |
| size_t model_dim) { |
| |
| return hwy::ConvertScalarTo<float>(hwy::ConvertScalarTo<hwy::bfloat16_t>( |
| Sqrt(static_cast<float>(model_dim)))); |
| } |
|
|
| } |
|
|
| #endif |
|
|