| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| #include <iostream> |
| #include <random> |
| #include <string> |
| #include <string_view> |
| #include <vector> |
|
|
| |
| #include "gemma/benchmark_helper.h" |
| #include "gemma/common.h" |
| #include "gemma/gemma.h" |
| #include "util/app.h" |
| #include "util/args.h" |
| #include "hwy/base.h" |
| #include "hwy/contrib/thread_pool/thread_pool.h" |
| #include "hwy/highway.h" |
| #include "hwy/profiler.h" |
|
|
| #if (!defined(HWY_VERSION_LT) || HWY_VERSION_LT(1, 2)) && !HWY_IDE |
| #error "Please update to version 1.2 of github.com/google/highway." |
| #endif |
| #if HWY_CXX_LANG < 201703L |
| #error "Gemma.cpp requires C++17, please pass -std=c++17." |
| #endif |
|
|
| static constexpr bool kVerboseLogTokens = false; |
|
|
| namespace gcpp { |
|
|
| static constexpr std::string_view kAsciiArtBanner = R""( |
| __ _ ___ _ __ ___ _ __ ___ __ _ ___ _ __ _ __ |
| / _` |/ _ \ '_ ` _ \| '_ ` _ \ / _` | / __| '_ \| '_ \ |
| | (_| | __/ | | | | | | | | | | (_| || (__| |_) | |_) | |
| \__, |\___|_| |_| |_|_| |_| |_|\__,_(_)___| .__/| .__/ |
| __/ | | | | | |
| |___/ |_| |_| |
| )""; |
|
|
| |
| void ReplGemma(gcpp::Gemma& model, ModelTraining training, |
| gcpp::KVCache& kv_cache, hwy::ThreadPool& pool, |
| const InferenceArgs& args, int verbosity, |
| const gcpp::AcceptFunc& accept_token, std::string& eot_line) { |
| PROFILER_ZONE("Gen.misc"); |
| size_t abs_pos = 0; |
| int current_pos = 0; |
| int prompt_size{}; |
|
|
| std::mt19937 gen; |
| InitGenerator(args, gen); |
|
|
| |
| auto stream_token = [&abs_pos, ¤t_pos, &args, &gen, &prompt_size, |
| &model, verbosity](int token, float) { |
| ++abs_pos; |
| ++current_pos; |
| |
| if (current_pos <= prompt_size) { |
| std::cerr << "." << std::flush; |
| } else if (token == gcpp::EOS_ID) { |
| if (!args.multiturn) { |
| abs_pos = 0; |
| if (args.deterministic) { |
| gen.seed(42); |
| } |
| } |
| if (verbosity >= 2) { |
| std::cout << "\n[ End ]\n"; |
| } |
| } else { |
| std::string token_text; |
| HWY_ASSERT( |
| model.Tokenizer().Decode(std::vector<int>{token}, &token_text)); |
| |
| if (current_pos == prompt_size + 1) { |
| |
| token_text.erase(0, token_text.find_first_not_of(" \t\n")); |
| if (verbosity >= 1) { |
| std::cout << "\n\n"; |
| } |
| } |
| std::cout << token_text << std::flush; |
| } |
| return true; |
| }; |
|
|
| while (abs_pos < args.max_tokens) { |
| std::string prompt_string; |
| current_pos = 0; |
| { |
| PROFILER_ZONE("Gen.input"); |
| if (verbosity >= 1) { |
| std::cout << "> " << std::flush; |
| } |
|
|
| if (eot_line.empty()) { |
| std::getline(std::cin, prompt_string); |
| } else { |
| std::string line; |
| while (std::getline(std::cin, line)) { |
| if (line == eot_line) { |
| break; |
| } |
| prompt_string += line + "\n"; |
| } |
| } |
| } |
|
|
| if (std::cin.fail() || prompt_string == "%q" || prompt_string == "%Q") { |
| return; |
| } |
|
|
| if (prompt_string == "%c" || prompt_string == "%C") { |
| abs_pos = 0; |
| continue; |
| } |
|
|
| const std::vector<int> prompt = |
| WrapAndTokenize(model.Tokenizer(), training, abs_pos, prompt_string); |
| prompt_size = prompt.size(); |
| std::cerr << "\n" |
| << "[ Reading prompt ] " << std::flush; |
| if constexpr (kVerboseLogTokens) { |
| for (int i = 0; i < prompt_size; ++i) { |
| fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]); |
| } |
| } |
|
|
| TimingInfo timing_info; |
| gcpp::RuntimeConfig runtime_config = { |
| .max_tokens = args.max_tokens, |
| .max_generated_tokens = args.max_generated_tokens, |
| .temperature = args.temperature, |
| .verbosity = verbosity, |
| .gen = &gen, |
| .stream_token = stream_token, |
| .accept_token = accept_token, |
| }; |
| model.Generate(runtime_config, prompt, abs_pos, kv_cache, timing_info); |
| if (verbosity >= 2) { |
| std::cout << current_pos << " tokens (" << abs_pos << " total tokens)" |
| << "\n" |
| << timing_info.prefill_tok_sec << " prefill tokens / sec" |
| << "\n" |
| << timing_info.gen_tok_sec << " tokens / sec" << "\n" |
| << static_cast<int>(timing_info.time_to_first_token * 1000) |
| << " milliseconds time to first token" << "\n"; |
| } |
| std::cout << "\n\n"; |
| } |
| std::cout |
| << "max_tokens (" << args.max_tokens |
| << ") exceeded. Use a larger value if desired using the --max_tokens " |
| << "command line flag.\n"; |
| } |
|
|
| void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { |
| PROFILER_ZONE("Run.misc"); |
|
|
| hwy::ThreadPool pool(app.num_threads); |
| |
| if (app.num_threads > 10) { |
| PinWorkersToCores(pool); |
| } |
|
|
| gcpp::Gemma model = gcpp::CreateGemma(loader, pool); |
| KVCache kv_cache = KVCache::Create(loader.ModelType()); |
|
|
| if (app.verbosity >= 1) { |
| const std::string instructions = |
| std::string( |
| "*Usage*\n" |
| " Enter an instruction and press enter (%C resets conversation, " |
| "%Q quits).\n") |
| .append( |
| (inference.multiturn == 0 |
| ? std::string( |
| " Since multiturn is set to 0, conversation will " |
| "automatically reset every turn.\n\n") |
| : "\n")) |
| .append( |
| "*Examples*\n" |
| " - Write an email to grandma thanking her for the cookies.\n" |
| " - What are some historical attractions to visit around " |
| "Massachusetts?\n" |
| " - Compute the nth fibonacci number in javascript.\n" |
| " - Write a standup comedy bit about GPU programming.\n"); |
|
|
| std::cout << "\033[2J\033[1;1H" |
| << kAsciiArtBanner << "\n\n"; |
| ShowConfig(loader, inference, app); |
| std::cout << "\n" << instructions << "\n"; |
| } |
|
|
| ReplGemma(model, loader.ModelTrainingType(), kv_cache, pool, inference, |
| app.verbosity, AcceptFunc(), app.eot_line); |
| } |
|
|
| } |
|
|
| int main(int argc, char** argv) { |
| { |
| PROFILER_ZONE("Startup.misc"); |
|
|
| |
|
|
| gcpp::LoaderArgs loader(argc, argv); |
| gcpp::InferenceArgs inference(argc, argv); |
| gcpp::AppArgs app(argc, argv); |
|
|
| if (gcpp::HasHelp(argc, argv)) { |
| std::cerr << gcpp::kAsciiArtBanner; |
| gcpp::ShowHelp(loader, inference, app); |
| return 0; |
| } |
|
|
| if (const char* error = loader.Validate()) { |
| std::cerr << gcpp::kAsciiArtBanner; |
| gcpp::ShowHelp(loader, inference, app); |
| HWY_ABORT("\nInvalid args: %s", error); |
| } |
|
|
| if (const char* error = inference.Validate()) { |
| std::cerr << gcpp::kAsciiArtBanner; |
| gcpp::ShowHelp(loader, inference, app); |
| HWY_ABORT("\nInvalid args: %s", error); |
| } |
|
|
| gcpp::Run(loader, inference, app); |
| } |
| PROFILER_PRINT_RESULTS(); |
| return 0; |
| } |
|
|