import json import os from transformers import PreTrainedTokenizer _RNA_VOCAB = { "": 0, "": 1, "": 2, "": 3, "A": 4, "C": 5, "G": 6, "U": 7, "R": 8, "Y": 9, "K": 10, "M": 11, "S": 12, "W": 13, "B": 14, "D": 15, "H": 16, "V": 17, "N": 18, "-": 19, "": 20, "": 21, "": 22, "": 23, "": 24, } _MRNA_VOCAB = { "": 0, "": 1, "": 2, "": 3, "GAG": 4, "AAG": 5, "GAA": 6, "CUG": 7, "CAG": 8, "GAU": 9, "AAA": 10, "GUG": 11, "GAC": 12, "AUG": 13, "GCC": 14, "AAC": 15, "GCU": 16, "AAU": 17, "AUC": 18, "UUC": 19, "GGA": 20, "AUU": 21, "GGC": 22, "UUU": 23, "CCA": 24, "AGC": 25, "GCA": 26, "UCU": 27, "CUC": 28, "ACC": 29, "CAA": 30, "CCU": 31, "UCC": 32, "ACA": 33, "UUG": 34, "GUU": 35, "CUU": 36, "UAC": 37, "ACU": 38, "CCC": 39, "UCA": 40, "GUC": 41, "GGU": 42, "CAC": 43, "AGU": 44, "UAU": 45, "AGA": 46, "CAU": 47, "GGG": 48, "UGG": 49, "UGC": 50, "AGG": 51, "UGU": 52, "AUA": 53, "CGC": 54, "UUA": 55, "GCG": 56, "CGG": 57, "CCG": 58, "GUA": 59, "CUA": 60, "ACG": 61, "UCG": 62, "CGA": 63, "CGU": 64, "UGA": 65, "UAA": 66, "UAG": 67, "": 68, "": 69, "": 70, "": 71, "": 72, } class RnaFmTokenizer(PreTrainedTokenizer): vocab_files_names = {"vocab_file": "vocab.json"} model_input_names = ["input_ids", "attention_mask"] def __init__( self, vocab_file=None, k_mer: int = 1, cls_token="", pad_token="", eos_token="", unk_token="", mask_token="", **kwargs, ): self.k_mer = k_mer if vocab_file and os.path.isfile(vocab_file): with open(vocab_file) as f: self._vocab = json.load(f) else: self._vocab = dict(_MRNA_VOCAB if k_mer == 3 else _RNA_VOCAB) self._ids_to_tokens = {v: k for k, v in self._vocab.items()} super().__init__( cls_token=cls_token, pad_token=pad_token, eos_token=eos_token, unk_token=unk_token, mask_token=mask_token, k_mer=k_mer, **kwargs, ) @property def vocab_size(self): return len(self._vocab) def get_vocab(self): return dict(self._vocab) def _tokenize(self, text): if self.k_mer == 1: return list(text) return [text[i:i + self.k_mer] for i in range(0, len(text), self.k_mer)] def _convert_token_to_id(self, token): return self._vocab.get(token, self._vocab[""]) def _convert_id_to_token(self, index): return self._ids_to_tokens.get(index, "") def save_vocabulary(self, save_directory, filename_prefix=None): os.makedirs(save_directory, exist_ok=True) fname = (filename_prefix + "-" if filename_prefix else "") + "vocab.json" path = os.path.join(save_directory, fname) with open(path, "w") as f: json.dump(self._vocab, f, indent=2) return (path,) def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): cls = [self.cls_token_id] eos = [self.eos_token_id] if token_ids_1 is None: return cls + token_ids_0 + eos return cls + token_ids_0 + eos + cls + token_ids_1 + eos def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False): if already_has_special_tokens: return super().get_special_tokens_mask(token_ids_0, token_ids_1, already_has_special_tokens=True) mask = [1] + [0] * len(token_ids_0) + [1] if token_ids_1 is not None: mask += [1] + [0] * len(token_ids_1) + [1] return mask def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None): if token_ids_1 is None: return [0] * (len(token_ids_0) + 2) return [0] * (len(token_ids_0) + 2) + [0] * (len(token_ids_1) + 2) @staticmethod def _extract_cds(sequence, cds): """Extract CDS region from a sequence, trimmed to a multiple of 3.""" import numpy as np if sum(cds) == 0: return sequence[:len(sequence) - (len(sequence) % 3)] first = int(np.argmax(cds == 1)) last = int(len(cds) - 1 - np.argmax(np.flip(cds) == 1)) + 2 region = sequence[first:last + 1] if len(region) % 3 != 0: region = region[:-(len(region) % 3)] return region def batch_encode_with_cds(self, sequences, cds, max_length=None, **kwargs): """Encode sequences with CDS extraction (k_mer=3 / mRNA-FM only). Applies T->U, extracts the CDS region, chunks to max_length nucleotides (aligned to codon boundaries), and encodes each chunk. Args: sequences: List of raw nucleotide strings (T or U). cds: List of numpy arrays marking CDS codon start positions. max_length: Nucleotide budget per chunk (defaults to (model_max_length - 2) * k_mer). **kwargs: Forwarded to batch_encode_plus (e.g. return_tensors, padding, add_special_tokens). Returns: Tuple of (BatchEncoding, chunk_counts) where chunk_counts[i] is the number of chunks produced for sequences[i]. """ if self.k_mer != 3: raise ValueError("batch_encode_with_cds requires k_mer=3 (mRNA-FM tokenizer)") budget = max_length if max_length is not None else (self.model_max_length - 2) * self.k_mer budget = (budget // self.k_mer) * self.k_mer all_chunks = [] chunk_counts = [] for seq, c in zip(sequences, cds): seq = seq.replace("T", "U").replace("t", "u") seq = self._extract_cds(seq, c) raw_chunks = [seq[i:i + budget] for i in range(0, max(len(seq), 1), budget)] chunks = [] for chunk in raw_chunks: if len(chunk) % self.k_mer != 0: chunk = chunk[:-(len(chunk) % self.k_mer)] if chunk: chunks.append(chunk) if not chunks: chunks = ["AUG"] all_chunks.extend(chunks) chunk_counts.append(len(chunks)) enc = self.batch_encode_plus(all_chunks, **kwargs) return enc, chunk_counts