--- library_name: keras-hub --- ```py import tensorflow as tf from tokenizers import Tokenizer as HFTokenizer # pip install tokenizers from keras_hub.tokenizers import Tokenizer as KerasTokenizerBase class HFRustTokenizerWrapper(KerasTokenizerBase): def __init__(self, hf_tokenizer): super().__init__() """ hf_tokenizer: either a tokenizers.Tokenizer instance (recommended) or a path to a tokenizer.json that Tokenizer.from_file can load. """ # 如果传入的是路径字符串,就从文件加载 if isinstance(hf_tokenizer, str): self.tk = HFTokenizer.from_file(hf_tokenizer) else: # 假设是已经构造好的 tokenizers.Tokenizer self.tk = hf_tokenizer self._dtype = "int32" def tokenize(self, inputs): """ inputs: tf.Tensor(dtype=string), shape [batch] or scalar return: tf.RaggedTensor(dtype=int32) """ inputs = tf.convert_to_tensor(inputs, dtype=tf.string) def _py_tokenize(x): # x: tf.Tensor[string] in eager context (inside py_function) arr = x.numpy() texts = [ s.decode("utf-8") if isinstance(s, (bytes, bytearray)) else str(s) for s in arr ] encs = self.tk.encode_batch(texts, add_special_tokens=False) ids = [enc.ids for enc in encs] # 返回 RaggedTensor 的 components return tf.ragged.constant(ids, dtype=tf.int32) # tf.py_function 只能返回 Tensor / CompositeTensor ragged = tf.py_function( func=_py_tokenize, inp=[inputs], Tout=tf.RaggedTensorSpec( shape=[None, None], dtype=tf.int32, ragged_rank=1, ), ) # 修正 static shape(否则下游有时会 complain) #ragged.set_shape([None, None]) return ragged def detokenize(self, inputs): """ inputs: RaggedTensor / Tensor / list 返回: tf.Tensor(dtype string) — batch of decoded strings, or scalar if single input """ # 规范化为 python list[list[int]] if isinstance(inputs, tf.RaggedTensor): ids_list = inputs.to_list() elif isinstance(inputs, tf.Tensor): # 可能是 [batch, seq] 的定长 tensor ids_list = inputs.numpy().tolist() else: ids_list = inputs # 如果传入的是单条 ids (like [1,2,3]), wrap 成 batch if ids_list and isinstance(ids_list[0], int): ids_list = [ids_list] texts = [] for ids in ids_list: # tokenizers.Tokenizer 提供 decode(ids) # 有些 tokenizer 实现有 decode_batch,但使用循环以兼容更多版本 texts.append(self.tk.decode(ids, skip_special_tokens=True)) # 如果原来是单条输入,返回 scalar string tensor 与原行为更接近 if len(texts) == 1: return tf.convert_to_tensor(texts[0]) return tf.convert_to_tensor(texts) def vocabulary_size(self): # Tokenizers API 提供 get_vocab_size() 或 len(self.tk.get_vocab()) try: return self.tk.get_vocab_size() except Exception: # 兜底 try: return len(self.tk.get_vocab()) except Exception: # 如果都不可用,返回 0 return 0 def id_to_token(self, id_): try: return self.tk.id_to_token(id_) except Exception: # 有些版本的 API 叫 token_to_id 的反向,需要手动查 vocab try: inv = {v: k for k, v in self.tk.get_vocab().items()} return inv.get(int(id_), "") except Exception: return "" def token_to_id(self, token): try: return self.tk.token_to_id(token) except Exception: try: return self.tk.get_vocab().get(token, None) except Exception: return None @property def dtype(self): return tf.int32 from huggingface_hub import hf_hub_download from tokenizers import Tokenizer tokenizer_path = hf_hub_download( repo_id="Qwen/Qwen3-4B-Base", filename="tokenizer.json", ) hf_tokenizer = Tokenizer.from_file(tokenizer_path) wrapper = HFRustTokenizerWrapper(hf_tokenizer) wrapper.start_token_id = 151643 # endoftext wrapper.end_token_id = 151643 wrapper.pad_token_id = 151643 gemma_lm.preprocessor.tokenizer = wrapper gemma_lm.preprocessor.add_end_token = True gemma_lm.preprocessor.add_start_token = False ```