telecomadm1145 commited on
Commit
b386a25
·
verified ·
1 Parent(s): 4da0f17

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +143 -0
README.md ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: keras-hub
3
+ ---
4
+
5
+ ```py
6
+ import tensorflow as tf
7
+ from tokenizers import Tokenizer as HFTokenizer # pip install tokenizers
8
+ from keras_hub.tokenizers import Tokenizer as KerasTokenizerBase
9
+
10
+ class HFRustTokenizerWrapper(KerasTokenizerBase):
11
+ def __init__(self, hf_tokenizer):
12
+ super().__init__()
13
+ """
14
+ hf_tokenizer: either a tokenizers.Tokenizer instance (recommended)
15
+ or a path to a tokenizer.json that Tokenizer.from_file can load.
16
+ """
17
+ # 如果传入的是路径字符串,就从文件加载
18
+ if isinstance(hf_tokenizer, str):
19
+ self.tk = HFTokenizer.from_file(hf_tokenizer)
20
+ else:
21
+ # 假设是已经构造好的 tokenizers.Tokenizer
22
+ self.tk = hf_tokenizer
23
+
24
+ self._dtype = "int32"
25
+
26
+ def tokenize(self, inputs):
27
+ """
28
+ inputs: tf.Tensor(dtype=string), shape [batch] or scalar
29
+ return: tf.RaggedTensor(dtype=int32)
30
+ """
31
+ inputs = tf.convert_to_tensor(inputs, dtype=tf.string)
32
+
33
+ def _py_tokenize(x):
34
+ # x: tf.Tensor[string] in eager context (inside py_function)
35
+ arr = x.numpy()
36
+ texts = [
37
+ s.decode("utf-8") if isinstance(s, (bytes, bytearray)) else str(s)
38
+ for s in arr
39
+ ]
40
+
41
+ encs = self.tk.encode_batch(texts, add_special_tokens=False)
42
+ ids = [enc.ids for enc in encs]
43
+
44
+ # 返回 RaggedTensor 的 components
45
+ return tf.ragged.constant(ids, dtype=tf.int32)
46
+
47
+ # tf.py_function 只能返回 Tensor / CompositeTensor
48
+ ragged = tf.py_function(
49
+ func=_py_tokenize,
50
+ inp=[inputs],
51
+ Tout=tf.RaggedTensorSpec(
52
+ shape=[None, None],
53
+ dtype=tf.int32,
54
+ ragged_rank=1,
55
+ ),
56
+ )
57
+
58
+ # 修正 static shape(否则下游有时会 complain)
59
+ #ragged.set_shape([None, None])
60
+ return ragged
61
+
62
+ def detokenize(self, inputs):
63
+ """
64
+ inputs: RaggedTensor / Tensor / list
65
+ 返回: tf.Tensor(dtype string) — batch of decoded strings, or scalar if single input
66
+ """
67
+ # 规范化为 python list[list[int]]
68
+ if isinstance(inputs, tf.RaggedTensor):
69
+ ids_list = inputs.to_list()
70
+ elif isinstance(inputs, tf.Tensor):
71
+ # 可能是 [batch, seq] 的定长 tensor
72
+ ids_list = inputs.numpy().tolist()
73
+ else:
74
+ ids_list = inputs
75
+
76
+ # 如果传入的是单条 ids (like [1,2,3]), wrap 成 batch
77
+ if ids_list and isinstance(ids_list[0], int):
78
+ ids_list = [ids_list]
79
+
80
+ texts = []
81
+ for ids in ids_list:
82
+ # tokenizers.Tokenizer 提供 decode(ids)
83
+ # 有些 tokenizer 实现有 decode_batch,但使用循环以兼容更多版本
84
+ texts.append(self.tk.decode(ids, skip_special_tokens=True))
85
+
86
+ # 如果原来是单条输入,返回 scalar string tensor 与原行为更接近
87
+ if len(texts) == 1:
88
+ return tf.convert_to_tensor(texts[0])
89
+ return tf.convert_to_tensor(texts)
90
+
91
+ def vocabulary_size(self):
92
+ # Tokenizers API 提供 get_vocab_size() 或 len(self.tk.get_vocab())
93
+ try:
94
+ return self.tk.get_vocab_size()
95
+ except Exception:
96
+ # 兜底
97
+ try:
98
+ return len(self.tk.get_vocab())
99
+ except Exception:
100
+ # 如果都不可用,返回 0
101
+ return 0
102
+
103
+ def id_to_token(self, id_):
104
+ try:
105
+ return self.tk.id_to_token(id_)
106
+ except Exception:
107
+ # 有些版本的 API 叫 token_to_id 的反向,需要手动查 vocab
108
+ try:
109
+ inv = {v: k for k, v in self.tk.get_vocab().items()}
110
+ return inv.get(int(id_), "")
111
+ except Exception:
112
+ return ""
113
+
114
+ def token_to_id(self, token):
115
+ try:
116
+ return self.tk.token_to_id(token)
117
+ except Exception:
118
+ try:
119
+ return self.tk.get_vocab().get(token, None)
120
+ except Exception:
121
+ return None
122
+
123
+ @property
124
+ def dtype(self):
125
+ return tf.int32
126
+
127
+ from huggingface_hub import hf_hub_download
128
+ from tokenizers import Tokenizer
129
+
130
+ tokenizer_path = hf_hub_download(
131
+ repo_id="Qwen/Qwen3-4B-Base",
132
+ filename="tokenizer.json",
133
+ )
134
+
135
+ hf_tokenizer = Tokenizer.from_file(tokenizer_path)
136
+ wrapper = HFRustTokenizerWrapper(hf_tokenizer)
137
+ wrapper.start_token_id = 151643 # endoftext
138
+ wrapper.end_token_id = 151643
139
+ wrapper.pad_token_id = 151643
140
+ gemma_lm.preprocessor.tokenizer = wrapper
141
+ gemma_lm.preprocessor.add_end_token = True
142
+ gemma_lm.preprocessor.add_start_token = False
143
+ ```