|
|
--- |
|
|
library_name: keras-hub |
|
|
--- |
|
|
|
|
|
```py |
|
|
import tensorflow as tf |
|
|
from tokenizers import Tokenizer as HFTokenizer # pip install tokenizers |
|
|
from keras_hub.tokenizers import Tokenizer as KerasTokenizerBase |
|
|
|
|
|
class HFRustTokenizerWrapper(KerasTokenizerBase): |
|
|
def __init__(self, hf_tokenizer): |
|
|
super().__init__() |
|
|
""" |
|
|
hf_tokenizer: either a tokenizers.Tokenizer instance (recommended) |
|
|
or a path to a tokenizer.json that Tokenizer.from_file can load. |
|
|
""" |
|
|
# 如果传入的是路径字符串,就从文件加载 |
|
|
if isinstance(hf_tokenizer, str): |
|
|
self.tk = HFTokenizer.from_file(hf_tokenizer) |
|
|
else: |
|
|
# 假设是已经构造好的 tokenizers.Tokenizer |
|
|
self.tk = hf_tokenizer |
|
|
|
|
|
self._dtype = "int32" |
|
|
|
|
|
def tokenize(self, inputs): |
|
|
""" |
|
|
inputs: tf.Tensor(dtype=string), shape [batch] or scalar |
|
|
return: tf.RaggedTensor(dtype=int32) |
|
|
""" |
|
|
inputs = tf.convert_to_tensor(inputs, dtype=tf.string) |
|
|
|
|
|
def _py_tokenize(x): |
|
|
# x: tf.Tensor[string] in eager context (inside py_function) |
|
|
arr = x.numpy() |
|
|
texts = [ |
|
|
s.decode("utf-8") if isinstance(s, (bytes, bytearray)) else str(s) |
|
|
for s in arr |
|
|
] |
|
|
|
|
|
encs = self.tk.encode_batch(texts, add_special_tokens=False) |
|
|
ids = [enc.ids for enc in encs] |
|
|
|
|
|
# 返回 RaggedTensor 的 components |
|
|
return tf.ragged.constant(ids, dtype=tf.int32) |
|
|
|
|
|
# tf.py_function 只能返回 Tensor / CompositeTensor |
|
|
ragged = tf.py_function( |
|
|
func=_py_tokenize, |
|
|
inp=[inputs], |
|
|
Tout=tf.RaggedTensorSpec( |
|
|
shape=[None, None], |
|
|
dtype=tf.int32, |
|
|
ragged_rank=1, |
|
|
), |
|
|
) |
|
|
|
|
|
# 修正 static shape(否则下游有时会 complain) |
|
|
#ragged.set_shape([None, None]) |
|
|
return ragged |
|
|
|
|
|
def detokenize(self, inputs): |
|
|
""" |
|
|
inputs: RaggedTensor / Tensor / list |
|
|
返回: tf.Tensor(dtype string) — batch of decoded strings, or scalar if single input |
|
|
""" |
|
|
# 规范化为 python list[list[int]] |
|
|
if isinstance(inputs, tf.RaggedTensor): |
|
|
ids_list = inputs.to_list() |
|
|
elif isinstance(inputs, tf.Tensor): |
|
|
# 可能是 [batch, seq] 的定长 tensor |
|
|
ids_list = inputs.numpy().tolist() |
|
|
else: |
|
|
ids_list = inputs |
|
|
|
|
|
# 如果传入的是单条 ids (like [1,2,3]), wrap 成 batch |
|
|
if ids_list and isinstance(ids_list[0], int): |
|
|
ids_list = [ids_list] |
|
|
|
|
|
texts = [] |
|
|
for ids in ids_list: |
|
|
# tokenizers.Tokenizer 提供 decode(ids) |
|
|
# 有些 tokenizer 实现有 decode_batch,但使用循环以兼容更多版本 |
|
|
texts.append(self.tk.decode(ids, skip_special_tokens=True)) |
|
|
|
|
|
# 如果原来是单条输入,返回 scalar string tensor 与原行为更接近 |
|
|
if len(texts) == 1: |
|
|
return tf.convert_to_tensor(texts[0]) |
|
|
return tf.convert_to_tensor(texts) |
|
|
|
|
|
def vocabulary_size(self): |
|
|
# Tokenizers API 提供 get_vocab_size() 或 len(self.tk.get_vocab()) |
|
|
try: |
|
|
return self.tk.get_vocab_size() |
|
|
except Exception: |
|
|
# 兜底 |
|
|
try: |
|
|
return len(self.tk.get_vocab()) |
|
|
except Exception: |
|
|
# 如果都不可用,返回 0 |
|
|
return 0 |
|
|
|
|
|
def id_to_token(self, id_): |
|
|
try: |
|
|
return self.tk.id_to_token(id_) |
|
|
except Exception: |
|
|
# 有些版本的 API 叫 token_to_id 的反向,需要手动查 vocab |
|
|
try: |
|
|
inv = {v: k for k, v in self.tk.get_vocab().items()} |
|
|
return inv.get(int(id_), "") |
|
|
except Exception: |
|
|
return "" |
|
|
|
|
|
def token_to_id(self, token): |
|
|
try: |
|
|
return self.tk.token_to_id(token) |
|
|
except Exception: |
|
|
try: |
|
|
return self.tk.get_vocab().get(token, None) |
|
|
except Exception: |
|
|
return None |
|
|
|
|
|
@property |
|
|
def dtype(self): |
|
|
return tf.int32 |
|
|
|
|
|
from huggingface_hub import hf_hub_download |
|
|
from tokenizers import Tokenizer |
|
|
|
|
|
tokenizer_path = hf_hub_download( |
|
|
repo_id="Qwen/Qwen3-4B-Base", |
|
|
filename="tokenizer.json", |
|
|
) |
|
|
|
|
|
hf_tokenizer = Tokenizer.from_file(tokenizer_path) |
|
|
wrapper = HFRustTokenizerWrapper(hf_tokenizer) |
|
|
wrapper.start_token_id = 151643 # endoftext |
|
|
wrapper.end_token_id = 151643 |
|
|
wrapper.pad_token_id = 151643 |
|
|
gemma_lm.preprocessor.tokenizer = wrapper |
|
|
gemma_lm.preprocessor.add_end_token = True |
|
|
gemma_lm.preprocessor.add_start_token = False |
|
|
``` |