| |
| """ |
| Codon tokenizer: 3-mer tokens + 4 special tokens. |
| |
| No frameworks, no inheritance chains. Just: |
| - encode_codon_seq("ATG...") -> [ids...] (appends EOS outside, not here) |
| - decode_codon_seq([ids...]) -> "ATG..." |
| - save_vocabulary(dir) / from_pretrained(dir) for reproducible runs |
| |
| Special IDs are fixed and contiguous from 0: |
| pad=0, unk=1, bos=2, eos=3 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| import os |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Dict, List, Optional, Tuple, Any |
|
|
|
|
| |
| |
| |
|
|
| @dataclass(frozen=True) |
| class SpecialIds: |
| pad: int = 0 |
| unk: int = 1 |
| bos: int = 2 |
| eos: int = 3 |
|
|
| def to_dict(self) -> Dict[str, int]: |
| return {"pad": self.pad, "unk": self.unk, "bos": self.bos, "eos": self.eos} |
|
|
|
|
| |
| |
| |
|
|
| class CodonTokenizer: |
| """Minimal tokenizer for codon (DNA 3-mer) sequences.""" |
|
|
| __slots__ = ( |
| "codons", |
| "_special_token_str", |
| "vocab", |
| "ids_to_tokens", |
| "_special_ids", |
| "_num_special_tokens", |
| "_genetic_code", |
| "_codon2aa_char", |
| "_aa2codons_char", |
| ) |
|
|
| def __init__( |
| self, |
| pad_token: str = "<pad>", |
| unk_token: str = "<unk>", |
| bos_token: str = "<bos>", |
| eos_token: str = "<stop>", |
| **_: Any, |
| ) -> None: |
| |
| bases = ("A", "C", "G", "T") |
| self.codons: List[str] = [a + b + c for a in bases for b in bases for c in bases] |
|
|
| |
| special_tokens = [pad_token, unk_token, bos_token, eos_token] |
| self._special_token_str = {"pad": pad_token, "unk": unk_token, "bos": bos_token, "eos": eos_token} |
|
|
| |
| self.vocab: Dict[str, int] = {} |
| for i, tok in enumerate(special_tokens): |
| self.vocab[tok] = i |
| for codon in self.codons: |
| self.vocab[codon] = len(special_tokens) + (len(self.vocab) - len(special_tokens)) |
|
|
| |
| self.ids_to_tokens: Dict[int, str] = {v: k for k, v in self.vocab.items()} |
|
|
| |
| self._special_ids = SpecialIds( |
| pad=self.vocab[pad_token], |
| unk=self.vocab[unk_token], |
| bos=self.vocab[bos_token], |
| eos=self.vocab[eos_token], |
| ) |
| self._num_special_tokens = len(special_tokens) |
|
|
| |
| self._genetic_code: Dict[str, str] = { |
| "TTT": "F", "TTC": "F", "TTA": "L", "TTG": "L", |
| "TCT": "S", "TCC": "S", "TCA": "S", "TCG": "S", |
| "TAT": "Y", "TAC": "Y", "TAA": "*", "TAG": "*", |
| "TGT": "C", "TGC": "C", "TGA": "*", "TGG": "W", |
| "CTT": "L", "CTC": "L", "CTA": "L", "CTG": "L", |
| "CCT": "P", "CCC": "P", "CCA": "P", "CCG": "P", |
| "CAT": "H", "CAC": "H", "CAA": "Q", "CAG": "Q", |
| "CGT": "R", "CGC": "R", "CGA": "R", "CGG": "R", |
| "ATT": "I", "ATC": "I", "ATA": "I", "ATG": "M", |
| "ACT": "T", "ACC": "T", "ACA": "T", "ACG": "T", |
| "AAT": "N", "AAC": "N", "AAA": "K", "AAG": "K", |
| "AGT": "S", "AGC": "S", "AGA": "R", "AGG": "R", |
| "GTT": "V", "GTC": "V", "GTA": "V", "GTG": "V", |
| "GCT": "A", "GCC": "A", "GCA": "A", "GCG": "A", |
| "GAT": "D", "GAC": "D", "GAA": "E", "GAG": "E", |
| "GGT": "G", "GGC": "G", "GGA": "G", "GGG": "G", |
| } |
|
|
| |
| self._codon2aa_char: Dict[int, str] = {} |
| self._aa2codons_char: Dict[str, List[int]] = {ch: [] for ch in "ACDEFGHIKLMNPQRSTVWY*"} |
| for codon in self.codons: |
| cid = self.vocab[codon] |
| aa = self._genetic_code.get(codon, "X") |
| self._codon2aa_char[cid] = aa |
| if aa in self._aa2codons_char: |
| self._aa2codons_char[aa].append(cid) |
|
|
| |
| ids = list(self._special_ids.to_dict().values()) |
| if sorted(ids) != list(range(self._num_special_tokens)): |
| raise AssertionError("Special token ids must be contiguous starting at 0") |
|
|
| |
| @property |
| def vocab_size(self) -> int: |
| return len(self.vocab) |
|
|
| @property |
| def special_ids(self) -> SpecialIds: |
| return self._special_ids |
|
|
| @property |
| def num_special_tokens(self) -> int: |
| return self._num_special_tokens |
|
|
| @property |
| def pad_token_id(self) -> int: |
| return self._special_ids.pad |
|
|
| @property |
| def unk_token_id(self) -> int: |
| return self._special_ids.unk |
|
|
| @property |
| def bos_token_id(self) -> int: |
| return self._special_ids.bos |
|
|
| @property |
| def eos_token_id(self) -> int: |
| return self._special_ids.eos |
|
|
| |
| def encode_codon_seq(self, seq: str, validate: bool = True) -> List[int]: |
| """ |
| Map DNA (ACGT)^3N to 3-mer ids. We don't append BOS/EOS here. |
| """ |
| s = seq.upper() |
| if validate: |
| if len(s) % 3 != 0: |
| raise ValueError(f"Sequence length {len(s)} not divisible by 3") |
| if not _is_acgt(s): |
| raise ValueError("Sequence contains invalid nucleotides (only ACGT supported)") |
| out: List[int] = [] |
| |
| for i in range(0, len(s), 3): |
| codon = s[i : i + 3] |
| out.append(self.vocab.get(codon, self._special_ids.unk)) |
| return out |
|
|
| def decode_codon_seq(self, token_ids: List[int]) -> str: |
| """ |
| Convert codon ids (>= num_special_tokens) back to DNA string. |
| Special ids are ignored unless they collide (they don't). |
| """ |
| parts: List[str] = [] |
| nst = self._num_special_tokens |
| for tid in token_ids: |
| if tid >= nst: |
| tok = self.ids_to_tokens.get(tid) |
| if tok is not None: |
| parts.append(tok) |
| return "".join(parts) |
|
|
| def decode(self, token_ids: List[int], skip_special_tokens: bool = True, **_: Any) -> str: |
| |
| if skip_special_tokens: |
| token_ids = [t for t in token_ids if t >= self._num_special_tokens] |
| return self.decode_codon_seq(token_ids) |
|
|
| |
| def codon_vocab(self) -> Dict[str, int]: |
| return {c: self.vocab[c] for c in self.codons} |
|
|
| def codon2aa_char_map(self) -> Dict[int, str]: |
| return dict(self._codon2aa_char) |
|
|
| def aa2codons_char_map(self) -> Dict[str, List[int]]: |
| return {k: v[:] for k, v in self._aa2codons_char.items()} |
|
|
| def aa_to_codon_length(self, aa_seq: str) -> int: |
| |
| return len(aa_seq) |
|
|
| |
| def _tokenize(self, text: str) -> List[str]: |
| if len(text) % 3 != 0: |
| raise ValueError(f"Text length {len(text)} not divisible by 3") |
| return [text[i : i + 3] for i in range(0, len(text), 3)] |
|
|
| def _convert_token_to_id(self, token: str) -> int: |
| return self.vocab.get(token, self._special_ids.unk) |
|
|
| def _convert_id_to_token(self, index: int) -> str: |
| return self.ids_to_tokens.get(index, self._special_token_str["unk"]) |
|
|
| def convert_tokens_to_string(self, tokens: List[str]) -> str: |
| return "".join(tokens) |
|
|
| def build_inputs_with_special_tokens(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]: |
| return token_ids_0 |
|
|
| def create_token_type_ids_from_sequences(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]: |
| return [0] * len(token_ids_0) |
|
|
| |
| def get_vocab(self) -> Dict[str, int]: |
| return dict(self.vocab) |
|
|
| def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: |
| """ |
| Save to JSON with both vocab and special token strings so we can |
| reconstruct IDs exactly. Deterministic and stable. |
| """ |
| os.makedirs(save_directory, exist_ok=True) |
| vocab_file = os.path.join( |
| save_directory, |
| (filename_prefix + "-" if filename_prefix else "") + "vocab.json", |
| ) |
| payload = { |
| "vocab": self.vocab, |
| "special_token_str": self._special_token_str, |
| } |
| with open(vocab_file, "w", encoding="utf-8") as f: |
| json.dump(payload, f, ensure_ascii=False, indent=2, sort_keys=True) |
| return (vocab_file,) |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs: Any) -> "CodonTokenizer": |
| """ |
| Load from a directory containing vocab.json produced by save_vocabulary(). |
| We rebuild the SpecialIds from the saved token strings to keep IDs stable. |
| """ |
| vocab_path = Path(pretrained_model_name_or_path) / "vocab.json" |
| tok = cls(**kwargs) |
| if not vocab_path.exists(): |
| |
| return tok |
|
|
| with open(vocab_path, "r", encoding="utf-8") as f: |
| save_data = json.load(f) |
|
|
| if not isinstance(save_data, dict) or "vocab" not in save_data: |
| |
| vocab = save_data |
| special_token_str = tok._special_token_str |
| else: |
| vocab = save_data["vocab"] |
| special_token_str = save_data.get("special_token_str", tok._special_token_str) |
|
|
| |
| tok.vocab = {str(k): int(v) for k, v in vocab.items()} |
| tok.ids_to_tokens = {int(v): str(k) for k, v in tok.vocab.items()} |
|
|
| |
| if isinstance(special_token_str, dict): |
| tok._special_token_str.update({k: v for k, v in special_token_str.items() if k in ("pad", "unk", "bos", "eos")}) |
|
|
| def _id_for(name: str, default_val: int) -> int: |
| sym = tok._special_token_str[name] |
| return int(tok.vocab.get(sym, default_val)) |
|
|
| tok._special_ids = SpecialIds( |
| pad=_id_for("pad", 0), |
| unk=_id_for("unk", 1), |
| bos=_id_for("bos", 2), |
| eos=_id_for("eos", 3), |
| ) |
|
|
| |
| |
| ids = [tok._special_ids.pad, tok._special_ids.unk, tok._special_ids.bos, tok._special_ids.eos] |
| m = max(ids) |
| tok._num_special_tokens = m + 1 if ids == list(range(m + 1)) else 4 |
|
|
| |
| tok._rebuild_helpers() |
| return tok |
|
|
| |
| def _rebuild_helpers(self) -> None: |
| self._codon2aa_char = {} |
| self._aa2codons_char = {ch: [] for ch in "ACDEFGHIKLMNPQRSTVWY*"} |
| for codon in self.codons: |
| cid = self.vocab[codon] |
| aa = self._genetic_code.get(codon, "X") |
| self._codon2aa_char[cid] = aa |
| if aa in self._aa2codons_char: |
| self._aa2codons_char[aa].append(cid) |
|
|
|
|
| |
| |
| |
|
|
| def _is_acgt(s: str) -> bool: |
| |
| for ch in s: |
| if ch not in ("A", "C", "G", "T"): |
| return False |
| return True |
|
|