import json import os import torch from typing import Any, Dict, List, Optional, Tuple, Union from transformers import PreTrainedTokenizer, AutoTokenizer class STLTokenizer(PreTrainedTokenizer): model_type = "stl_encoder" def __init__( self, vocab_file="vocab.json", unk_token="unk", pad_token="pad", bos_token="/s", eos_token="s", model_max_length=512, **kwargs ): current_dir = os.path.dirname(__file__) full_vocab_path = os.path.join(current_dir, vocab_file) if not os.path.exists(full_vocab_path): from huggingface_hub import hf_hub_download try: full_vocab_path = hf_hub_download("saracandu/stlenc", vocab_file) except: full_vocab_path = vocab_file with open(full_vocab_path, "r", encoding="utf-8") as f: self.vocab = json.load(f) self.id_to_token = {v: k for k, v in self.vocab.items()} super().__init__( unk_token=unk_token, pad_token=pad_token, bos_token=bos_token, eos_token=eos_token, model_max_length=model_max_length, **kwargs ) @property def vocab_size(self) -> int: return len(self.vocab) def get_vocab(self) -> Dict[str, int]: return dict(self.vocab) def _tokenize(self, text: str) -> List[str]: text = f'{self.bos_token} {text} {self.eos_token}'.replace(' ', '@') tokens = [] i = 0 while i < len(text): best_match = None for j in range(min(i + 50, len(text)), i, -1): subtoken = text[i:j] if subtoken in self.vocab: best_match = subtoken break if best_match: tokens.append(best_match) i += len(best_match) else: tokens.append(self.unk_token) i += 1 return tokens def _convert_token_to_id(self, token: str) -> int: return self.vocab.get(token, self.vocab.get(self.unk_token)) def _convert_id_to_token(self, index: int) -> str: return self.id_to_token.get(index, self.unk_token) def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: if not os.path.isdir(save_directory): os.makedirs(save_directory) prefix = filename_prefix if filename_prefix is not None else "" vocab_file = os.path.join(save_directory, prefix + "vocab.json") with open(vocab_file, "w", encoding="utf-8") as f: json.dump(self.vocab, f, indent=2, ensure_ascii=False) return (vocab_file,) try: AutoTokenizer.register("stl_encoder", STLTokenizer) except Exception: pass