Instructions to use Taykhoom/mRNABERT with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Taykhoom/mRNABERT with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("fill-mask", model="Taykhoom/mRNABERT", trust_remote_code=True)# Load model directly from transformers import AutoTokenizer, AutoModelForMaskedLM tokenizer = AutoTokenizer.from_pretrained("Taykhoom/mRNABERT", trust_remote_code=True) model = AutoModelForMaskedLM.from_pretrained("Taykhoom/mRNABERT", trust_remote_code=True) - Notebooks
- Google Colab
- Kaggle
| import numpy as np | |
| from transformers import BertTokenizer | |
| class MRNABERTTokenizer(BertTokenizer): | |
| """BertTokenizer extended with CDS-aware preprocessing for mRNABERT. | |
| mRNABERT expects space-separated tokens where UTR regions use single | |
| nucleotides and CDS regions use three-letter codons. This tokenizer adds | |
| batch_encode_with_cds() to handle that preprocessing automatically. | |
| Standard usage (pre-formatted strings) still works as before: | |
| tokenizer(["A T C G ATG TTT CCC"], return_tensors="pt") | |
| CDS-aware usage (raw sequences + CDS track): | |
| tokenizer.batch_encode_with_cds( | |
| ["ATCGATGTTTCCC"], | |
| cds=[np.array([0,0,0,1,0,0,1,0,0,1,0,0,0])], | |
| return_tensors="pt", | |
| ) | |
| """ | |
| _SPECIAL_TOKENS = frozenset({"[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"}) | |
| def _parse_tokens(cls, sequence): | |
| """Split a sequence string into per-position tokens. | |
| Bracket-enclosed special tokens ([CLS], [MASK], etc.) are treated as | |
| atomic units so the token list length matches the CDS array length. | |
| """ | |
| tokens = [] | |
| i = 0 | |
| n = len(sequence) | |
| while i < n: | |
| if sequence[i] == "[": | |
| close = sequence.find("]", i + 1) | |
| if close != -1: | |
| candidate = sequence[i:close + 1] | |
| if candidate in cls._SPECIAL_TOKENS: | |
| tokens.append(candidate) | |
| i = close + 1 | |
| continue | |
| tokens.append(sequence[i]) | |
| i += 1 | |
| return tokens | |
| def preprocess(cls, sequence, cds=None): | |
| """Convert a raw sequence and optional CDS track to a tokenizer string. | |
| Args: | |
| sequence: Raw nucleotide string, e.g. "ATCGATGTTTCCC". | |
| Uses T (not U). | |
| cds: Integer array of length len(sequence). Non-zero values mark | |
| the start of each codon in the CDS region. If None, every | |
| position is treated as UTR (single-character separated). | |
| Returns: | |
| Space-separated string ready for the tokenizer, e.g. | |
| "A T C ATG TTT CCC" for a sequence with CDS starting at pos 3. | |
| """ | |
| tokens = cls._parse_tokens(sequence) | |
| if cds is None or not np.any(cds != 0): | |
| return " ".join(tokens) | |
| starts = np.where(cds != 0)[0] | |
| start = int(starts[0]) | |
| end = min(int(starts[-1]) + 3, len(tokens)) | |
| parts = [] | |
| if start > 0: | |
| parts.append(" ".join(tokens[:start])) | |
| cds_items = [] | |
| for j in range(start, end, 3): | |
| codon = tokens[j:j + 3] | |
| if len(codon) < 3 or any(t in cls._SPECIAL_TOKENS for t in codon): | |
| cds_items.extend(codon) | |
| else: | |
| cds_items.append("".join(codon)) | |
| parts.append(" ".join(cds_items)) | |
| if end < len(tokens): | |
| parts.append(" ".join(tokens[end:])) | |
| return " ".join(parts) | |
| def chunk_sequence_cds_aware(cls, sequence, cds, chunk_length): | |
| """Chunk a sequence while respecting codon boundaries. | |
| Args: | |
| sequence: Raw nucleotide string. | |
| cds: Integer array, one entry per nucleotide. Non-zero marks | |
| codon start positions. | |
| chunk_length: Maximum number of nucleotides per chunk. | |
| Returns: | |
| List of (token_list_chunk, cds_chunk) tuples. | |
| """ | |
| tokens = cls._parse_tokens(sequence) | |
| n = len(tokens) | |
| codon_starts = set(np.where(cds != 0)[0].tolist()) | |
| if not codon_starts: | |
| return [ | |
| (tokens[i:i + chunk_length], cds[i:i + chunk_length]) | |
| for i in range(0, max(n, 1), chunk_length) | |
| ] | |
| chunks = [] | |
| i = 0 | |
| while i < n: | |
| end = min(i + chunk_length, n) | |
| while end > i and any((end - k) in codon_starts for k in (1, 2)): | |
| end -= 1 | |
| if end == i: | |
| end = min(i + chunk_length, n) | |
| chunks.append((tokens[i:end], cds[i:end])) | |
| i = end | |
| return chunks | |
| def batch_encode_with_cds(self, sequences, cds, max_length=None, **kwargs): | |
| """Encode a batch of raw sequences using CDS-aware preprocessing. | |
| Sequences longer than max_length nucleotides are split into | |
| CDS-boundary-aligned chunks; each chunk is encoded separately and the | |
| caller is responsible for aggregating across chunks. | |
| Args: | |
| sequences: List of raw nucleotide strings (use T, not U). | |
| cds: List of integer numpy arrays, one per sequence. | |
| Non-zero values mark codon start positions. | |
| max_length: Nucleotide budget per chunk (special tokens excluded). | |
| Defaults to model_max_length - 2. | |
| **kwargs: Forwarded to batch_encode_plus (e.g. return_tensors, | |
| padding, add_special_tokens). | |
| Returns: | |
| If no sequence exceeds max_length: a standard BatchEncoding. | |
| If any sequence is chunked: a list of BatchEncodings, one per | |
| (sequence, chunk) pair, together with a list of chunk counts so | |
| the caller can re-associate chunks with their source sequences. | |
| Use chunk_counts to index into the list: | |
| chunk_ptr = 0 | |
| for i, n_chunks in enumerate(chunk_counts): | |
| seq_encodings = encodings[chunk_ptr:chunk_ptr + n_chunks] | |
| chunk_ptr += n_chunks | |
| """ | |
| budget = (max_length or self.model_max_length) - 2 | |
| all_strings = [] | |
| chunk_counts = [] | |
| for seq, c in zip(sequences, cds): | |
| chunks = self.chunk_sequence_cds_aware(seq, c, budget) | |
| for token_list, c_chunk in chunks: | |
| all_strings.append(self.preprocess("".join(token_list), c_chunk)) | |
| chunk_counts.append(len(chunks)) | |
| enc = self.batch_encode_plus(all_strings, **kwargs) | |
| if all(n == 1 for n in chunk_counts): | |
| return enc, chunk_counts | |
| return enc, chunk_counts | |