import json import os from typing import Dict, List, Optional, Union import torch from transformers import PreTrainedTokenizer from transformers.tokenization_utils_base import BatchEncoding _VOCAB = { "": 0, "": 1, "": 2, "": 3, "A": 4, "G": 5, "C": 6, "U": 7, "X": 8, "N": 9, "-": 10, "": 11, } class RNAMSMTokenizer(PreTrainedTokenizer): """ Tokenizer for RNA-MSM. Vocabulary: (0) (1) (2) (3) A(4) G(5) C(6) U(7) X(8) N(9) -(10) (11) RNA-MSM is an MSA Transformer: it always expects 3D input (batch, num_alignments, seqlen). This tokenizer treats each input string as a single-sequence MSA (1 alignment row), so the standard __call__ API: enc = tokenizer(["AGCU", "GAUC"], return_tensors="pt", padding=True) # enc.input_ids: (2, 1, T) -- batch of 2 single-sequence MSAs For real MSAs (multiple aligned sequences), use encode_msa(): enc = tokenizer.encode_msa([["AGCU--", "AGCUUU"]], return_tensors="pt") # enc["input_ids"]: (1, 2, T) -- 1 MSA with 2 alignment rows """ vocab_files_names = {"vocab_file": "vocab.json"} model_input_names = ["input_ids", "attention_mask"] def __init__( self, vocab_file=None, cls_token="", pad_token="", eos_token="", unk_token="", mask_token="", **kwargs, ): if vocab_file and os.path.isfile(vocab_file): with open(vocab_file) as f: self._vocab = json.load(f) else: self._vocab = dict(_VOCAB) self._ids_to_tokens = {v: k for k, v in self._vocab.items()} super().__init__( cls_token=cls_token, pad_token=pad_token, eos_token=eos_token, unk_token=unk_token, mask_token=mask_token, **kwargs, ) @property def vocab_size(self): return len(self._vocab) def get_vocab(self): return dict(self._vocab) def _tokenize(self, text): return list(text) def _convert_token_to_id(self, token): return self._vocab.get(token, self._vocab[""]) def _convert_id_to_token(self, index): return self._ids_to_tokens.get(index, "") def save_vocabulary(self, save_directory, filename_prefix=None): os.makedirs(save_directory, exist_ok=True) fname = (filename_prefix + "-" if filename_prefix else "") + "vocab.json" path = os.path.join(save_directory, fname) with open(path, "w") as f: json.dump(self._vocab, f, indent=2) return (path,) def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): cls = [self.cls_token_id] if token_ids_1 is None: return cls + token_ids_0 return cls + token_ids_0 + cls + token_ids_1 def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False): if already_has_special_tokens: return super().get_special_tokens_mask( token_ids_0, token_ids_1, already_has_special_tokens=True) mask = [1] + [0] * len(token_ids_0) if token_ids_1 is not None: mask += [1] + [0] * len(token_ids_1) return mask def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None): if token_ids_1 is None: return [0] + token_ids_0 return [0] + token_ids_0 + [0] + token_ids_1 def __call__( self, text, text_pair=None, add_special_tokens=True, padding=False, truncation=False, max_length=None, return_tensors=None, **kwargs, ): """ Tokenize one or more sequences, each treated as a 1-row MSA. text: str or List[str] Returns dict with input_ids of shape (batch, 1, seqlen) and attention_mask of shape (batch, 1, seqlen). """ if isinstance(text, str): sequences = [text] else: sequences = list(text) encoded = [] for seq in sequences: ids = self._tokenize_single(seq, add_special_tokens) encoded.append(ids) if padding and len(encoded) > 1: max_len = max(len(ids) for ids in encoded) pad_id = self.pad_token_id encoded = [ids + [pad_id] * (max_len - len(ids)) for ids in encoded] input_ids = [[ids] for ids in encoded] attention_mask = [[[1 if t != self.pad_token_id else 0 for t in ids]] for ids in encoded] if return_tensors == "pt": input_ids = torch.tensor(input_ids, dtype=torch.long) attention_mask = torch.tensor(attention_mask, dtype=torch.long) return BatchEncoding({"input_ids": input_ids, "attention_mask": attention_mask}) return BatchEncoding({"input_ids": input_ids, "attention_mask": attention_mask}) def _tokenize_single(self, sequence, add_special_tokens=True): tokens = list(sequence) ids = [self._convert_token_to_id(t) for t in tokens] if add_special_tokens: ids = [self.cls_token_id] + ids return ids def encode_msa( self, msas, add_special_tokens=True, padding=False, return_tensors=None, ): """ Tokenize a batch of MSAs. msas: List[List[str]] Each inner list is one MSA (multiple aligned sequences of equal length). All sequences within an MSA must have the same length. Returns dict with: input_ids: (batch, max_alignments, max_seqlen) attention_mask: (batch, max_alignments, max_seqlen) """ if isinstance(msas[0], str): msas = [msas] max_rows = max(len(msa) for msa in msas) max_seqlen = max( len(self._tokenize_single(seq, add_special_tokens)) for msa in msas for seq in msa ) pad_id = self.pad_token_id batch_ids = [] batch_mask = [] for msa in msas: msa_ids = [] msa_mask = [] for seq in msa: ids = self._tokenize_single(seq, add_special_tokens) if padding: pad_len = max_seqlen - len(ids) mask = [1] * len(ids) + [0] * pad_len ids = ids + [pad_id] * pad_len else: mask = [1] * len(ids) msa_ids.append(ids) msa_mask.append(mask) if padding: pad_row = [pad_id] * max_seqlen pad_mask_row = [0] * max_seqlen while len(msa_ids) < max_rows: msa_ids.append(pad_row) msa_mask.append(pad_mask_row) batch_ids.append(msa_ids) batch_mask.append(msa_mask) if return_tensors == "pt": batch_ids = torch.tensor(batch_ids, dtype=torch.long) batch_mask = torch.tensor(batch_mask, dtype=torch.long) return BatchEncoding({"input_ids": batch_ids, "attention_mask": batch_mask}) return BatchEncoding({"input_ids": batch_ids, "attention_mask": batch_mask}) def decode(self, token_ids, skip_special_tokens=False, **kwargs): if isinstance(token_ids, torch.Tensor): token_ids = token_ids.tolist() tokens = [self._convert_id_to_token(i) for i in token_ids] if skip_special_tokens: special = {self.cls_token, self.pad_token, self.eos_token, self.unk_token, self.mask_token} tokens = [t for t in tokens if t not in special] return "".join(tokens) def num_special_tokens_to_add(self, pair=False): return 1