RiNALMo-mega / tokenization_rinalmo.py
Taykhoom's picture
Upload folder using huggingface_hub
7b96962 verified
import json
import os
from typing import Dict, List, Optional
from transformers import PreTrainedTokenizer
_VOCAB = {
"<cls>": 0,
"<pad>": 1,
"<eos>": 2,
"<unk>": 3,
"<mask>": 4,
"A": 5,
"C": 6,
"G": 7,
"T": 8,
"I": 9,
"R": 10,
"Y": 11,
"K": 12,
"M": 13,
"S": 14,
"W": 15,
"B": 16,
"D": 17,
"H": 18,
"V": 19,
"N": 20,
"-": 21,
}
class RiNALMoTokenizer(PreTrainedTokenizer):
"""
Tokenizer for RiNALMo. Character-level over a 22-token RNA alphabet.
Converts U->T before tokenizing (the model was trained on T, not U).
Wraps sequences as <cls> ... <eos>.
"""
vocab_files_names = {"vocab_file": "vocab.json"}
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file: Optional[str] = None,
cls_token: str = "<cls>",
pad_token: str = "<pad>",
eos_token: str = "<eos>",
unk_token: str = "<unk>",
mask_token: str = "<mask>",
**kwargs,
):
if vocab_file is not None and os.path.isfile(vocab_file):
with open(vocab_file) as f:
self._vocab = json.load(f)
else:
self._vocab = dict(_VOCAB)
self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
super().__init__(
cls_token=cls_token,
pad_token=pad_token,
eos_token=eos_token,
unk_token=unk_token,
mask_token=mask_token,
**kwargs,
)
@property
def vocab_size(self) -> int:
return len(self._vocab)
def get_vocab(self) -> Dict[str, int]:
return dict(self._vocab)
def _tokenize(self, text: str) -> List[str]:
text = text.upper().replace("U", "T")
return list(text)
def _convert_token_to_id(self, token: str) -> int:
return self._vocab.get(token, self._vocab["<unk>"])
def _convert_id_to_token(self, index: int) -> str:
return self._ids_to_tokens.get(index, "<unk>")
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None):
os.makedirs(save_directory, exist_ok=True)
fname = (filename_prefix + "-" if filename_prefix else "") + "vocab.json"
path = os.path.join(save_directory, fname)
with open(path, "w") as f:
json.dump(self._vocab, f, indent=2)
return (path,)
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
cls = [self.cls_token_id]
eos = [self.eos_token_id]
if token_ids_1 is None:
return cls + token_ids_0 + eos
return cls + token_ids_0 + eos + cls + token_ids_1 + eos
def get_special_tokens_mask(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
already_has_special_tokens: bool = False,
) -> List[int]:
if already_has_special_tokens:
return super().get_special_tokens_mask(token_ids_0, token_ids_1, True)
mask = [1] + [0] * len(token_ids_0) + [1]
if token_ids_1 is not None:
mask += [1] + [0] * len(token_ids_1) + [1]
return mask
def create_token_type_ids_from_sequences(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
) -> List[int]:
if token_ids_1 is None:
return [0] * (len(token_ids_0) + 2)
return [0] * (len(token_ids_0) + 2) + [0] * (len(token_ids_1) + 2)