File size: 4,780 Bytes
b386a25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
---
library_name: keras-hub
---

```py
import tensorflow as tf
from tokenizers import Tokenizer as HFTokenizer  # pip install tokenizers
from keras_hub.tokenizers import Tokenizer as KerasTokenizerBase

class HFRustTokenizerWrapper(KerasTokenizerBase):
    def __init__(self, hf_tokenizer):
        super().__init__()
        """
        hf_tokenizer: either a tokenizers.Tokenizer instance (recommended)
                      or a path to a tokenizer.json that Tokenizer.from_file can load.
        """
        # 如果传入的是路径字符串,就从文件加载
        if isinstance(hf_tokenizer, str):
            self.tk = HFTokenizer.from_file(hf_tokenizer)
        else:
            # 假设是已经构造好的 tokenizers.Tokenizer
            self.tk = hf_tokenizer

        self._dtype = "int32"
    
    def tokenize(self, inputs):
        """
        inputs: tf.Tensor(dtype=string), shape [batch] or scalar
        return: tf.RaggedTensor(dtype=int32)
        """
        inputs = tf.convert_to_tensor(inputs, dtype=tf.string)
    
        def _py_tokenize(x):
            # x: tf.Tensor[string] in eager context (inside py_function)
            arr = x.numpy()
            texts = [
                s.decode("utf-8") if isinstance(s, (bytes, bytearray)) else str(s)
                for s in arr
            ]
    
            encs = self.tk.encode_batch(texts, add_special_tokens=False)
            ids = [enc.ids for enc in encs]
    
            # 返回 RaggedTensor 的 components
            return tf.ragged.constant(ids, dtype=tf.int32)
    
        # tf.py_function 只能返回 Tensor / CompositeTensor
        ragged = tf.py_function(
            func=_py_tokenize,
            inp=[inputs],
            Tout=tf.RaggedTensorSpec(
                shape=[None, None],
                dtype=tf.int32,
                ragged_rank=1,
            ),
        )
    
        # 修正 static shape(否则下游有时会 complain)
        #ragged.set_shape([None, None])
        return ragged

    def detokenize(self, inputs):
        """
        inputs: RaggedTensor / Tensor / list
        返回: tf.Tensor(dtype string) — batch of decoded strings, or scalar if single input
        """
        # 规范化为 python list[list[int]]
        if isinstance(inputs, tf.RaggedTensor):
            ids_list = inputs.to_list()
        elif isinstance(inputs, tf.Tensor):
            # 可能是 [batch, seq] 的定长 tensor
            ids_list = inputs.numpy().tolist()
        else:
            ids_list = inputs

        # 如果传入的是单条 ids (like [1,2,3]), wrap 成 batch
        if ids_list and isinstance(ids_list[0], int):
            ids_list = [ids_list]

        texts = []
        for ids in ids_list:
            # tokenizers.Tokenizer 提供 decode(ids)
            # 有些 tokenizer 实现有 decode_batch,但使用循环以兼容更多版本
            texts.append(self.tk.decode(ids, skip_special_tokens=True))

        # 如果原来是单条输入,返回 scalar string tensor 与原行为更接近
        if len(texts) == 1:
            return tf.convert_to_tensor(texts[0])
        return tf.convert_to_tensor(texts)

    def vocabulary_size(self):
        # Tokenizers API 提供 get_vocab_size() 或 len(self.tk.get_vocab())
        try:
            return self.tk.get_vocab_size()
        except Exception:
            # 兜底
            try:
                return len(self.tk.get_vocab())
            except Exception:
                # 如果都不可用,返回 0
                return 0

    def id_to_token(self, id_):
        try:
            return self.tk.id_to_token(id_)
        except Exception:
            # 有些版本的 API 叫 token_to_id 的反向,需要手动查 vocab
            try:
                inv = {v: k for k, v in self.tk.get_vocab().items()}
                return inv.get(int(id_), "")
            except Exception:
                return ""

    def token_to_id(self, token):
        try:
            return self.tk.token_to_id(token)
        except Exception:
            try:
                return self.tk.get_vocab().get(token, None)
            except Exception:
                return None

    @property
    def dtype(self):
        return tf.int32

from huggingface_hub import hf_hub_download
from tokenizers import Tokenizer

tokenizer_path = hf_hub_download(
    repo_id="Qwen/Qwen3-4B-Base",
    filename="tokenizer.json",
)

hf_tokenizer = Tokenizer.from_file(tokenizer_path)
wrapper = HFRustTokenizerWrapper(hf_tokenizer)
wrapper.start_token_id = 151643 # endoftext
wrapper.end_token_id = 151643
wrapper.pad_token_id = 151643
gemma_lm.preprocessor.tokenizer = wrapper
gemma_lm.preprocessor.add_end_token = True
gemma_lm.preprocessor.add_start_token = False
```