LiteRT-LM / runtime /conversation /internal_callback_util.cc
SeaWolf-AI's picture
Upload full LiteRT-LM codebase
5f923cd verified
// Copyright 2025 The ODML Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may
// may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "runtime/conversation/internal_callback_util.h"
#include <algorithm>
#include <cstddef>
#include <string>
#include <utility>
#include <vector>
#include "absl/functional/any_invocable.h" // from @com_google_absl
#include "absl/status/status.h" // from @com_google_absl
#include "absl/status/statusor.h" // from @com_google_absl
#include "absl/strings/string_view.h" // from @com_google_absl
#include "nlohmann/json_fwd.hpp" // from @nlohmann_json
#include "runtime/conversation/channel_util.h"
#include "runtime/conversation/io_types.h"
#include "runtime/conversation/model_data_processor/config_registry.h"
#include "runtime/conversation/model_data_processor/model_data_processor.h"
#include "runtime/engine/io_types.h"
namespace litert::lm {
namespace {
// Returns the number of overlapping characters between the suffix of string
// `a` and the prefix of string `b`.
size_t SuffixPrefixOverlap(absl::string_view a, absl::string_view b) {
if (a.empty() || b.empty()) {
return 0;
}
size_t max_overlap = std::min(a.length(), b.length());
for (size_t len = max_overlap; len > 0; --len) {
if (a.substr(a.length() - len) == b.substr(0, len)) {
return len;
}
}
return 0;
};
// Sends standard text and tool calls to the user callback. The text and/or tool
// calls are first wrapped in a Message object via
// `model_data_processor.ToMessage` before being sent.
void SendMessage(
absl::AnyInvocable<void(absl::StatusOr<Message>)>& user_callback,
absl::string_view text, const ModelDataProcessor& model_data_processor,
DataProcessorArguments processor_args) {
if (text.empty()) {
return;
}
auto message = model_data_processor.ToMessage(
Responses(TaskState::kProcessing, {std::string(text)}), processor_args);
if (!message.ok()) {
user_callback(message.status());
return;
}
user_callback(std::move(message.value()));
}
// Sends streamed text associated with a specific channel. It wraps the text in
// a JsonMessage under the provided `target_channel_name` with a role of
// "assistant" and bypasses `model_data_processor.ToMessage` formatting.
void SendMessageToChannel(
absl::AnyInvocable<void(absl::StatusOr<Message>)>& user_callback,
absl::string_view text, absl::string_view channel_name) {
if (text.empty()) {
return;
}
JsonMessage json_msg;
json_msg["role"] = "assistant";
json_msg["channels"] = nlohmann::ordered_json::object();
json_msg["channels"][std::string(channel_name)] = std::string(text);
user_callback(Message(json_msg));
}
// Sends remaining un-flushed text at the end of generation and then invokes
// `complete_message_callback` on the final message.
void SendCompleteMessage(
absl::AnyInvocable<void(absl::StatusOr<Message>)>& user_callback,
absl::string_view accumulated_response_text,
const ModelDataProcessor& model_data_processor,
DataProcessorArguments processor_args, int cursor,
absl::AnyInvocable<void(Message)>& complete_message_callback,
const std::string& active_channel_name,
const std::vector<Channel>& channels) {
// Send remaining un-flushed text at the end of generation.
if (cursor < accumulated_response_text.size()) {
if (!active_channel_name.empty()) {
SendMessageToChannel(user_callback,
accumulated_response_text.substr(cursor),
active_channel_name);
} else {
SendMessage(user_callback, accumulated_response_text.substr(cursor),
model_data_processor, processor_args);
}
}
// Wrap the accumulated response text in a `Responses` object.
Responses responses(TaskState::kProcessing,
{std::string(accumulated_response_text)});
// Extract channel content from the responses. Modifies responses in place.
auto extracted_channels = ExtractChannelContent(channels, responses);
if (!extracted_channels.ok()) {
user_callback(extracted_channels.status());
return;
}
auto complete_message =
model_data_processor.ToMessage(responses, processor_args);
if (!complete_message.ok()) {
user_callback(complete_message.status());
return;
}
InsertChannelContentIntoMessage(*extracted_channels, *complete_message);
if (complete_message_callback) {
complete_message_callback(*complete_message);
}
user_callback(Message(JsonMessage()));
}
// Returns the complete list of channels the parser should search for, including
// any tool call code blocks (treated as a special channel with no target
// message field) and custom channels passed in by the user config.
std::vector<Channel> GetChannels(const ModelDataProcessor& model_data_processor,
const std::vector<Channel>& custom_channels) {
std::vector<Channel> channels;
// Add the tool call channel if the code fence start is not empty.
if (!model_data_processor.CodeFenceStart().empty()) {
channels.push_back({"", std::string(model_data_processor.CodeFenceStart()),
std::string(model_data_processor.CodeFenceEnd())});
}
// Add the custom channels.
for (const auto& channel : custom_channels) {
if (!channel.start.empty()) {
channels.push_back({channel.channel_name, channel.start, channel.end});
}
}
return channels;
}
// Searches the provided text string starting at `cursor` for the earliest
// matching channel start delimiter out of the possible channels provided.
//
// Returns a pointer to the matching channel and mutates `best_start_pos` to
// store its index in the accumulated response text.
//
// Returns nullptr if no channel start delimiter is found.
const Channel* FindNextChannelStart(
const std::vector<Channel>& possible_channels, absl::string_view text,
size_t cursor, size_t& best_start_pos) {
best_start_pos = std::string::npos;
const Channel* best_match = nullptr;
for (const auto& channel : possible_channels) {
size_t start_pos = text.find(channel.start, cursor);
if (start_pos != std::string::npos) {
if (best_start_pos == std::string::npos || start_pos < best_start_pos) {
best_start_pos = start_pos;
best_match = &channel;
}
}
}
return best_match;
}
// Checks if the end of the un-parsed string might potentially be the first part
// of any channel start delimiter. Returns the maximum character substring
// overlap length between the end of the response string and the start of any
// active channels.
size_t FindMaxOverlap(const std::vector<Channel>& channels,
absl::string_view text) {
size_t max_overlap = 0;
for (const auto& channel : channels) {
size_t overlap = SuffixPrefixOverlap(text, channel.start);
if (overlap > max_overlap) {
max_overlap = overlap;
}
}
return max_overlap;
}
// Streams out channel tokens. Only streams text that it safely validates could
// not possibly be a partial overlap of the active channel end delimiter.
void StreamActiveChannel(
absl::AnyInvocable<void(absl::StatusOr<Message>)>& user_callback,
absl::string_view accumulated_response_text, size_t search_start,
size_t& cursor, absl::string_view active_channel_end,
const std::string& active_channel_name) {
// Stream channel content except for potential partial matches of
// the end delimiter.
size_t overlap = SuffixPrefixOverlap(
accumulated_response_text.substr(search_start), active_channel_end);
size_t safe_end = accumulated_response_text.size() - overlap;
if (safe_end > cursor) {
SendMessageToChannel(
user_callback,
accumulated_response_text.substr(cursor, safe_end - cursor),
active_channel_name);
cursor = safe_end;
}
}
} // namespace
// Creates an internal callback that parses the model's raw text responses.
//
// This parser supports "channels" defined by start and end delimiters.
// 1. Tool Calls: These act as a special channel (defined by
// CodeFenceStart/End).
// Because their content needs to be parsed as a whole object (e.g. JSON),
// tool calls do not stream. The parser buffers the text and waits for the
// end delimiter before using `model_data_processor.ToMessage` to emit them.
// This is represented internally by an empty `active_message_field`.
// `custom_channels`. If a custom channel specifies a
// `channel_name`, its text is streamed out immediately as new JSON messages
// with that specific field as it arrives instead of being buffered.
absl::AnyInvocable<void(absl::StatusOr<Responses>)> CreateInternalCallback(
const ModelDataProcessor& model_data_processor,
const DataProcessorArguments processor_args,
const std::vector<Channel>& custom_channels,
absl::AnyInvocable<void(absl::StatusOr<Message>)> user_callback,
absl::AnyInvocable<void()> cancel_callback,
absl::AnyInvocable<void(Message)> complete_message_callback) {
return [&model_data_processor, processor_args, custom_channels,
user_callback = std::move(user_callback),
cancel_callback = std::move(cancel_callback),
complete_message_callback = std::move(complete_message_callback),
accumulated_response_text = std::string(), cursor = size_t(0),
channels = GetChannels(model_data_processor, custom_channels),
inside_channel = false, active_channel_end = std::string(),
active_channel_start_pos = size_t(0),
active_channel_start_size = size_t(0),
active_channel_name =
std::string()](absl::StatusOr<Responses> responses) mutable {
if (!responses.ok()) {
// If the error is due to cancellation, then we should trigger the cancel
// callback for removing the last message from the history.
if (cancel_callback && absl::IsCancelled(responses.status())) {
cancel_callback();
}
user_callback(responses.status());
return;
}
// If there are no more new responses, it means the model has finished
// generating content, trigger the complete message callback and return an
// OK status to indicate the inference is done.
if (responses->GetTaskState() == TaskState::kDone ||
responses->GetTaskState() == TaskState::kMaxNumTokensReached) {
SendCompleteMessage(user_callback, accumulated_response_text,
model_data_processor, processor_args, cursor,
complete_message_callback,
inside_channel ? active_channel_name : "", channels);
cursor = accumulated_response_text.size();
return;
}
// Else, add the new response text to the accumulated text and process the
// response text.(Which sends to the user callback accordingly.)
if (responses->GetTaskState() == TaskState::kProcessing) {
// If there are no new responses, it is just a state update and we can
// return early.
if (responses->GetTexts().empty()) {
return;
}
// Append the new response text to the accumulated text.
accumulated_response_text += responses->GetTexts()[0];
// Loop through the accumulated response text and send to the user
// callback accordingly.
while (cursor < accumulated_response_text.size()) {
if (!inside_channel) {
size_t channel_start_pos;
const Channel* next_channel = FindNextChannelStart(
channels, accumulated_response_text, cursor, channel_start_pos);
if (next_channel != nullptr) {
// A channel start delimiter was found.
// The text from the cursor up to the channel start is normal text
// and can be sent to the user callback.
SendMessage(user_callback,
absl::string_view(accumulated_response_text)
.substr(cursor, channel_start_pos - cursor),
model_data_processor, processor_args);
// Move cursor up to channel start.
cursor = channel_start_pos;
inside_channel = true;
active_channel_end = next_channel->end;
active_channel_start_pos = channel_start_pos;
active_channel_start_size = next_channel->start.size();
active_channel_name = next_channel->channel_name;
// For custom channels, move the cursor past the start delimiter so
// that it is not included in the resulting streamed content.
// For tool calls (empty message field), we leave the cursor alone
// to buffer the complete block including delimiters.
if (!active_channel_name.empty()) {
cursor += active_channel_start_size;
}
} else {
// A channel start delimiter was not found. We still need to check
// if there's a partial match of any channel start at the very end
// of the string.
size_t max_overlap = FindMaxOverlap(
channels,
absl::string_view(accumulated_response_text).substr(cursor));
if (max_overlap > 0) {
// There's a partial match of a channel at the end of the
// string.
size_t possible_start_pos =
accumulated_response_text.size() - max_overlap;
// Call the callback with text up to the potential start of the
// channel.
SendMessage(user_callback,
accumulated_response_text.substr(
cursor, possible_start_pos - cursor),
model_data_processor, processor_args);
// Move cursor up to potential start of channel.
cursor = possible_start_pos;
// Break for the next token.
break;
} else {
// Remaining string is text.
SendMessage(user_callback,
accumulated_response_text.substr(cursor),
model_data_processor, processor_args);
cursor = accumulated_response_text.size();
}
}
}
if (inside_channel) {
// Look for channel end.
size_t search_start =
std::max(static_cast<size_t>(cursor),
active_channel_start_pos + active_channel_start_size);
size_t end_pos =
accumulated_response_text.find(active_channel_end, search_start);
if (end_pos != std::string::npos) {
// A channel end delimiter was found.
if (!active_channel_name.empty()) {
// Flush the active stream channel.
SendMessageToChannel(user_callback,
absl::string_view(accumulated_response_text)
.substr(cursor, end_pos - cursor),
active_channel_name);
} else {
// Treat as tool call: include everything up to and including the
// end delimiter.
SendMessage(
user_callback,
accumulated_response_text.substr(
cursor, end_pos + active_channel_end.size() - cursor),
model_data_processor, processor_args);
}
// Move cursor past the end of the channel block.
cursor = end_pos + active_channel_end.size();
inside_channel = false;
} else {
// We're inside a channel or tool call but the end has not been
// found.
if (!active_channel_name.empty()) {
// If we're inside a channel, stream the text, but stop before
// any potential partial match of the channel's end delimiter.
StreamActiveChannel(user_callback, accumulated_response_text,
search_start, cursor, active_channel_end,
active_channel_name);
}
// Break for the next token.
break;
}
}
}
}
};
}
} // namespace litert::lm