Spaces:
Running
Running
| // Copyright 2025 The ODML Authors. | |
| // | |
| // Licensed under the Apache License, Version 2.0 (the "License"); | |
| // you may not use this file except in compliance with the License. | |
| // You may obtain a copy of the License at | |
| // | |
| // http://www.apache.org/licenses/LICENSE-2.0 | |
| // | |
| // Unless required by applicable law or agreed to in writing, software | |
| // distributed under the License is distributed on an "AS IS" BASIS, | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| // See the License for the specific language governing permissions and | |
| // limitations under the License. | |
| // ODML pipeline to execute or benchmark LLM graph on device. | |
| // | |
| // The pipeline does the following | |
| // 1) Read the corresponding parameters, weight and model file paths. | |
| // 2) Construct a graph model with the setting. | |
| // 3) Execute model inference and generate the output. | |
| ABSL_FLAG(std::string, backend, "gpu", | |
| "Executor backend to use for LLM execution (cpu, gpu, etc.)"); | |
| ABSL_FLAG(std::string, model_path, "", "Model path to use for LLM execution."); | |
| ABSL_FLAG(std::string, input_prompt, "", | |
| "Input prompt to use for testing LLM execution."); | |
| ABSL_FLAG(std::string, input_prompt_file, "", "File path to the input prompt."); | |
| namespace { | |
| using ::litert::lm::Backend; | |
| using ::litert::lm::Conversation; | |
| using ::litert::lm::ConversationConfig; | |
| using ::litert::lm::Engine; | |
| using ::litert::lm::EngineSettings; | |
| using ::litert::lm::InputData; | |
| using ::litert::lm::JsonMessage; | |
| using ::litert::lm::Message; | |
| using ::litert::lm::ModelAssets; | |
| using ::nlohmann::json; | |
| absl::AnyInvocable<void(absl::StatusOr<Message>)> CreateMessageCallback() { | |
| return [](absl::StatusOr<Message> message) { | |
| if (!message.ok()) { | |
| std::cout << "Error: " << message.status() << std::endl; | |
| return; | |
| } | |
| if (std::holds_alternative<JsonMessage>(*message)) { | |
| const auto& json_message = std::get<JsonMessage>(*message); | |
| if (json_message.is_null()) { | |
| std::cout << std::endl << std::flush; | |
| return; | |
| } | |
| for (const auto& content : json_message["content"]) { | |
| std::cout << content["text"].get<std::string>(); | |
| } | |
| std::cout << std::flush; | |
| } | |
| }; | |
| } | |
| // Gets the input prompt from the command line flag or file. | |
| std::string GetInputPrompt() { | |
| const std::string input_prompt = absl::GetFlag(FLAGS_input_prompt); | |
| const std::string input_prompt_file = absl::GetFlag(FLAGS_input_prompt_file); | |
| if (!input_prompt.empty() && !input_prompt_file.empty()) { | |
| ABSL_LOG(FATAL) << "Only one of --input_prompt and --input_prompt_file can " | |
| "be specified."; | |
| } | |
| if (!input_prompt.empty()) { | |
| return input_prompt; | |
| } | |
| if (!input_prompt_file.empty()) { | |
| std::ifstream file(input_prompt_file); | |
| if (!file.is_open()) { | |
| std::cerr << "Error: Could not open file " << input_prompt_file | |
| << std::endl; | |
| return ""; | |
| } | |
| std::stringstream buffer; | |
| buffer << file.rdbuf(); | |
| return buffer.str(); | |
| } | |
| // If no input prompt is provided, use the default prompt. | |
| return "What is the tallest building in the world?"; | |
| } | |
| absl::Status MainHelper(int argc, char** argv) { | |
| absl::ParseCommandLine(argc, argv); | |
| // Overrides the default for FLAGS_minloglevel to error. | |
| absl::SetMinLogLevel(absl::LogSeverityAtLeast::kError); | |
| absl::SetStderrThreshold(absl::LogSeverityAtLeast::kFatal); | |
| const std::string model_path = absl::GetFlag(FLAGS_model_path); | |
| if (model_path.empty()) { | |
| return absl::InvalidArgumentError("Model path is empty."); | |
| } | |
| ASSIGN_OR_RETURN(ModelAssets model_assets, // NOLINT | |
| ModelAssets::Create(model_path)); | |
| auto backend_str = absl::GetFlag(FLAGS_backend); | |
| ASSIGN_OR_RETURN(Backend backend, | |
| litert::lm::GetBackendFromString(backend_str)); | |
| ASSIGN_OR_RETURN( | |
| EngineSettings engine_settings, | |
| EngineSettings::CreateDefault(std::move(model_assets), backend)); | |
| // Enable benchmark by default. | |
| engine_settings.GetMutableBenchmarkParams() = | |
| litert::lm::proto::BenchmarkParams(); | |
| // Create the engine. | |
| ASSIGN_OR_RETURN(auto engine, litert::lm::EngineFactory::CreateAny( | |
| std::move(engine_settings))); | |
| // Create the conversation. | |
| std::unique_ptr<Conversation> conversation; | |
| auto session_config = litert::lm::SessionConfig::CreateDefault(); | |
| ASSIGN_OR_RETURN(auto conversation_config, | |
| ConversationConfig::Builder() | |
| .SetSessionConfig(session_config) | |
| .Build(*engine)); | |
| ASSIGN_OR_RETURN(conversation, | |
| Conversation::Create(*engine, conversation_config)); | |
| // Prepare the message to send. | |
| json content_list = json::array(); | |
| const std::string input_prompt = GetInputPrompt(); | |
| std::cout << "input_prompt: " << input_prompt << std::endl; | |
| content_list.push_back({{"type", "text"}, {"text", input_prompt}}); | |
| // Send the message and wait for the response, asynchronously log the | |
| // response. | |
| RETURN_IF_ERROR(conversation->SendMessageAsync( | |
| json::object({{"role", "user"}, {"content", content_list}}), | |
| CreateMessageCallback())); | |
| RETURN_IF_ERROR(engine->WaitUntilDone(absl::Minutes(10))); | |
| // Print the benchmark info. | |
| auto benchmark_info = conversation->GetBenchmarkInfo(); | |
| std::cout << std::endl << *benchmark_info << std::endl; | |
| return absl::OkStatus(); | |
| } | |
| } // namespace | |
| int main(int argc, char** argv) { | |
| ABSL_CHECK_OK(MainHelper(argc, argv)); | |
| return 0; | |
| } | |