stlenc / tokenizer_stlenc.py
saracandu's picture
fix?
707b7ba verified
import json
import os
import torch
from typing import Any, Dict, List, Optional, Tuple, Union
from transformers import PreTrainedTokenizer, AutoTokenizer
class STLTokenizer(PreTrainedTokenizer):
model_type = "stl_encoder"
def __init__(
self,
vocab_file="vocab.json",
unk_token="unk",
pad_token="pad",
bos_token="/s",
eos_token="s",
model_max_length=512,
**kwargs
):
current_dir = os.path.dirname(__file__)
full_vocab_path = os.path.join(current_dir, vocab_file)
if not os.path.exists(full_vocab_path):
from huggingface_hub import hf_hub_download
try:
full_vocab_path = hf_hub_download("saracandu/stlenc", vocab_file)
except:
full_vocab_path = vocab_file
with open(full_vocab_path, "r", encoding="utf-8") as f:
self.vocab = json.load(f)
self.id_to_token = {v: k for k, v in self.vocab.items()}
super().__init__(
unk_token=unk_token,
pad_token=pad_token,
bos_token=bos_token,
eos_token=eos_token,
model_max_length=model_max_length,
**kwargs
)
@property
def vocab_size(self) -> int:
return len(self.vocab)
def get_vocab(self) -> Dict[str, int]:
return dict(self.vocab)
def _tokenize(self, text: str) -> List[str]:
text = f'{self.bos_token} {text} {self.eos_token}'.replace(' ', '@')
tokens = []
i = 0
while i < len(text):
best_match = None
for j in range(min(i + 50, len(text)), i, -1):
subtoken = text[i:j]
if subtoken in self.vocab:
best_match = subtoken
break
if best_match:
tokens.append(best_match)
i += len(best_match)
else:
tokens.append(self.unk_token)
i += 1
return tokens
def _convert_token_to_id(self, token: str) -> int:
return self.vocab.get(token, self.vocab.get(self.unk_token))
def _convert_id_to_token(self, index: int) -> str:
return self.id_to_token.get(index, self.unk_token)
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory):
os.makedirs(save_directory)
prefix = filename_prefix if filename_prefix is not None else ""
vocab_file = os.path.join(save_directory, prefix + "vocab.json")
with open(vocab_file, "w", encoding="utf-8") as f:
json.dump(self.vocab, f, indent=2, ensure_ascii=False)
return (vocab_file,)
try:
AutoTokenizer.register("stl_encoder", STLTokenizer)
except Exception:
pass