import numpy as np from transformers import BertTokenizer class MRNABERTTokenizer(BertTokenizer): """BertTokenizer extended with CDS-aware preprocessing for mRNABERT. mRNABERT expects space-separated tokens where UTR regions use single nucleotides and CDS regions use three-letter codons. This tokenizer adds batch_encode_with_cds() to handle that preprocessing automatically. Standard usage (pre-formatted strings) still works as before: tokenizer(["A T C G ATG TTT CCC"], return_tensors="pt") CDS-aware usage (raw sequences + CDS track): tokenizer.batch_encode_with_cds( ["ATCGATGTTTCCC"], cds=[np.array([0,0,0,1,0,0,1,0,0,1,0,0,0])], return_tensors="pt", ) """ _SPECIAL_TOKENS = frozenset({"[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"}) @classmethod def _parse_tokens(cls, sequence): """Split a sequence string into per-position tokens. Bracket-enclosed special tokens ([CLS], [MASK], etc.) are treated as atomic units so the token list length matches the CDS array length. """ tokens = [] i = 0 n = len(sequence) while i < n: if sequence[i] == "[": close = sequence.find("]", i + 1) if close != -1: candidate = sequence[i:close + 1] if candidate in cls._SPECIAL_TOKENS: tokens.append(candidate) i = close + 1 continue tokens.append(sequence[i]) i += 1 return tokens @classmethod def preprocess(cls, sequence, cds=None): """Convert a raw sequence and optional CDS track to a tokenizer string. Args: sequence: Raw nucleotide string, e.g. "ATCGATGTTTCCC". Uses T (not U). cds: Integer array of length len(sequence). Non-zero values mark the start of each codon in the CDS region. If None, every position is treated as UTR (single-character separated). Returns: Space-separated string ready for the tokenizer, e.g. "A T C ATG TTT CCC" for a sequence with CDS starting at pos 3. """ tokens = cls._parse_tokens(sequence) if cds is None or not np.any(cds != 0): return " ".join(tokens) starts = np.where(cds != 0)[0] start = int(starts[0]) end = min(int(starts[-1]) + 3, len(tokens)) parts = [] if start > 0: parts.append(" ".join(tokens[:start])) cds_items = [] for j in range(start, end, 3): codon = tokens[j:j + 3] if len(codon) < 3 or any(t in cls._SPECIAL_TOKENS for t in codon): cds_items.extend(codon) else: cds_items.append("".join(codon)) parts.append(" ".join(cds_items)) if end < len(tokens): parts.append(" ".join(tokens[end:])) return " ".join(parts) @classmethod def chunk_sequence_cds_aware(cls, sequence, cds, chunk_length): """Chunk a sequence while respecting codon boundaries. Args: sequence: Raw nucleotide string. cds: Integer array, one entry per nucleotide. Non-zero marks codon start positions. chunk_length: Maximum number of nucleotides per chunk. Returns: List of (token_list_chunk, cds_chunk) tuples. """ tokens = cls._parse_tokens(sequence) n = len(tokens) codon_starts = set(np.where(cds != 0)[0].tolist()) if not codon_starts: return [ (tokens[i:i + chunk_length], cds[i:i + chunk_length]) for i in range(0, max(n, 1), chunk_length) ] chunks = [] i = 0 while i < n: end = min(i + chunk_length, n) while end > i and any((end - k) in codon_starts for k in (1, 2)): end -= 1 if end == i: end = min(i + chunk_length, n) chunks.append((tokens[i:end], cds[i:end])) i = end return chunks def batch_encode_with_cds(self, sequences, cds, max_length=None, **kwargs): """Encode a batch of raw sequences using CDS-aware preprocessing. Sequences longer than max_length nucleotides are split into CDS-boundary-aligned chunks; each chunk is encoded separately and the caller is responsible for aggregating across chunks. Args: sequences: List of raw nucleotide strings (use T, not U). cds: List of integer numpy arrays, one per sequence. Non-zero values mark codon start positions. max_length: Nucleotide budget per chunk (special tokens excluded). Defaults to model_max_length - 2. **kwargs: Forwarded to batch_encode_plus (e.g. return_tensors, padding, add_special_tokens). Returns: If no sequence exceeds max_length: a standard BatchEncoding. If any sequence is chunked: a list of BatchEncodings, one per (sequence, chunk) pair, together with a list of chunk counts so the caller can re-associate chunks with their source sequences. Use chunk_counts to index into the list: chunk_ptr = 0 for i, n_chunks in enumerate(chunk_counts): seq_encodings = encodings[chunk_ptr:chunk_ptr + n_chunks] chunk_ptr += n_chunks """ budget = (max_length or self.model_max_length) - 2 all_strings = [] chunk_counts = [] for seq, c in zip(sequences, cds): chunks = self.chunk_sequence_cds_aware(seq, c, budget) for token_list, c_chunk in chunks: all_strings.append(self.preprocess("".join(token_list), c_chunk)) chunk_counts.append(len(chunks)) enc = self.batch_encode_plus(all_strings, **kwargs) if all(n == 1 for n in chunk_counts): return enc, chunk_counts return enc, chunk_counts