| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| #ifndef THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_ |
| #define THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_ |
|
|
| #include <memory> |
|
|
| #include "hwy/contrib/thread_pool/thread_pool.h" |
| #if HWY_OS_LINUX |
| #include <sched.h> |
| #endif |
| #include <stddef.h> |
| #include <stdio.h> |
|
|
| #include <algorithm> |
| #include <string> |
| #include <thread> |
| #include <vector> |
|
|
| #include "compression/io.h" |
| #include "gemma/common.h" |
| #include "gemma/configs.h" |
| #include "gemma/gemma.h" |
| #include "util/args.h" |
| #include "hwy/base.h" |
| #include "hwy/contrib/thread_pool/topology.h" |
|
|
| namespace gcpp { |
|
|
| static inline const char* CompiledConfig() { |
| if (HWY_IS_ASAN) { |
| return "asan"; |
| } else if (HWY_IS_MSAN) { |
| return "msan"; |
| } else if (HWY_IS_TSAN) { |
| return "tsan"; |
| } else if (HWY_IS_HWASAN) { |
| return "hwasan"; |
| } else if (HWY_IS_UBSAN) { |
| return "ubsan"; |
| } else if (HWY_IS_DEBUG_BUILD) { |
| return "dbg"; |
| } else { |
| return "opt"; |
| } |
| } |
|
|
| static inline std::vector<size_t> LpsToCpus( |
| const hwy::LogicalProcessorSet& lps) { |
| std::vector<size_t> cpus; |
| cpus.reserve(lps.Count()); |
| lps.Foreach([&cpus](size_t lp) { cpus.push_back(lp); }); |
| return cpus; |
| } |
|
|
| static inline std::vector<size_t> AssignCpusFromTopology( |
| const hwy::Topology& topology, const size_t num_workers) { |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| std::vector<std::vector<size_t>> clusters; |
| clusters.reserve(topology.packages[0].clusters.size()); |
| for (auto& cluster : topology.packages[0].clusters) { |
| clusters.push_back(LpsToCpus(cluster.lps)); |
| } |
| std::vector<size_t> assigned_cpus; |
| assigned_cpus.reserve(num_workers); |
| for (size_t i = 0; i < num_workers; ++i) { |
| size_t cluster_index = i % clusters.size(); |
| size_t cpu_index = (i / clusters.size()) % clusters[cluster_index].size(); |
| assigned_cpus.push_back(clusters[cluster_index][cpu_index]); |
| } |
| return assigned_cpus; |
| } |
|
|
| static inline void PinWorkersToCores(hwy::ThreadPool& pool) { |
| |
| hwy::Topology topology; |
| if (!topology.packages.empty()) { |
| std::vector<size_t> assigned_cpus = |
| AssignCpusFromTopology(topology, pool.NumWorkers()); |
| pool.Run(0, pool.NumWorkers(), |
| [&assigned_cpus](uint64_t , size_t thread) { |
| hwy::PinThreadToLogicalProcessor(assigned_cpus[thread]); |
| }); |
| } else { |
| pool.Run(0, pool.NumWorkers(), [](uint64_t , size_t thread) { |
| hwy::PinThreadToLogicalProcessor(thread); |
| }); |
| } |
| } |
|
|
| class AppArgs : public ArgsBase<AppArgs> { |
| static constexpr size_t kDefaultNumThreads = ~size_t{0}; |
|
|
| void ChooseNumThreads() { |
| if (num_threads == kDefaultNumThreads) { |
| |
| num_threads = GetSupportedThreadCount(); |
| } |
| } |
|
|
| public: |
| AppArgs(int argc, char* argv[]) { |
| InitAndParse(argc, argv); |
| ChooseNumThreads(); |
| } |
|
|
| static inline size_t GetSupportedThreadCount() { |
| return std::clamp(hwy::ThreadPool::MaxThreads(), size_t{1}, |
| std::min(kMaxThreads, size_t{18})); |
| } |
|
|
| Path log; |
| int verbosity; |
| size_t num_threads; |
| std::string eot_line; |
|
|
| template <class Visitor> |
| void ForEach(const Visitor& visitor) { |
| visitor(verbosity, "verbosity", 1, |
| "Show verbose developer information\n 0 = only print generation " |
| "output\n 1 = standard user-facing terminal ui\n 2 = show " |
| "developer/debug info).\n Default = 1.", |
| 2); |
| visitor(num_threads, "num_threads", |
| kDefaultNumThreads, |
| "Number of threads to use.\n Default = Estimate of the " |
| "number of supported concurrent threads.", |
| 2); |
| visitor( |
| eot_line, "eot_line", std::string(""), |
| "End of turn line. " |
| "When you specify this, the prompt will be all lines " |
| "before the line where only the given string appears.\n Default = " |
| "When a newline is encountered, that signals the end of the turn.", |
| 2); |
| } |
| }; |
|
|
| struct LoaderArgs : public ArgsBase<LoaderArgs> { |
| LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } |
|
|
| |
| const char* Validate() { |
| if (const char* err = ParseModelTypeAndTraining(model_type_str, model_type_, |
| model_training_)) { |
| return err; |
| } |
| if (const char* err = ParseType(weight_type_str, weight_type_)) { |
| return err; |
| } |
| if (tokenizer.path.empty()) { |
| return "Missing --tokenizer flag, a file for the tokenizer is required."; |
| } |
| if (!tokenizer.Exists()) { |
| return "Can't open file specified with --tokenizer flag."; |
| } |
| if (!compressed_weights.path.empty()) { |
| if (weights.path.empty()) { |
| weights = compressed_weights; |
| } else { |
| return "Only one of --weights and --compressed_weights can be " |
| "specified. To create compressed weights use the " |
| "compress_weights tool."; |
| } |
| } |
| if (weights.path.empty()) { |
| return "Missing --weights flag, a file for the model weights."; |
| } |
| if (!weights.Exists()) { |
| return "Can't open file specified with --weights flag."; |
| } |
| return nullptr; |
| } |
|
|
| Path tokenizer; |
| Path weights; |
| Path compressed_weights; |
| std::string model_type_str; |
| std::string weight_type_str; |
|
|
| template <class Visitor> |
| void ForEach(const Visitor& visitor) { |
| visitor(tokenizer, "tokenizer", Path(), |
| "Path name of tokenizer model file.\n Required argument."); |
| visitor(weights, "weights", Path(), |
| "Path name of model weights (.sbs) file.\n Required argument."); |
| visitor(compressed_weights, "compressed_weights", Path(), |
| "Alias for --weights."); |
| visitor(model_type_str, "model", std::string(), |
| "Model type\n 2b-it = 2B parameters, instruction-tuned\n " |
| "2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters " |
| "instruction-tuned\n 7b-pt = 7B parameters, pretrained\n " |
| "gr2b-it = griffin 2B parameters, instruction-tuned\n " |
| "gr2b-pt = griffin 2B parameters, pretrained\n " |
| " Required argument."); |
| visitor(weight_type_str, "weight_type", std::string("sfp"), |
| "Weight type\n f32 = float, bf16 = bfloat16, SFP = 8-bit FP\n" |
| " Required argument."); |
| } |
|
|
| |
| gcpp::Model ModelType() const { return model_type_; } |
| gcpp::ModelTraining ModelTrainingType() const { return model_training_; } |
| gcpp::Type WeightType() const { return weight_type_; } |
|
|
| private: |
| Model model_type_; |
| ModelTraining model_training_; |
| Type weight_type_; |
| }; |
|
|
| static inline Gemma CreateGemma(const LoaderArgs& loader, |
| hwy::ThreadPool& pool) { |
| return Gemma(loader.tokenizer, loader.weights, loader.ModelType(), |
| loader.WeightType(), pool); |
| } |
|
|
| static inline std::unique_ptr<Gemma> AllocateGemma(const LoaderArgs& loader, |
| hwy::ThreadPool& pool) { |
| return std::make_unique<Gemma>(loader.tokenizer, loader.weights, |
| loader.ModelType(), loader.WeightType(), pool); |
| } |
|
|
| struct InferenceArgs : public ArgsBase<InferenceArgs> { |
| InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } |
|
|
| size_t max_tokens; |
| size_t max_generated_tokens; |
|
|
| float temperature; |
| bool deterministic; |
| bool multiturn; |
|
|
| |
| const char* Validate() const { |
| if (max_tokens > gcpp::kSeqLen) { |
| return "max_tokens is larger than the maximum sequence length (see " |
| "configs.h)."; |
| } |
| if (max_generated_tokens > max_tokens) { |
| return "Maximum number of generated tokens is larger than the maximum " |
| "total tokens."; |
| } |
| return nullptr; |
| } |
|
|
| template <class Visitor> |
| void ForEach(const Visitor& visitor) { |
| visitor(max_tokens, "max_tokens", size_t{3072}, |
| "Maximum number of tokens in prompt + generation."); |
| visitor(max_generated_tokens, "max_generated_tokens", size_t{2048}, |
| "Maximum number of tokens to generate."); |
|
|
| visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2); |
| visitor(deterministic, "deterministic", false, |
| "Make top-k sampling deterministic", 2); |
| visitor(multiturn, "multiturn", false, |
| "Multiturn mode\n 0 = clear KV cache after every " |
| "interaction\n 1 = continue KV cache after every interaction\n " |
| " Default : 0 (conversation " |
| "resets every turn)"); |
| } |
| }; |
|
|
| } |
|
|
| #endif |
|
|