| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| #include <stddef.h> |
| #include <stdio.h> |
|
|
| #include <string> |
|
|
| #include "benchmark/benchmark.h" |
| #include "gemma/benchmark_helper.h" |
|
|
| namespace gcpp { |
|
|
| |
| |
| |
| GemmaEnv* s_env = nullptr; |
|
|
| void RunPrompt(const std::string& original_prompt, benchmark::State& state) { |
| size_t total_tokens = 0; |
| for (auto s : state) { |
| std::string prompt = original_prompt; |
| auto [response, n] = s_env->QueryModel(prompt); |
| if (s_env->Verbosity() != 0) { |
| fprintf(stdout, "|%s|\n", response.c_str()); |
| } |
| total_tokens += n; |
| } |
|
|
| state.SetItemsProcessed(total_tokens); |
| } |
|
|
| } |
|
|
| static void BM_short_prompt(benchmark::State& state) { |
| gcpp::RunPrompt("What is the capital of Spain?", state); |
| } |
|
|
| static void BM_factuality_prompt(benchmark::State& state) { |
| gcpp::RunPrompt("How does an inkjet printer work?", state); |
| } |
|
|
| static void BM_creative_prompt(benchmark::State& state) { |
| gcpp::RunPrompt("Tell me a story about a magical bunny and their TRS-80.", |
| state); |
| } |
|
|
| static void BM_coding_prompt(benchmark::State& state) { |
| gcpp::RunPrompt("Write a python program to generate a fibonacci sequence.", |
| state); |
| } |
|
|
| BENCHMARK(BM_short_prompt) |
| ->Iterations(3) |
| ->Unit(benchmark::kMillisecond) |
| ->UseRealTime(); |
|
|
| BENCHMARK(BM_factuality_prompt) |
| ->Iterations(3) |
| ->Unit(benchmark::kMillisecond) |
| ->UseRealTime(); |
|
|
| BENCHMARK(BM_creative_prompt) |
| ->Iterations(3) |
| ->Unit(benchmark::kMillisecond) |
| ->UseRealTime(); |
|
|
| BENCHMARK(BM_coding_prompt) |
| ->Iterations(3) |
| ->Unit(benchmark::kMillisecond) |
| ->UseRealTime(); |
|
|
| int main(int argc, char** argv) { |
| gcpp::GemmaEnv env(argc, argv); |
| env.SetMaxGeneratedTokens(256); |
| gcpp::s_env = &env; |
|
|
| ::benchmark::RunSpecifiedBenchmarks(); |
| ::benchmark::Shutdown(); |
| return 0; |
| } |
|
|