RNA-MSM / tokenization_rnamsm.py
Taykhoom's picture
Upload tokenization_rnamsm.py with huggingface_hub
d0628b6 verified
import json
import os
from typing import Dict, List, Optional, Union
import torch
from transformers import PreTrainedTokenizer
from transformers.tokenization_utils_base import BatchEncoding
_VOCAB = {
"<cls>": 0,
"<pad>": 1,
"<eos>": 2,
"<unk>": 3,
"A": 4,
"G": 5,
"C": 6,
"U": 7,
"X": 8,
"N": 9,
"-": 10,
"<mask>": 11,
}
class RNAMSMTokenizer(PreTrainedTokenizer):
"""
Tokenizer for RNA-MSM.
Vocabulary: <cls>(0) <pad>(1) <eos>(2) <unk>(3) A(4) G(5) C(6) U(7) X(8) N(9) -(10) <mask>(11)
RNA-MSM is an MSA Transformer: it always expects 3D input
(batch, num_alignments, seqlen). This tokenizer treats each input string
as a single-sequence MSA (1 alignment row), so the standard __call__ API:
enc = tokenizer(["AGCU", "GAUC"], return_tensors="pt", padding=True)
# enc.input_ids: (2, 1, T) -- batch of 2 single-sequence MSAs
For real MSAs (multiple aligned sequences), use encode_msa():
enc = tokenizer.encode_msa([["AGCU--", "AGCUUU"]], return_tensors="pt")
# enc["input_ids"]: (1, 2, T) -- 1 MSA with 2 alignment rows
"""
vocab_files_names = {"vocab_file": "vocab.json"}
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file=None,
cls_token="<cls>",
pad_token="<pad>",
eos_token="<eos>",
unk_token="<unk>",
mask_token="<mask>",
**kwargs,
):
if vocab_file and os.path.isfile(vocab_file):
with open(vocab_file) as f:
self._vocab = json.load(f)
else:
self._vocab = dict(_VOCAB)
self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
super().__init__(
cls_token=cls_token,
pad_token=pad_token,
eos_token=eos_token,
unk_token=unk_token,
mask_token=mask_token,
**kwargs,
)
@property
def vocab_size(self):
return len(self._vocab)
def get_vocab(self):
return dict(self._vocab)
def _tokenize(self, text):
return list(text)
def _convert_token_to_id(self, token):
return self._vocab.get(token, self._vocab["<unk>"])
def _convert_id_to_token(self, index):
return self._ids_to_tokens.get(index, "<unk>")
def save_vocabulary(self, save_directory, filename_prefix=None):
os.makedirs(save_directory, exist_ok=True)
fname = (filename_prefix + "-" if filename_prefix else "") + "vocab.json"
path = os.path.join(save_directory, fname)
with open(path, "w") as f:
json.dump(self._vocab, f, indent=2)
return (path,)
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
cls = [self.cls_token_id]
if token_ids_1 is None:
return cls + token_ids_0
return cls + token_ids_0 + cls + token_ids_1
def get_special_tokens_mask(self, token_ids_0, token_ids_1=None,
already_has_special_tokens=False):
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0, token_ids_1, already_has_special_tokens=True)
mask = [1] + [0] * len(token_ids_0)
if token_ids_1 is not None:
mask += [1] + [0] * len(token_ids_1)
return mask
def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):
if token_ids_1 is None:
return [0] + token_ids_0
return [0] + token_ids_0 + [0] + token_ids_1
def __call__(
self,
text,
text_pair=None,
add_special_tokens=True,
padding=False,
truncation=False,
max_length=None,
return_tensors=None,
**kwargs,
):
"""
Tokenize one or more sequences, each treated as a 1-row MSA.
text: str or List[str]
Returns dict with input_ids of shape (batch, 1, seqlen) and
attention_mask of shape (batch, 1, seqlen).
"""
if isinstance(text, str):
sequences = [text]
else:
sequences = list(text)
encoded = []
for seq in sequences:
ids = self._tokenize_single(seq, add_special_tokens)
encoded.append(ids)
if padding and len(encoded) > 1:
max_len = max(len(ids) for ids in encoded)
pad_id = self.pad_token_id
encoded = [ids + [pad_id] * (max_len - len(ids)) for ids in encoded]
input_ids = [[ids] for ids in encoded]
attention_mask = [[[1 if t != self.pad_token_id else 0 for t in ids]]
for ids in encoded]
if return_tensors == "pt":
input_ids = torch.tensor(input_ids, dtype=torch.long)
attention_mask = torch.tensor(attention_mask, dtype=torch.long)
return BatchEncoding({"input_ids": input_ids, "attention_mask": attention_mask})
return BatchEncoding({"input_ids": input_ids, "attention_mask": attention_mask})
def _tokenize_single(self, sequence, add_special_tokens=True):
tokens = list(sequence)
ids = [self._convert_token_to_id(t) for t in tokens]
if add_special_tokens:
ids = [self.cls_token_id] + ids
return ids
def encode_msa(
self,
msas,
add_special_tokens=True,
padding=False,
return_tensors=None,
):
"""
Tokenize a batch of MSAs.
msas: List[List[str]]
Each inner list is one MSA (multiple aligned sequences of equal length).
All sequences within an MSA must have the same length.
Returns dict with:
input_ids: (batch, max_alignments, max_seqlen)
attention_mask: (batch, max_alignments, max_seqlen)
"""
if isinstance(msas[0], str):
msas = [msas]
max_rows = max(len(msa) for msa in msas)
max_seqlen = max(
len(self._tokenize_single(seq, add_special_tokens))
for msa in msas for seq in msa
)
pad_id = self.pad_token_id
batch_ids = []
batch_mask = []
for msa in msas:
msa_ids = []
msa_mask = []
for seq in msa:
ids = self._tokenize_single(seq, add_special_tokens)
if padding:
pad_len = max_seqlen - len(ids)
mask = [1] * len(ids) + [0] * pad_len
ids = ids + [pad_id] * pad_len
else:
mask = [1] * len(ids)
msa_ids.append(ids)
msa_mask.append(mask)
if padding:
pad_row = [pad_id] * max_seqlen
pad_mask_row = [0] * max_seqlen
while len(msa_ids) < max_rows:
msa_ids.append(pad_row)
msa_mask.append(pad_mask_row)
batch_ids.append(msa_ids)
batch_mask.append(msa_mask)
if return_tensors == "pt":
batch_ids = torch.tensor(batch_ids, dtype=torch.long)
batch_mask = torch.tensor(batch_mask, dtype=torch.long)
return BatchEncoding({"input_ids": batch_ids, "attention_mask": batch_mask})
return BatchEncoding({"input_ids": batch_ids, "attention_mask": batch_mask})
def decode(self, token_ids, skip_special_tokens=False, **kwargs):
if isinstance(token_ids, torch.Tensor):
token_ids = token_ids.tolist()
tokens = [self._convert_id_to_token(i) for i in token_ids]
if skip_special_tokens:
special = {self.cls_token, self.pad_token, self.eos_token,
self.unk_token, self.mask_token}
tokens = [t for t in tokens if t not in special]
return "".join(tokens)
def num_special_tokens_to_add(self, pair=False):
return 1