// Copyright 2025 The ODML Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ODML pipeline to execute or benchmark LLM graph on device. // // The pipeline does the following // 1) Read the corresponding parameters, weight and model file paths. // 2) Construct a graph model with the setting. // 3) Execute model inference and generate the output. #include #include #include #include #include #include #include #include "absl/base/log_severity.h" // from @com_google_absl #include "absl/flags/flag.h" // from @com_google_absl #include "absl/flags/parse.h" // from @com_google_absl #include "absl/functional/any_invocable.h" // from @com_google_absl #include "absl/log/absl_check.h" // from @com_google_absl #include "absl/log/absl_log.h" // from @com_google_absl #include "absl/log/globals.h" // from @com_google_absl #include "absl/status/status.h" // from @com_google_absl #include "absl/status/statusor.h" // from @com_google_absl #include "absl/strings/string_view.h" // from @com_google_absl #include "absl/time/time.h" // from @com_google_absl #include "nlohmann/json.hpp" // from @nlohmann_json #include "litert/cc/internal/scoped_file.h" // from @litert #include "runtime/conversation/conversation.h" #include "runtime/conversation/io_types.h" #include "runtime/engine/engine.h" #include "runtime/engine/engine_factory.h" #include "runtime/engine/engine_settings.h" #include "runtime/engine/io_types.h" #include "runtime/executor/executor_settings_base.h" #include "runtime/util/status_macros.h" ABSL_FLAG(std::string, backend, "gpu", "Executor backend to use for LLM execution (cpu, gpu, etc.)"); ABSL_FLAG(std::string, model_path, "", "Model path to use for LLM execution."); ABSL_FLAG(std::string, input_prompt, "", "Input prompt to use for testing LLM execution."); ABSL_FLAG(std::string, input_prompt_file, "", "File path to the input prompt."); namespace { using ::litert::lm::Backend; using ::litert::lm::Conversation; using ::litert::lm::ConversationConfig; using ::litert::lm::Engine; using ::litert::lm::EngineSettings; using ::litert::lm::InputData; using ::litert::lm::JsonMessage; using ::litert::lm::Message; using ::litert::lm::ModelAssets; using ::nlohmann::json; absl::AnyInvocable)> CreateMessageCallback() { return [](absl::StatusOr message) { if (!message.ok()) { std::cout << "Error: " << message.status() << std::endl; return; } if (std::holds_alternative(*message)) { const auto& json_message = std::get(*message); if (json_message.is_null()) { std::cout << std::endl << std::flush; return; } for (const auto& content : json_message["content"]) { std::cout << content["text"].get(); } std::cout << std::flush; } }; } // Gets the input prompt from the command line flag or file. std::string GetInputPrompt() { const std::string input_prompt = absl::GetFlag(FLAGS_input_prompt); const std::string input_prompt_file = absl::GetFlag(FLAGS_input_prompt_file); if (!input_prompt.empty() && !input_prompt_file.empty()) { ABSL_LOG(FATAL) << "Only one of --input_prompt and --input_prompt_file can " "be specified."; } if (!input_prompt.empty()) { return input_prompt; } if (!input_prompt_file.empty()) { std::ifstream file(input_prompt_file); if (!file.is_open()) { std::cerr << "Error: Could not open file " << input_prompt_file << std::endl; return ""; } std::stringstream buffer; buffer << file.rdbuf(); return buffer.str(); } // If no input prompt is provided, use the default prompt. return "What is the tallest building in the world?"; } absl::Status MainHelper(int argc, char** argv) { absl::ParseCommandLine(argc, argv); // Overrides the default for FLAGS_minloglevel to error. absl::SetMinLogLevel(absl::LogSeverityAtLeast::kError); absl::SetStderrThreshold(absl::LogSeverityAtLeast::kFatal); const std::string model_path = absl::GetFlag(FLAGS_model_path); if (model_path.empty()) { return absl::InvalidArgumentError("Model path is empty."); } ASSIGN_OR_RETURN(ModelAssets model_assets, // NOLINT ModelAssets::Create(model_path)); auto backend_str = absl::GetFlag(FLAGS_backend); ASSIGN_OR_RETURN(Backend backend, litert::lm::GetBackendFromString(backend_str)); ASSIGN_OR_RETURN( EngineSettings engine_settings, EngineSettings::CreateDefault(std::move(model_assets), backend)); // Enable benchmark by default. engine_settings.GetMutableBenchmarkParams() = litert::lm::proto::BenchmarkParams(); // Create the engine. ASSIGN_OR_RETURN(auto engine, litert::lm::EngineFactory::CreateAny( std::move(engine_settings))); // Create the conversation. std::unique_ptr conversation; auto session_config = litert::lm::SessionConfig::CreateDefault(); ASSIGN_OR_RETURN(auto conversation_config, ConversationConfig::Builder() .SetSessionConfig(session_config) .Build(*engine)); ASSIGN_OR_RETURN(conversation, Conversation::Create(*engine, conversation_config)); // Prepare the message to send. json content_list = json::array(); const std::string input_prompt = GetInputPrompt(); std::cout << "input_prompt: " << input_prompt << std::endl; content_list.push_back({{"type", "text"}, {"text", input_prompt}}); // Send the message and wait for the response, asynchronously log the // response. RETURN_IF_ERROR(conversation->SendMessageAsync( json::object({{"role", "user"}, {"content", content_list}}), CreateMessageCallback())); RETURN_IF_ERROR(engine->WaitUntilDone(absl::Minutes(10))); // Print the benchmark info. auto benchmark_info = conversation->GetBenchmarkInfo(); std::cout << std::endl << *benchmark_info << std::endl; return absl::OkStatus(); } } // namespace int main(int argc, char** argv) { ABSL_CHECK_OK(MainHelper(argc, argv)); return 0; }