File size: 13,214 Bytes
5f923cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
// Copyright 2025 The ODML Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// ODML pipeline to execute or benchmark LLM graph on device.
//
// The pipeline does the following
// 1) Read the corresponding parameters, weight and model file paths.
// 2) Construct a graph model with the setting.
// 3) Execute model inference and generate the output.
//
// Consider run_llm_inference_engine.sh as an example to run on android device.

#include <fstream>
#include <iostream>
#include <memory>
#include <optional>
#include <set>
#include <sstream>
#include <string>
#include <vector>

#include "absl/flags/flag.h"  // from @com_google_absl
#include "absl/flags/parse.h"  // from @com_google_absl
#include "absl/log/absl_check.h"  // from @com_google_absl
#include "absl/log/absl_log.h"  // from @com_google_absl
#include "absl/status/status.h"  // from @com_google_absl
#include "absl/status/statusor.h"  // from @com_google_absl
#include "absl/strings/numbers.h"  // from @com_google_absl
#include "absl/strings/str_cat.h"  // from @com_google_absl
#include "absl/strings/string_view.h"  // from @com_google_absl
#include "runtime/engine/litert_lm_lib.h"
#include "runtime/engine/shared_flags.h"
#include "runtime/proto/litert_lm_metrics.pb.h"
#include "runtime/util/metrics_util.h"
#include "runtime/util/status_macros.h"

ABSL_FLAG(std::string, backend, "cpu",
          "Executor backend to use for LLM execution (cpu, gpu, etc.)");
ABSL_FLAG(std::string, model_path, "", "Model path to use for LLM execution.");
ABSL_FLAG(
    bool, load_model_from_descriptor, false,
    "Whether to load the model from a file descriptor rather than by path.");
ABSL_FLAG(std::string, input_prompt, "",
          "Input prompt to use for testing LLM execution.");
ABSL_FLAG(std::string, input_prompt_file, "", "File path to the input prompt.");
ABSL_FLAG(std::string, metric_proto_file_path, "",
          "Path to the file where the benchmark metrics will be saved in "
          "protobuf format. Only collected when --benchmark is true.");
ABSL_FLAG(int, prefill_chunk_size, -1,
          "Prefill chunk size for LLM execution. A positive value enables "
          "breaking the input prefill sequence into smaller chunks for "
          "incremental processing. For example, a chunk size of 128 with an "
          "input length of 300 results in 3 chunks: 128, 128, and 44 tokens. "
          "A value of -1 disables chunking. Only supported by the dynamic "
          "executor.");
ABSL_FLAG(bool, use_session, false,
          "If true, use Session instead of Conversation to run inference. "
          "Note that session does not use Jinja templates. As such, if using "
          "Jinja in LLM Metadata, the user is responsible for manually "
          "applying the prompt template to the input prompt.");

namespace {

absl::StatusOr<std::set<int>> ParsePrefillBatchSizes(
    const std::vector<std::string>& prefill_batch_sizes) {
  std::set<int> parsed_prefill_batch_sizes;
  for (const auto& prefill_batch_size : prefill_batch_sizes) {
    int size;
    if (!absl::SimpleAtoi(prefill_batch_size, &size)) {
      return absl::InvalidArgumentError(
          absl::StrCat("Invalid prefill batch size: ", prefill_batch_size));
    }
    parsed_prefill_batch_sizes.insert(size);
  }
  return parsed_prefill_batch_sizes;
}

std::string GetInputPrompt() {
  const std::string input_prompt = absl::GetFlag(FLAGS_input_prompt);
  const std::string input_prompt_file = absl::GetFlag(FLAGS_input_prompt_file);
  if (!input_prompt.empty() && !input_prompt_file.empty()) {
    ABSL_LOG(FATAL) << "Only one of --input_prompt and --input_prompt_file can "
                       "be specified. Currently both are specified as "
                    << input_prompt << " and " << input_prompt_file;
  }
  if (!input_prompt.empty()) {
    return input_prompt;
  }
  if (!input_prompt_file.empty()) {
    std::ifstream file(input_prompt_file);
    if (!file.is_open()) {
      std::cerr << "Error: Could not open file " << input_prompt_file
                << std::endl;
      return "";
    }
    std::stringstream buffer;
    buffer << file.rdbuf();
    return buffer.str();
  }
  // If no input prompt is provided, use the default prompt.
  return "What is the tallest building in the world?";
}

// Writes the metrics to the given file path in protobuf format. Only used in
// benchmark mode when the metric file path is specified.
absl::Status WriteMetricsToFile(
    const std::vector<litert::lm::LitertLmMetrics>& metrics,
    const std::string& file_path) {
  if (metrics.empty()) {
    return absl::InvalidArgumentError("No metrics to write.");
  }

  ASSIGN_OR_RETURN(auto proto_list, litert::lm::ToProtoList(metrics));

  std::ofstream out(file_path, std::ios::out | std::ios::binary);
  if (!out) {
    return absl::InternalError(
        absl::StrCat("Failed to open metric file: ", file_path));
  }
  if (!proto_list.SerializeToOstream(&out)) {
    return absl::InternalError("Failed to serialize metrics to file.");
  }
  ABSL_LOG(INFO) << "Metrics written to: " << file_path;
  return absl::OkStatus();
}

absl::Status MainHelper(int argc, char** argv) {
  absl::ParseCommandLine(argc, argv);

  if (argc <= 1) {
    ABSL_LOG(INFO)
        << "Example usage: ./litert_lm_main --model_path=<model_path> "
           "[--input_prompt=<input_prompt>] "
           "[--input_prompt_file=<input_prompt_file>] "
           "[--expected_output=<expected_output>] [--backend=<cpu|gpu|npu>] "
           "[--log_sink_file=<log_sink_file>] "
           "[--max_num_tokens=<max_num_tokens>] "
           "[--prefill_batch_sizes=<size1>[,<size2>,...]]"
           "[--prefill_chunk_size=<prefill_chunk_size>] "
           "[--vision_backend=<cpu|gpu>] [--audio_backend=<cpu|gpu>] "
           "[--sampler_backend=<cpu|gpu>] [--benchmark] "
           "[--benchmark_prefill_tokens=<num_prefill_tokens>] "
           "[--benchmark_decode_tokens=<num_decode_tokens>] "
           "[--async=<true|false>] [--force_f32=<true|false] "
           "[--report_peak_memory_footprint] [--multi_turns=<true|false>] "
           "[--num_cpu_threads=<num_cpu_threads>] "
           "[--gpu_external_tensor_mode=<true|false>] "
           "[--configure_magic_numbers=<true|false>] "
           "[--verify_magic_numbers=<true|false>] "
           "[--clear_kv_cache_before_prefill=<true|false>] "
           "[--num_logits_to_print_after_decode=<num_logits_to_print>]"
           "[--score_target_text=<target_text>]"
           "[--gpu_madvise_original_shared_tensors=<true|false>]"
           "[--preferred_device_substr=<device_substr>]"
           "[--num_threads_to_upload=<num_threads_to_upload>]"
           "[--num_threads_to_compile=<num_threads_to_compile>]"
           "[--convert_weights_on_gpu=<true|false>]"
           "[--wait_for_weights_conversion_complete_in_benchmark=<true|false>]"
           "[--optimize_shader_compilation=<true|false>]"
           "[--share_constant_tensors=<true|false>]"
           "[--num_iterations=<num_iterations>]"
           "[--litert_dispatch_lib_dir=<litert_dispatch_lib_dir>]"
           "[--sampler_handles_input=<true|false>]"
           "[--disable_cache=<true|false>]"
           "[--cache_compiled_shader_only=<true|false>]"
           "[--conv_type=<auto|float|int8>]"
           "[--enable_speculative_decoding=<true|false>]";
    ABSL_LOG(INFO)
        << "To provide data for multimodality, use [image:/path/to/image.jpg] "
           "or [audio:/path/to/audio.wav] in the input prompt. e.g. \"Describe "
           "the image: [image:/path/to/image.jpg]\", or \"Transcribe the audio "
           "[audio:/path/to/audio.wav]\"";
    return absl::InvalidArgumentError("No arguments provided.");
  }

  litert::lm::LiteRtLmSettings settings;

  settings.backend = absl::GetFlag(FLAGS_backend);
  settings.vision_backend = absl::GetFlag(FLAGS_vision_backend);
  settings.audio_backend = absl::GetFlag(FLAGS_audio_backend);
  settings.sampler_backend = absl::GetFlag(FLAGS_sampler_backend);
  settings.model_path = absl::GetFlag(FLAGS_model_path);
  settings.load_model_from_descriptor =
      absl::GetFlag(FLAGS_load_model_from_descriptor);
  settings.input_prompt = GetInputPrompt();
  settings.expected_output = absl::GetFlag(FLAGS_expected_output);
  settings.log_sink_file = absl::GetFlag(FLAGS_log_sink_file);
  settings.max_num_tokens = absl::GetFlag(FLAGS_max_num_tokens);
  settings.max_output_tokens = absl::GetFlag(FLAGS_max_output_tokens);
  settings.max_num_images = absl::GetFlag(FLAGS_max_num_images);
  ASSIGN_OR_RETURN(
      settings.prefill_batch_sizes,
      ParsePrefillBatchSizes(absl::GetFlag(FLAGS_prefill_batch_sizes)));
  settings.prefill_chunk_size = absl::GetFlag(FLAGS_prefill_chunk_size);
  settings.num_output_candidates = absl::GetFlag(FLAGS_num_output_candidates);
  settings.benchmark = absl::GetFlag(FLAGS_benchmark);
  settings.benchmark_prefill_tokens =
      absl::GetFlag(FLAGS_benchmark_prefill_tokens);
  settings.benchmark_decode_tokens =
      absl::GetFlag(FLAGS_benchmark_decode_tokens);
  settings.async = absl::GetFlag(FLAGS_async);
  settings.report_peak_memory_footprint =
      absl::GetFlag(FLAGS_report_peak_memory_footprint);
  settings.force_f32 = absl::GetFlag(FLAGS_force_f32);
  settings.multi_turns = absl::GetFlag(FLAGS_multi_turns);
  settings.num_cpu_threads = absl::GetFlag(FLAGS_num_cpu_threads);
  settings.gpu_external_tensor_mode =
      absl::GetFlag(FLAGS_gpu_external_tensor_mode);
  settings.configure_magic_numbers =
      absl::GetFlag(FLAGS_configure_magic_numbers);
  settings.verify_magic_numbers = absl::GetFlag(FLAGS_verify_magic_numbers);
  settings.clear_kv_cache_before_prefill =
      absl::GetFlag(FLAGS_clear_kv_cache_before_prefill);
  settings.num_logits_to_print_after_decode =
      absl::GetFlag(FLAGS_num_logits_to_print_after_decode);
  settings.score_target_text = absl::GetFlag(FLAGS_score_target_text);
  settings.gpu_madvise_original_shared_tensors =
      absl::GetFlag(FLAGS_gpu_madvise_original_shared_tensors);
  settings.disable_cache = absl::GetFlag(FLAGS_disable_cache);
  settings.cache_compiled_shaders_only =
      absl::GetFlag(FLAGS_cache_compiled_shaders_only);
  settings.preferred_device_substr =
      absl::GetFlag(FLAGS_preferred_device_substr);
  settings.num_threads_to_upload = absl::GetFlag(FLAGS_num_threads_to_upload);
  settings.num_threads_to_compile = absl::GetFlag(FLAGS_num_threads_to_compile);
  settings.convert_weights_on_gpu = absl::GetFlag(FLAGS_convert_weights_on_gpu);
  settings.wait_for_weights_conversion_complete_in_benchmark =
      absl::GetFlag(FLAGS_wait_for_weights_conversion_complete_in_benchmark);
  settings.optimize_shader_compilation =
      absl::GetFlag(FLAGS_optimize_shader_compilation);
  settings.share_constant_tensors = absl::GetFlag(FLAGS_share_constant_tensors);
  settings.use_session = absl::GetFlag(FLAGS_use_session);
  settings.num_iterations = absl::GetFlag(FLAGS_num_iterations);
  settings.litert_dispatch_lib_dir =
      absl::GetFlag(FLAGS_litert_dispatch_lib_dir);
  settings.sampler_handles_input = absl::GetFlag(FLAGS_sampler_handles_input);
  settings.conv_type =
      absl::GetFlag(FLAGS_conv_type) == "float"  ? litert::lm::ConvType::kFloat
      : absl::GetFlag(FLAGS_conv_type) == "int8" ? litert::lm::ConvType::kInt8
                                                 : litert::lm::ConvType::kAuto;
  settings.constraint_regex = absl::GetFlag(FLAGS_constraint_regex);
  settings.use_submodel = absl::GetFlag(FLAGS_use_submodel);
  settings.enable_speculative_decoding =
      absl::GetFlag(FLAGS_enable_speculative_decoding);

  // Adjust max_num_tokens and prefill_batch_size if not set on benchmark mode.
  if (settings.benchmark && settings.benchmark_prefill_tokens > 0) {
    if (settings.max_num_tokens == 0 && settings.benchmark_decode_tokens > 0) {
      settings.max_num_tokens =
          settings.benchmark_prefill_tokens + settings.benchmark_decode_tokens;
    }
    if (settings.prefill_batch_sizes.empty()) {
      settings.prefill_batch_sizes.insert(settings.benchmark_prefill_tokens);
    }
  }

  std::vector<litert::lm::LitertLmMetrics> metrics;
  const std::string metric_proto_file_path =
      absl::GetFlag(FLAGS_metric_proto_file_path);
  const bool collect_metrics =
      (settings.benchmark && !metric_proto_file_path.empty());

  RETURN_IF_ERROR(
      litert::lm::RunLiteRtLm(settings, collect_metrics ? &metrics : nullptr));

  if (collect_metrics) {
    RETURN_IF_ERROR(WriteMetricsToFile(metrics, metric_proto_file_path));
  }

  return absl::OkStatus();
}

}  // namespace

int main(int argc, char** argv) {
  ABSL_CHECK_OK(MainHelper(argc, argv));
  return 0;
}