| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | #include <fstream> |
| | #include <iostream> |
| | #include <string> |
| | #include <vector> |
| |
|
| | #include "compression/io.h" |
| | #include "gemma/benchmark_helper.h" |
| | #include "gemma/gemma.h" |
| | #include "util/args.h" |
| | #include "hwy/base.h" |
| | #include "nlohmann/json.hpp" |
| |
|
| | using json = nlohmann::json; |
| |
|
| | namespace gcpp { |
| |
|
| | class PromptArgs : public ArgsBase<PromptArgs> { |
| | public: |
| | PromptArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } |
| |
|
| | Path layers_output; |
| | std::string prompt; |
| |
|
| | |
| | const char* Validate() const { |
| | if (prompt.empty()) return "Must specify --prompt"; |
| | return nullptr; |
| | } |
| |
|
| | template <class Visitor> |
| | void ForEach(const Visitor& visitor) { |
| | visitor(layers_output, "layers_output", Path(""), |
| | "Path to store layers output", 2); |
| | visitor(prompt, "prompt", std::string(""), "Prompt to the model", 2); |
| | } |
| | }; |
| |
|
| | int Run(int argc, char** argv) { |
| | PromptArgs prompt_args(argc, argv); |
| | AbortIfInvalidArgs(prompt_args); |
| |
|
| | json json_output; |
| | GemmaEnv env(argc, argv); |
| | env.MutableConfig().layers_output = |
| | prompt_args.layers_output.Empty() |
| | ? LayersOutputFunc() |
| | : [&json_output](int pos, const std::string& key, const float* values, |
| | size_t values_len) { |
| | std::vector<float> v{values, values + values_len}; |
| | json_output[std::to_string(pos)][key] = v; |
| | }; |
| |
|
| | const auto [answer, token_count] = env.QueryModel(prompt_args.prompt); |
| | std::cout << answer.substr(prompt_args.prompt.size()) << "\n" << std::flush; |
| |
|
| | if (env.MutableConfig().layers_output) { |
| | std::ofstream output_f(prompt_args.layers_output.path, std::ofstream::out); |
| | if (!output_f) HWY_ABORT("Opening layer output file failed"); |
| | output_f << json_output.dump(); |
| | if (!output_f) HWY_ABORT("Writing to layer output file failed"); |
| | output_f.close(); |
| | } |
| | return 0; |
| | } |
| |
|
| | } |
| |
|
| | int main(int argc, char** argv) { return gcpp::Run(argc, argv); } |
| |
|