CodonTranslator / src /tokenizer.py
alegendaryfish's picture
Public CodonTranslator model and training code release
2d8da02 verified
raw
history blame
12 kB
# src/tokenizer.py
"""
Codon tokenizer: 3-mer tokens + 4 special tokens.
No frameworks, no inheritance chains. Just:
- encode_codon_seq("ATG...") -> [ids...] (appends EOS outside, not here)
- decode_codon_seq([ids...]) -> "ATG..."
- save_vocabulary(dir) / from_pretrained(dir) for reproducible runs
Special IDs are fixed and contiguous from 0:
pad=0, unk=1, bos=2, eos=3
"""
from __future__ import annotations
import json
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any
# ------------------------------
# Special token ids
# ------------------------------
@dataclass(frozen=True)
class SpecialIds:
pad: int = 0
unk: int = 1
bos: int = 2
eos: int = 3
def to_dict(self) -> Dict[str, int]:
return {"pad": self.pad, "unk": self.unk, "bos": self.bos, "eos": self.eos}
# ------------------------------
# Tokenizer
# ------------------------------
class CodonTokenizer:
"""Minimal tokenizer for codon (DNA 3-mer) sequences."""
__slots__ = (
"codons",
"_special_token_str",
"vocab",
"ids_to_tokens",
"_special_ids",
"_num_special_tokens",
"_genetic_code",
"_codon2aa_char",
"_aa2codons_char",
)
def __init__(
self,
pad_token: str = "<pad>",
unk_token: str = "<unk>",
bos_token: str = "<bos>",
eos_token: str = "<stop>", # human-readable; id is still 3
**_: Any, # ignore junk kwargs – we don't play framework games
) -> None:
# 64 codons
bases = ("A", "C", "G", "T")
self.codons: List[str] = [a + b + c for a in bases for b in bases for c in bases]
# specials come first, contiguous
special_tokens = [pad_token, unk_token, bos_token, eos_token]
self._special_token_str = {"pad": pad_token, "unk": unk_token, "bos": bos_token, "eos": eos_token}
# vocab: specials [0..3], then 64 codons [4..67]
self.vocab: Dict[str, int] = {}
for i, tok in enumerate(special_tokens):
self.vocab[tok] = i
for codon in self.codons:
self.vocab[codon] = len(special_tokens) + (len(self.vocab) - len(special_tokens))
# reverse map
self.ids_to_tokens: Dict[int, str] = {v: k for k, v in self.vocab.items()}
# fixed ids
self._special_ids = SpecialIds(
pad=self.vocab[pad_token],
unk=self.vocab[unk_token],
bos=self.vocab[bos_token],
eos=self.vocab[eos_token],
)
self._num_special_tokens = len(special_tokens)
# genetic code (char)
self._genetic_code: Dict[str, str] = {
"TTT": "F", "TTC": "F", "TTA": "L", "TTG": "L",
"TCT": "S", "TCC": "S", "TCA": "S", "TCG": "S",
"TAT": "Y", "TAC": "Y", "TAA": "*", "TAG": "*",
"TGT": "C", "TGC": "C", "TGA": "*", "TGG": "W",
"CTT": "L", "CTC": "L", "CTA": "L", "CTG": "L",
"CCT": "P", "CCC": "P", "CCA": "P", "CCG": "P",
"CAT": "H", "CAC": "H", "CAA": "Q", "CAG": "Q",
"CGT": "R", "CGC": "R", "CGA": "R", "CGG": "R",
"ATT": "I", "ATC": "I", "ATA": "I", "ATG": "M",
"ACT": "T", "ACC": "T", "ACA": "T", "ACG": "T",
"AAT": "N", "AAC": "N", "AAA": "K", "AAG": "K",
"AGT": "S", "AGC": "S", "AGA": "R", "AGG": "R",
"GTT": "V", "GTC": "V", "GTA": "V", "GTG": "V",
"GCT": "A", "GCC": "A", "GCA": "A", "GCG": "A",
"GAT": "D", "GAC": "D", "GAA": "E", "GAG": "E",
"GGT": "G", "GGC": "G", "GGA": "G", "GGG": "G",
}
# precompute char helpers
self._codon2aa_char: Dict[int, str] = {}
self._aa2codons_char: Dict[str, List[int]] = {ch: [] for ch in "ACDEFGHIKLMNPQRSTVWY*"}
for codon in self.codons:
cid = self.vocab[codon]
aa = self._genetic_code.get(codon, "X")
self._codon2aa_char[cid] = aa
if aa in self._aa2codons_char:
self._aa2codons_char[aa].append(cid)
# sanity: specials are contiguous 0..3
ids = list(self._special_ids.to_dict().values())
if sorted(ids) != list(range(self._num_special_tokens)):
raise AssertionError("Special token ids must be contiguous starting at 0")
# ---------- properties ----------
@property
def vocab_size(self) -> int:
return len(self.vocab)
@property
def special_ids(self) -> SpecialIds:
return self._special_ids
@property
def num_special_tokens(self) -> int:
return self._num_special_tokens
@property
def pad_token_id(self) -> int:
return self._special_ids.pad
@property
def unk_token_id(self) -> int:
return self._special_ids.unk
@property
def bos_token_id(self) -> int:
return self._special_ids.bos
@property
def eos_token_id(self) -> int:
return self._special_ids.eos
# ---------- core API ----------
def encode_codon_seq(self, seq: str, validate: bool = True) -> List[int]:
"""
Map DNA (ACGT)^3N to 3-mer ids. We don't append BOS/EOS here.
"""
s = seq.upper()
if validate:
if len(s) % 3 != 0:
raise ValueError(f"Sequence length {len(s)} not divisible by 3")
if not _is_acgt(s):
raise ValueError("Sequence contains invalid nucleotides (only ACGT supported)")
out: List[int] = []
# Fast Python slice loop – good enough. NumPy won't help for tiny strings.
for i in range(0, len(s), 3):
codon = s[i : i + 3]
out.append(self.vocab.get(codon, self._special_ids.unk))
return out
def decode_codon_seq(self, token_ids: List[int]) -> str:
"""
Convert codon ids (>= num_special_tokens) back to DNA string.
Special ids are ignored unless they collide (they don't).
"""
parts: List[str] = []
nst = self._num_special_tokens
for tid in token_ids:
if tid >= nst:
tok = self.ids_to_tokens.get(tid)
if tok is not None: # should always be a codon
parts.append(tok)
return "".join(parts)
def decode(self, token_ids: List[int], skip_special_tokens: bool = True, **_: Any) -> str:
# kept for API parity with your old code
if skip_special_tokens:
token_ids = [t for t in token_ids if t >= self._num_special_tokens]
return self.decode_codon_seq(token_ids)
# ---------- misc helpers ----------
def codon_vocab(self) -> Dict[str, int]:
return {c: self.vocab[c] for c in self.codons}
def codon2aa_char_map(self) -> Dict[int, str]:
return dict(self._codon2aa_char)
def aa2codons_char_map(self) -> Dict[str, List[int]]:
return {k: v[:] for k, v in self._aa2codons_char.items()}
def aa_to_codon_length(self, aa_seq: str) -> int:
# You don't count stop unless it's explicitly there.
return len(aa_seq)
# HF compatibility stubs (your code calls these in a few places)
def _tokenize(self, text: str) -> List[str]:
if len(text) % 3 != 0:
raise ValueError(f"Text length {len(text)} not divisible by 3")
return [text[i : i + 3] for i in range(0, len(text), 3)]
def _convert_token_to_id(self, token: str) -> int:
return self.vocab.get(token, self._special_ids.unk)
def _convert_id_to_token(self, index: int) -> str:
return self.ids_to_tokens.get(index, self._special_token_str["unk"])
def convert_tokens_to_string(self, tokens: List[str]) -> str:
return "".join(tokens)
def build_inputs_with_special_tokens(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]:
return token_ids_0
def create_token_type_ids_from_sequences(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]:
return [0] * len(token_ids_0)
# ---------- persistence ----------
def get_vocab(self) -> Dict[str, int]:
return dict(self.vocab)
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
"""
Save to JSON with both vocab and special token strings so we can
reconstruct IDs exactly. Deterministic and stable.
"""
os.makedirs(save_directory, exist_ok=True)
vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") + "vocab.json",
)
payload = {
"vocab": self.vocab,
"special_token_str": self._special_token_str,
}
with open(vocab_file, "w", encoding="utf-8") as f:
json.dump(payload, f, ensure_ascii=False, indent=2, sort_keys=True)
return (vocab_file,)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs: Any) -> "CodonTokenizer":
"""
Load from a directory containing vocab.json produced by save_vocabulary().
We rebuild the SpecialIds from the saved token strings to keep IDs stable.
"""
vocab_path = Path(pretrained_model_name_or_path) / "vocab.json"
tok = cls(**kwargs) # default structure; we'll overwrite below
if not vocab_path.exists():
# If nothing to load, return defaults. It keeps the rest of your code happy.
return tok
with open(vocab_path, "r", encoding="utf-8") as f:
save_data = json.load(f)
if not isinstance(save_data, dict) or "vocab" not in save_data:
# Old, dumber format: the whole file was the vocab dict
vocab = save_data
special_token_str = tok._special_token_str
else:
vocab = save_data["vocab"]
special_token_str = save_data.get("special_token_str", tok._special_token_str)
# rebuild maps
tok.vocab = {str(k): int(v) for k, v in vocab.items()}
tok.ids_to_tokens = {int(v): str(k) for k, v in tok.vocab.items()}
# reconcile special strings → ids
if isinstance(special_token_str, dict):
tok._special_token_str.update({k: v for k, v in special_token_str.items() if k in ("pad", "unk", "bos", "eos")})
def _id_for(name: str, default_val: int) -> int:
sym = tok._special_token_str[name]
return int(tok.vocab.get(sym, default_val))
tok._special_ids = SpecialIds(
pad=_id_for("pad", 0),
unk=_id_for("unk", 1),
bos=_id_for("bos", 2),
eos=_id_for("eos", 3),
)
# Figure out how many specials to reserve. If the saved mapping had extra junk,
# we still preserve a contiguous prefix if present. Otherwise default to 4.
ids = [tok._special_ids.pad, tok._special_ids.unk, tok._special_ids.bos, tok._special_ids.eos]
m = max(ids)
tok._num_special_tokens = m + 1 if ids == list(range(m + 1)) else 4
# Rebuild genetic helpers (cheap)
tok._rebuild_helpers()
return tok
# internal: rebuild helper maps after load
def _rebuild_helpers(self) -> None:
self._codon2aa_char = {}
self._aa2codons_char = {ch: [] for ch in "ACDEFGHIKLMNPQRSTVWY*"}
for codon in self.codons:
cid = self.vocab[codon]
aa = self._genetic_code.get(codon, "X")
self._codon2aa_char[cid] = aa
if aa in self._aa2codons_char:
self._aa2codons_char[aa].append(cid)
# ------------------------------
# small helpers
# ------------------------------
def _is_acgt(s: str) -> bool:
# Faster than regex for short strings.
for ch in s:
if ch not in ("A", "C", "G", "T"):
return False
return True