// Copyright 2025 The ODML Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/framework/resource_management/utils/resource_manager_utils.h" #include #include #include #include #include "absl/status/status.h" // from @com_google_absl #include "runtime/util/status_macros.h" // NOLINT namespace litert::lm { namespace { // Returns the comparable length of the vector, excluding the negative tokens, // which indicates the multi modal token and just a placeholder. From token id // level, there is no way to tell if the multi modal token is matched or not, so // we assume the multi modal token are not matched. If all tokens are not // negative, return vec.size(). size_t ComparableLengthExcludingNegativeTokens(const std::vector& vec) { for (size_t i = 0; i < vec.size(); ++i) { if (vec[i] < 0) { return i; } } return vec.size(); } } // namespace absl::Status RemoveMatchingTokens(const std::vector &processed_tokens, std::vector *input_ids, int *time_step) { RET_CHECK_NE(input_ids, nullptr) << "input_ids is null."; RET_CHECK_NE(time_step, nullptr) << "time_step is null."; RET_CHECK_GE(*time_step, 0) << "Time step is negative."; RET_CHECK_GE(processed_tokens.size(), *time_step) << "The processed tokens size is smaller than the time step."; // Determine the number of tokens available in processed_tokens from the // current time_step. *time_step is int, .size() is size_t. RET_CHECKs ensure // *time_step is valid. const size_t processed_tokens_comparable_len = processed_tokens.size() - static_cast(*time_step); // Determine how many elements to actually compare (the minimum of the two // effective sequence lengths). const size_t comparison_len = std::min(processed_tokens_comparable_len, ComparableLengthExcludingNegativeTokens(*input_ids)); // Find the first mismatch. // The first range for mismatch is [input_ids->begin(), input_ids->begin() + // comparison_len). The second range for mismatch is [processed_tokens.begin() // + *time_step, processed_tokens.begin() + *time_step + comparison_len). auto mismatch_it = std::mismatch(input_ids->begin(), input_ids->begin() + comparison_len, processed_tokens.begin() + *time_step); // Calculate the number of matching tokens. // mismatch_it.first is an iterator to the first differing element in // input_ids, or input_ids->begin() + comparison_len if all compared elements // match. int matching_tokens = std::distance(input_ids->begin(), mismatch_it.first); // Update the input_ids and time_step. input_ids->erase(input_ids->begin(), mismatch_it.first); *time_step += matching_tokens; return absl::OkStatus(); }; } // namespace litert::lm