Spaces:
Running
Running
| // Copyright 2025 The ODML Authors. | |
| // | |
| // Licensed under the Apache License, Version 2.0 (the "License"); | |
| // you may not use this file except in compliance with the License. | |
| // You may obtain a copy of the License at | |
| // | |
| // http://www.apache.org/licenses/LICENSE-2.0 | |
| // | |
| // Unless required by applicable law or agreed to in writing, software | |
| // distributed under the License is distributed on an "AS IS" BASIS, | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| // See the License for the specific language governing permissions and | |
| // limitations under the License. | |
| namespace litert::lm { | |
| namespace { | |
| // Returns the comparable length of the vector, excluding the negative tokens, | |
| // which indicates the multi modal token and just a placeholder. From token id | |
| // level, there is no way to tell if the multi modal token is matched or not, so | |
| // we assume the multi modal token are not matched. If all tokens are not | |
| // negative, return vec.size(). | |
| size_t ComparableLengthExcludingNegativeTokens(const std::vector<int>& vec) { | |
| for (size_t i = 0; i < vec.size(); ++i) { | |
| if (vec[i] < 0) { | |
| return i; | |
| } | |
| } | |
| return vec.size(); | |
| } | |
| } // namespace | |
| absl::Status RemoveMatchingTokens(const std::vector<int> &processed_tokens, | |
| std::vector<int> *input_ids, int *time_step) { | |
| RET_CHECK_NE(input_ids, nullptr) << "input_ids is null."; | |
| RET_CHECK_NE(time_step, nullptr) << "time_step is null."; | |
| RET_CHECK_GE(*time_step, 0) << "Time step is negative."; | |
| RET_CHECK_GE(processed_tokens.size(), *time_step) | |
| << "The processed tokens size is smaller than the time step."; | |
| // Determine the number of tokens available in processed_tokens from the | |
| // current time_step. *time_step is int, .size() is size_t. RET_CHECKs ensure | |
| // *time_step is valid. | |
| const size_t processed_tokens_comparable_len = | |
| processed_tokens.size() - static_cast<size_t>(*time_step); | |
| // Determine how many elements to actually compare (the minimum of the two | |
| // effective sequence lengths). | |
| const size_t comparison_len = | |
| std::min(processed_tokens_comparable_len, | |
| ComparableLengthExcludingNegativeTokens(*input_ids)); | |
| // Find the first mismatch. | |
| // The first range for mismatch is [input_ids->begin(), input_ids->begin() + | |
| // comparison_len). The second range for mismatch is [processed_tokens.begin() | |
| // + *time_step, processed_tokens.begin() + *time_step + comparison_len). | |
| auto mismatch_it = | |
| std::mismatch(input_ids->begin(), input_ids->begin() + comparison_len, | |
| processed_tokens.begin() + *time_step); | |
| // Calculate the number of matching tokens. | |
| // mismatch_it.first is an iterator to the first differing element in | |
| // input_ids, or input_ids->begin() + comparison_len if all compared elements | |
| // match. | |
| int matching_tokens = std::distance(input_ids->begin(), mismatch_it.first); | |
| // Update the input_ids and time_step. | |
| input_ids->erase(input_ids->begin(), mismatch_it.first); | |
| *time_step += matching_tokens; | |
| return absl::OkStatus(); | |
| }; | |
| } // namespace litert::lm | |