Spaces:
Running
Running
| // Copyright 2025 The ODML Authors. | |
| // | |
| // Licensed under the Apache License, Version 2.0 (the "License"); | |
| // you may not use this file except in compliance with the License. | |
| // You may obtain a copy of the License at | |
| // | |
| // http://www.apache.org/licenses/LICENSE-2.0 | |
| // | |
| // Unless required by applicable law or agreed to in writing, software | |
| // distributed under the License is distributed on an "AS IS" BASIS, | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| // See the License for the specific language governing permissions and | |
| // limitations under the License. | |
| namespace litert::lm { | |
| typedef std::vector<int> TokenIds; | |
| // Enum representing the type of tokenizer. | |
| enum class TokenizerType { | |
| kUnspecified, | |
| kSentencePiece, | |
| kHuggingFace, | |
| }; | |
| class Tokenizer { | |
| public: | |
| virtual ~Tokenizer() = default; | |
| // Returns the type of the tokenizer. | |
| virtual TokenizerType GetTokenizerType() const = 0; | |
| // Encodes the given input text to token ids. Includes tokenizer pre/post | |
| // processing. | |
| virtual absl::StatusOr<TokenIds> TextToTokenIds(absl::string_view text) = 0; | |
| // Converts a token string to its token id. This is a raw token look up, | |
| // without any tokenizer pre/post processing. The implementation is expected | |
| // to return absl::NotFoundError if the token is not found. | |
| virtual absl::StatusOr<int> TokenToId(absl::string_view token) = 0; | |
| // Helper function to convert a vector of token ids into a 1D | |
| // litert::TensorBuffer of shape [batch_size(==1), num_tokens]. | |
| static absl::StatusOr<TensorBuffer> TokenIdsToTensorBuffer( | |
| const TokenIds& token_ids) { | |
| LITERT_ASSIGN_OR_RETURN( | |
| auto tensor, | |
| CopyToTensorBuffer<int>(absl::MakeConstSpan(token_ids), | |
| {1, static_cast<int>(token_ids.size())})); | |
| return tensor; | |
| } | |
| // Decodes the given sequence of token ids into a string. | |
| // Returns absl::DataLossError if any of the tokens are part of an incomplete | |
| // BPE sequence. | |
| virtual absl::StatusOr<std::string> TokenIdsToText( | |
| const TokenIds& token_ids) = 0; | |
| // Returns the list of tokens in the tokenizer. | |
| virtual std::vector<std::string> GetTokens() const = 0; | |
| // Converts a tensor buffer of token ids into a vector of token ids. The input | |
| // is a 2D litert::TensorBuffer shape [batch_size, decode_steps]. | |
| static absl::StatusOr<std::vector<TokenIds>> TensorBufferToTokenIds( | |
| const TensorBuffer& token_ids_tensor) { | |
| LITERT_ASSIGN_OR_RETURN(auto tensor_type, token_ids_tensor.TensorType()); | |
| auto dims = tensor_type.Layout().Dimensions(); | |
| if (dims.size() != 2) { | |
| return absl::InvalidArgumentError( | |
| "The input tensor must have 2 dimensions."); | |
| } | |
| auto token_ids = CopyFromTensorBuffer2D<int>(token_ids_tensor); | |
| return token_ids.Value(); | |
| } | |
| // Merges the previous and next token ids, by appending each next token | |
| // id to the corresponding previous token id row by row. | |
| static absl::StatusOr<std::vector<TokenIds>> MergeTokenIds( | |
| const std::vector<TokenIds>& previous_token_ids, | |
| const std::vector<TokenIds>& next_token_ids) { | |
| std::vector<TokenIds> merged_token_ids(next_token_ids.size()); | |
| if (previous_token_ids.size() != next_token_ids.size()) { | |
| return absl::InvalidArgumentError( | |
| "The previous and next token ids must have the same size."); | |
| } | |
| for (int i = 0; i < previous_token_ids.size(); ++i) { | |
| merged_token_ids[i] = previous_token_ids[i]; | |
| merged_token_ids[i].insert(merged_token_ids[i].end(), | |
| next_token_ids[i].begin(), | |
| next_token_ids[i].end()); | |
| } | |
| return merged_token_ids; | |
| } | |
| // Decodes the given sequence of token ids into a string. The input is a 2D | |
| // vector of token ids, each of which is a sequence of token ids. The output | |
| // Tokenizer is a vector of strings, each of which is a decoded string of the | |
| // corresponding batch or absl::DataLossError if an incomplete BPE sequence. | |
| absl::StatusOr<std::vector<absl::StatusOr<std::string>>> TokenIdsToTexts( | |
| int batch_size, const std::vector<TokenIds>& token_ids) { | |
| if (token_ids.size() != batch_size) { | |
| return absl::InvalidArgumentError( | |
| "The token ID vector must have the same number of rows as the batch " | |
| "size."); | |
| } | |
| std::vector<absl::StatusOr<std::string>> decoded_strings(batch_size); | |
| for (int i = 0; i < batch_size; ++i) { | |
| decoded_strings[i] = this->TokenIdsToText(token_ids[i]); | |
| } | |
| return decoded_strings; | |
| } | |
| template <typename T> | |
| static bool IsIncompleteBpeSequence(const absl::StatusOr<T>& result) { | |
| return result.status().code() == absl::StatusCode::kDataLoss; | |
| } | |
| // Checks if the decoded string ends with the replacement character (U+FFFD), | |
| // which indicates that the set of token IDs passed to the tokenizer is part | |
| // of a BPE sequence and needs more tokens to be decoded. | |
| static bool HasBpeSuffix(absl::string_view decoded) { | |
| static const char kReplacementCharacter[] = "\xef\xbf\xbd"; | |
| return absl::EndsWith(decoded, kReplacementCharacter); | |
| } | |
| }; | |
| } // namespace litert::lm | |