CodonTranslator / src /sampler.py
alegendaryfish's picture
Rename internal model references to CodonTranslatorModel
af19adc verified
# src/sampler.py
"""
Sampling utilities for CodonTranslator.
Conditioning invariants:
- Species context: fixed-size [B, Ds] via species_emb or variable-length [B, Ls, Ds] via species_tok_emb
- Protein context: raw sequences; the model's Frozen ESM handles tokenization
"""
from __future__ import annotations
from typing import List, Optional, Dict, Union, Tuple
from pathlib import Path
import logging
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from safetensors.torch import load_file
from .models import CodonTranslatorModel
from .tokenizer import CodonTokenizer
logger = logging.getLogger(__name__)
# ----------------------------
# Logit filtering
# ----------------------------
def _ensure_2d_logits(logits: torch.Tensor) -> torch.Tensor:
return logits if logits.dim() == 2 else logits.unsqueeze(0)
def _top_k_filtering(logits: torch.Tensor, k: int) -> torch.Tensor:
"""Top-k filtering; logits is [B,V] or [V]."""
x = _ensure_2d_logits(logits)
k = max(1, min(int(k), x.size(-1)))
values, _ = torch.topk(x, k, dim=-1)
min_values = values[:, -1].unsqueeze(-1)
x = torch.where(x < min_values, torch.full_like(x, float('-inf')), x)
return x if logits.dim() == 2 else x.squeeze(0)
def _top_p_filtering(logits: torch.Tensor, p: float) -> torch.Tensor:
"""Top-p (nucleus) filtering; logits is [B,V] or [V]."""
if p >= 1.0:
return logits
if p <= 0.0:
# You asked for nothing; enjoy the abyss.
return torch.full_like(logits, float('-inf'))
x = _ensure_2d_logits(logits)
sorted_logits, sorted_indices = torch.sort(x, descending=True, dim=-1)
probs = F.softmax(sorted_logits, dim=-1)
cumprobs = torch.cumsum(probs, dim=-1)
to_remove = cumprobs > p
to_remove[:, 1:] = to_remove[:, :-1].clone()
to_remove[:, 0] = False
mask = torch.zeros_like(x, dtype=torch.bool).scatter(-1, sorted_indices, to_remove)
x = torch.where(mask, torch.full_like(x, float('-inf')), x)
return x if logits.dim() == 2 else x.squeeze(0)
# ----------------------------
# Sampler
# ----------------------------
class CodonSampler:
"""
GPT sampler with conditional generation.
Requires in model_dir:
- vocab.json
- model.safetensors (preferred)
or pytorch_model.bin (legacy)
- trainer_config.json or config.json
"""
def __init__(
self,
model_path: str,
device: str = "cuda",
species_store=None, # SpeciesEmbeddingStore
compile_model: bool = False,
taxonomy_db_path: Optional[str] = None,
qwen_max_length: int = 512,
qwen_batch_size: int = 16,
**_: dict,
):
self.device = torch.device(device)
self.model_dir = Path(model_path)
# Required files (allow fallback to parent dir for vocab.json)
vocab_path = self.model_dir / "vocab.json"
if not vocab_path.exists():
parent_vocab = self.model_dir.parent / "vocab.json"
if parent_vocab.exists():
vocab_path = parent_vocab
else:
raise FileNotFoundError(f"Missing {self.model_dir / 'vocab.json'}")
trainer_cfg = self.model_dir / "trainer_config.json"
cfg_path = trainer_cfg if trainer_cfg.exists() else (self.model_dir / "config.json")
if not cfg_path.exists():
raise FileNotFoundError(f"Missing trainer_config.json or config.json in {self.model_dir}")
# Load config
with open(cfg_path, "r") as f:
self.config = json.load(f)
# Tokenizer
# If vocab was loaded from parent dir, pass that path; else model_dir
vocab_dir = vocab_path.parent
self.tokenizer = CodonTokenizer.from_pretrained(str(vocab_dir))
self.V = int(self.tokenizer.vocab_size)
self._eos_id = int(self.tokenizer.eos_token_id)
self._pad_id = int(self.tokenizer.pad_token_id)
self._num_special = int(self.tokenizer.num_special_tokens)
# Species store (optional if you pass species_emb* directly at sample())
self.species_store = species_store
self.species_vocab = (self.species_store.vocab if self.species_store is not None else {})
self.taxonomy_db_path = taxonomy_db_path
self.qwen_opts = {
"max_length": int(qwen_max_length),
"batch_size": int(qwen_batch_size),
}
# Lazy-inited Qwen objects
self._qwen_tokenizer = None
self._qwen_model = None
# Model
state = self._load_state_dict()
arch = self._infer_arch_from_state_dict(state)
self.model = CodonTranslatorModel(
vocab_size=self.V,
hidden_size=int(arch["hidden_size"]),
num_layers=int(arch["num_layers"]),
num_heads=int(arch["num_heads"]),
mlp_ratio=float(arch["mlp_ratio"]),
max_position_embeddings=int(arch["max_position_embeddings"]),
dropout=float(self.config.get("dropout", 0.1)),
num_special_tokens=self._num_special,
special_ids=self.tokenizer.special_ids,
esm_model_name=str(arch["esm_model_name"]) if bool(arch["prepend_protein"]) else None,
esm_device=str(arch["esm_device"]),
esm_dtype=str(arch["esm_dtype"]),
max_protein_prefix=int(arch["max_protein_prefix"]) if bool(arch["prepend_protein"]) else 0,
max_species_prefix=int(arch["max_species_prefix"]) if bool(arch["prepend_species"]) else 0,
prepend_species=bool(arch["prepend_species"]),
prepend_protein=bool(arch["prepend_protein"]),
species_embedding_dim=int(self.config.get("species_embedding_dim", 1024)),
attn_impl=str(arch.get("attn_impl", "gqa")),
num_kv_groups=int(arch.get("num_kv_groups", 0)),
)
missing, unexpected = self.model.load_state_dict(state, strict=False)
if len(unexpected) > 0:
logger.warning(f"Unexpected keys in state dict: {unexpected[:10]}{'...' if len(unexpected) > 10 else ''}")
if len(missing) > 0:
logger.warning(f"Missing keys in state dict: {missing[:10]}{'...' if len(missing) > 10 else ''}")
if compile_model:
# If this errors on your PyTorch build, that's on you. No try/except.
self.model = torch.compile(self.model) # type: ignore
self.model.to(self.device).eval()
logger.info(f"Loaded GPT model from {self.model_dir}")
try:
hs = int(getattr(self.model, "hidden_size", -1))
hh = int(getattr(self.model, "num_heads", -1))
nl = int(getattr(self.model, "num_layers", -1))
logger.info(f"Reconstructed arch: hidden={hs} heads={hh} layers={nl}")
except Exception:
pass
# Static masks
self._allowed_fixed = torch.ones(self.V, dtype=torch.bool, device=self.device)
self._allowed_fixed[:self._num_special] = False # no specials in fixed mode
self._allowed_variable = torch.ones(self.V, dtype=torch.bool, device=self.device)
self._allowed_variable[:self._num_special] = False
self._allowed_variable[self._eos_id] = True # EOS allowed in variable mode
# ----------------------------
# Loading / arch inference
# ----------------------------
def _load_state_dict(self) -> Dict[str, torch.Tensor]:
st_p = self.model_dir / "model.safetensors"
pt_p = self.model_dir / "pytorch_model.bin"
if st_p.exists():
return load_file(st_p)
if pt_p.exists():
return torch.load(pt_p, map_location="cpu")
raise FileNotFoundError(f"No model.safetensors or pytorch_model.bin in {self.model_dir}")
def _infer_arch_from_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, Union[int, float, bool, str]]:
arch: Dict[str, Union[int, float, bool, str]] = {}
# hidden size
if "lm_head.weight" in state_dict:
arch["hidden_size"] = int(state_dict["lm_head.weight"].shape[1])
else:
for k, v in state_dict.items():
if k.endswith("ln_f.weight"):
arch["hidden_size"] = int(v.shape[0])
break
# Prefer config when present to avoid guessing errors
cfg = self.config or {}
if "hidden_size" in cfg:
arch["hidden_size"] = int(cfg["hidden_size"]) # type: ignore[index]
if "hidden_size" not in arch:
arch["hidden_size"] = int(cfg.get("hidden_size", 960))
H = int(arch["hidden_size"])
# layers
max_block = -1
for k in state_dict.keys():
if k.startswith("blocks."):
idx = int(k.split(".")[1])
if idx > max_block:
max_block = idx
arch["num_layers"] = (max_block + 1) if max_block >= 0 else int(cfg.get("num_hidden_layers", 12))
if "num_hidden_layers" in cfg:
arch["num_layers"] = int(cfg["num_hidden_layers"]) # type: ignore[index]
# mlp ratio from w1
w1_key = "blocks.0.ffn.w1.weight" if "blocks.0.ffn.w1.weight" in state_dict else None
if w1_key is None:
for i in range(1, 3):
k = f"blocks.{i}.ffn.w1.weight"
if k in state_dict:
w1_key = k
break
if w1_key is not None and H > 0:
arch["mlp_ratio"] = float(int(state_dict[w1_key].shape[0]) / H)
else:
arch["mlp_ratio"] = float(cfg.get("mlp_ratio", 4.0))
# heads – pick a divisor of H
cfg_heads = cfg.get("num_attention_heads")
if isinstance(cfg_heads, int) and cfg_heads > 0 and H % cfg_heads == 0:
arch["num_heads"] = int(cfg_heads)
else:
for h in (16, 15, 12, 10, 8, 6, 5, 4, 3, 2, 1):
if H % h == 0:
arch["num_heads"] = h
break
# conditioning flags from presence of submodules
arch["prepend_species"] = bool(cfg.get("prepend_species", any(k.startswith("species_ln.") for k in state_dict.keys())))
has_esm = any(k.startswith("esm_ln.") for k in state_dict.keys()) or any(k.startswith("esm.") for k in state_dict.keys())
arch["prepend_protein"] = bool(cfg.get("prepend_protein", bool(has_esm)))
arch["esm_model_name"] = str(cfg.get("esm_model_name", "esmc_300m"))
arch["esm_device"] = str(cfg.get("esm_device", "cuda"))
arch["esm_dtype"] = str(cfg.get("esm_dtype", "bf16")).lower()
arch["max_protein_prefix"] = int(cfg.get("max_protein_prefix", 0))
arch["max_species_prefix"] = int(cfg.get("max_species_prefix", 0))
if "max_length" in cfg:
arch["max_position_embeddings"] = int(cfg.get("max_length", 1024))
else:
arch["max_position_embeddings"] = int(cfg.get("max_position_embeddings", 1024))
# Attention impl and num_kv_groups (from config or infer from weights)
attn_impl = str(cfg.get("attn_impl", ""))
num_kv_groups = int(cfg.get("num_kv_groups", 0))
if not attn_impl:
wk_key = next((k for k in state_dict.keys() if k.endswith("attn.Wk.weight")), None)
if wk_key is not None:
attn_impl = "gqa"
out_ch, _ = state_dict[wk_key].shape
num_heads = int(arch.get("num_heads", 1))
head_dim = int(arch["hidden_size"]) // max(1, num_heads)
if head_dim > 0:
num_kv_groups = max(1, out_ch // head_dim)
else:
attn_impl = "mha"
num_kv_groups = 0
arch["attn_impl"] = attn_impl
arch["num_kv_groups"] = num_kv_groups
return arch # type: ignore[return-value]
# ----------------------------
# Public API
# ----------------------------
@torch.no_grad()
def sample(
self,
num_sequences: int = 1,
sequence_length: int = 100, # target number of codons (fixed mode); max iterations (variable)
species: Optional[Union[str, List[str]]] = None,
protein_sequences: Optional[Union[str, List[str]]] = None,
control_mode: str = "fixed", # "fixed" or "variable"
target_protein_length: Optional[int] = None, # deprecated; alias to sequence_length
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
seed: Optional[int] = None,
return_intermediate: bool = False,
progress_bar: bool = False,
species_emb: Optional[torch.Tensor] = None, # [B, Ds]
species_tok_emb: Optional[torch.Tensor] = None, # [B, Ls, Ds]
enforce_translation: bool = False,
codon_enforcement_weight: float = 10.0, # unused with hard mask; kept for API compatibility
) -> Dict[str, Union[List[str], torch.Tensor, List[bool]]]:
if seed is not None:
torch.manual_seed(int(seed))
np.random.seed(int(seed))
if control_mode not in ("fixed", "variable"):
raise ValueError(f"control_mode must be 'fixed' or 'variable', got {control_mode}")
B = int(num_sequences)
T_codons = int(sequence_length if target_protein_length is None else target_protein_length)
# Prepare conditioning
cond: Dict[str, Union[str, List[str], torch.Tensor]] = {"control_mode": control_mode}
# Species (priority: provided tensors → names via store)
if species_tok_emb is not None:
if species_tok_emb.ndim != 3 or species_tok_emb.size(0) != B:
raise ValueError("species_tok_emb must be [B, Ls, Ds]")
st = species_tok_emb.to(self.device)
cond["species_tok_emb_src"] = st
cond["species_tok_emb_tgt"] = st
elif species_emb is not None:
if species_emb.ndim != 2 or species_emb.size(0) != B:
raise ValueError("species_emb must be [B, Ds]")
se = species_emb.to(self.device)
cond["species_emb_src"] = se
cond["species_emb_tgt"] = se
elif species is not None:
names = [species] * B if isinstance(species, str) else species
if len(names) != B:
raise ValueError("Length of species list must match num_sequences")
# If we have a store (variable-length), use it for known species and compute Qwen embeddings for unknowns.
if self.species_store is not None:
ids = [self.species_store.vocab.get(n, -1) for n in names]
known_mask = [i for i, sid in enumerate(ids) if sid >= 0]
unk_mask = [i for i, sid in enumerate(ids) if sid < 0]
# Only variable-length embeddings are supported. If the store is not sequence-based, compute via Qwen for all.
use_sequence = bool(getattr(self.species_store, "is_legacy", False))
if not use_sequence:
# Fall back to Qwen for everything
q_tok, q_len = self._qwen_embed_names(names, pooling="sequence")
cond["species_tok_emb_src"] = q_tok.to(self.device)
cond["species_tok_emb_tgt"] = q_tok.to(self.device)
else:
# list of per-sample [L,D] tensors to be padded later
seq_list: List[torch.Tensor] = [None] * B # type: ignore[list-item]
D = int(getattr(self.species_store, "_ds", 1024))
# Known via store
if known_mask:
sub_ids = [ids[i] for i in known_mask]
result = self.species_store.batch_get(sub_ids)
assert isinstance(result, tuple)
sp_tok, _ = result
for j, i in enumerate(known_mask):
row = sp_tok[j]
nonzero = (row.abs().sum(dim=-1) > 0)
L = int(nonzero.sum().item()) if nonzero.any() else int(row.size(0))
seq_list[i] = row[:L].to(self.device)
# Unknown via Qwen
if unk_mask:
unk_names = [names[i] for i in unk_mask]
q_tok, q_len = self._qwen_embed_names(unk_names, pooling="sequence")
for j, i in enumerate(unk_mask):
L = int(q_len[j].item())
seq_list[i] = q_tok[j, :L, :].to(self.device)
# Pad to [B,Lmax,D]
Lmax = max((t.size(0) for t in seq_list if t is not None), default=0)
if Lmax == 0:
raise RuntimeError("No species embeddings could be constructed.")
padded = torch.zeros(B, Lmax, D, device=self.device, dtype=seq_list[0].dtype)
for i, t in enumerate(seq_list):
if t is None:
continue
L = t.size(0)
padded[i, :L, :] = t
cond["species_tok_emb_src"] = padded
cond["species_tok_emb_tgt"] = padded
else:
# No store: compute everything via Qwen (sequence pooling only)
emb, lengths = self._qwen_embed_names(names, pooling="sequence")
st = emb.to(self.device, non_blocking=True)
cond["species_tok_emb_src"] = st
cond["species_tok_emb_tgt"] = st
# Protein sequences (raw AA strings; the model handles ESM-C)
if protein_sequences is not None:
if isinstance(protein_sequences, list):
if len(protein_sequences) != B:
raise ValueError("Length of protein_sequences must match num_sequences")
cond["protein_seqs"] = protein_sequences
else:
cond["protein_seqs"] = [protein_sequences] * B
# Start with empty codon context; we'll prefill to build KV cache and get first-step logits
input_ids = torch.empty((B, 0), dtype=torch.long, device=self.device)
# Capacity probe and fallback: if prefix consumes all budget, cap species/protein prefix temporarily (prefill path)
pref = None
try:
out0 = self.model(codon_ids=input_ids, cond=cond, return_dict=True, use_cache=True)
pref = out0.get("prefix_len") if isinstance(out0, dict) else None
if pref is not None:
max_pos = int(getattr(self.model, "max_position_embeddings", 1024))
remaining0 = max_pos - (pref + 1)
need_cap = (remaining0 <= 0).any()
else:
need_cap = False
if need_cap:
prev_sp = int(getattr(self.model, "max_species_prefix", 0))
prev_pp = int(getattr(self.model, "max_protein_prefix", 0))
if prev_sp == 0 or prev_sp > 256:
setattr(self.model, "max_species_prefix", 256)
if prev_pp == 0 or prev_pp > 256:
setattr(self.model, "max_protein_prefix", 256)
out0b = self.model(codon_ids=input_ids, cond=cond, return_dict=True, use_cache=True)
pref = out0b.get("prefix_len") if isinstance(out0b, dict) else None
if pref is not None:
remaining0b = max_pos - (pref + 1)
if (remaining0b <= 0).all():
setattr(self.model, "max_species_prefix", 128)
setattr(self.model, "max_protein_prefix", 128)
out0b = self.model(codon_ids=input_ids, cond=cond, return_dict=True, use_cache=True)
pref = out0b.get("prefix_len") if isinstance(out0b, dict) else pref
# Use the prefill output
out_prefill = out0 if pref is None else out0
except Exception:
# Fallback without cache
out_prefill = self.model(codon_ids=input_ids, cond=cond, return_dict=True, use_cache=True)
pref = out_prefill.get("prefix_len") if isinstance(out_prefill, dict) else None
allowed = self._allowed_variable if control_mode == "variable" else self._allowed_fixed
finished = torch.zeros(B, dtype=torch.bool, device=self.device) # EOS reached (variable) OR capacity exhausted
capacity_truncated = torch.zeros(B, dtype=torch.bool, device=self.device)
intermediate = [] if return_intermediate else None
aa2codons = self.tokenizer.aa2codons_char_map()
# If we probed capacity, optionally clamp target codons by available capacity at step 0
try:
if pref is not None:
max_pos = int(getattr(self.model, "max_position_embeddings", 1024))
remaining = (max_pos - (pref + 1)).clamp(min=0)
T_codons = int(min(T_codons, int(remaining.max().item())))
except Exception:
pass
# KV cache and initial logits from prefill
kv = out_prefill.get("present_kv") if isinstance(out_prefill, dict) else None
logits = out_prefill.get("next_logits") if isinstance(out_prefill, dict) else None
if kv is None or logits is None:
# Safety: compute once if not provided
out_prefill = self.model(codon_ids=input_ids, cond=cond, return_dict=True, use_cache=True)
kv = out_prefill.get("present_kv")
logits = out_prefill.get("next_logits")
assert kv is not None and logits is not None
prefix_len = pref if pref is not None else torch.zeros(B, dtype=torch.long, device=self.device)
prefill_len = (prefix_len + 1) # prefix + start
rng = range(T_codons)
if progress_bar:
from tqdm import tqdm
rng = tqdm(rng, desc="GPT sampling", total=T_codons)
for step in rng:
# Enforce global capacity per sample using prefix_len and current generated length
max_pos = int(getattr(self.model, "max_position_embeddings", 1024))
remaining_now = (max_pos - prefill_len - input_ids.size(1)).clamp(max=10**9)
cant_extend = remaining_now <= 0
newly_blocked = (~finished) & cant_extend
capacity_truncated = capacity_truncated | newly_blocked
finished = finished | cant_extend
# Base mask: disallow specials in fixed, allow EOS in variable.
logits = logits.masked_fill(~allowed, float("-inf"))
# If a sample is finished (EOS or capacity), force PAD to keep shapes stable.
# Decoding will drop PAD anyway.
if finished.any():
logits[finished] = float("-inf")
logits[finished, self._pad_id] = 0.0
# Optional: enforce codon ↔ AA mapping at this step (hard mask)
if enforce_translation and ("protein_seqs" in cond):
aas_now: List[Optional[str]] = []
prot_list = cond["protein_seqs"] # type: ignore[index]
assert isinstance(prot_list, list)
for i in range(B):
seq = prot_list[i]
aas_now.append(seq[step] if step < len(seq) else None)
mask = torch.zeros_like(logits, dtype=torch.bool)
for i, a in enumerate(aas_now):
if a is None:
mask[i, self._num_special:self.V] = True
else:
valid = aa2codons.get(a, [])
if len(valid) == 0:
mask[i, self._num_special:self.V] = True
else:
mask[i, valid] = True
logits = logits.masked_fill(~mask, float("-inf"))
# Temperature + filtering
if temperature != 1.0:
logits = logits / float(temperature)
if top_k is not None:
logits = _top_k_filtering(logits, int(top_k))
if top_p is not None:
logits = _top_p_filtering(logits, float(top_p))
probs = F.softmax(logits, dim=-1)
next_tok = torch.multinomial(probs, num_samples=1) # [B,1]
if control_mode == "variable":
# Stop sequences at EOS
eos_mask = (next_tok.squeeze(-1) == self._eos_id)
finished = finished | eos_mask
input_ids = torch.cat([input_ids, next_tok], dim=1)
if return_intermediate:
intermediate.append(input_ids.clone())
# If all sequences are finished, we're done.
if finished.all():
break
# Incremental decode: compute logits for next step and update KV cache
pos_offset = int(prefill_len.max().item()) + input_ids.size(1) - 1 # use max offset for shared RoPE cache
out_inc = self.model(
codon_ids=next_tok,
cond=None,
return_dict=True,
use_cache=True,
past_kv=kv,
position_offset=pos_offset,
)
kv = out_inc.get("present_kv")
logits = out_inc.get("next_logits")
assert kv is not None and logits is not None
# Build final DNA strings, dropping specials and any PADs we added
output_token_rows: List[List[int]] = []
for row in input_ids.tolist():
toks: List[int] = []
for t in row:
if t == self._pad_id:
continue
if t == self._eos_id:
break # variable mode terminator
if t >= self._num_special and t < self.V:
toks.append(int(t))
if control_mode == "fixed":
# In fixed mode we *intended* T_codons; if capacity cut us short, it's fine.
toks = toks[:T_codons]
output_token_rows.append(toks)
sequences = [self.tokenizer.decode_codon_seq(row) for row in output_token_rows]
# Pad variable-length rows for input_ids to avoid tensor construction errors when
# some samples are capacity-truncated in fixed mode.
max_len = max((len(r) for r in output_token_rows), default=0)
if max_len > 0:
ids_padded = torch.full(
(len(output_token_rows), max_len),
self._pad_id,
device=self.device,
dtype=torch.long,
)
for i, row in enumerate(output_token_rows):
if len(row) > 0:
ids_padded[i, : len(row)] = torch.tensor(row, device=self.device, dtype=torch.long)
else:
ids_padded = torch.empty((len(output_token_rows), 0), device=self.device, dtype=torch.long)
result: Dict[str, Union[List[str], torch.Tensor, List[bool]]] = {
"sequences": sequences,
"input_ids": ids_padded,
"capacity_truncated": capacity_truncated.detach().bool().tolist(),
}
if return_intermediate:
result["intermediate_states"] = intermediate # list[Tensor], length = steps actually taken
return result
# ----------------------------
# Qwen embedding (inline; no separate module)
# ----------------------------
def _ensure_qwen_loaded(self):
if self._qwen_tokenizer is not None and self._qwen_model is not None:
return
from transformers import AutoTokenizer, AutoModel
self._qwen_tokenizer = AutoTokenizer.from_pretrained(
"Qwen/Qwen3-Embedding-0.6B", trust_remote_code=True, padding_side="left"
)
dtype = torch.float16 if self.device.type == "cuda" else torch.float32
self._qwen_model = AutoModel.from_pretrained(
"Qwen/Qwen3-Embedding-0.6B", torch_dtype=dtype, trust_remote_code=True
).to(self.device).eval()
@staticmethod
def _last_token_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
return last_hidden_states[:, -1]
else:
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden_states.shape[0]
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
@staticmethod
def _format_instruct(task: str, query: str) -> str:
return f"Instruct: {task}\nQuery: {query}"
@torch.no_grad()
def _qwen_embed_names(self, names: List[str], pooling: str = "sequence") -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Load taxonomy DB if provided
taxonomy_db = None
if self.taxonomy_db_path:
try:
with open(self.taxonomy_db_path, "r") as f:
import json
taxonomy_db = json.load(f)
except Exception:
taxonomy_db = None
self._ensure_qwen_loaded()
tokenizer = self._qwen_tokenizer
model = self._qwen_model
assert tokenizer is not None and model is not None
task = (
"Given a species taxonomy information, generate a biological embedding "
"representing its taxonomic and evolutionary characteristics"
)
texts = [self._format_instruct(task, taxonomy_db.get(s, s) if taxonomy_db else s) for s in names]
BATCH = int(self.qwen_opts.get("batch_size", 16))
max_len = int(self.qwen_opts.get("max_length", 512))
# sequence pooling only
seqs: List[torch.Tensor] = []
lens: List[int] = []
for i in range(0, len(texts), BATCH):
chunk = texts[i : i + BATCH]
inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True, max_length=max_len).to(self.device)
out = model(**inputs)
h = torch.nn.functional.normalize(out.last_hidden_state, p=2, dim=-1) # [B,L,D]
attn = inputs["attention_mask"]
for j in range(h.size(0)):
L = int(attn[j].sum().item())
seqs.append(h[j, :L, :].float().cpu())
lens.append(L)
# Pad to [B,Lmax,D]
Lmax = max(lens) if lens else 0
D = seqs[0].size(1) if seqs else 0
padded = torch.zeros(len(seqs), Lmax, D)
for i, t in enumerate(seqs):
padded[i, : t.size(0), :] = t
return padded, torch.tensor(lens, dtype=torch.long)
# ----------------------------
# Conditioning helper
# ----------------------------
# (Kept minimal. Species embeddings are prepared inline in sample().)
# ----------------------------
# Convenience function
# ----------------------------
def sample_sequences(
model_path: str,
num_sequences: int = 10,
sequence_length: int = 100,
species: Optional[Union[str, List[str]]] = None,
protein_sequence: Optional[Union[str, List[str]]] = None,
**kwargs
) -> List[str]:
sampler = CodonSampler(model_path)
out = sampler.sample(
num_sequences=num_sequences,
sequence_length=sequence_length,
species=species,
protein_sequences=protein_sequence,
**kwargs
)
return out["sequences"] # type: ignore[return-value]