| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_ |
| #define THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_ |
|
|
| #include <functional> |
| #include <memory> |
| #include <random> |
| #include <string> |
| #include <vector> |
|
|
| #include "compression/io.h" |
| #include "gemma/common.h" |
| #include "hwy/aligned_allocator.h" |
| #include "hwy/base.h" |
| #include "hwy/contrib/thread_pool/thread_pool.h" |
|
|
| namespace gcpp { |
|
|
| constexpr size_t kPrefillBatchSize = 16; |
| constexpr size_t kDecodeBatchSize = 1; |
| constexpr bool kSystemPrompt = false; |
|
|
| struct KVCache { |
| hwy::AlignedFreeUniquePtr<float[]> |
| kv_cache; |
| hwy::AlignedFreeUniquePtr<float[]> |
| conv1d_cache; |
| hwy::AlignedFreeUniquePtr<float[]> |
| rglru_cache; |
|
|
| static KVCache Create(Model type); |
| }; |
|
|
| |
| constexpr int EOS_ID = 1; |
| constexpr int BOS_ID = 2; |
|
|
| class GemmaTokenizer { |
| public: |
| GemmaTokenizer(); |
| explicit GemmaTokenizer(const Path& tokenizer_path); |
|
|
| |
| ~GemmaTokenizer(); |
| GemmaTokenizer(GemmaTokenizer&& other); |
| GemmaTokenizer& operator=(GemmaTokenizer&& other); |
|
|
| bool Encode(const std::string& input, std::vector<std::string>* pieces) const; |
| bool Encode(const std::string& input, std::vector<int>* ids) const; |
| bool Decode(const std::vector<int>& ids, std::string* detokenized) const; |
|
|
| private: |
| class Impl; |
| std::unique_ptr<Impl> impl_; |
| }; |
|
|
| |
| |
| |
| using StreamFunc = std::function<bool(int, float)>; |
| |
| |
| using AcceptFunc = std::function<bool(int, float)>; |
| |
| |
| using SampleFunc = std::function<int(const float*, size_t)>; |
| |
| |
| |
| |
| |
| using LayersOutputFunc = |
| std::function<void(int, const std::string&, const float*, size_t)>; |
|
|
| struct RuntimeConfig { |
| size_t max_tokens; |
| size_t max_generated_tokens; |
| float temperature; |
| int verbosity; |
| std::mt19937* gen; |
| StreamFunc stream_token; |
| AcceptFunc accept_token; |
| SampleFunc sample_func; |
| LayersOutputFunc layers_output; |
| int eos_id = EOS_ID; |
| }; |
|
|
| struct TimingInfo { |
| double prefill_tok_sec = 0.0; |
| double gen_tok_sec = 0.0; |
| double time_to_first_token = 0.0; |
| }; |
|
|
| class Gemma { |
| public: |
| Gemma(const Path& tokenizer_path, const Path& weights, Model model_type, |
| Type weight_type, hwy::ThreadPool& pool); |
|
|
| |
| Gemma(GemmaTokenizer&& tokenizer, Model model_type, Type weight_type, |
| hwy::ThreadPool& pool); |
| ~Gemma(); |
|
|
| Model ModelType() const { return model_type_; } |
| const GemmaTokenizer& Tokenizer() const { return tokenizer_; } |
| const ByteStorageT& Weights() const { return weights_u8_; } |
| const ByteStorageT& Prefill() const { return prefill_u8_; } |
| const ByteStorageT& Decode() const { return decode_u8_; } |
|
|
| void Generate(const RuntimeConfig& runtime_config, |
| const std::vector<int>& prompt, size_t start_pos, |
| KVCache& kv_cache, TimingInfo& timing_info); |
|
|
| private: |
| hwy::ThreadPool& pool_; |
|
|
| GemmaTokenizer tokenizer_; |
| |
| |
| ByteStorageT weights_u8_; |
| ByteStorageT prefill_u8_; |
| ByteStorageT decode_u8_; |
| Model model_type_; |
| Type weight_type_; |
| }; |
|
|
| |
| |
| |
| std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer, |
| ModelTraining training, size_t pos, |
| std::string& prompt); |
|
|
| |
| HWY_INLINE void GenerateGemma(Gemma& gemma, const RuntimeConfig& runtime_config, |
| const std::vector<int>& prompt, size_t start_pos, |
| KVCache& kv_cache, hwy::ThreadPool& , |
| TimingInfo& timing_info) { |
| gemma.Generate(runtime_config, prompt, start_pos, kv_cache, timing_info); |
| } |
|
|
| } |
|
|
| #endif |
|
|