Spaces:
Running
Running
File size: 19,711 Bytes
5f923cd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 | // Copyright 2025 The ODML Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_ODML_LITERT_LM_RUNTIME_FRAMEWORK_RESOURCE_MANAGEMENT_EXECUTION_MANAGER_H_
#define THIRD_PARTY_ODML_LITERT_LM_RUNTIME_FRAMEWORK_RESOURCE_MANAGEMENT_EXECUTION_MANAGER_H_
#include <atomic>
#include <limits>
#include <memory>
#include <optional>
#include <tuple>
#include <utility>
#include <vector>
#include "absl/base/nullability.h" // from @com_google_absl
#include "absl/base/thread_annotations.h" // from @com_google_absl
#include "absl/container/flat_hash_map.h" // from @com_google_absl
#include "absl/container/flat_hash_set.h" // from @com_google_absl
#include "absl/functional/any_invocable.h" // from @com_google_absl
#include "absl/status/status.h" // from @com_google_absl
#include "absl/status/statusor.h" // from @com_google_absl
#include "absl/strings/string_view.h" // from @com_google_absl
#include "absl/synchronization/mutex.h" // from @com_google_absl
#include "absl/time/time.h" // from @com_google_absl
#include "litert/cc/litert_environment.h" // from @litert
#include "runtime/components/constrained_decoding/constraint.h"
#include "runtime/components/model_resources.h"
#include "runtime/components/sampler.h"
#include "runtime/components/stop_token_detector.h"
#include "runtime/components/tokenizer.h"
#include "runtime/engine/engine.h"
#include "runtime/engine/engine_settings.h"
#include "runtime/engine/io_types.h"
#include "runtime/executor/audio_executor.h"
#include "runtime/executor/audio_executor_settings.h"
#include "runtime/executor/llm_executor.h"
#include "runtime/executor/llm_executor_io_types.h"
#include "runtime/executor/vision_executor_settings.h"
#include "runtime/framework/resource_management/context_handler/context_handler.h"
#include "runtime/framework/resource_management/resource_manager.h"
#include "runtime/framework/threadpool.h"
namespace litert::lm {
using SessionId = int;
using TaskId = int;
// All the information about a session.
// - session_config: The config of the session.
// - context_handler: The context handler of the session.
// - sampler: The sampler of the session.
// - last_prefill_token_id: The last prefill token ID of the session.
// - stop_token_detector: The stop token detector of the session.
// - benchmark_info: The benchmark info of the session.
// - active_tasks: The active tasks of the session.
struct SessionInfo {
SessionConfig session_config;
std::shared_ptr<ContextHandler> context_handler;
std::unique_ptr<Sampler> sampler;
int last_prefill_token_id = 0;
std::unique_ptr<StopTokenDetector> stop_token_detector;
std::optional<BenchmarkInfo> benchmark_info = std::nullopt;
absl::flat_hash_set<TaskId> active_tasks = {};
};
// All the information about a task.
// - session_id: The ID of the session that created the task.
// - task: The task function. This is the function that will be executed by the
// execution manager. Will be retrieved and moved by the queue task function.
// - task_state: The state of the task.
// - dependent_tasks: The dependent tasks that should be done before the task
// starts.
// - following_tasks: The following tasks that are waiting for the task to
// finish.
// - callback: The callback function. This is the function that will be called
// when the task is done. Will be retrieved and moved by the start task
// function.
struct TaskInfo {
SessionId session_id;
absl::AnyInvocable<void()> task;
TaskState task_state = TaskState::kUnknown;
absl::flat_hash_set<TaskId> dependent_tasks = {};
absl::flat_hash_set<TaskId> following_tasks = {};
std::shared_ptr<std::atomic<bool>> cancelled = nullptr;
absl::AnyInvocable<void(absl::StatusOr<Responses>)> callback;
};
// The execution manager is responsible for managing the execution of the tasks.
// It will handle the scheduling of the tasks and the dependencies between them.
// Note: The execution manager will create its own threadpool for executing the
// tasks, so thread safety interaction should be handled properly.
class ExecutionManager {
public:
// Creates an ExecutionManager.
// The ExecutionManager will take ownership of the executors and the sampler.
// - tokenizer: The tokenizer used for encoding the text input. This is
// expected to be non-null.
// - llm_executor: The executor used for prefill/decode the LLM. This is
// expected to be non-null.
// - vision_executor_settings: The vision executor settings used for creating
// the vision executor. This can be null if no vision modality is used.
// - audio_executor_settings: The audio executor settings used for creating
// the audio executor. This can be null if no audio modality is used.
// - litert_env: The LIRTER environment used for creating the LLM context.
// This can be null if no LLM context is needed.
static absl::StatusOr<std::unique_ptr<ExecutionManager>> Create(
Tokenizer* absl_nonnull tokenizer,
ModelResources* absl_nullable model_resources,
std::unique_ptr<LlmExecutor> absl_nonnull llm_executor,
std::unique_ptr<VisionExecutorSettings> absl_nullable
vision_executor_settings,
std::unique_ptr<AudioExecutorSettings> absl_nullable
audio_executor_settings,
::litert::Environment* absl_nullable litert_env,
std::unique_ptr<AudioExecutor> absl_nullable audio_executor = nullptr);
~ExecutionManager() {
WaitUntilAllDone(Engine::kDefaultTimeout).IgnoreError();
};
// Waits until the task is done or the timeout is reached.
// Returns:
// - OK if the task is done.
// - DEADLINE_EXCEEDED if the timeout is reached.
// - Other errors if the task is failed.
absl::Status WaitUntilDone(TaskId task_id, absl::Duration timeout)
ABSL_LOCKS_EXCLUDED(session_and_task_lookup_mutex_);
absl::Status WaitUntilSessionDone(SessionId session_id,
absl::Duration timeout)
ABSL_LOCKS_EXCLUDED(session_and_task_lookup_mutex_);
// Waits until all tasks are done or the timeout is reached.
// Returns:
// - OK if all tasks are done.
// - DEADLINE_EXCEEDED if the timeout is reached.
// - Other errors if any of the tasks is failed.
absl::Status WaitUntilAllDone(absl::Duration timeout)
ABSL_LOCKS_EXCLUDED(session_and_task_lookup_mutex_);
// Returns a new session ID.
// The returned session ID is guaranteed to be unique.
absl::StatusOr<SessionId> RegisterNewSession(
SessionConfig session_config,
std::optional<BenchmarkInfo> benchmark_info = std::nullopt)
ABSL_LOCKS_EXCLUDED(session_and_task_lookup_mutex_);
// Releases the session with the given session ID.
absl::Status ReleaseSession(SessionId session_id)
ABSL_LOCKS_EXCLUDED(session_and_task_lookup_mutex_);
// Cancels all tasks in the session with the given session ID.
absl::Status CancelAllTasksInSession(SessionId session_id)
ABSL_LOCKS_EXCLUDED(session_and_task_lookup_mutex_);
// Returns the session info with the given session ID.
// Returns:
// - The session info.
// - INVALID_ARGUMENT if the session ID is not found.
absl::StatusOr<std::shared_ptr<const SessionInfo>> GetSessionInfo(
SessionId session_id) ABSL_LOCKS_EXCLUDED(session_and_task_lookup_mutex_);
// Returns the mutable benchmark info with the given session ID.
// Note: The returned benchmark info is not thread-safe and should be used
// with care to record appropriate metrics.
// Returns:
// - The mutable benchmark info.
// - INVALID_ARGUMENT if the session ID is not found.
absl::StatusOr<BenchmarkInfo*> GetMutableBenchmarkInfo(SessionId session_id)
ABSL_LOCKS_EXCLUDED(session_and_task_lookup_mutex_);
// Returns a new task ID.
// The returned task ID is guaranteed to be unique.
absl::StatusOr<TaskId> GetNewTaskId();
// Adds a prefill task to the execution manager.
// - session_id: The ID of the session that created the task.
// - task_id: The task ID of the task.
// - inputs: The inputs of the prefill task.
// - dep_tasks: The dependent tasks that should be done before the prefill
// task starts.
// - cancelled: The cancelled flag for the prefill task.
// - callback: The callback function.
// Note: AddPrefillTask will acquire the task lookup mutex.
absl::Status AddPrefillTask(
SessionId session_id, TaskId task_id, std::vector<InputData> inputs,
absl::flat_hash_set<TaskId> dep_tasks,
std::shared_ptr<std::atomic<bool>> absl_nonnull cancelled,
absl::AnyInvocable<void(absl::StatusOr<Responses>)> callback)
ABSL_LOCKS_EXCLUDED(session_and_task_lookup_mutex_);
// Adds a decode task to the execution manager.
// - session_id: The ID of the session that created the task.
// - task_id: The task ID of the task.
// - dep_tasks: The dependent tasks that should be done before the decode
// task starts.
// - constraint: The constraint for the decode task.
// - cancelled: The cancelled flag for the decode task.
// - callback: The callback function.
// Note: AddDecodeTask will acquire the task lookup mutex.
absl::Status AddDecodeTask(
SessionId session_id, TaskId task_id,
absl::flat_hash_set<TaskId> dep_tasks,
Constraint* absl_nullable constraint,
std::shared_ptr<std::atomic<bool>> absl_nonnull cancelled,
absl::AnyInvocable<void(absl::StatusOr<Responses>)> callback,
int max_output_tokens = std::numeric_limits<int>::max())
ABSL_LOCKS_EXCLUDED(session_and_task_lookup_mutex_);
// Adds a clone session task to the execution manager.
// - session_id: The ID of the session that created the task.
// - task_id: The task ID of the task.
// - dep_tasks: The dependent tasks that should be done before the clone
// session task starts.
// - cloned_session_id: The ID of the cloned session.
// - callback: The callback function.
// Note: AddCloneSessionTask will acquire the task lookup mutex.
// TODO b/409401231 - Add unit tests for this function.
absl::Status AddCloneSessionTask(
SessionId session_id, TaskId task_id,
absl::flat_hash_set<TaskId> dep_tasks, SessionId cloned_session_id,
std::shared_ptr<std::atomic<bool>> absl_nonnull cancelled,
absl::AnyInvocable<void(absl::StatusOr<Responses>)> callback)
ABSL_LOCKS_EXCLUDED(session_and_task_lookup_mutex_);
// Adds a text scoring task to the execution manager.
// - session_id: The ID of the session that created the task.
// - task_id: The task ID of the task.
// - dep_tasks: The dependent tasks that should be done before the text
// scoring task starts.
// - target_text: The target text to be scored.
// - store_token_lengths: Whether to store the token lengths in the
// responses.
// - cancelled: The cancelled flag for the text scoring task.
// - callback: The callback function.
// Note: AddTextScoringTask will acquire the task lookup mutex.
absl::Status AddTextScoringTask(
SessionId session_id, TaskId task_id,
absl::flat_hash_set<TaskId> dep_tasks,
const std::vector<absl::string_view>& target_text,
bool store_token_lengths,
std::shared_ptr<std::atomic<bool>> absl_nonnull cancelled,
absl::AnyInvocable<void(absl::StatusOr<Responses>)> callback)
ABSL_LOCKS_EXCLUDED(session_and_task_lookup_mutex_);
// Returns the current step of the session.
// - session_info: The session info of the session.
// Returns:
// - The current step of the session.
absl::StatusOr<int> GetCurrentStep(const SessionInfo& session_info);
// Sets the current step of the session to the target step.
// - session_info: The session info of the session.
// - target_step: The step to set the executor's current step to.
// Returns:
// - OK if the current step is set successfully.
// - INVALID_ARGUMENT if the target step is greater than the current step.
absl::Status SetCurrentStep(const SessionInfo& session_info, int target_step);
// Returns the audio executor properties.
absl::StatusOr<AudioExecutorProperties> GetAudioExecutorProperties() const {
return resource_manager_->GetAudioExecutorProperties();
}
// Returns the vision executor properties.
absl::StatusOr<VisionExecutorProperties> GetVisionExecutorProperties() const {
return resource_manager_->GetVisionExecutorProperties();
}
private:
// Private constructor. Use the Create function instead.
ExecutionManager(
Tokenizer* absl_nonnull tokenizer,
std::unique_ptr<ResourceManager> absl_nonnull resource_manager,
::litert::Environment* absl_nullable litert_env = nullptr)
: tokenizer_(std::move(tokenizer)),
resource_manager_(std::move(resource_manager)),
litert_env_(litert_env) {
execution_thread_pool_ =
std::make_unique<ThreadPool>(/*name_prefix=*/"execution_thread_pool",
/*max_num_threads=*/1);
callback_thread_pool_ =
std::make_unique<ThreadPool>(/*name_prefix=*/"callback_thread_pool",
/*max_num_threads=*/1);
}
// Creates a task with the given task ID, task, dependent tasks, and callback.
// - session_id: The ID of the session that created the task.
// - task_id: The task ID of the task.
// - task: The task function.
// - dependent_tasks: The dependent tasks that should be done before the task
// starts.
// - callback: The callback function.
// Note: CreateTask will acquire the task lookup mutex.
absl::Status CreateTask(
SessionId session_id, TaskId task_id,
absl::AnyInvocable<void()> absl_nonnull task,
absl::flat_hash_set<TaskId> dependent_tasks,
std::shared_ptr<std::atomic<bool>> absl_nonnull cancelled,
absl::AnyInvocable<void(absl::StatusOr<Responses>)> absl_nonnull callback)
ABSL_LOCKS_EXCLUDED(session_and_task_lookup_mutex_);
// Queues the task with the given task ID.
// - task_id: The task ID of the task.
// Note: QueueTask expects the callers to acquire the task lookup mutex before
// calling it.
absl::Status QueueTask(TaskId task_id)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(session_and_task_lookup_mutex_);
// Starts the task with the given task ID, and returns the session info and
// callback function of the task.
// - task_id: The task ID of the task.
// Returns:
// - The session info, cancelled flag and callback function of the task.
// Note: StartTask will acquire the task lookup mutex.
absl::StatusOr<std::tuple<
std::shared_ptr<SessionInfo>, std::shared_ptr<std::atomic<bool>>,
absl::AnyInvocable<void(absl::StatusOr<Responses>)>>>
StartTask(TaskId task_id) ABSL_LOCKS_EXCLUDED(session_and_task_lookup_mutex_);
// Finishes the task with the given task ID, responses, and callback.
// - task_id: The task ID of the task.
// - responses: The responses of the task.
// - callback: The callback function.
// Note: FinishTask will acquire the task lookup mutex.
absl::Status FinishTask(
TaskId task_id, absl::StatusOr<Responses> responses,
absl::AnyInvocable<void(absl::StatusOr<Responses>)> absl_nonnull callback)
ABSL_LOCKS_EXCLUDED(session_and_task_lookup_mutex_);
// Finishes the task with the given task ID, responses, and callback. If the
// task fails, the error will be logged.
// - task_id: The task ID of the task.
// - responses: The responses of the task.
// - callback: The callback function.
// Note: FinishTaskAndLogErrors will acquire the task lookup mutex.
void FinishTaskAndLogErrors(
TaskId task_id, absl::StatusOr<Responses> responses,
absl::AnyInvocable<void(absl::StatusOr<Responses>)> absl_nonnull callback)
ABSL_LOCKS_EXCLUDED(session_and_task_lookup_mutex_);
// Returns all following tasks that are waiting.
// - task_id: The task ID of the task.
// Returns:
// - The set of following tasks that are waiting for dependent tasks.
// Note: AllFollowingWaitingTasks expects the callers to acquire the task
// lookup mutex before calling it.
absl::StatusOr<absl::flat_hash_set<TaskId>> FollowingWaitingTasks(
TaskId task_id)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(session_and_task_lookup_mutex_);
// Updates the task state with the given task ID and task state.
// - task_id: The task ID of the task.
// - task_state: The state of the task.
// Note: UpdateTaskState expects the callers to acquire the task lookup mutex
// before calling it.
absl::Status UpdateTaskState(TaskId task_id, TaskState task_state)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(session_and_task_lookup_mutex_);
// Updates all tasks to the given state.
// - task_ids: The task IDs of the tasks.
// - task_state: The state of the tasks.
// Note: UpdateAllTasksToState expects the callers to acquire the task lookup
// mutex before calling it.
absl::Status UpdateAllTasksToState(
const absl::flat_hash_set<TaskId>& task_ids, TaskState task_state)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(session_and_task_lookup_mutex_);
// Processes and combines the contents of the preprocessed contents.
// - preprocessed_contents: The preprocessed contents of the task.
// Returns:
// - The processed and combined contents of the preprocessed contents.
// - benchmark_info: The benchmark info of the session.
absl::StatusOr<ExecutorInputs> ProcessAndCombineContents(
const std::vector<InputData>& preprocessed_contents,
std::optional<BenchmarkInfo>& benchmark_info);
// The session ID.
std::atomic<SessionId> next_session_id_ = 0;
// The next unique task ID.
std::atomic<TaskId> next_task_id_ = 0;
// The mutex for protecting the session and task lookup.
absl::Mutex session_and_task_lookup_mutex_;
// The session lookup map.
// The key is the session ID.
// The value is the session states.
absl::flat_hash_map<SessionId, std::shared_ptr<SessionInfo> absl_nonnull>
session_lookup_ ABSL_GUARDED_BY(session_and_task_lookup_mutex_) = {};
// The task lookup map.
// The key is the task ID.
// The value is the task info.
absl::flat_hash_map<TaskId, TaskInfo> task_lookup_
ABSL_GUARDED_BY(session_and_task_lookup_mutex_) = {};
// TODO b/409401231 - Use LLM Context which is will be wrapped in a session
// state.
int last_prefill_token_id_ = 0;
// The tokenizer used for encoding the text input.
Tokenizer* absl_nonnull tokenizer_;
// The resource manager used for managing the resources.
std::unique_ptr<ResourceManager> absl_nonnull resource_manager_;
// The LIRTER environment used for creating the LLM context.
::litert::Environment* absl_nullable litert_env_;
// The thread pool with a single worker thread used for executing the tasks.
std::unique_ptr<ThreadPool> absl_nonnull execution_thread_pool_;
// The thread pool used for running the callbacks without blocking the
// execution thread pool.
// TODO b/476205457 - Consider updating all the callback triggering to use
// this thread pool, and remove the syncing logic.
std::unique_ptr<ThreadPool> absl_nonnull callback_thread_pool_;
};
} // namespace litert::lm
#endif // THIRD_PARTY_ODML_LITERT_LM_RUNTIME_FRAMEWORK_RESOURCE_MANAGEMENT_EXECUTION_MANAGER_H_
|