| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| #include "backprop/forward.h" |
|
|
| #include "backprop/prompt.h" |
| #include "gemma/activations.h" |
| #include "gemma/common.h" |
| #include "hwy/contrib/thread_pool/thread_pool.h" |
|
|
| |
| |
| #undef HWY_TARGET_INCLUDE |
| #define HWY_TARGET_INCLUDE "backprop/forward.cc" |
| #include "hwy/foreach_target.h" |
|
|
| #include "hwy/highway.h" |
| |
| #include "backprop/forward-inl.h" |
| #include "gemma/weights.h" |
|
|
| HWY_BEFORE_NAMESPACE(); |
| namespace gcpp { |
| namespace HWY_NAMESPACE { |
|
|
| template <typename TConfig> |
| float CrossEntropyLossForwardPass(const Prompt& prompt, |
| const ByteStorageT& weights_u8, |
| ByteStorageT& forward_u8, |
| hwy::ThreadPool& pool) { |
| const auto& weights = |
| *reinterpret_cast<CompressedWeights<TConfig>*>(weights_u8.get()); |
| auto& forward = |
| *reinterpret_cast<ForwardPass<float, TConfig>*>(forward_u8.get()); |
| return |
| CrossEntropyLossForwardPass<TConfig, CompressedWeights, CompressedLayer>( |
| prompt.tokens, prompt.context_size, weights, forward, pool); |
| } |
|
|
| float CrossEntropyLossForwardPassT(Model model, const Prompt& prompt, |
| const ByteStorageT& weights, |
| ByteStorageT& forward, |
| hwy::ThreadPool& pool) { |
| |
| switch (model) { |
| case Model::GEMMA_2B: |
| return CrossEntropyLossForwardPass<ConfigGemma2B<float>>(prompt, weights, |
| forward, pool); |
| case Model::GEMMA_TINY: |
| return CrossEntropyLossForwardPass<ConfigGemmaTiny<float>>( |
| prompt, weights, forward, pool); |
| default: |
| HWY_ABORT("Model type %d unknown.", static_cast<int>(model)); |
| } |
| } |
|
|
| } |
| } |
| HWY_AFTER_NAMESPACE(); |
|
|
| #if HWY_ONCE |
| namespace gcpp { |
|
|
| HWY_EXPORT(CrossEntropyLossForwardPassT); |
|
|
| float CrossEntropyLossForwardPass( |
| const Model& model, const Prompt& prompt, const ByteStorageT& weights, |
| ByteStorageT& forward, hwy::ThreadPool& pool) { |
| return HWY_DYNAMIC_DISPATCH(CrossEntropyLossForwardPassT)( |
| model, prompt, weights, forward, pool); |
| } |
|
|
| } |
| #endif |
|
|