// Copyright 2025 The ODML Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/components/stop_token_detector.h" #include #include #include #include #include "absl/log/absl_check.h" // from @com_google_absl #include "absl/log/absl_log.h" // from @com_google_absl #include "absl/status/status.h" // from @com_google_absl #include "absl/status/statusor.h" // from @com_google_absl #include "absl/strings/str_cat.h" // from @com_google_absl #include "absl/strings/str_format.h" // from @com_google_absl #include "absl/types/span.h" // from @com_google_absl #include "runtime/util/status_macros.h" namespace litert::lm { namespace { // Prints a sequence of integers. inline std::string PrintSequence(const std::vector& sequence) { std::string existing_sequence_str = "{"; for (size_t i = 0; i < sequence.size(); ++i) { absl::StrAppend(&existing_sequence_str, sequence[i]); if (i < sequence.size() - 1) { absl::StrAppend(&existing_sequence_str, ", "); } } absl::StrAppend(&existing_sequence_str, "}"); return existing_sequence_str; } } // namespace StopTokenDetector::StopTokenDetector(size_t batch_size) { ABSL_CHECK_GT(batch_size, 0) << "Batch size must be greater than 0."; ResetBatch(batch_size); } absl::Status StopTokenDetector::AddStopTokenSequence( const std::vector& stop_sequence) { if (stop_sequence.empty()) { return absl::InvalidArgumentError( "Cannot add an empty stop token sequence."); } // Check if the sequence already exists if (std::find(stop_sequences_storage_.begin(), stop_sequences_storage_.end(), stop_sequence) != stop_sequences_storage_.end()) { ABSL_LOG(INFO) << absl::StrFormat( "Stop token sequence %s already exists. Skipping " "adding the stop token sequence.", PrintSequence(stop_sequence)); return absl::OkStatus(); } stop_sequences_storage_.push_back(stop_sequence); // Add a progress tracker for the new stop sequence for each batch item. for (auto& progress_vector_for_item : batch_item_match_progress_) { progress_vector_for_item.push_back(0); } return absl::OkStatus(); } void StopTokenDetector::ResetBatch(size_t batch_size) { int new_batch_size = batch_size == 0 ? stop_token_found_.size() : batch_size; stop_token_found_.assign(new_batch_size, false); max_batch_item_match_progress_.assign(new_batch_size, 0); // Initialize progress for each batch item for all currently defined stop // sequences. batch_item_match_progress_.assign( new_batch_size, std::vector(stop_sequences_storage_.size(), 0)); matched_stop_sequence_length_.assign(new_batch_size, 0); } // Processes the latest incoming token for each sequence in the batch. absl::Status StopTokenDetector::ProcessTokens( absl::Span latest_tokens) { if (latest_tokens.size() != stop_token_found_.size()) { return absl::InvalidArgumentError(absl::StrFormat( "Size of latest_tokens (%d) does not match configured batch size (%d).", latest_tokens.size(), stop_token_found_.size())); } if (stop_sequences_storage_.empty()) { // No stop sequences to check against. return absl::InvalidArgumentError( "No stop sequences to check against. Did you forget to call " "AddStopTokenSequence()?"); } for (size_t i = 0; i < latest_tokens.size(); ++i) { if (stop_token_found_[i]) { // Already stopped, but increase the length of the matched stop sequence. matched_stop_sequence_length_[i]++; continue; } max_batch_item_match_progress_[i] = 0; int current_token_id = latest_tokens[i]; for (size_t k = 0; k < stop_sequences_storage_.size(); ++k) { const auto& stop_seq_k = stop_sequences_storage_[k]; // Guaranteed non-empty int& current_match_len_for_k = batch_item_match_progress_[i][k]; if (current_match_len_for_k < stop_seq_k.size() && stop_seq_k[current_match_len_for_k] == current_token_id) { current_match_len_for_k++; } else { // Mismatch or sequence completed; reset progress for this stop_seq_k. // Check if current token starts stop_seq_k anew. if (stop_seq_k[0] == current_token_id) { current_match_len_for_k = 1; } else { current_match_len_for_k = 0; } } if (current_match_len_for_k > 0 && current_match_len_for_k == stop_seq_k.size()) { stop_token_found_[i] = true; matched_stop_sequence_length_[i] = stop_seq_k.size(); } max_batch_item_match_progress_[i] = std::max(max_batch_item_match_progress_[i], current_match_len_for_k); } } return absl::OkStatus(); } absl::Status StopTokenDetector::ProcessTokens( const std::vector>& latest_tokens) { if (latest_tokens.size() != stop_token_found_.size()) { return absl::InvalidArgumentError(absl::StrFormat( "Size of latest_tokens (%d) does not match configured batch size (%d).", latest_tokens.size(), stop_token_found_.size())); } std::vector flattened_tokens; flattened_tokens.reserve(latest_tokens.size()); for (auto& tokens : latest_tokens) { RET_CHECK_EQ(tokens.size(), 1) << "The current implementation of ProcessTokens() requires that " "latest_tokens must contain only single tokens."; flattened_tokens.push_back(tokens[0]); } return ProcessTokens(flattened_tokens); } int StopTokenDetector::MaxPartialStopTokenLength(int index) const { return max_batch_item_match_progress_[index]; } const std::vector& StopTokenDetector::GetStepsBeforeStopTokens() const { return matched_stop_sequence_length_; } absl::StatusOr StopTokenDetector::AllDone() const { if (stop_token_found_.empty()) { return absl::FailedPreconditionError( "The Detector is not initialized with non-zero batch size. Did you " "forget to call ResetBatch() or AddStopTokenSequence() ??"); } return std::all_of(stop_token_found_.begin(), stop_token_found_.end(), [](bool found) { return found; }); } } // namespace litert::lm