// Copyright 2025 The ODML Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_ODML_LITERT_LM_RUNTIME_COMPONENTS_TOKENIZER_H_ #define THIRD_PARTY_ODML_LITERT_LM_RUNTIME_COMPONENTS_TOKENIZER_H_ #include #include #include "absl/status/status.h" // from @com_google_absl #include "absl/status/statusor.h" // from @com_google_absl #include "absl/strings/match.h" // from @com_google_absl #include "absl/strings/string_view.h" // from @com_google_absl #include "absl/types/span.h" // from @com_google_absl #include "litert/cc/litert_macros.h" // from @litert #include "litert/cc/litert_tensor_buffer.h" // from @litert #include "runtime/util/convert_tensor_buffer.h" namespace litert::lm { typedef std::vector TokenIds; // Enum representing the type of tokenizer. enum class TokenizerType { kUnspecified, kSentencePiece, kHuggingFace, }; class Tokenizer { public: virtual ~Tokenizer() = default; // Returns the type of the tokenizer. virtual TokenizerType GetTokenizerType() const = 0; // Encodes the given input text to token ids. Includes tokenizer pre/post // processing. virtual absl::StatusOr TextToTokenIds(absl::string_view text) = 0; // Converts a token string to its token id. This is a raw token look up, // without any tokenizer pre/post processing. The implementation is expected // to return absl::NotFoundError if the token is not found. virtual absl::StatusOr TokenToId(absl::string_view token) = 0; // Helper function to convert a vector of token ids into a 1D // litert::TensorBuffer of shape [batch_size(==1), num_tokens]. static absl::StatusOr TokenIdsToTensorBuffer( const TokenIds& token_ids) { LITERT_ASSIGN_OR_RETURN( auto tensor, CopyToTensorBuffer(absl::MakeConstSpan(token_ids), {1, static_cast(token_ids.size())})); return tensor; } // Decodes the given sequence of token ids into a string. // Returns absl::DataLossError if any of the tokens are part of an incomplete // BPE sequence. virtual absl::StatusOr TokenIdsToText( const TokenIds& token_ids) = 0; // Returns the list of tokens in the tokenizer. virtual std::vector GetTokens() const = 0; // Converts a tensor buffer of token ids into a vector of token ids. The input // is a 2D litert::TensorBuffer shape [batch_size, decode_steps]. static absl::StatusOr> TensorBufferToTokenIds( const TensorBuffer& token_ids_tensor) { LITERT_ASSIGN_OR_RETURN(auto tensor_type, token_ids_tensor.TensorType()); auto dims = tensor_type.Layout().Dimensions(); if (dims.size() != 2) { return absl::InvalidArgumentError( "The input tensor must have 2 dimensions."); } auto token_ids = CopyFromTensorBuffer2D(token_ids_tensor); return token_ids.Value(); } // Merges the previous and next token ids, by appending each next token // id to the corresponding previous token id row by row. static absl::StatusOr> MergeTokenIds( const std::vector& previous_token_ids, const std::vector& next_token_ids) { std::vector merged_token_ids(next_token_ids.size()); if (previous_token_ids.size() != next_token_ids.size()) { return absl::InvalidArgumentError( "The previous and next token ids must have the same size."); } for (int i = 0; i < previous_token_ids.size(); ++i) { merged_token_ids[i] = previous_token_ids[i]; merged_token_ids[i].insert(merged_token_ids[i].end(), next_token_ids[i].begin(), next_token_ids[i].end()); } return merged_token_ids; } // Decodes the given sequence of token ids into a string. The input is a 2D // vector of token ids, each of which is a sequence of token ids. The output // Tokenizer is a vector of strings, each of which is a decoded string of the // corresponding batch or absl::DataLossError if an incomplete BPE sequence. absl::StatusOr>> TokenIdsToTexts( int batch_size, const std::vector& token_ids) { if (token_ids.size() != batch_size) { return absl::InvalidArgumentError( "The token ID vector must have the same number of rows as the batch " "size."); } std::vector> decoded_strings(batch_size); for (int i = 0; i < batch_size; ++i) { decoded_strings[i] = this->TokenIdsToText(token_ids[i]); } return decoded_strings; } template static bool IsIncompleteBpeSequence(const absl::StatusOr& result) { return result.status().code() == absl::StatusCode::kDataLoss; } // Checks if the decoded string ends with the replacement character (U+FFFD), // which indicates that the set of token IDs passed to the tokenizer is part // of a BPE sequence and needs more tokens to be decoded. static bool HasBpeSuffix(absl::string_view decoded) { static const char kReplacementCharacter[] = "\xef\xbf\xbd"; return absl::EndsWith(decoded, kReplacementCharacter); } }; } // namespace litert::lm #endif // THIRD_PARTY_ODML_LITERT_LM_RUNTIME_COMPONENTS_TOKENIZER_H_