Spaces:
Running
Running
File size: 6,760 Bytes
5f923cd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 | // Copyright 2025 The ODML Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// ODML pipeline to execute or benchmark LLM graph on device.
//
// The pipeline does the following
// 1) Read the corresponding parameters, weight and model file paths.
// 2) Construct a graph model with the setting.
// 3) Execute model inference and generate the output.
#include <fstream>
#include <iostream>
#include <memory>
#include <sstream>
#include <string>
#include <utility>
#include <variant>
#include "absl/base/log_severity.h" // from @com_google_absl
#include "absl/flags/flag.h" // from @com_google_absl
#include "absl/flags/parse.h" // from @com_google_absl
#include "absl/functional/any_invocable.h" // from @com_google_absl
#include "absl/log/absl_check.h" // from @com_google_absl
#include "absl/log/absl_log.h" // from @com_google_absl
#include "absl/log/globals.h" // from @com_google_absl
#include "absl/status/status.h" // from @com_google_absl
#include "absl/status/statusor.h" // from @com_google_absl
#include "absl/strings/string_view.h" // from @com_google_absl
#include "absl/time/time.h" // from @com_google_absl
#include "nlohmann/json.hpp" // from @nlohmann_json
#include "litert/cc/internal/scoped_file.h" // from @litert
#include "runtime/conversation/conversation.h"
#include "runtime/conversation/io_types.h"
#include "runtime/engine/engine.h"
#include "runtime/engine/engine_factory.h"
#include "runtime/engine/engine_settings.h"
#include "runtime/engine/io_types.h"
#include "runtime/executor/executor_settings_base.h"
#include "runtime/util/status_macros.h"
ABSL_FLAG(std::string, backend, "gpu",
"Executor backend to use for LLM execution (cpu, gpu, etc.)");
ABSL_FLAG(std::string, model_path, "", "Model path to use for LLM execution.");
ABSL_FLAG(std::string, input_prompt, "",
"Input prompt to use for testing LLM execution.");
ABSL_FLAG(std::string, input_prompt_file, "", "File path to the input prompt.");
namespace {
using ::litert::lm::Backend;
using ::litert::lm::Conversation;
using ::litert::lm::ConversationConfig;
using ::litert::lm::Engine;
using ::litert::lm::EngineSettings;
using ::litert::lm::InputData;
using ::litert::lm::JsonMessage;
using ::litert::lm::Message;
using ::litert::lm::ModelAssets;
using ::nlohmann::json;
absl::AnyInvocable<void(absl::StatusOr<Message>)> CreateMessageCallback() {
return [](absl::StatusOr<Message> message) {
if (!message.ok()) {
std::cout << "Error: " << message.status() << std::endl;
return;
}
if (std::holds_alternative<JsonMessage>(*message)) {
const auto& json_message = std::get<JsonMessage>(*message);
if (json_message.is_null()) {
std::cout << std::endl << std::flush;
return;
}
for (const auto& content : json_message["content"]) {
std::cout << content["text"].get<std::string>();
}
std::cout << std::flush;
}
};
}
// Gets the input prompt from the command line flag or file.
std::string GetInputPrompt() {
const std::string input_prompt = absl::GetFlag(FLAGS_input_prompt);
const std::string input_prompt_file = absl::GetFlag(FLAGS_input_prompt_file);
if (!input_prompt.empty() && !input_prompt_file.empty()) {
ABSL_LOG(FATAL) << "Only one of --input_prompt and --input_prompt_file can "
"be specified.";
}
if (!input_prompt.empty()) {
return input_prompt;
}
if (!input_prompt_file.empty()) {
std::ifstream file(input_prompt_file);
if (!file.is_open()) {
std::cerr << "Error: Could not open file " << input_prompt_file
<< std::endl;
return "";
}
std::stringstream buffer;
buffer << file.rdbuf();
return buffer.str();
}
// If no input prompt is provided, use the default prompt.
return "What is the tallest building in the world?";
}
absl::Status MainHelper(int argc, char** argv) {
absl::ParseCommandLine(argc, argv);
// Overrides the default for FLAGS_minloglevel to error.
absl::SetMinLogLevel(absl::LogSeverityAtLeast::kError);
absl::SetStderrThreshold(absl::LogSeverityAtLeast::kFatal);
const std::string model_path = absl::GetFlag(FLAGS_model_path);
if (model_path.empty()) {
return absl::InvalidArgumentError("Model path is empty.");
}
ASSIGN_OR_RETURN(ModelAssets model_assets, // NOLINT
ModelAssets::Create(model_path));
auto backend_str = absl::GetFlag(FLAGS_backend);
ASSIGN_OR_RETURN(Backend backend,
litert::lm::GetBackendFromString(backend_str));
ASSIGN_OR_RETURN(
EngineSettings engine_settings,
EngineSettings::CreateDefault(std::move(model_assets), backend));
// Enable benchmark by default.
engine_settings.GetMutableBenchmarkParams() =
litert::lm::proto::BenchmarkParams();
// Create the engine.
ASSIGN_OR_RETURN(auto engine, litert::lm::EngineFactory::CreateAny(
std::move(engine_settings)));
// Create the conversation.
std::unique_ptr<Conversation> conversation;
auto session_config = litert::lm::SessionConfig::CreateDefault();
ASSIGN_OR_RETURN(auto conversation_config,
ConversationConfig::Builder()
.SetSessionConfig(session_config)
.Build(*engine));
ASSIGN_OR_RETURN(conversation,
Conversation::Create(*engine, conversation_config));
// Prepare the message to send.
json content_list = json::array();
const std::string input_prompt = GetInputPrompt();
std::cout << "input_prompt: " << input_prompt << std::endl;
content_list.push_back({{"type", "text"}, {"text", input_prompt}});
// Send the message and wait for the response, asynchronously log the
// response.
RETURN_IF_ERROR(conversation->SendMessageAsync(
json::object({{"role", "user"}, {"content", content_list}}),
CreateMessageCallback()));
RETURN_IF_ERROR(engine->WaitUntilDone(absl::Minutes(10)));
// Print the benchmark info.
auto benchmark_info = conversation->GetBenchmarkInfo();
std::cout << std::endl << *benchmark_info << std::endl;
return absl::OkStatus();
}
} // namespace
int main(int argc, char** argv) {
ABSL_CHECK_OK(MainHelper(argc, argv));
return 0;
}
|