LiteRT-LM / runtime /engine /litert_lm_main.cc
SeaWolf-AI's picture
Upload full LiteRT-LM codebase
5f923cd verified
// Copyright 2025 The ODML Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// ODML pipeline to execute or benchmark LLM graph on device.
//
// The pipeline does the following
// 1) Read the corresponding parameters, weight and model file paths.
// 2) Construct a graph model with the setting.
// 3) Execute model inference and generate the output.
#include <fstream>
#include <iostream>
#include <memory>
#include <sstream>
#include <string>
#include <utility>
#include <variant>
#include "absl/base/log_severity.h" // from @com_google_absl
#include "absl/flags/flag.h" // from @com_google_absl
#include "absl/flags/parse.h" // from @com_google_absl
#include "absl/functional/any_invocable.h" // from @com_google_absl
#include "absl/log/absl_check.h" // from @com_google_absl
#include "absl/log/absl_log.h" // from @com_google_absl
#include "absl/log/globals.h" // from @com_google_absl
#include "absl/status/status.h" // from @com_google_absl
#include "absl/status/statusor.h" // from @com_google_absl
#include "absl/strings/string_view.h" // from @com_google_absl
#include "absl/time/time.h" // from @com_google_absl
#include "nlohmann/json.hpp" // from @nlohmann_json
#include "litert/cc/internal/scoped_file.h" // from @litert
#include "runtime/conversation/conversation.h"
#include "runtime/conversation/io_types.h"
#include "runtime/engine/engine.h"
#include "runtime/engine/engine_factory.h"
#include "runtime/engine/engine_settings.h"
#include "runtime/engine/io_types.h"
#include "runtime/executor/executor_settings_base.h"
#include "runtime/util/status_macros.h"
ABSL_FLAG(std::string, backend, "gpu",
"Executor backend to use for LLM execution (cpu, gpu, etc.)");
ABSL_FLAG(std::string, model_path, "", "Model path to use for LLM execution.");
ABSL_FLAG(std::string, input_prompt, "",
"Input prompt to use for testing LLM execution.");
ABSL_FLAG(std::string, input_prompt_file, "", "File path to the input prompt.");
namespace {
using ::litert::lm::Backend;
using ::litert::lm::Conversation;
using ::litert::lm::ConversationConfig;
using ::litert::lm::Engine;
using ::litert::lm::EngineSettings;
using ::litert::lm::InputData;
using ::litert::lm::JsonMessage;
using ::litert::lm::Message;
using ::litert::lm::ModelAssets;
using ::nlohmann::json;
absl::AnyInvocable<void(absl::StatusOr<Message>)> CreateMessageCallback() {
return [](absl::StatusOr<Message> message) {
if (!message.ok()) {
std::cout << "Error: " << message.status() << std::endl;
return;
}
if (std::holds_alternative<JsonMessage>(*message)) {
const auto& json_message = std::get<JsonMessage>(*message);
if (json_message.is_null()) {
std::cout << std::endl << std::flush;
return;
}
for (const auto& content : json_message["content"]) {
std::cout << content["text"].get<std::string>();
}
std::cout << std::flush;
}
};
}
// Gets the input prompt from the command line flag or file.
std::string GetInputPrompt() {
const std::string input_prompt = absl::GetFlag(FLAGS_input_prompt);
const std::string input_prompt_file = absl::GetFlag(FLAGS_input_prompt_file);
if (!input_prompt.empty() && !input_prompt_file.empty()) {
ABSL_LOG(FATAL) << "Only one of --input_prompt and --input_prompt_file can "
"be specified.";
}
if (!input_prompt.empty()) {
return input_prompt;
}
if (!input_prompt_file.empty()) {
std::ifstream file(input_prompt_file);
if (!file.is_open()) {
std::cerr << "Error: Could not open file " << input_prompt_file
<< std::endl;
return "";
}
std::stringstream buffer;
buffer << file.rdbuf();
return buffer.str();
}
// If no input prompt is provided, use the default prompt.
return "What is the tallest building in the world?";
}
absl::Status MainHelper(int argc, char** argv) {
absl::ParseCommandLine(argc, argv);
// Overrides the default for FLAGS_minloglevel to error.
absl::SetMinLogLevel(absl::LogSeverityAtLeast::kError);
absl::SetStderrThreshold(absl::LogSeverityAtLeast::kFatal);
const std::string model_path = absl::GetFlag(FLAGS_model_path);
if (model_path.empty()) {
return absl::InvalidArgumentError("Model path is empty.");
}
ASSIGN_OR_RETURN(ModelAssets model_assets, // NOLINT
ModelAssets::Create(model_path));
auto backend_str = absl::GetFlag(FLAGS_backend);
ASSIGN_OR_RETURN(Backend backend,
litert::lm::GetBackendFromString(backend_str));
ASSIGN_OR_RETURN(
EngineSettings engine_settings,
EngineSettings::CreateDefault(std::move(model_assets), backend));
// Enable benchmark by default.
engine_settings.GetMutableBenchmarkParams() =
litert::lm::proto::BenchmarkParams();
// Create the engine.
ASSIGN_OR_RETURN(auto engine, litert::lm::EngineFactory::CreateAny(
std::move(engine_settings)));
// Create the conversation.
std::unique_ptr<Conversation> conversation;
auto session_config = litert::lm::SessionConfig::CreateDefault();
ASSIGN_OR_RETURN(auto conversation_config,
ConversationConfig::Builder()
.SetSessionConfig(session_config)
.Build(*engine));
ASSIGN_OR_RETURN(conversation,
Conversation::Create(*engine, conversation_config));
// Prepare the message to send.
json content_list = json::array();
const std::string input_prompt = GetInputPrompt();
std::cout << "input_prompt: " << input_prompt << std::endl;
content_list.push_back({{"type", "text"}, {"text", input_prompt}});
// Send the message and wait for the response, asynchronously log the
// response.
RETURN_IF_ERROR(conversation->SendMessageAsync(
json::object({{"role", "user"}, {"content", content_list}}),
CreateMessageCallback()));
RETURN_IF_ERROR(engine->WaitUntilDone(absl::Minutes(10)));
// Print the benchmark info.
auto benchmark_info = conversation->GetBenchmarkInfo();
std::cout << std::endl << *benchmark_info << std::endl;
return absl::OkStatus();
}
} // namespace
int main(int argc, char** argv) {
ABSL_CHECK_OK(MainHelper(argc, argv));
return 0;
}