RNA-FM / tokenization_rnafm.py
Taykhoom's picture
Upload tokenization_rnafm.py with huggingface_hub
5f987d0 verified
import json
import os
from transformers import PreTrainedTokenizer
_RNA_VOCAB = {
"<cls>": 0, "<pad>": 1, "<eos>": 2, "<unk>": 3,
"A": 4, "C": 5, "G": 6, "U": 7,
"R": 8, "Y": 9, "K": 10, "M": 11,
"S": 12, "W": 13, "B": 14, "D": 15,
"H": 16, "V": 17, "N": 18, "-": 19,
"<null_1>": 20, "<null_2>": 21, "<null_3>": 22, "<null_4>": 23,
"<mask>": 24,
}
_MRNA_VOCAB = {
"<cls>": 0, "<pad>": 1, "<eos>": 2, "<unk>": 3,
"GAG": 4, "AAG": 5, "GAA": 6, "CUG": 7, "CAG": 8, "GAU": 9,
"AAA": 10, "GUG": 11, "GAC": 12, "AUG": 13, "GCC": 14, "AAC": 15,
"GCU": 16, "AAU": 17, "AUC": 18, "UUC": 19, "GGA": 20, "AUU": 21,
"GGC": 22, "UUU": 23, "CCA": 24, "AGC": 25, "GCA": 26, "UCU": 27,
"CUC": 28, "ACC": 29, "CAA": 30, "CCU": 31, "UCC": 32, "ACA": 33,
"UUG": 34, "GUU": 35, "CUU": 36, "UAC": 37, "ACU": 38, "CCC": 39,
"UCA": 40, "GUC": 41, "GGU": 42, "CAC": 43, "AGU": 44, "UAU": 45,
"AGA": 46, "CAU": 47, "GGG": 48, "UGG": 49, "UGC": 50, "AGG": 51,
"UGU": 52, "AUA": 53, "CGC": 54, "UUA": 55, "GCG": 56, "CGG": 57,
"CCG": 58, "GUA": 59, "CUA": 60, "ACG": 61, "UCG": 62, "CGA": 63,
"CGU": 64, "UGA": 65, "UAA": 66, "UAG": 67,
"<null_1>": 68, "<null_2>": 69, "<null_3>": 70, "<null_4>": 71,
"<mask>": 72,
}
class RnaFmTokenizer(PreTrainedTokenizer):
vocab_files_names = {"vocab_file": "vocab.json"}
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file=None,
k_mer: int = 1,
cls_token="<cls>",
pad_token="<pad>",
eos_token="<eos>",
unk_token="<unk>",
mask_token="<mask>",
**kwargs,
):
self.k_mer = k_mer
if vocab_file and os.path.isfile(vocab_file):
with open(vocab_file) as f:
self._vocab = json.load(f)
else:
self._vocab = dict(_MRNA_VOCAB if k_mer == 3 else _RNA_VOCAB)
self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
super().__init__(
cls_token=cls_token,
pad_token=pad_token,
eos_token=eos_token,
unk_token=unk_token,
mask_token=mask_token,
k_mer=k_mer,
**kwargs,
)
@property
def vocab_size(self):
return len(self._vocab)
def get_vocab(self):
return dict(self._vocab)
def _tokenize(self, text):
if self.k_mer == 1:
return list(text)
return [text[i:i + self.k_mer] for i in range(0, len(text), self.k_mer)]
def _convert_token_to_id(self, token):
return self._vocab.get(token, self._vocab["<unk>"])
def _convert_id_to_token(self, index):
return self._ids_to_tokens.get(index, "<unk>")
def save_vocabulary(self, save_directory, filename_prefix=None):
os.makedirs(save_directory, exist_ok=True)
fname = (filename_prefix + "-" if filename_prefix else "") + "vocab.json"
path = os.path.join(save_directory, fname)
with open(path, "w") as f:
json.dump(self._vocab, f, indent=2)
return (path,)
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
cls = [self.cls_token_id]
eos = [self.eos_token_id]
if token_ids_1 is None:
return cls + token_ids_0 + eos
return cls + token_ids_0 + eos + cls + token_ids_1 + eos
def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
if already_has_special_tokens:
return super().get_special_tokens_mask(token_ids_0, token_ids_1, already_has_special_tokens=True)
mask = [1] + [0] * len(token_ids_0) + [1]
if token_ids_1 is not None:
mask += [1] + [0] * len(token_ids_1) + [1]
return mask
def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):
if token_ids_1 is None:
return [0] * (len(token_ids_0) + 2)
return [0] * (len(token_ids_0) + 2) + [0] * (len(token_ids_1) + 2)
@staticmethod
def _extract_cds(sequence, cds):
"""Extract CDS region from a sequence, trimmed to a multiple of 3."""
import numpy as np
if sum(cds) == 0:
return sequence[:len(sequence) - (len(sequence) % 3)]
first = int(np.argmax(cds == 1))
last = int(len(cds) - 1 - np.argmax(np.flip(cds) == 1)) + 2
region = sequence[first:last + 1]
if len(region) % 3 != 0:
region = region[:-(len(region) % 3)]
return region
def batch_encode_with_cds(self, sequences, cds, max_length=None, **kwargs):
"""Encode sequences with CDS extraction (k_mer=3 / mRNA-FM only).
Applies T->U, extracts the CDS region, chunks to max_length nucleotides
(aligned to codon boundaries), and encodes each chunk.
Args:
sequences: List of raw nucleotide strings (T or U).
cds: List of numpy arrays marking CDS codon start positions.
max_length: Nucleotide budget per chunk (defaults to
(model_max_length - 2) * k_mer).
**kwargs: Forwarded to batch_encode_plus (e.g. return_tensors,
padding, add_special_tokens).
Returns:
Tuple of (BatchEncoding, chunk_counts) where chunk_counts[i] is the
number of chunks produced for sequences[i].
"""
if self.k_mer != 3:
raise ValueError("batch_encode_with_cds requires k_mer=3 (mRNA-FM tokenizer)")
budget = max_length if max_length is not None else (self.model_max_length - 2) * self.k_mer
budget = (budget // self.k_mer) * self.k_mer
all_chunks = []
chunk_counts = []
for seq, c in zip(sequences, cds):
seq = seq.replace("T", "U").replace("t", "u")
seq = self._extract_cds(seq, c)
raw_chunks = [seq[i:i + budget] for i in range(0, max(len(seq), 1), budget)]
chunks = []
for chunk in raw_chunks:
if len(chunk) % self.k_mer != 0:
chunk = chunk[:-(len(chunk) % self.k_mer)]
if chunk:
chunks.append(chunk)
if not chunks:
chunks = ["AUG"]
all_chunks.extend(chunks)
chunk_counts.append(len(chunks))
enc = self.batch_encode_plus(all_chunks, **kwargs)
return enc, chunk_counts