# src/sampler.py """ Sampling utilities for CodonTranslator. Conditioning invariants: - Species context: fixed-size [B, Ds] via species_emb or variable-length [B, Ls, Ds] via species_tok_emb - Protein context: raw sequences; the model's Frozen ESM handles tokenization """ from __future__ import annotations from typing import List, Optional, Dict, Union, Tuple from pathlib import Path import logging import json import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from safetensors.torch import load_file from .models import CodonTranslatorModel from .tokenizer import CodonTokenizer logger = logging.getLogger(__name__) # ---------------------------- # Logit filtering # ---------------------------- def _ensure_2d_logits(logits: torch.Tensor) -> torch.Tensor: return logits if logits.dim() == 2 else logits.unsqueeze(0) def _top_k_filtering(logits: torch.Tensor, k: int) -> torch.Tensor: """Top-k filtering; logits is [B,V] or [V].""" x = _ensure_2d_logits(logits) k = max(1, min(int(k), x.size(-1))) values, _ = torch.topk(x, k, dim=-1) min_values = values[:, -1].unsqueeze(-1) x = torch.where(x < min_values, torch.full_like(x, float('-inf')), x) return x if logits.dim() == 2 else x.squeeze(0) def _top_p_filtering(logits: torch.Tensor, p: float) -> torch.Tensor: """Top-p (nucleus) filtering; logits is [B,V] or [V].""" if p >= 1.0: return logits if p <= 0.0: # You asked for nothing; enjoy the abyss. return torch.full_like(logits, float('-inf')) x = _ensure_2d_logits(logits) sorted_logits, sorted_indices = torch.sort(x, descending=True, dim=-1) probs = F.softmax(sorted_logits, dim=-1) cumprobs = torch.cumsum(probs, dim=-1) to_remove = cumprobs > p to_remove[:, 1:] = to_remove[:, :-1].clone() to_remove[:, 0] = False mask = torch.zeros_like(x, dtype=torch.bool).scatter(-1, sorted_indices, to_remove) x = torch.where(mask, torch.full_like(x, float('-inf')), x) return x if logits.dim() == 2 else x.squeeze(0) # ---------------------------- # Sampler # ---------------------------- class CodonSampler: """ GPT sampler with conditional generation. Requires in model_dir: - vocab.json - model.safetensors (preferred) or pytorch_model.bin (legacy) - trainer_config.json or config.json """ def __init__( self, model_path: str, device: str = "cuda", species_store=None, # SpeciesEmbeddingStore compile_model: bool = False, taxonomy_db_path: Optional[str] = None, qwen_max_length: int = 512, qwen_batch_size: int = 16, **_: dict, ): self.device = torch.device(device) self.model_dir = Path(model_path) # Required files (allow fallback to parent dir for vocab.json) vocab_path = self.model_dir / "vocab.json" if not vocab_path.exists(): parent_vocab = self.model_dir.parent / "vocab.json" if parent_vocab.exists(): vocab_path = parent_vocab else: raise FileNotFoundError(f"Missing {self.model_dir / 'vocab.json'}") trainer_cfg = self.model_dir / "trainer_config.json" cfg_path = trainer_cfg if trainer_cfg.exists() else (self.model_dir / "config.json") if not cfg_path.exists(): raise FileNotFoundError(f"Missing trainer_config.json or config.json in {self.model_dir}") # Load config with open(cfg_path, "r") as f: self.config = json.load(f) # Tokenizer # If vocab was loaded from parent dir, pass that path; else model_dir vocab_dir = vocab_path.parent self.tokenizer = CodonTokenizer.from_pretrained(str(vocab_dir)) self.V = int(self.tokenizer.vocab_size) self._eos_id = int(self.tokenizer.eos_token_id) self._pad_id = int(self.tokenizer.pad_token_id) self._num_special = int(self.tokenizer.num_special_tokens) # Species store (optional if you pass species_emb* directly at sample()) self.species_store = species_store self.species_vocab = (self.species_store.vocab if self.species_store is not None else {}) self.taxonomy_db_path = taxonomy_db_path self.qwen_opts = { "max_length": int(qwen_max_length), "batch_size": int(qwen_batch_size), } # Lazy-inited Qwen objects self._qwen_tokenizer = None self._qwen_model = None # Model state = self._load_state_dict() arch = self._infer_arch_from_state_dict(state) self.model = CodonTranslatorModel( vocab_size=self.V, hidden_size=int(arch["hidden_size"]), num_layers=int(arch["num_layers"]), num_heads=int(arch["num_heads"]), mlp_ratio=float(arch["mlp_ratio"]), max_position_embeddings=int(arch["max_position_embeddings"]), dropout=float(self.config.get("dropout", 0.1)), num_special_tokens=self._num_special, special_ids=self.tokenizer.special_ids, esm_model_name=str(arch["esm_model_name"]) if bool(arch["prepend_protein"]) else None, esm_device=str(arch["esm_device"]), esm_dtype=str(arch["esm_dtype"]), max_protein_prefix=int(arch["max_protein_prefix"]) if bool(arch["prepend_protein"]) else 0, max_species_prefix=int(arch["max_species_prefix"]) if bool(arch["prepend_species"]) else 0, prepend_species=bool(arch["prepend_species"]), prepend_protein=bool(arch["prepend_protein"]), species_embedding_dim=int(self.config.get("species_embedding_dim", 1024)), attn_impl=str(arch.get("attn_impl", "gqa")), num_kv_groups=int(arch.get("num_kv_groups", 0)), ) missing, unexpected = self.model.load_state_dict(state, strict=False) if len(unexpected) > 0: logger.warning(f"Unexpected keys in state dict: {unexpected[:10]}{'...' if len(unexpected) > 10 else ''}") if len(missing) > 0: logger.warning(f"Missing keys in state dict: {missing[:10]}{'...' if len(missing) > 10 else ''}") if compile_model: # If this errors on your PyTorch build, that's on you. No try/except. self.model = torch.compile(self.model) # type: ignore self.model.to(self.device).eval() logger.info(f"Loaded GPT model from {self.model_dir}") try: hs = int(getattr(self.model, "hidden_size", -1)) hh = int(getattr(self.model, "num_heads", -1)) nl = int(getattr(self.model, "num_layers", -1)) logger.info(f"Reconstructed arch: hidden={hs} heads={hh} layers={nl}") except Exception: pass # Static masks self._allowed_fixed = torch.ones(self.V, dtype=torch.bool, device=self.device) self._allowed_fixed[:self._num_special] = False # no specials in fixed mode self._allowed_variable = torch.ones(self.V, dtype=torch.bool, device=self.device) self._allowed_variable[:self._num_special] = False self._allowed_variable[self._eos_id] = True # EOS allowed in variable mode # ---------------------------- # Loading / arch inference # ---------------------------- def _load_state_dict(self) -> Dict[str, torch.Tensor]: st_p = self.model_dir / "model.safetensors" pt_p = self.model_dir / "pytorch_model.bin" if st_p.exists(): return load_file(st_p) if pt_p.exists(): return torch.load(pt_p, map_location="cpu") raise FileNotFoundError(f"No model.safetensors or pytorch_model.bin in {self.model_dir}") def _infer_arch_from_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, Union[int, float, bool, str]]: arch: Dict[str, Union[int, float, bool, str]] = {} # hidden size if "lm_head.weight" in state_dict: arch["hidden_size"] = int(state_dict["lm_head.weight"].shape[1]) else: for k, v in state_dict.items(): if k.endswith("ln_f.weight"): arch["hidden_size"] = int(v.shape[0]) break # Prefer config when present to avoid guessing errors cfg = self.config or {} if "hidden_size" in cfg: arch["hidden_size"] = int(cfg["hidden_size"]) # type: ignore[index] if "hidden_size" not in arch: arch["hidden_size"] = int(cfg.get("hidden_size", 960)) H = int(arch["hidden_size"]) # layers max_block = -1 for k in state_dict.keys(): if k.startswith("blocks."): idx = int(k.split(".")[1]) if idx > max_block: max_block = idx arch["num_layers"] = (max_block + 1) if max_block >= 0 else int(cfg.get("num_hidden_layers", 12)) if "num_hidden_layers" in cfg: arch["num_layers"] = int(cfg["num_hidden_layers"]) # type: ignore[index] # mlp ratio from w1 w1_key = "blocks.0.ffn.w1.weight" if "blocks.0.ffn.w1.weight" in state_dict else None if w1_key is None: for i in range(1, 3): k = f"blocks.{i}.ffn.w1.weight" if k in state_dict: w1_key = k break if w1_key is not None and H > 0: arch["mlp_ratio"] = float(int(state_dict[w1_key].shape[0]) / H) else: arch["mlp_ratio"] = float(cfg.get("mlp_ratio", 4.0)) # heads – pick a divisor of H cfg_heads = cfg.get("num_attention_heads") if isinstance(cfg_heads, int) and cfg_heads > 0 and H % cfg_heads == 0: arch["num_heads"] = int(cfg_heads) else: for h in (16, 15, 12, 10, 8, 6, 5, 4, 3, 2, 1): if H % h == 0: arch["num_heads"] = h break # conditioning flags from presence of submodules arch["prepend_species"] = bool(cfg.get("prepend_species", any(k.startswith("species_ln.") for k in state_dict.keys()))) has_esm = any(k.startswith("esm_ln.") for k in state_dict.keys()) or any(k.startswith("esm.") for k in state_dict.keys()) arch["prepend_protein"] = bool(cfg.get("prepend_protein", bool(has_esm))) arch["esm_model_name"] = str(cfg.get("esm_model_name", "esmc_300m")) arch["esm_device"] = str(cfg.get("esm_device", "cuda")) arch["esm_dtype"] = str(cfg.get("esm_dtype", "bf16")).lower() arch["max_protein_prefix"] = int(cfg.get("max_protein_prefix", 0)) arch["max_species_prefix"] = int(cfg.get("max_species_prefix", 0)) if "max_length" in cfg: arch["max_position_embeddings"] = int(cfg.get("max_length", 1024)) else: arch["max_position_embeddings"] = int(cfg.get("max_position_embeddings", 1024)) # Attention impl and num_kv_groups (from config or infer from weights) attn_impl = str(cfg.get("attn_impl", "")) num_kv_groups = int(cfg.get("num_kv_groups", 0)) if not attn_impl: wk_key = next((k for k in state_dict.keys() if k.endswith("attn.Wk.weight")), None) if wk_key is not None: attn_impl = "gqa" out_ch, _ = state_dict[wk_key].shape num_heads = int(arch.get("num_heads", 1)) head_dim = int(arch["hidden_size"]) // max(1, num_heads) if head_dim > 0: num_kv_groups = max(1, out_ch // head_dim) else: attn_impl = "mha" num_kv_groups = 0 arch["attn_impl"] = attn_impl arch["num_kv_groups"] = num_kv_groups return arch # type: ignore[return-value] # ---------------------------- # Public API # ---------------------------- @torch.no_grad() def sample( self, num_sequences: int = 1, sequence_length: int = 100, # target number of codons (fixed mode); max iterations (variable) species: Optional[Union[str, List[str]]] = None, protein_sequences: Optional[Union[str, List[str]]] = None, control_mode: str = "fixed", # "fixed" or "variable" target_protein_length: Optional[int] = None, # deprecated; alias to sequence_length temperature: float = 1.0, top_k: Optional[int] = None, top_p: Optional[float] = None, seed: Optional[int] = None, return_intermediate: bool = False, progress_bar: bool = False, species_emb: Optional[torch.Tensor] = None, # [B, Ds] species_tok_emb: Optional[torch.Tensor] = None, # [B, Ls, Ds] enforce_translation: bool = False, codon_enforcement_weight: float = 10.0, # unused with hard mask; kept for API compatibility ) -> Dict[str, Union[List[str], torch.Tensor, List[bool]]]: if seed is not None: torch.manual_seed(int(seed)) np.random.seed(int(seed)) if control_mode not in ("fixed", "variable"): raise ValueError(f"control_mode must be 'fixed' or 'variable', got {control_mode}") B = int(num_sequences) T_codons = int(sequence_length if target_protein_length is None else target_protein_length) # Prepare conditioning cond: Dict[str, Union[str, List[str], torch.Tensor]] = {"control_mode": control_mode} # Species (priority: provided tensors → names via store) if species_tok_emb is not None: if species_tok_emb.ndim != 3 or species_tok_emb.size(0) != B: raise ValueError("species_tok_emb must be [B, Ls, Ds]") st = species_tok_emb.to(self.device) cond["species_tok_emb_src"] = st cond["species_tok_emb_tgt"] = st elif species_emb is not None: if species_emb.ndim != 2 or species_emb.size(0) != B: raise ValueError("species_emb must be [B, Ds]") se = species_emb.to(self.device) cond["species_emb_src"] = se cond["species_emb_tgt"] = se elif species is not None: names = [species] * B if isinstance(species, str) else species if len(names) != B: raise ValueError("Length of species list must match num_sequences") # If we have a store (variable-length), use it for known species and compute Qwen embeddings for unknowns. if self.species_store is not None: ids = [self.species_store.vocab.get(n, -1) for n in names] known_mask = [i for i, sid in enumerate(ids) if sid >= 0] unk_mask = [i for i, sid in enumerate(ids) if sid < 0] # Only variable-length embeddings are supported. If the store is not sequence-based, compute via Qwen for all. use_sequence = bool(getattr(self.species_store, "is_legacy", False)) if not use_sequence: # Fall back to Qwen for everything q_tok, q_len = self._qwen_embed_names(names, pooling="sequence") cond["species_tok_emb_src"] = q_tok.to(self.device) cond["species_tok_emb_tgt"] = q_tok.to(self.device) else: # list of per-sample [L,D] tensors to be padded later seq_list: List[torch.Tensor] = [None] * B # type: ignore[list-item] D = int(getattr(self.species_store, "_ds", 1024)) # Known via store if known_mask: sub_ids = [ids[i] for i in known_mask] result = self.species_store.batch_get(sub_ids) assert isinstance(result, tuple) sp_tok, _ = result for j, i in enumerate(known_mask): row = sp_tok[j] nonzero = (row.abs().sum(dim=-1) > 0) L = int(nonzero.sum().item()) if nonzero.any() else int(row.size(0)) seq_list[i] = row[:L].to(self.device) # Unknown via Qwen if unk_mask: unk_names = [names[i] for i in unk_mask] q_tok, q_len = self._qwen_embed_names(unk_names, pooling="sequence") for j, i in enumerate(unk_mask): L = int(q_len[j].item()) seq_list[i] = q_tok[j, :L, :].to(self.device) # Pad to [B,Lmax,D] Lmax = max((t.size(0) for t in seq_list if t is not None), default=0) if Lmax == 0: raise RuntimeError("No species embeddings could be constructed.") padded = torch.zeros(B, Lmax, D, device=self.device, dtype=seq_list[0].dtype) for i, t in enumerate(seq_list): if t is None: continue L = t.size(0) padded[i, :L, :] = t cond["species_tok_emb_src"] = padded cond["species_tok_emb_tgt"] = padded else: # No store: compute everything via Qwen (sequence pooling only) emb, lengths = self._qwen_embed_names(names, pooling="sequence") st = emb.to(self.device, non_blocking=True) cond["species_tok_emb_src"] = st cond["species_tok_emb_tgt"] = st # Protein sequences (raw AA strings; the model handles ESM-C) if protein_sequences is not None: if isinstance(protein_sequences, list): if len(protein_sequences) != B: raise ValueError("Length of protein_sequences must match num_sequences") cond["protein_seqs"] = protein_sequences else: cond["protein_seqs"] = [protein_sequences] * B # Start with empty codon context; we'll prefill to build KV cache and get first-step logits input_ids = torch.empty((B, 0), dtype=torch.long, device=self.device) # Capacity probe and fallback: if prefix consumes all budget, cap species/protein prefix temporarily (prefill path) pref = None try: out0 = self.model(codon_ids=input_ids, cond=cond, return_dict=True, use_cache=True) pref = out0.get("prefix_len") if isinstance(out0, dict) else None if pref is not None: max_pos = int(getattr(self.model, "max_position_embeddings", 1024)) remaining0 = max_pos - (pref + 1) need_cap = (remaining0 <= 0).any() else: need_cap = False if need_cap: prev_sp = int(getattr(self.model, "max_species_prefix", 0)) prev_pp = int(getattr(self.model, "max_protein_prefix", 0)) if prev_sp == 0 or prev_sp > 256: setattr(self.model, "max_species_prefix", 256) if prev_pp == 0 or prev_pp > 256: setattr(self.model, "max_protein_prefix", 256) out0b = self.model(codon_ids=input_ids, cond=cond, return_dict=True, use_cache=True) pref = out0b.get("prefix_len") if isinstance(out0b, dict) else None if pref is not None: remaining0b = max_pos - (pref + 1) if (remaining0b <= 0).all(): setattr(self.model, "max_species_prefix", 128) setattr(self.model, "max_protein_prefix", 128) out0b = self.model(codon_ids=input_ids, cond=cond, return_dict=True, use_cache=True) pref = out0b.get("prefix_len") if isinstance(out0b, dict) else pref # Use the prefill output out_prefill = out0 if pref is None else out0 except Exception: # Fallback without cache out_prefill = self.model(codon_ids=input_ids, cond=cond, return_dict=True, use_cache=True) pref = out_prefill.get("prefix_len") if isinstance(out_prefill, dict) else None allowed = self._allowed_variable if control_mode == "variable" else self._allowed_fixed finished = torch.zeros(B, dtype=torch.bool, device=self.device) # EOS reached (variable) OR capacity exhausted capacity_truncated = torch.zeros(B, dtype=torch.bool, device=self.device) intermediate = [] if return_intermediate else None aa2codons = self.tokenizer.aa2codons_char_map() # If we probed capacity, optionally clamp target codons by available capacity at step 0 try: if pref is not None: max_pos = int(getattr(self.model, "max_position_embeddings", 1024)) remaining = (max_pos - (pref + 1)).clamp(min=0) T_codons = int(min(T_codons, int(remaining.max().item()))) except Exception: pass # KV cache and initial logits from prefill kv = out_prefill.get("present_kv") if isinstance(out_prefill, dict) else None logits = out_prefill.get("next_logits") if isinstance(out_prefill, dict) else None if kv is None or logits is None: # Safety: compute once if not provided out_prefill = self.model(codon_ids=input_ids, cond=cond, return_dict=True, use_cache=True) kv = out_prefill.get("present_kv") logits = out_prefill.get("next_logits") assert kv is not None and logits is not None prefix_len = pref if pref is not None else torch.zeros(B, dtype=torch.long, device=self.device) prefill_len = (prefix_len + 1) # prefix + start rng = range(T_codons) if progress_bar: from tqdm import tqdm rng = tqdm(rng, desc="GPT sampling", total=T_codons) for step in rng: # Enforce global capacity per sample using prefix_len and current generated length max_pos = int(getattr(self.model, "max_position_embeddings", 1024)) remaining_now = (max_pos - prefill_len - input_ids.size(1)).clamp(max=10**9) cant_extend = remaining_now <= 0 newly_blocked = (~finished) & cant_extend capacity_truncated = capacity_truncated | newly_blocked finished = finished | cant_extend # Base mask: disallow specials in fixed, allow EOS in variable. logits = logits.masked_fill(~allowed, float("-inf")) # If a sample is finished (EOS or capacity), force PAD to keep shapes stable. # Decoding will drop PAD anyway. if finished.any(): logits[finished] = float("-inf") logits[finished, self._pad_id] = 0.0 # Optional: enforce codon ↔ AA mapping at this step (hard mask) if enforce_translation and ("protein_seqs" in cond): aas_now: List[Optional[str]] = [] prot_list = cond["protein_seqs"] # type: ignore[index] assert isinstance(prot_list, list) for i in range(B): seq = prot_list[i] aas_now.append(seq[step] if step < len(seq) else None) mask = torch.zeros_like(logits, dtype=torch.bool) for i, a in enumerate(aas_now): if a is None: mask[i, self._num_special:self.V] = True else: valid = aa2codons.get(a, []) if len(valid) == 0: mask[i, self._num_special:self.V] = True else: mask[i, valid] = True logits = logits.masked_fill(~mask, float("-inf")) # Temperature + filtering if temperature != 1.0: logits = logits / float(temperature) if top_k is not None: logits = _top_k_filtering(logits, int(top_k)) if top_p is not None: logits = _top_p_filtering(logits, float(top_p)) probs = F.softmax(logits, dim=-1) next_tok = torch.multinomial(probs, num_samples=1) # [B,1] if control_mode == "variable": # Stop sequences at EOS eos_mask = (next_tok.squeeze(-1) == self._eos_id) finished = finished | eos_mask input_ids = torch.cat([input_ids, next_tok], dim=1) if return_intermediate: intermediate.append(input_ids.clone()) # If all sequences are finished, we're done. if finished.all(): break # Incremental decode: compute logits for next step and update KV cache pos_offset = int(prefill_len.max().item()) + input_ids.size(1) - 1 # use max offset for shared RoPE cache out_inc = self.model( codon_ids=next_tok, cond=None, return_dict=True, use_cache=True, past_kv=kv, position_offset=pos_offset, ) kv = out_inc.get("present_kv") logits = out_inc.get("next_logits") assert kv is not None and logits is not None # Build final DNA strings, dropping specials and any PADs we added output_token_rows: List[List[int]] = [] for row in input_ids.tolist(): toks: List[int] = [] for t in row: if t == self._pad_id: continue if t == self._eos_id: break # variable mode terminator if t >= self._num_special and t < self.V: toks.append(int(t)) if control_mode == "fixed": # In fixed mode we *intended* T_codons; if capacity cut us short, it's fine. toks = toks[:T_codons] output_token_rows.append(toks) sequences = [self.tokenizer.decode_codon_seq(row) for row in output_token_rows] # Pad variable-length rows for input_ids to avoid tensor construction errors when # some samples are capacity-truncated in fixed mode. max_len = max((len(r) for r in output_token_rows), default=0) if max_len > 0: ids_padded = torch.full( (len(output_token_rows), max_len), self._pad_id, device=self.device, dtype=torch.long, ) for i, row in enumerate(output_token_rows): if len(row) > 0: ids_padded[i, : len(row)] = torch.tensor(row, device=self.device, dtype=torch.long) else: ids_padded = torch.empty((len(output_token_rows), 0), device=self.device, dtype=torch.long) result: Dict[str, Union[List[str], torch.Tensor, List[bool]]] = { "sequences": sequences, "input_ids": ids_padded, "capacity_truncated": capacity_truncated.detach().bool().tolist(), } if return_intermediate: result["intermediate_states"] = intermediate # list[Tensor], length = steps actually taken return result # ---------------------------- # Qwen embedding (inline; no separate module) # ---------------------------- def _ensure_qwen_loaded(self): if self._qwen_tokenizer is not None and self._qwen_model is not None: return from transformers import AutoTokenizer, AutoModel self._qwen_tokenizer = AutoTokenizer.from_pretrained( "Qwen/Qwen3-Embedding-0.6B", trust_remote_code=True, padding_side="left" ) dtype = torch.float16 if self.device.type == "cuda" else torch.float32 self._qwen_model = AutoModel.from_pretrained( "Qwen/Qwen3-Embedding-0.6B", torch_dtype=dtype, trust_remote_code=True ).to(self.device).eval() @staticmethod def _last_token_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) if left_padding: return last_hidden_states[:, -1] else: sequence_lengths = attention_mask.sum(dim=1) - 1 batch_size = last_hidden_states.shape[0] return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths] @staticmethod def _format_instruct(task: str, query: str) -> str: return f"Instruct: {task}\nQuery: {query}" @torch.no_grad() def _qwen_embed_names(self, names: List[str], pooling: str = "sequence") -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # Load taxonomy DB if provided taxonomy_db = None if self.taxonomy_db_path: try: with open(self.taxonomy_db_path, "r") as f: import json taxonomy_db = json.load(f) except Exception: taxonomy_db = None self._ensure_qwen_loaded() tokenizer = self._qwen_tokenizer model = self._qwen_model assert tokenizer is not None and model is not None task = ( "Given a species taxonomy information, generate a biological embedding " "representing its taxonomic and evolutionary characteristics" ) texts = [self._format_instruct(task, taxonomy_db.get(s, s) if taxonomy_db else s) for s in names] BATCH = int(self.qwen_opts.get("batch_size", 16)) max_len = int(self.qwen_opts.get("max_length", 512)) # sequence pooling only seqs: List[torch.Tensor] = [] lens: List[int] = [] for i in range(0, len(texts), BATCH): chunk = texts[i : i + BATCH] inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True, max_length=max_len).to(self.device) out = model(**inputs) h = torch.nn.functional.normalize(out.last_hidden_state, p=2, dim=-1) # [B,L,D] attn = inputs["attention_mask"] for j in range(h.size(0)): L = int(attn[j].sum().item()) seqs.append(h[j, :L, :].float().cpu()) lens.append(L) # Pad to [B,Lmax,D] Lmax = max(lens) if lens else 0 D = seqs[0].size(1) if seqs else 0 padded = torch.zeros(len(seqs), Lmax, D) for i, t in enumerate(seqs): padded[i, : t.size(0), :] = t return padded, torch.tensor(lens, dtype=torch.long) # ---------------------------- # Conditioning helper # ---------------------------- # (Kept minimal. Species embeddings are prepared inline in sample().) # ---------------------------- # Convenience function # ---------------------------- def sample_sequences( model_path: str, num_sequences: int = 10, sequence_length: int = 100, species: Optional[Union[str, List[str]]] = None, protein_sequence: Optional[Union[str, List[str]]] = None, **kwargs ) -> List[str]: sampler = CodonSampler(model_path) out = sampler.sample( num_sequences=num_sequences, sequence_length=sequence_length, species=species, protein_sequences=protein_sequence, **kwargs ) return out["sequences"] # type: ignore[return-value]