File size: 5,320 Bytes
5f923cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
// Copyright 2025 The ODML Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "runtime/components/sentencepiece_tokenizer.h"

#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "absl/memory/memory.h"  // from @com_google_absl
#include "absl/status/status.h"  // from @com_google_absl
#include "absl/status/statusor.h"  // from @com_google_absl
#include "absl/strings/str_cat.h"  // from @com_google_absl
#include "absl/strings/string_view.h"  // from @com_google_absl
#include "runtime/components/tokenizer.h"
#include "sentencepiece_model.pb.h"  // from @sentencepiece
#include "sentencepiece_processor.h"  // from @sentencepiece

namespace litert::lm {

absl::StatusOr<std::unique_ptr<SentencePieceTokenizer>>
SentencePieceTokenizer::CreateFromFile(absl::string_view model_path) {
  auto processor = std::make_unique<sentencepiece::SentencePieceProcessor>();
  auto status = processor->Load(model_path);
  if (!status.ok()) {
    return status;
  }
  return absl::WrapUnique(new SentencePieceTokenizer(std::move(processor)));
}

absl::StatusOr<std::unique_ptr<SentencePieceTokenizer>>
SentencePieceTokenizer::CreateFromBuffer(absl::string_view model_buffer) {
  auto processor = std::make_unique<sentencepiece::SentencePieceProcessor>();
  auto status = processor->LoadFromSerializedProto(model_buffer);
  if (!status.ok()) {
    return status;
  }
  return absl::WrapUnique(new SentencePieceTokenizer(std::move(processor)));
}

absl::StatusOr<std::unique_ptr<SentencePieceTokenizer>>
SentencePieceTokenizer::CreateFromProto(
    std::unique_ptr<sentencepiece::ModelProto> model_proto) {
  auto processor = std::make_unique<sentencepiece::SentencePieceProcessor>();
  auto status = processor->Load(std::move(model_proto));
  if (!status.ok()) {
    return status;
  }
  return absl::WrapUnique(new SentencePieceTokenizer(std::move(processor)));
}

// Encodes the given text into a TensorBuffer of token ids.
absl::StatusOr<std::vector<int>> SentencePieceTokenizer::TextToTokenIds(
    absl::string_view text) {
  std::vector<int> ids;
  auto status = processor_->Encode(text, &ids);
  if (!status.ok()) {
    return status;
  }
  return ids;
}

absl::StatusOr<int> SentencePieceTokenizer::TokenToId(absl::string_view token) {
  int id = processor_->PieceToId(token);
  if (id == processor_->unk_id()) {
    return absl::NotFoundError(absl::StrCat("Unknown token: ", token));
  }
  return id;
}

// Decodes the given TensorBuffer of token ids into a string.
absl::StatusOr<std::string> SentencePieceTokenizer::TokenIdsToText(
    const std::vector<int>& token_ids) {
  std::string text = "";
  std::vector<int> chunk_byte_token_ids;
  for (const auto& token_id : token_ids) {
    if (token_id >= vocab_size_ || token_id < 0) {
      return absl::NotFoundError(
          absl::StrCat("Token id ", token_id,
                       " is out of range. Vocab size is ", vocab_size_));
    }
    if (processor_->IsByte(token_id)) {
      std::string decoded = processor_->DecodeIds({token_id});
      if (Tokenizer::HasBpeSuffix(decoded)) {
        // If the token is a partial BPE token, we need to wait for more tokens
        // to be decoded before we can decode it.
        chunk_byte_token_ids.push_back(token_id);
      } else {
        // If the token is a single byte or invalid/continuation byte and not
        // bundled with other tokens, decode it immediately.
        absl::StrAppend(&text, decoded);
      }
    } else {
      // If the token is not a byte token, decode the chunk of byte tokens and
      // clear buffer.
      if (!chunk_byte_token_ids.empty()) {
        absl::StrAppend(&text, processor_->DecodeIds(chunk_byte_token_ids));
        chunk_byte_token_ids.clear();
      }
      // We are forced to use IdToPiece to account for leading whitespace.
      // Otherwise, the normalizer (depending on the configuration) would
      // remove that which makes streaming decoding impossible.
      // e.g., [[change], [_volume]] -> "change volume" vs.
      //       [[change], [volume]] -> "changevolume"
      absl::StrAppend(&text, processor_->IdToPiece(token_id));
    }
  }
  if (!chunk_byte_token_ids.empty()) {
    std::string decoded = processor_->DecodeIds(chunk_byte_token_ids);
    if (Tokenizer::HasBpeSuffix(decoded)) {
      return absl::DataLossError(
          "The set of token IDs passed to the tokenizer is part of a BPE "
          "sequence and needs more tokens to be decoded.");
    } else {
      absl::StrAppend(&text, decoded);
    }
  }
  return text;
}

std::vector<std::string> SentencePieceTokenizer::GetTokens() const {
  std::vector<std::string> tokens;
  for (const auto& piece : processor_->model_proto().pieces()) {
    tokens.push_back(piece.piece());
  }
  return tokens;
}

}  // namespace litert::lm