# src/tokenizer.py """ Codon tokenizer: 3-mer tokens + 4 special tokens. No frameworks, no inheritance chains. Just: - encode_codon_seq("ATG...") -> [ids...] (appends EOS outside, not here) - decode_codon_seq([ids...]) -> "ATG..." - save_vocabulary(dir) / from_pretrained(dir) for reproducible runs Special IDs are fixed and contiguous from 0: pad=0, unk=1, bos=2, eos=3 """ from __future__ import annotations import json import os from dataclasses import dataclass from pathlib import Path from typing import Dict, List, Optional, Tuple, Any # ------------------------------ # Special token ids # ------------------------------ @dataclass(frozen=True) class SpecialIds: pad: int = 0 unk: int = 1 bos: int = 2 eos: int = 3 def to_dict(self) -> Dict[str, int]: return {"pad": self.pad, "unk": self.unk, "bos": self.bos, "eos": self.eos} # ------------------------------ # Tokenizer # ------------------------------ class CodonTokenizer: """Minimal tokenizer for codon (DNA 3-mer) sequences.""" __slots__ = ( "codons", "_special_token_str", "vocab", "ids_to_tokens", "_special_ids", "_num_special_tokens", "_genetic_code", "_codon2aa_char", "_aa2codons_char", ) def __init__( self, pad_token: str = "", unk_token: str = "", bos_token: str = "", eos_token: str = "", # human-readable; id is still 3 **_: Any, # ignore junk kwargs – we don't play framework games ) -> None: # 64 codons bases = ("A", "C", "G", "T") self.codons: List[str] = [a + b + c for a in bases for b in bases for c in bases] # specials come first, contiguous special_tokens = [pad_token, unk_token, bos_token, eos_token] self._special_token_str = {"pad": pad_token, "unk": unk_token, "bos": bos_token, "eos": eos_token} # vocab: specials [0..3], then 64 codons [4..67] self.vocab: Dict[str, int] = {} for i, tok in enumerate(special_tokens): self.vocab[tok] = i for codon in self.codons: self.vocab[codon] = len(special_tokens) + (len(self.vocab) - len(special_tokens)) # reverse map self.ids_to_tokens: Dict[int, str] = {v: k for k, v in self.vocab.items()} # fixed ids self._special_ids = SpecialIds( pad=self.vocab[pad_token], unk=self.vocab[unk_token], bos=self.vocab[bos_token], eos=self.vocab[eos_token], ) self._num_special_tokens = len(special_tokens) # genetic code (char) self._genetic_code: Dict[str, str] = { "TTT": "F", "TTC": "F", "TTA": "L", "TTG": "L", "TCT": "S", "TCC": "S", "TCA": "S", "TCG": "S", "TAT": "Y", "TAC": "Y", "TAA": "*", "TAG": "*", "TGT": "C", "TGC": "C", "TGA": "*", "TGG": "W", "CTT": "L", "CTC": "L", "CTA": "L", "CTG": "L", "CCT": "P", "CCC": "P", "CCA": "P", "CCG": "P", "CAT": "H", "CAC": "H", "CAA": "Q", "CAG": "Q", "CGT": "R", "CGC": "R", "CGA": "R", "CGG": "R", "ATT": "I", "ATC": "I", "ATA": "I", "ATG": "M", "ACT": "T", "ACC": "T", "ACA": "T", "ACG": "T", "AAT": "N", "AAC": "N", "AAA": "K", "AAG": "K", "AGT": "S", "AGC": "S", "AGA": "R", "AGG": "R", "GTT": "V", "GTC": "V", "GTA": "V", "GTG": "V", "GCT": "A", "GCC": "A", "GCA": "A", "GCG": "A", "GAT": "D", "GAC": "D", "GAA": "E", "GAG": "E", "GGT": "G", "GGC": "G", "GGA": "G", "GGG": "G", } # precompute char helpers self._codon2aa_char: Dict[int, str] = {} self._aa2codons_char: Dict[str, List[int]] = {ch: [] for ch in "ACDEFGHIKLMNPQRSTVWY*"} for codon in self.codons: cid = self.vocab[codon] aa = self._genetic_code.get(codon, "X") self._codon2aa_char[cid] = aa if aa in self._aa2codons_char: self._aa2codons_char[aa].append(cid) # sanity: specials are contiguous 0..3 ids = list(self._special_ids.to_dict().values()) if sorted(ids) != list(range(self._num_special_tokens)): raise AssertionError("Special token ids must be contiguous starting at 0") # ---------- properties ---------- @property def vocab_size(self) -> int: return len(self.vocab) @property def special_ids(self) -> SpecialIds: return self._special_ids @property def num_special_tokens(self) -> int: return self._num_special_tokens @property def pad_token_id(self) -> int: return self._special_ids.pad @property def unk_token_id(self) -> int: return self._special_ids.unk @property def bos_token_id(self) -> int: return self._special_ids.bos @property def eos_token_id(self) -> int: return self._special_ids.eos # ---------- core API ---------- def encode_codon_seq(self, seq: str, validate: bool = True) -> List[int]: """ Map DNA (ACGT)^3N to 3-mer ids. We don't append BOS/EOS here. """ s = seq.upper() if validate: if len(s) % 3 != 0: raise ValueError(f"Sequence length {len(s)} not divisible by 3") if not _is_acgt(s): raise ValueError("Sequence contains invalid nucleotides (only ACGT supported)") out: List[int] = [] # Fast Python slice loop – good enough. NumPy won't help for tiny strings. for i in range(0, len(s), 3): codon = s[i : i + 3] out.append(self.vocab.get(codon, self._special_ids.unk)) return out def decode_codon_seq(self, token_ids: List[int]) -> str: """ Convert codon ids (>= num_special_tokens) back to DNA string. Special ids are ignored unless they collide (they don't). """ parts: List[str] = [] nst = self._num_special_tokens for tid in token_ids: if tid >= nst: tok = self.ids_to_tokens.get(tid) if tok is not None: # should always be a codon parts.append(tok) return "".join(parts) def decode(self, token_ids: List[int], skip_special_tokens: bool = True, **_: Any) -> str: # kept for API parity with your old code if skip_special_tokens: token_ids = [t for t in token_ids if t >= self._num_special_tokens] return self.decode_codon_seq(token_ids) # ---------- misc helpers ---------- def codon_vocab(self) -> Dict[str, int]: return {c: self.vocab[c] for c in self.codons} def codon2aa_char_map(self) -> Dict[int, str]: return dict(self._codon2aa_char) def aa2codons_char_map(self) -> Dict[str, List[int]]: return {k: v[:] for k, v in self._aa2codons_char.items()} def aa_to_codon_length(self, aa_seq: str) -> int: # You don't count stop unless it's explicitly there. return len(aa_seq) # HF compatibility stubs (your code calls these in a few places) def _tokenize(self, text: str) -> List[str]: if len(text) % 3 != 0: raise ValueError(f"Text length {len(text)} not divisible by 3") return [text[i : i + 3] for i in range(0, len(text), 3)] def _convert_token_to_id(self, token: str) -> int: return self.vocab.get(token, self._special_ids.unk) def _convert_id_to_token(self, index: int) -> str: return self.ids_to_tokens.get(index, self._special_token_str["unk"]) def convert_tokens_to_string(self, tokens: List[str]) -> str: return "".join(tokens) def build_inputs_with_special_tokens(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]: return token_ids_0 def create_token_type_ids_from_sequences(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]: return [0] * len(token_ids_0) # ---------- persistence ---------- def get_vocab(self) -> Dict[str, int]: return dict(self.vocab) def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: """ Save to JSON with both vocab and special token strings so we can reconstruct IDs exactly. Deterministic and stable. """ os.makedirs(save_directory, exist_ok=True) vocab_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.json", ) payload = { "vocab": self.vocab, "special_token_str": self._special_token_str, } with open(vocab_file, "w", encoding="utf-8") as f: json.dump(payload, f, ensure_ascii=False, indent=2, sort_keys=True) return (vocab_file,) @classmethod def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs: Any) -> "CodonTokenizer": """ Load from a directory containing vocab.json produced by save_vocabulary(). We rebuild the SpecialIds from the saved token strings to keep IDs stable. """ vocab_path = Path(pretrained_model_name_or_path) / "vocab.json" tok = cls(**kwargs) # default structure; we'll overwrite below if not vocab_path.exists(): # If nothing to load, return defaults. It keeps the rest of your code happy. return tok with open(vocab_path, "r", encoding="utf-8") as f: save_data = json.load(f) if not isinstance(save_data, dict) or "vocab" not in save_data: # Old, dumber format: the whole file was the vocab dict vocab = save_data special_token_str = tok._special_token_str else: vocab = save_data["vocab"] special_token_str = save_data.get("special_token_str", tok._special_token_str) # rebuild maps tok.vocab = {str(k): int(v) for k, v in vocab.items()} tok.ids_to_tokens = {int(v): str(k) for k, v in tok.vocab.items()} # reconcile special strings → ids if isinstance(special_token_str, dict): tok._special_token_str.update({k: v for k, v in special_token_str.items() if k in ("pad", "unk", "bos", "eos")}) def _id_for(name: str, default_val: int) -> int: sym = tok._special_token_str[name] return int(tok.vocab.get(sym, default_val)) tok._special_ids = SpecialIds( pad=_id_for("pad", 0), unk=_id_for("unk", 1), bos=_id_for("bos", 2), eos=_id_for("eos", 3), ) # Figure out how many specials to reserve. If the saved mapping had extra junk, # we still preserve a contiguous prefix if present. Otherwise default to 4. ids = [tok._special_ids.pad, tok._special_ids.unk, tok._special_ids.bos, tok._special_ids.eos] m = max(ids) tok._num_special_tokens = m + 1 if ids == list(range(m + 1)) else 4 # Rebuild genetic helpers (cheap) tok._rebuild_helpers() return tok # internal: rebuild helper maps after load def _rebuild_helpers(self) -> None: self._codon2aa_char = {} self._aa2codons_char = {ch: [] for ch in "ACDEFGHIKLMNPQRSTVWY*"} for codon in self.codons: cid = self.vocab[codon] aa = self._genetic_code.get(codon, "X") self._codon2aa_char[cid] = aa if aa in self._aa2codons_char: self._aa2codons_char[aa].append(cid) # ------------------------------ # small helpers # ------------------------------ def _is_acgt(s: str) -> bool: # Faster than regex for short strings. for ch in s: if ch not in ("A", "C", "G", "T"): return False return True