| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| #include "backprop/backward.h" |
|
|
| #include "backprop/prompt.h" |
| #include "gemma/activations.h" |
| #include "gemma/common.h" |
| #include "hwy/contrib/thread_pool/thread_pool.h" |
|
|
| |
| |
| #undef HWY_TARGET_INCLUDE |
| #define HWY_TARGET_INCLUDE "backprop/backward.cc" |
| #include "hwy/foreach_target.h" |
|
|
| #include "hwy/highway.h" |
| |
| #include "backprop/backward-inl.h" |
| #include "gemma/weights.h" |
|
|
| HWY_BEFORE_NAMESPACE(); |
| namespace gcpp { |
| namespace HWY_NAMESPACE { |
|
|
| template <typename TConfig> |
| void CrossEntropyLossBackwardPass(const Prompt& prompt, |
| const ByteStorageT& weights_u8, |
| const ByteStorageT& forward_u8, |
| ByteStorageT& grad_u8, |
| ByteStorageT& backward_u8, |
| hwy::ThreadPool& pool) { |
| using TWeights = CompressedWeights<TConfig>; |
| const auto& weights = *reinterpret_cast<const TWeights*>(weights_u8.get()); |
| auto& grad = *reinterpret_cast<TWeights*>(grad_u8.get()); |
| using TAct = ForwardPass<float, TConfig>; |
| const auto& forward = *reinterpret_cast<const TAct*>(forward_u8.get()); |
| auto& backward = *reinterpret_cast<TAct*>(backward_u8.get()); |
| CrossEntropyLossBackwardPass<TConfig, CompressedWeights, CompressedLayer>( |
| prompt, weights, forward, grad, backward, pool); |
| } |
|
|
| void CrossEntropyLossBackwardPassT(Model model, |
| const Prompt& prompt, |
| const ByteStorageT& weights, |
| const ByteStorageT& forward, |
| ByteStorageT& grad, |
| ByteStorageT& backward, |
| hwy::ThreadPool& pool) { |
| |
| switch (model) { |
| case Model::GEMMA_2B: |
| CrossEntropyLossBackwardPass<ConfigGemma2B<float>>( |
| prompt, weights, forward, grad, backward, pool); |
| break; |
| case Model::GEMMA_TINY: |
| CrossEntropyLossBackwardPass<ConfigGemmaTiny<float>>( |
| prompt, weights, forward, grad, backward, pool); |
| break; |
| default: |
| HWY_ABORT("Model type %d unknown.", static_cast<int>(model)); |
| } |
| } |
|
|
| } |
| } |
| HWY_AFTER_NAMESPACE(); |
|
|
| #if HWY_ONCE |
| namespace gcpp { |
|
|
| HWY_EXPORT(CrossEntropyLossBackwardPassT); |
|
|
| void CrossEntropyLossBackwardPass( |
| const Model& model, const Prompt& prompt, |
| const ByteStorageT& weights, const ByteStorageT& forward, |
| ByteStorageT& grad, ByteStorageT& backward, hwy::ThreadPool& pool) { |
| return HWY_DYNAMIC_DISPATCH(CrossEntropyLossBackwardPassT)( |
| model, prompt, weights, forward, grad, backward, pool); |
| } |
|
|
| } |
| #endif |
|
|