// Copyright 2025 The ODML Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/components/sentencepiece_tokenizer.h" #include #include #include #include #include "absl/memory/memory.h" // from @com_google_absl #include "absl/status/status.h" // from @com_google_absl #include "absl/status/statusor.h" // from @com_google_absl #include "absl/strings/str_cat.h" // from @com_google_absl #include "absl/strings/string_view.h" // from @com_google_absl #include "runtime/components/tokenizer.h" #include "sentencepiece_model.pb.h" // from @sentencepiece #include "sentencepiece_processor.h" // from @sentencepiece namespace litert::lm { absl::StatusOr> SentencePieceTokenizer::CreateFromFile(absl::string_view model_path) { auto processor = std::make_unique(); auto status = processor->Load(model_path); if (!status.ok()) { return status; } return absl::WrapUnique(new SentencePieceTokenizer(std::move(processor))); } absl::StatusOr> SentencePieceTokenizer::CreateFromBuffer(absl::string_view model_buffer) { auto processor = std::make_unique(); auto status = processor->LoadFromSerializedProto(model_buffer); if (!status.ok()) { return status; } return absl::WrapUnique(new SentencePieceTokenizer(std::move(processor))); } absl::StatusOr> SentencePieceTokenizer::CreateFromProto( std::unique_ptr model_proto) { auto processor = std::make_unique(); auto status = processor->Load(std::move(model_proto)); if (!status.ok()) { return status; } return absl::WrapUnique(new SentencePieceTokenizer(std::move(processor))); } // Encodes the given text into a TensorBuffer of token ids. absl::StatusOr> SentencePieceTokenizer::TextToTokenIds( absl::string_view text) { std::vector ids; auto status = processor_->Encode(text, &ids); if (!status.ok()) { return status; } return ids; } absl::StatusOr SentencePieceTokenizer::TokenToId(absl::string_view token) { int id = processor_->PieceToId(token); if (id == processor_->unk_id()) { return absl::NotFoundError(absl::StrCat("Unknown token: ", token)); } return id; } // Decodes the given TensorBuffer of token ids into a string. absl::StatusOr SentencePieceTokenizer::TokenIdsToText( const std::vector& token_ids) { std::string text = ""; std::vector chunk_byte_token_ids; for (const auto& token_id : token_ids) { if (token_id >= vocab_size_ || token_id < 0) { return absl::NotFoundError( absl::StrCat("Token id ", token_id, " is out of range. Vocab size is ", vocab_size_)); } if (processor_->IsByte(token_id)) { std::string decoded = processor_->DecodeIds({token_id}); if (Tokenizer::HasBpeSuffix(decoded)) { // If the token is a partial BPE token, we need to wait for more tokens // to be decoded before we can decode it. chunk_byte_token_ids.push_back(token_id); } else { // If the token is a single byte or invalid/continuation byte and not // bundled with other tokens, decode it immediately. absl::StrAppend(&text, decoded); } } else { // If the token is not a byte token, decode the chunk of byte tokens and // clear buffer. if (!chunk_byte_token_ids.empty()) { absl::StrAppend(&text, processor_->DecodeIds(chunk_byte_token_ids)); chunk_byte_token_ids.clear(); } // We are forced to use IdToPiece to account for leading whitespace. // Otherwise, the normalizer (depending on the configuration) would // remove that which makes streaming decoding impossible. // e.g., [[change], [_volume]] -> "change volume" vs. // [[change], [volume]] -> "changevolume" absl::StrAppend(&text, processor_->IdToPiece(token_id)); } } if (!chunk_byte_token_ids.empty()) { std::string decoded = processor_->DecodeIds(chunk_byte_token_ids); if (Tokenizer::HasBpeSuffix(decoded)) { return absl::DataLossError( "The set of token IDs passed to the tokenizer is part of a BPE " "sequence and needs more tokens to be decoded."); } else { absl::StrAppend(&text, decoded); } } return text; } std::vector SentencePieceTokenizer::GetTokens() const { std::vector tokens; for (const auto& piece : processor_->model_proto().pieces()) { tokens.push_back(piece.piece()); } return tokens; } } // namespace litert::lm