File size: 5,872 Bytes
5f923cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
// Copyright 2025 The ODML Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef THIRD_PARTY_ODML_LITERT_LM_RUNTIME_COMPONENTS_TOKENIZER_H_
#define THIRD_PARTY_ODML_LITERT_LM_RUNTIME_COMPONENTS_TOKENIZER_H_

#include <string>
#include <vector>

#include "absl/status/status.h"  // from @com_google_absl
#include "absl/status/statusor.h"  // from @com_google_absl
#include "absl/strings/match.h"  // from @com_google_absl
#include "absl/strings/string_view.h"  // from @com_google_absl
#include "absl/types/span.h"  // from @com_google_absl
#include "litert/cc/litert_macros.h"  // from @litert
#include "litert/cc/litert_tensor_buffer.h"  // from @litert
#include "runtime/util/convert_tensor_buffer.h"

namespace litert::lm {

typedef std::vector<int> TokenIds;

// Enum representing the type of tokenizer.
enum class TokenizerType {
  kUnspecified,
  kSentencePiece,
  kHuggingFace,
};

class Tokenizer {
 public:
  virtual ~Tokenizer() = default;

  // Returns the type of the tokenizer.
  virtual TokenizerType GetTokenizerType() const = 0;

  // Encodes the given input text to token ids. Includes tokenizer pre/post
  // processing.
  virtual absl::StatusOr<TokenIds> TextToTokenIds(absl::string_view text) = 0;

  // Converts a token string to its token id. This is a raw token look up,
  // without any tokenizer pre/post processing. The implementation is expected
  // to return absl::NotFoundError if the token is not found.
  virtual absl::StatusOr<int> TokenToId(absl::string_view token) = 0;

  // Helper function to convert a vector of token ids into a 1D
  // litert::TensorBuffer of shape [batch_size(==1), num_tokens].
  static absl::StatusOr<TensorBuffer> TokenIdsToTensorBuffer(
      const TokenIds& token_ids) {
    LITERT_ASSIGN_OR_RETURN(
        auto tensor,
        CopyToTensorBuffer<int>(absl::MakeConstSpan(token_ids),
                                {1, static_cast<int>(token_ids.size())}));
    return tensor;
  }

  // Decodes the given sequence of token ids into a string.
  // Returns absl::DataLossError if any of the tokens are part of an incomplete
  // BPE sequence.
  virtual absl::StatusOr<std::string> TokenIdsToText(
      const TokenIds& token_ids) = 0;

  // Returns the list of tokens in the tokenizer.
  virtual std::vector<std::string> GetTokens() const = 0;

  // Converts a tensor buffer of token ids into a vector of token ids. The input
  // is a 2D litert::TensorBuffer shape [batch_size, decode_steps].
  static absl::StatusOr<std::vector<TokenIds>> TensorBufferToTokenIds(
      const TensorBuffer& token_ids_tensor) {
    LITERT_ASSIGN_OR_RETURN(auto tensor_type, token_ids_tensor.TensorType());
    auto dims = tensor_type.Layout().Dimensions();
    if (dims.size() != 2) {
      return absl::InvalidArgumentError(
          "The input tensor must have 2 dimensions.");
    }
    auto token_ids = CopyFromTensorBuffer2D<int>(token_ids_tensor);
    return token_ids.Value();
  }

  // Merges the previous and next token ids, by appending each next token
  // id to the corresponding previous token id row by row.
  static absl::StatusOr<std::vector<TokenIds>> MergeTokenIds(
      const std::vector<TokenIds>& previous_token_ids,
      const std::vector<TokenIds>& next_token_ids) {
    std::vector<TokenIds> merged_token_ids(next_token_ids.size());
    if (previous_token_ids.size() != next_token_ids.size()) {
      return absl::InvalidArgumentError(
          "The previous and next token ids must have the same size.");
    }
    for (int i = 0; i < previous_token_ids.size(); ++i) {
      merged_token_ids[i] = previous_token_ids[i];
      merged_token_ids[i].insert(merged_token_ids[i].end(),
                                 next_token_ids[i].begin(),
                                 next_token_ids[i].end());
    }
    return merged_token_ids;
  }

  // Decodes the given sequence of token ids into a string. The input is a 2D
  // vector of token ids, each of which is a sequence of token ids. The output
  // Tokenizer is a vector of strings, each of which is a decoded string of the
  // corresponding batch or absl::DataLossError if an incomplete BPE sequence.
  absl::StatusOr<std::vector<absl::StatusOr<std::string>>> TokenIdsToTexts(
      int batch_size, const std::vector<TokenIds>& token_ids) {
    if (token_ids.size() != batch_size) {
      return absl::InvalidArgumentError(
          "The token ID vector must have the same number of rows as the batch "
          "size.");
    }
    std::vector<absl::StatusOr<std::string>> decoded_strings(batch_size);
    for (int i = 0; i < batch_size; ++i) {
      decoded_strings[i] = this->TokenIdsToText(token_ids[i]);
    }
    return decoded_strings;
  }

  template <typename T>
  static bool IsIncompleteBpeSequence(const absl::StatusOr<T>& result) {
    return result.status().code() == absl::StatusCode::kDataLoss;
  }

  // Checks if the decoded string ends with the replacement character (U+FFFD),
  // which indicates that the set of token IDs passed to the tokenizer is part
  // of a BPE sequence and needs more tokens to be decoded.
  static bool HasBpeSuffix(absl::string_view decoded) {
    static const char kReplacementCharacter[] = "\xef\xbf\xbd";
    return absl::EndsWith(decoded, kReplacementCharacter);
  }
};

}  // namespace litert::lm

#endif  // THIRD_PARTY_ODML_LITERT_LM_RUNTIME_COMPONENTS_TOKENIZER_H_