File size: 4,780 Bytes
b386a25 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
---
library_name: keras-hub
---
```py
import tensorflow as tf
from tokenizers import Tokenizer as HFTokenizer # pip install tokenizers
from keras_hub.tokenizers import Tokenizer as KerasTokenizerBase
class HFRustTokenizerWrapper(KerasTokenizerBase):
def __init__(self, hf_tokenizer):
super().__init__()
"""
hf_tokenizer: either a tokenizers.Tokenizer instance (recommended)
or a path to a tokenizer.json that Tokenizer.from_file can load.
"""
# 如果传入的是路径字符串,就从文件加载
if isinstance(hf_tokenizer, str):
self.tk = HFTokenizer.from_file(hf_tokenizer)
else:
# 假设是已经构造好的 tokenizers.Tokenizer
self.tk = hf_tokenizer
self._dtype = "int32"
def tokenize(self, inputs):
"""
inputs: tf.Tensor(dtype=string), shape [batch] or scalar
return: tf.RaggedTensor(dtype=int32)
"""
inputs = tf.convert_to_tensor(inputs, dtype=tf.string)
def _py_tokenize(x):
# x: tf.Tensor[string] in eager context (inside py_function)
arr = x.numpy()
texts = [
s.decode("utf-8") if isinstance(s, (bytes, bytearray)) else str(s)
for s in arr
]
encs = self.tk.encode_batch(texts, add_special_tokens=False)
ids = [enc.ids for enc in encs]
# 返回 RaggedTensor 的 components
return tf.ragged.constant(ids, dtype=tf.int32)
# tf.py_function 只能返回 Tensor / CompositeTensor
ragged = tf.py_function(
func=_py_tokenize,
inp=[inputs],
Tout=tf.RaggedTensorSpec(
shape=[None, None],
dtype=tf.int32,
ragged_rank=1,
),
)
# 修正 static shape(否则下游有时会 complain)
#ragged.set_shape([None, None])
return ragged
def detokenize(self, inputs):
"""
inputs: RaggedTensor / Tensor / list
返回: tf.Tensor(dtype string) — batch of decoded strings, or scalar if single input
"""
# 规范化为 python list[list[int]]
if isinstance(inputs, tf.RaggedTensor):
ids_list = inputs.to_list()
elif isinstance(inputs, tf.Tensor):
# 可能是 [batch, seq] 的定长 tensor
ids_list = inputs.numpy().tolist()
else:
ids_list = inputs
# 如果传入的是单条 ids (like [1,2,3]), wrap 成 batch
if ids_list and isinstance(ids_list[0], int):
ids_list = [ids_list]
texts = []
for ids in ids_list:
# tokenizers.Tokenizer 提供 decode(ids)
# 有些 tokenizer 实现有 decode_batch,但使用循环以兼容更多版本
texts.append(self.tk.decode(ids, skip_special_tokens=True))
# 如果原来是单条输入,返回 scalar string tensor 与原行为更接近
if len(texts) == 1:
return tf.convert_to_tensor(texts[0])
return tf.convert_to_tensor(texts)
def vocabulary_size(self):
# Tokenizers API 提供 get_vocab_size() 或 len(self.tk.get_vocab())
try:
return self.tk.get_vocab_size()
except Exception:
# 兜底
try:
return len(self.tk.get_vocab())
except Exception:
# 如果都不可用,返回 0
return 0
def id_to_token(self, id_):
try:
return self.tk.id_to_token(id_)
except Exception:
# 有些版本的 API 叫 token_to_id 的反向,需要手动查 vocab
try:
inv = {v: k for k, v in self.tk.get_vocab().items()}
return inv.get(int(id_), "")
except Exception:
return ""
def token_to_id(self, token):
try:
return self.tk.token_to_id(token)
except Exception:
try:
return self.tk.get_vocab().get(token, None)
except Exception:
return None
@property
def dtype(self):
return tf.int32
from huggingface_hub import hf_hub_download
from tokenizers import Tokenizer
tokenizer_path = hf_hub_download(
repo_id="Qwen/Qwen3-4B-Base",
filename="tokenizer.json",
)
hf_tokenizer = Tokenizer.from_file(tokenizer_path)
wrapper = HFRustTokenizerWrapper(hf_tokenizer)
wrapper.start_token_id = 151643 # endoftext
wrapper.end_token_id = 151643
wrapper.pad_token_id = 151643
gemma_lm.preprocessor.tokenizer = wrapper
gemma_lm.preprocessor.add_end_token = True
gemma_lm.preprocessor.add_start_token = False
``` |