File size: 31,665 Bytes
2d8da02 b3a92f0 2d8da02 af19adc 2d8da02 af19adc 2d8da02 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 | # src/sampler.py
"""
Sampling utilities for CodonTranslator.
Conditioning invariants:
- Species context: fixed-size [B, Ds] via species_emb or variable-length [B, Ls, Ds] via species_tok_emb
- Protein context: raw sequences; the model's Frozen ESM handles tokenization
"""
from __future__ import annotations
from typing import List, Optional, Dict, Union, Tuple
from pathlib import Path
import logging
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from safetensors.torch import load_file
from .models import CodonTranslatorModel
from .tokenizer import CodonTokenizer
logger = logging.getLogger(__name__)
# ----------------------------
# Logit filtering
# ----------------------------
def _ensure_2d_logits(logits: torch.Tensor) -> torch.Tensor:
return logits if logits.dim() == 2 else logits.unsqueeze(0)
def _top_k_filtering(logits: torch.Tensor, k: int) -> torch.Tensor:
"""Top-k filtering; logits is [B,V] or [V]."""
x = _ensure_2d_logits(logits)
k = max(1, min(int(k), x.size(-1)))
values, _ = torch.topk(x, k, dim=-1)
min_values = values[:, -1].unsqueeze(-1)
x = torch.where(x < min_values, torch.full_like(x, float('-inf')), x)
return x if logits.dim() == 2 else x.squeeze(0)
def _top_p_filtering(logits: torch.Tensor, p: float) -> torch.Tensor:
"""Top-p (nucleus) filtering; logits is [B,V] or [V]."""
if p >= 1.0:
return logits
if p <= 0.0:
# You asked for nothing; enjoy the abyss.
return torch.full_like(logits, float('-inf'))
x = _ensure_2d_logits(logits)
sorted_logits, sorted_indices = torch.sort(x, descending=True, dim=-1)
probs = F.softmax(sorted_logits, dim=-1)
cumprobs = torch.cumsum(probs, dim=-1)
to_remove = cumprobs > p
to_remove[:, 1:] = to_remove[:, :-1].clone()
to_remove[:, 0] = False
mask = torch.zeros_like(x, dtype=torch.bool).scatter(-1, sorted_indices, to_remove)
x = torch.where(mask, torch.full_like(x, float('-inf')), x)
return x if logits.dim() == 2 else x.squeeze(0)
# ----------------------------
# Sampler
# ----------------------------
class CodonSampler:
"""
GPT sampler with conditional generation.
Requires in model_dir:
- vocab.json
- model.safetensors (preferred)
or pytorch_model.bin (legacy)
- trainer_config.json or config.json
"""
def __init__(
self,
model_path: str,
device: str = "cuda",
species_store=None, # SpeciesEmbeddingStore
compile_model: bool = False,
taxonomy_db_path: Optional[str] = None,
qwen_max_length: int = 512,
qwen_batch_size: int = 16,
**_: dict,
):
self.device = torch.device(device)
self.model_dir = Path(model_path)
# Required files (allow fallback to parent dir for vocab.json)
vocab_path = self.model_dir / "vocab.json"
if not vocab_path.exists():
parent_vocab = self.model_dir.parent / "vocab.json"
if parent_vocab.exists():
vocab_path = parent_vocab
else:
raise FileNotFoundError(f"Missing {self.model_dir / 'vocab.json'}")
trainer_cfg = self.model_dir / "trainer_config.json"
cfg_path = trainer_cfg if trainer_cfg.exists() else (self.model_dir / "config.json")
if not cfg_path.exists():
raise FileNotFoundError(f"Missing trainer_config.json or config.json in {self.model_dir}")
# Load config
with open(cfg_path, "r") as f:
self.config = json.load(f)
# Tokenizer
# If vocab was loaded from parent dir, pass that path; else model_dir
vocab_dir = vocab_path.parent
self.tokenizer = CodonTokenizer.from_pretrained(str(vocab_dir))
self.V = int(self.tokenizer.vocab_size)
self._eos_id = int(self.tokenizer.eos_token_id)
self._pad_id = int(self.tokenizer.pad_token_id)
self._num_special = int(self.tokenizer.num_special_tokens)
# Species store (optional if you pass species_emb* directly at sample())
self.species_store = species_store
self.species_vocab = (self.species_store.vocab if self.species_store is not None else {})
self.taxonomy_db_path = taxonomy_db_path
self.qwen_opts = {
"max_length": int(qwen_max_length),
"batch_size": int(qwen_batch_size),
}
# Lazy-inited Qwen objects
self._qwen_tokenizer = None
self._qwen_model = None
# Model
state = self._load_state_dict()
arch = self._infer_arch_from_state_dict(state)
self.model = CodonTranslatorModel(
vocab_size=self.V,
hidden_size=int(arch["hidden_size"]),
num_layers=int(arch["num_layers"]),
num_heads=int(arch["num_heads"]),
mlp_ratio=float(arch["mlp_ratio"]),
max_position_embeddings=int(arch["max_position_embeddings"]),
dropout=float(self.config.get("dropout", 0.1)),
num_special_tokens=self._num_special,
special_ids=self.tokenizer.special_ids,
esm_model_name=str(arch["esm_model_name"]) if bool(arch["prepend_protein"]) else None,
esm_device=str(arch["esm_device"]),
esm_dtype=str(arch["esm_dtype"]),
max_protein_prefix=int(arch["max_protein_prefix"]) if bool(arch["prepend_protein"]) else 0,
max_species_prefix=int(arch["max_species_prefix"]) if bool(arch["prepend_species"]) else 0,
prepend_species=bool(arch["prepend_species"]),
prepend_protein=bool(arch["prepend_protein"]),
species_embedding_dim=int(self.config.get("species_embedding_dim", 1024)),
attn_impl=str(arch.get("attn_impl", "gqa")),
num_kv_groups=int(arch.get("num_kv_groups", 0)),
)
missing, unexpected = self.model.load_state_dict(state, strict=False)
if len(unexpected) > 0:
logger.warning(f"Unexpected keys in state dict: {unexpected[:10]}{'...' if len(unexpected) > 10 else ''}")
if len(missing) > 0:
logger.warning(f"Missing keys in state dict: {missing[:10]}{'...' if len(missing) > 10 else ''}")
if compile_model:
# If this errors on your PyTorch build, that's on you. No try/except.
self.model = torch.compile(self.model) # type: ignore
self.model.to(self.device).eval()
logger.info(f"Loaded GPT model from {self.model_dir}")
try:
hs = int(getattr(self.model, "hidden_size", -1))
hh = int(getattr(self.model, "num_heads", -1))
nl = int(getattr(self.model, "num_layers", -1))
logger.info(f"Reconstructed arch: hidden={hs} heads={hh} layers={nl}")
except Exception:
pass
# Static masks
self._allowed_fixed = torch.ones(self.V, dtype=torch.bool, device=self.device)
self._allowed_fixed[:self._num_special] = False # no specials in fixed mode
self._allowed_variable = torch.ones(self.V, dtype=torch.bool, device=self.device)
self._allowed_variable[:self._num_special] = False
self._allowed_variable[self._eos_id] = True # EOS allowed in variable mode
# ----------------------------
# Loading / arch inference
# ----------------------------
def _load_state_dict(self) -> Dict[str, torch.Tensor]:
st_p = self.model_dir / "model.safetensors"
pt_p = self.model_dir / "pytorch_model.bin"
if st_p.exists():
return load_file(st_p)
if pt_p.exists():
return torch.load(pt_p, map_location="cpu")
raise FileNotFoundError(f"No model.safetensors or pytorch_model.bin in {self.model_dir}")
def _infer_arch_from_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, Union[int, float, bool, str]]:
arch: Dict[str, Union[int, float, bool, str]] = {}
# hidden size
if "lm_head.weight" in state_dict:
arch["hidden_size"] = int(state_dict["lm_head.weight"].shape[1])
else:
for k, v in state_dict.items():
if k.endswith("ln_f.weight"):
arch["hidden_size"] = int(v.shape[0])
break
# Prefer config when present to avoid guessing errors
cfg = self.config or {}
if "hidden_size" in cfg:
arch["hidden_size"] = int(cfg["hidden_size"]) # type: ignore[index]
if "hidden_size" not in arch:
arch["hidden_size"] = int(cfg.get("hidden_size", 960))
H = int(arch["hidden_size"])
# layers
max_block = -1
for k in state_dict.keys():
if k.startswith("blocks."):
idx = int(k.split(".")[1])
if idx > max_block:
max_block = idx
arch["num_layers"] = (max_block + 1) if max_block >= 0 else int(cfg.get("num_hidden_layers", 12))
if "num_hidden_layers" in cfg:
arch["num_layers"] = int(cfg["num_hidden_layers"]) # type: ignore[index]
# mlp ratio from w1
w1_key = "blocks.0.ffn.w1.weight" if "blocks.0.ffn.w1.weight" in state_dict else None
if w1_key is None:
for i in range(1, 3):
k = f"blocks.{i}.ffn.w1.weight"
if k in state_dict:
w1_key = k
break
if w1_key is not None and H > 0:
arch["mlp_ratio"] = float(int(state_dict[w1_key].shape[0]) / H)
else:
arch["mlp_ratio"] = float(cfg.get("mlp_ratio", 4.0))
# heads – pick a divisor of H
cfg_heads = cfg.get("num_attention_heads")
if isinstance(cfg_heads, int) and cfg_heads > 0 and H % cfg_heads == 0:
arch["num_heads"] = int(cfg_heads)
else:
for h in (16, 15, 12, 10, 8, 6, 5, 4, 3, 2, 1):
if H % h == 0:
arch["num_heads"] = h
break
# conditioning flags from presence of submodules
arch["prepend_species"] = bool(cfg.get("prepend_species", any(k.startswith("species_ln.") for k in state_dict.keys())))
has_esm = any(k.startswith("esm_ln.") for k in state_dict.keys()) or any(k.startswith("esm.") for k in state_dict.keys())
arch["prepend_protein"] = bool(cfg.get("prepend_protein", bool(has_esm)))
arch["esm_model_name"] = str(cfg.get("esm_model_name", "esmc_300m"))
arch["esm_device"] = str(cfg.get("esm_device", "cuda"))
arch["esm_dtype"] = str(cfg.get("esm_dtype", "bf16")).lower()
arch["max_protein_prefix"] = int(cfg.get("max_protein_prefix", 0))
arch["max_species_prefix"] = int(cfg.get("max_species_prefix", 0))
if "max_length" in cfg:
arch["max_position_embeddings"] = int(cfg.get("max_length", 1024))
else:
arch["max_position_embeddings"] = int(cfg.get("max_position_embeddings", 1024))
# Attention impl and num_kv_groups (from config or infer from weights)
attn_impl = str(cfg.get("attn_impl", ""))
num_kv_groups = int(cfg.get("num_kv_groups", 0))
if not attn_impl:
wk_key = next((k for k in state_dict.keys() if k.endswith("attn.Wk.weight")), None)
if wk_key is not None:
attn_impl = "gqa"
out_ch, _ = state_dict[wk_key].shape
num_heads = int(arch.get("num_heads", 1))
head_dim = int(arch["hidden_size"]) // max(1, num_heads)
if head_dim > 0:
num_kv_groups = max(1, out_ch // head_dim)
else:
attn_impl = "mha"
num_kv_groups = 0
arch["attn_impl"] = attn_impl
arch["num_kv_groups"] = num_kv_groups
return arch # type: ignore[return-value]
# ----------------------------
# Public API
# ----------------------------
@torch.no_grad()
def sample(
self,
num_sequences: int = 1,
sequence_length: int = 100, # target number of codons (fixed mode); max iterations (variable)
species: Optional[Union[str, List[str]]] = None,
protein_sequences: Optional[Union[str, List[str]]] = None,
control_mode: str = "fixed", # "fixed" or "variable"
target_protein_length: Optional[int] = None, # deprecated; alias to sequence_length
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
seed: Optional[int] = None,
return_intermediate: bool = False,
progress_bar: bool = False,
species_emb: Optional[torch.Tensor] = None, # [B, Ds]
species_tok_emb: Optional[torch.Tensor] = None, # [B, Ls, Ds]
enforce_translation: bool = False,
codon_enforcement_weight: float = 10.0, # unused with hard mask; kept for API compatibility
) -> Dict[str, Union[List[str], torch.Tensor, List[bool]]]:
if seed is not None:
torch.manual_seed(int(seed))
np.random.seed(int(seed))
if control_mode not in ("fixed", "variable"):
raise ValueError(f"control_mode must be 'fixed' or 'variable', got {control_mode}")
B = int(num_sequences)
T_codons = int(sequence_length if target_protein_length is None else target_protein_length)
# Prepare conditioning
cond: Dict[str, Union[str, List[str], torch.Tensor]] = {"control_mode": control_mode}
# Species (priority: provided tensors → names via store)
if species_tok_emb is not None:
if species_tok_emb.ndim != 3 or species_tok_emb.size(0) != B:
raise ValueError("species_tok_emb must be [B, Ls, Ds]")
st = species_tok_emb.to(self.device)
cond["species_tok_emb_src"] = st
cond["species_tok_emb_tgt"] = st
elif species_emb is not None:
if species_emb.ndim != 2 or species_emb.size(0) != B:
raise ValueError("species_emb must be [B, Ds]")
se = species_emb.to(self.device)
cond["species_emb_src"] = se
cond["species_emb_tgt"] = se
elif species is not None:
names = [species] * B if isinstance(species, str) else species
if len(names) != B:
raise ValueError("Length of species list must match num_sequences")
# If we have a store (variable-length), use it for known species and compute Qwen embeddings for unknowns.
if self.species_store is not None:
ids = [self.species_store.vocab.get(n, -1) for n in names]
known_mask = [i for i, sid in enumerate(ids) if sid >= 0]
unk_mask = [i for i, sid in enumerate(ids) if sid < 0]
# Only variable-length embeddings are supported. If the store is not sequence-based, compute via Qwen for all.
use_sequence = bool(getattr(self.species_store, "is_legacy", False))
if not use_sequence:
# Fall back to Qwen for everything
q_tok, q_len = self._qwen_embed_names(names, pooling="sequence")
cond["species_tok_emb_src"] = q_tok.to(self.device)
cond["species_tok_emb_tgt"] = q_tok.to(self.device)
else:
# list of per-sample [L,D] tensors to be padded later
seq_list: List[torch.Tensor] = [None] * B # type: ignore[list-item]
D = int(getattr(self.species_store, "_ds", 1024))
# Known via store
if known_mask:
sub_ids = [ids[i] for i in known_mask]
result = self.species_store.batch_get(sub_ids)
assert isinstance(result, tuple)
sp_tok, _ = result
for j, i in enumerate(known_mask):
row = sp_tok[j]
nonzero = (row.abs().sum(dim=-1) > 0)
L = int(nonzero.sum().item()) if nonzero.any() else int(row.size(0))
seq_list[i] = row[:L].to(self.device)
# Unknown via Qwen
if unk_mask:
unk_names = [names[i] for i in unk_mask]
q_tok, q_len = self._qwen_embed_names(unk_names, pooling="sequence")
for j, i in enumerate(unk_mask):
L = int(q_len[j].item())
seq_list[i] = q_tok[j, :L, :].to(self.device)
# Pad to [B,Lmax,D]
Lmax = max((t.size(0) for t in seq_list if t is not None), default=0)
if Lmax == 0:
raise RuntimeError("No species embeddings could be constructed.")
padded = torch.zeros(B, Lmax, D, device=self.device, dtype=seq_list[0].dtype)
for i, t in enumerate(seq_list):
if t is None:
continue
L = t.size(0)
padded[i, :L, :] = t
cond["species_tok_emb_src"] = padded
cond["species_tok_emb_tgt"] = padded
else:
# No store: compute everything via Qwen (sequence pooling only)
emb, lengths = self._qwen_embed_names(names, pooling="sequence")
st = emb.to(self.device, non_blocking=True)
cond["species_tok_emb_src"] = st
cond["species_tok_emb_tgt"] = st
# Protein sequences (raw AA strings; the model handles ESM-C)
if protein_sequences is not None:
if isinstance(protein_sequences, list):
if len(protein_sequences) != B:
raise ValueError("Length of protein_sequences must match num_sequences")
cond["protein_seqs"] = protein_sequences
else:
cond["protein_seqs"] = [protein_sequences] * B
# Start with empty codon context; we'll prefill to build KV cache and get first-step logits
input_ids = torch.empty((B, 0), dtype=torch.long, device=self.device)
# Capacity probe and fallback: if prefix consumes all budget, cap species/protein prefix temporarily (prefill path)
pref = None
try:
out0 = self.model(codon_ids=input_ids, cond=cond, return_dict=True, use_cache=True)
pref = out0.get("prefix_len") if isinstance(out0, dict) else None
if pref is not None:
max_pos = int(getattr(self.model, "max_position_embeddings", 1024))
remaining0 = max_pos - (pref + 1)
need_cap = (remaining0 <= 0).any()
else:
need_cap = False
if need_cap:
prev_sp = int(getattr(self.model, "max_species_prefix", 0))
prev_pp = int(getattr(self.model, "max_protein_prefix", 0))
if prev_sp == 0 or prev_sp > 256:
setattr(self.model, "max_species_prefix", 256)
if prev_pp == 0 or prev_pp > 256:
setattr(self.model, "max_protein_prefix", 256)
out0b = self.model(codon_ids=input_ids, cond=cond, return_dict=True, use_cache=True)
pref = out0b.get("prefix_len") if isinstance(out0b, dict) else None
if pref is not None:
remaining0b = max_pos - (pref + 1)
if (remaining0b <= 0).all():
setattr(self.model, "max_species_prefix", 128)
setattr(self.model, "max_protein_prefix", 128)
out0b = self.model(codon_ids=input_ids, cond=cond, return_dict=True, use_cache=True)
pref = out0b.get("prefix_len") if isinstance(out0b, dict) else pref
# Use the prefill output
out_prefill = out0 if pref is None else out0
except Exception:
# Fallback without cache
out_prefill = self.model(codon_ids=input_ids, cond=cond, return_dict=True, use_cache=True)
pref = out_prefill.get("prefix_len") if isinstance(out_prefill, dict) else None
allowed = self._allowed_variable if control_mode == "variable" else self._allowed_fixed
finished = torch.zeros(B, dtype=torch.bool, device=self.device) # EOS reached (variable) OR capacity exhausted
capacity_truncated = torch.zeros(B, dtype=torch.bool, device=self.device)
intermediate = [] if return_intermediate else None
aa2codons = self.tokenizer.aa2codons_char_map()
# If we probed capacity, optionally clamp target codons by available capacity at step 0
try:
if pref is not None:
max_pos = int(getattr(self.model, "max_position_embeddings", 1024))
remaining = (max_pos - (pref + 1)).clamp(min=0)
T_codons = int(min(T_codons, int(remaining.max().item())))
except Exception:
pass
# KV cache and initial logits from prefill
kv = out_prefill.get("present_kv") if isinstance(out_prefill, dict) else None
logits = out_prefill.get("next_logits") if isinstance(out_prefill, dict) else None
if kv is None or logits is None:
# Safety: compute once if not provided
out_prefill = self.model(codon_ids=input_ids, cond=cond, return_dict=True, use_cache=True)
kv = out_prefill.get("present_kv")
logits = out_prefill.get("next_logits")
assert kv is not None and logits is not None
prefix_len = pref if pref is not None else torch.zeros(B, dtype=torch.long, device=self.device)
prefill_len = (prefix_len + 1) # prefix + start
rng = range(T_codons)
if progress_bar:
from tqdm import tqdm
rng = tqdm(rng, desc="GPT sampling", total=T_codons)
for step in rng:
# Enforce global capacity per sample using prefix_len and current generated length
max_pos = int(getattr(self.model, "max_position_embeddings", 1024))
remaining_now = (max_pos - prefill_len - input_ids.size(1)).clamp(max=10**9)
cant_extend = remaining_now <= 0
newly_blocked = (~finished) & cant_extend
capacity_truncated = capacity_truncated | newly_blocked
finished = finished | cant_extend
# Base mask: disallow specials in fixed, allow EOS in variable.
logits = logits.masked_fill(~allowed, float("-inf"))
# If a sample is finished (EOS or capacity), force PAD to keep shapes stable.
# Decoding will drop PAD anyway.
if finished.any():
logits[finished] = float("-inf")
logits[finished, self._pad_id] = 0.0
# Optional: enforce codon ↔ AA mapping at this step (hard mask)
if enforce_translation and ("protein_seqs" in cond):
aas_now: List[Optional[str]] = []
prot_list = cond["protein_seqs"] # type: ignore[index]
assert isinstance(prot_list, list)
for i in range(B):
seq = prot_list[i]
aas_now.append(seq[step] if step < len(seq) else None)
mask = torch.zeros_like(logits, dtype=torch.bool)
for i, a in enumerate(aas_now):
if a is None:
mask[i, self._num_special:self.V] = True
else:
valid = aa2codons.get(a, [])
if len(valid) == 0:
mask[i, self._num_special:self.V] = True
else:
mask[i, valid] = True
logits = logits.masked_fill(~mask, float("-inf"))
# Temperature + filtering
if temperature != 1.0:
logits = logits / float(temperature)
if top_k is not None:
logits = _top_k_filtering(logits, int(top_k))
if top_p is not None:
logits = _top_p_filtering(logits, float(top_p))
probs = F.softmax(logits, dim=-1)
next_tok = torch.multinomial(probs, num_samples=1) # [B,1]
if control_mode == "variable":
# Stop sequences at EOS
eos_mask = (next_tok.squeeze(-1) == self._eos_id)
finished = finished | eos_mask
input_ids = torch.cat([input_ids, next_tok], dim=1)
if return_intermediate:
intermediate.append(input_ids.clone())
# If all sequences are finished, we're done.
if finished.all():
break
# Incremental decode: compute logits for next step and update KV cache
pos_offset = int(prefill_len.max().item()) + input_ids.size(1) - 1 # use max offset for shared RoPE cache
out_inc = self.model(
codon_ids=next_tok,
cond=None,
return_dict=True,
use_cache=True,
past_kv=kv,
position_offset=pos_offset,
)
kv = out_inc.get("present_kv")
logits = out_inc.get("next_logits")
assert kv is not None and logits is not None
# Build final DNA strings, dropping specials and any PADs we added
output_token_rows: List[List[int]] = []
for row in input_ids.tolist():
toks: List[int] = []
for t in row:
if t == self._pad_id:
continue
if t == self._eos_id:
break # variable mode terminator
if t >= self._num_special and t < self.V:
toks.append(int(t))
if control_mode == "fixed":
# In fixed mode we *intended* T_codons; if capacity cut us short, it's fine.
toks = toks[:T_codons]
output_token_rows.append(toks)
sequences = [self.tokenizer.decode_codon_seq(row) for row in output_token_rows]
# Pad variable-length rows for input_ids to avoid tensor construction errors when
# some samples are capacity-truncated in fixed mode.
max_len = max((len(r) for r in output_token_rows), default=0)
if max_len > 0:
ids_padded = torch.full(
(len(output_token_rows), max_len),
self._pad_id,
device=self.device,
dtype=torch.long,
)
for i, row in enumerate(output_token_rows):
if len(row) > 0:
ids_padded[i, : len(row)] = torch.tensor(row, device=self.device, dtype=torch.long)
else:
ids_padded = torch.empty((len(output_token_rows), 0), device=self.device, dtype=torch.long)
result: Dict[str, Union[List[str], torch.Tensor, List[bool]]] = {
"sequences": sequences,
"input_ids": ids_padded,
"capacity_truncated": capacity_truncated.detach().bool().tolist(),
}
if return_intermediate:
result["intermediate_states"] = intermediate # list[Tensor], length = steps actually taken
return result
# ----------------------------
# Qwen embedding (inline; no separate module)
# ----------------------------
def _ensure_qwen_loaded(self):
if self._qwen_tokenizer is not None and self._qwen_model is not None:
return
from transformers import AutoTokenizer, AutoModel
self._qwen_tokenizer = AutoTokenizer.from_pretrained(
"Qwen/Qwen3-Embedding-0.6B", trust_remote_code=True, padding_side="left"
)
dtype = torch.float16 if self.device.type == "cuda" else torch.float32
self._qwen_model = AutoModel.from_pretrained(
"Qwen/Qwen3-Embedding-0.6B", torch_dtype=dtype, trust_remote_code=True
).to(self.device).eval()
@staticmethod
def _last_token_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
return last_hidden_states[:, -1]
else:
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden_states.shape[0]
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
@staticmethod
def _format_instruct(task: str, query: str) -> str:
return f"Instruct: {task}\nQuery: {query}"
@torch.no_grad()
def _qwen_embed_names(self, names: List[str], pooling: str = "sequence") -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Load taxonomy DB if provided
taxonomy_db = None
if self.taxonomy_db_path:
try:
with open(self.taxonomy_db_path, "r") as f:
import json
taxonomy_db = json.load(f)
except Exception:
taxonomy_db = None
self._ensure_qwen_loaded()
tokenizer = self._qwen_tokenizer
model = self._qwen_model
assert tokenizer is not None and model is not None
task = (
"Given a species taxonomy information, generate a biological embedding "
"representing its taxonomic and evolutionary characteristics"
)
texts = [self._format_instruct(task, taxonomy_db.get(s, s) if taxonomy_db else s) for s in names]
BATCH = int(self.qwen_opts.get("batch_size", 16))
max_len = int(self.qwen_opts.get("max_length", 512))
# sequence pooling only
seqs: List[torch.Tensor] = []
lens: List[int] = []
for i in range(0, len(texts), BATCH):
chunk = texts[i : i + BATCH]
inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True, max_length=max_len).to(self.device)
out = model(**inputs)
h = torch.nn.functional.normalize(out.last_hidden_state, p=2, dim=-1) # [B,L,D]
attn = inputs["attention_mask"]
for j in range(h.size(0)):
L = int(attn[j].sum().item())
seqs.append(h[j, :L, :].float().cpu())
lens.append(L)
# Pad to [B,Lmax,D]
Lmax = max(lens) if lens else 0
D = seqs[0].size(1) if seqs else 0
padded = torch.zeros(len(seqs), Lmax, D)
for i, t in enumerate(seqs):
padded[i, : t.size(0), :] = t
return padded, torch.tensor(lens, dtype=torch.long)
# ----------------------------
# Conditioning helper
# ----------------------------
# (Kept minimal. Species embeddings are prepared inline in sample().)
# ----------------------------
# Convenience function
# ----------------------------
def sample_sequences(
model_path: str,
num_sequences: int = 10,
sequence_length: int = 100,
species: Optional[Union[str, List[str]]] = None,
protein_sequence: Optional[Union[str, List[str]]] = None,
**kwargs
) -> List[str]:
sampler = CodonSampler(model_path)
out = sampler.sample(
num_sequences=num_sequences,
sequence_length=sequence_length,
species=species,
protein_sequences=protein_sequence,
**kwargs
)
return out["sequences"] # type: ignore[return-value]
|