File size: 2,932 Bytes
0956ad7 707b7ba 0956ad7 707b7ba 0956ad7 707b7ba 0956ad7 707b7ba 0956ad7 707b7ba 0956ad7 707b7ba 0956ad7 707b7ba 0956ad7 707b7ba 0956ad7 707b7ba 0956ad7 707b7ba 0956ad7 707b7ba 0956ad7 707b7ba 0956ad7 707b7ba 0956ad7 707b7ba 0956ad7 707b7ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import json
import os
import torch
from typing import Any, Dict, List, Optional, Tuple, Union
from transformers import PreTrainedTokenizer, AutoTokenizer
class STLTokenizer(PreTrainedTokenizer):
model_type = "stl_encoder"
def __init__(
self,
vocab_file="vocab.json",
unk_token="unk",
pad_token="pad",
bos_token="/s",
eos_token="s",
model_max_length=512,
**kwargs
):
current_dir = os.path.dirname(__file__)
full_vocab_path = os.path.join(current_dir, vocab_file)
if not os.path.exists(full_vocab_path):
from huggingface_hub import hf_hub_download
try:
full_vocab_path = hf_hub_download("saracandu/stlenc", vocab_file)
except:
full_vocab_path = vocab_file
with open(full_vocab_path, "r", encoding="utf-8") as f:
self.vocab = json.load(f)
self.id_to_token = {v: k for k, v in self.vocab.items()}
super().__init__(
unk_token=unk_token,
pad_token=pad_token,
bos_token=bos_token,
eos_token=eos_token,
model_max_length=model_max_length,
**kwargs
)
@property
def vocab_size(self) -> int:
return len(self.vocab)
def get_vocab(self) -> Dict[str, int]:
return dict(self.vocab)
def _tokenize(self, text: str) -> List[str]:
text = f'{self.bos_token} {text} {self.eos_token}'.replace(' ', '@')
tokens = []
i = 0
while i < len(text):
best_match = None
for j in range(min(i + 50, len(text)), i, -1):
subtoken = text[i:j]
if subtoken in self.vocab:
best_match = subtoken
break
if best_match:
tokens.append(best_match)
i += len(best_match)
else:
tokens.append(self.unk_token)
i += 1
return tokens
def _convert_token_to_id(self, token: str) -> int:
return self.vocab.get(token, self.vocab.get(self.unk_token))
def _convert_id_to_token(self, index: int) -> str:
return self.id_to_token.get(index, self.unk_token)
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory):
os.makedirs(save_directory)
prefix = filename_prefix if filename_prefix is not None else ""
vocab_file = os.path.join(save_directory, prefix + "vocab.json")
with open(vocab_file, "w", encoding="utf-8") as f:
json.dump(self.vocab, f, indent=2, ensure_ascii=False)
return (vocab_file,)
try:
AutoTokenizer.register("stl_encoder", STLTokenizer)
except Exception:
pass |