Spaces:
Running
Running
| // Copyright 2025 The ODML Authors. | |
| // | |
| // Licensed under the Apache License, Version 2.0 (the "License"); | |
| // you may not use this file except in compliance with the License. | |
| // You may obtain a copy of the License at | |
| // | |
| // http://www.apache.org/licenses/LICENSE-2.0 | |
| // | |
| // Unless required by applicable law or agreed to in writing, software | |
| // distributed under the License is distributed on an "AS IS" BASIS, | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| // See the License for the specific language governing permissions and | |
| // limitations under the License. | |
| namespace litert::lm { | |
| absl::Status PreprocessTokenIds(std::vector<int>& token_ids, int start_token_id, | |
| int max_num_tokens, | |
| float context_length_ratio_threhold) { | |
| if (token_ids.size() + 1 > max_num_tokens * context_length_ratio_threhold) { | |
| return absl::InvalidArgumentError(absl::StrFormat( | |
| "The input context length is too long. The input token length is %d " | |
| "and the max_num_tokens is %d.", | |
| token_ids.size() + 1, max_num_tokens)); | |
| } | |
| // Prepend the start token id to the token ids. | |
| token_ids.insert(token_ids.begin(), start_token_id); | |
| return absl::OkStatus(); | |
| } | |
| absl::StatusOr<bool> StopTokenFound(absl::Span<const int> decoded_token_ids, | |
| const std::vector<int>& stop_token_ids, | |
| std::vector<bool>& stop_token_found) { | |
| if (decoded_token_ids.size() != stop_token_found.size()) { | |
| return absl::InvalidArgumentError(absl::StrFormat( | |
| "The size of the decoded token ids is %d and the size of the stop " | |
| "token found vector is %d. They should be the same.", | |
| decoded_token_ids.size(), stop_token_found.size())); | |
| } | |
| for (int i = 0; i < decoded_token_ids.size(); ++i) { | |
| if (stop_token_found[i]) { | |
| continue; | |
| } | |
| for (const auto& stop_token_id : stop_token_ids) { | |
| if (decoded_token_ids[i] == stop_token_id) { | |
| stop_token_found[i] = true; | |
| break; | |
| } | |
| } | |
| } | |
| return std::all_of( | |
| stop_token_found.begin(), stop_token_found.end(), | |
| [](int stop_token_found) { return stop_token_found == true; }); | |
| } | |
| } // namespace litert::lm | |