| |
| """ |
| Sampling utilities for CodonTranslator. |
| |
| Conditioning invariants: |
| - Species context: fixed-size [B, Ds] via species_emb or variable-length [B, Ls, Ds] via species_tok_emb |
| - Protein context: raw sequences; the model's Frozen ESM handles tokenization |
| """ |
|
|
| from __future__ import annotations |
| from typing import List, Optional, Dict, Union, Tuple |
| from pathlib import Path |
| import logging |
| import json |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| from safetensors.torch import load_file |
|
|
| from .models import CodonTranslatorModel |
| from .tokenizer import CodonTokenizer |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| |
| |
| |
|
|
| def _ensure_2d_logits(logits: torch.Tensor) -> torch.Tensor: |
| return logits if logits.dim() == 2 else logits.unsqueeze(0) |
|
|
| def _top_k_filtering(logits: torch.Tensor, k: int) -> torch.Tensor: |
| """Top-k filtering; logits is [B,V] or [V].""" |
| x = _ensure_2d_logits(logits) |
| k = max(1, min(int(k), x.size(-1))) |
| values, _ = torch.topk(x, k, dim=-1) |
| min_values = values[:, -1].unsqueeze(-1) |
| x = torch.where(x < min_values, torch.full_like(x, float('-inf')), x) |
| return x if logits.dim() == 2 else x.squeeze(0) |
|
|
| def _top_p_filtering(logits: torch.Tensor, p: float) -> torch.Tensor: |
| """Top-p (nucleus) filtering; logits is [B,V] or [V].""" |
| if p >= 1.0: |
| return logits |
| if p <= 0.0: |
| |
| return torch.full_like(logits, float('-inf')) |
| x = _ensure_2d_logits(logits) |
| sorted_logits, sorted_indices = torch.sort(x, descending=True, dim=-1) |
| probs = F.softmax(sorted_logits, dim=-1) |
| cumprobs = torch.cumsum(probs, dim=-1) |
| to_remove = cumprobs > p |
| to_remove[:, 1:] = to_remove[:, :-1].clone() |
| to_remove[:, 0] = False |
| mask = torch.zeros_like(x, dtype=torch.bool).scatter(-1, sorted_indices, to_remove) |
| x = torch.where(mask, torch.full_like(x, float('-inf')), x) |
| return x if logits.dim() == 2 else x.squeeze(0) |
|
|
|
|
| |
| |
| |
|
|
| class CodonSampler: |
| """ |
| GPT sampler with conditional generation. |
| |
| Requires in model_dir: |
| - vocab.json |
| - model.safetensors (preferred) |
| or pytorch_model.bin (legacy) |
| - trainer_config.json or config.json |
| """ |
|
|
| def __init__( |
| self, |
| model_path: str, |
| device: str = "cuda", |
| species_store=None, |
| compile_model: bool = False, |
| taxonomy_db_path: Optional[str] = None, |
| qwen_max_length: int = 512, |
| qwen_batch_size: int = 16, |
| **_: dict, |
| ): |
| self.device = torch.device(device) |
| self.model_dir = Path(model_path) |
|
|
| |
| vocab_path = self.model_dir / "vocab.json" |
| if not vocab_path.exists(): |
| parent_vocab = self.model_dir.parent / "vocab.json" |
| if parent_vocab.exists(): |
| vocab_path = parent_vocab |
| else: |
| raise FileNotFoundError(f"Missing {self.model_dir / 'vocab.json'}") |
| trainer_cfg = self.model_dir / "trainer_config.json" |
| cfg_path = trainer_cfg if trainer_cfg.exists() else (self.model_dir / "config.json") |
| if not cfg_path.exists(): |
| raise FileNotFoundError(f"Missing trainer_config.json or config.json in {self.model_dir}") |
|
|
| |
| with open(cfg_path, "r") as f: |
| self.config = json.load(f) |
|
|
| |
| |
| vocab_dir = vocab_path.parent |
| self.tokenizer = CodonTokenizer.from_pretrained(str(vocab_dir)) |
| self.V = int(self.tokenizer.vocab_size) |
| self._eos_id = int(self.tokenizer.eos_token_id) |
| self._pad_id = int(self.tokenizer.pad_token_id) |
| self._num_special = int(self.tokenizer.num_special_tokens) |
|
|
| |
| self.species_store = species_store |
| self.species_vocab = (self.species_store.vocab if self.species_store is not None else {}) |
| self.taxonomy_db_path = taxonomy_db_path |
| self.qwen_opts = { |
| "max_length": int(qwen_max_length), |
| "batch_size": int(qwen_batch_size), |
| } |
| |
| self._qwen_tokenizer = None |
| self._qwen_model = None |
|
|
| |
| state = self._load_state_dict() |
| arch = self._infer_arch_from_state_dict(state) |
| self.model = CodonTranslatorModel( |
| vocab_size=self.V, |
| hidden_size=int(arch["hidden_size"]), |
| num_layers=int(arch["num_layers"]), |
| num_heads=int(arch["num_heads"]), |
| mlp_ratio=float(arch["mlp_ratio"]), |
| max_position_embeddings=int(arch["max_position_embeddings"]), |
| dropout=float(self.config.get("dropout", 0.1)), |
| num_special_tokens=self._num_special, |
| special_ids=self.tokenizer.special_ids, |
| esm_model_name=str(arch["esm_model_name"]) if bool(arch["prepend_protein"]) else None, |
| esm_device=str(arch["esm_device"]), |
| esm_dtype=str(arch["esm_dtype"]), |
| max_protein_prefix=int(arch["max_protein_prefix"]) if bool(arch["prepend_protein"]) else 0, |
| max_species_prefix=int(arch["max_species_prefix"]) if bool(arch["prepend_species"]) else 0, |
| prepend_species=bool(arch["prepend_species"]), |
| prepend_protein=bool(arch["prepend_protein"]), |
| species_embedding_dim=int(self.config.get("species_embedding_dim", 1024)), |
| attn_impl=str(arch.get("attn_impl", "gqa")), |
| num_kv_groups=int(arch.get("num_kv_groups", 0)), |
| ) |
| missing, unexpected = self.model.load_state_dict(state, strict=False) |
| if len(unexpected) > 0: |
| logger.warning(f"Unexpected keys in state dict: {unexpected[:10]}{'...' if len(unexpected) > 10 else ''}") |
| if len(missing) > 0: |
| logger.warning(f"Missing keys in state dict: {missing[:10]}{'...' if len(missing) > 10 else ''}") |
|
|
| if compile_model: |
| |
| self.model = torch.compile(self.model) |
|
|
| self.model.to(self.device).eval() |
| logger.info(f"Loaded GPT model from {self.model_dir}") |
| try: |
| hs = int(getattr(self.model, "hidden_size", -1)) |
| hh = int(getattr(self.model, "num_heads", -1)) |
| nl = int(getattr(self.model, "num_layers", -1)) |
| logger.info(f"Reconstructed arch: hidden={hs} heads={hh} layers={nl}") |
| except Exception: |
| pass |
|
|
| |
| self._allowed_fixed = torch.ones(self.V, dtype=torch.bool, device=self.device) |
| self._allowed_fixed[:self._num_special] = False |
|
|
| self._allowed_variable = torch.ones(self.V, dtype=torch.bool, device=self.device) |
| self._allowed_variable[:self._num_special] = False |
| self._allowed_variable[self._eos_id] = True |
|
|
| |
| |
| |
|
|
| def _load_state_dict(self) -> Dict[str, torch.Tensor]: |
| st_p = self.model_dir / "model.safetensors" |
| pt_p = self.model_dir / "pytorch_model.bin" |
| if st_p.exists(): |
| return load_file(st_p) |
| if pt_p.exists(): |
| return torch.load(pt_p, map_location="cpu") |
| raise FileNotFoundError(f"No model.safetensors or pytorch_model.bin in {self.model_dir}") |
|
|
| def _infer_arch_from_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, Union[int, float, bool, str]]: |
| arch: Dict[str, Union[int, float, bool, str]] = {} |
|
|
| |
| if "lm_head.weight" in state_dict: |
| arch["hidden_size"] = int(state_dict["lm_head.weight"].shape[1]) |
| else: |
| for k, v in state_dict.items(): |
| if k.endswith("ln_f.weight"): |
| arch["hidden_size"] = int(v.shape[0]) |
| break |
| |
| cfg = self.config or {} |
| if "hidden_size" in cfg: |
| arch["hidden_size"] = int(cfg["hidden_size"]) |
| if "hidden_size" not in arch: |
| arch["hidden_size"] = int(cfg.get("hidden_size", 960)) |
| H = int(arch["hidden_size"]) |
|
|
| |
| max_block = -1 |
| for k in state_dict.keys(): |
| if k.startswith("blocks."): |
| idx = int(k.split(".")[1]) |
| if idx > max_block: |
| max_block = idx |
| arch["num_layers"] = (max_block + 1) if max_block >= 0 else int(cfg.get("num_hidden_layers", 12)) |
| if "num_hidden_layers" in cfg: |
| arch["num_layers"] = int(cfg["num_hidden_layers"]) |
|
|
| |
| w1_key = "blocks.0.ffn.w1.weight" if "blocks.0.ffn.w1.weight" in state_dict else None |
| if w1_key is None: |
| for i in range(1, 3): |
| k = f"blocks.{i}.ffn.w1.weight" |
| if k in state_dict: |
| w1_key = k |
| break |
| if w1_key is not None and H > 0: |
| arch["mlp_ratio"] = float(int(state_dict[w1_key].shape[0]) / H) |
| else: |
| arch["mlp_ratio"] = float(cfg.get("mlp_ratio", 4.0)) |
|
|
| |
| cfg_heads = cfg.get("num_attention_heads") |
| if isinstance(cfg_heads, int) and cfg_heads > 0 and H % cfg_heads == 0: |
| arch["num_heads"] = int(cfg_heads) |
| else: |
| for h in (16, 15, 12, 10, 8, 6, 5, 4, 3, 2, 1): |
| if H % h == 0: |
| arch["num_heads"] = h |
| break |
|
|
| |
| arch["prepend_species"] = bool(cfg.get("prepend_species", any(k.startswith("species_ln.") for k in state_dict.keys()))) |
| has_esm = any(k.startswith("esm_ln.") for k in state_dict.keys()) or any(k.startswith("esm.") for k in state_dict.keys()) |
| arch["prepend_protein"] = bool(cfg.get("prepend_protein", bool(has_esm))) |
| arch["esm_model_name"] = str(cfg.get("esm_model_name", "esmc_300m")) |
| arch["esm_device"] = str(cfg.get("esm_device", "cuda")) |
| arch["esm_dtype"] = str(cfg.get("esm_dtype", "bf16")).lower() |
| arch["max_protein_prefix"] = int(cfg.get("max_protein_prefix", 0)) |
| arch["max_species_prefix"] = int(cfg.get("max_species_prefix", 0)) |
|
|
| if "max_length" in cfg: |
| arch["max_position_embeddings"] = int(cfg.get("max_length", 1024)) |
| else: |
| arch["max_position_embeddings"] = int(cfg.get("max_position_embeddings", 1024)) |
| |
| attn_impl = str(cfg.get("attn_impl", "")) |
| num_kv_groups = int(cfg.get("num_kv_groups", 0)) |
| if not attn_impl: |
| wk_key = next((k for k in state_dict.keys() if k.endswith("attn.Wk.weight")), None) |
| if wk_key is not None: |
| attn_impl = "gqa" |
| out_ch, _ = state_dict[wk_key].shape |
| num_heads = int(arch.get("num_heads", 1)) |
| head_dim = int(arch["hidden_size"]) // max(1, num_heads) |
| if head_dim > 0: |
| num_kv_groups = max(1, out_ch // head_dim) |
| else: |
| attn_impl = "mha" |
| num_kv_groups = 0 |
| arch["attn_impl"] = attn_impl |
| arch["num_kv_groups"] = num_kv_groups |
|
|
| return arch |
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def sample( |
| self, |
| num_sequences: int = 1, |
| sequence_length: int = 100, |
| species: Optional[Union[str, List[str]]] = None, |
| protein_sequences: Optional[Union[str, List[str]]] = None, |
| control_mode: str = "fixed", |
| target_protein_length: Optional[int] = None, |
| temperature: float = 1.0, |
| top_k: Optional[int] = None, |
| top_p: Optional[float] = None, |
| seed: Optional[int] = None, |
| return_intermediate: bool = False, |
| progress_bar: bool = False, |
| species_emb: Optional[torch.Tensor] = None, |
| species_tok_emb: Optional[torch.Tensor] = None, |
| enforce_translation: bool = False, |
| codon_enforcement_weight: float = 10.0, |
| ) -> Dict[str, Union[List[str], torch.Tensor, List[bool]]]: |
|
|
| if seed is not None: |
| torch.manual_seed(int(seed)) |
| np.random.seed(int(seed)) |
|
|
| if control_mode not in ("fixed", "variable"): |
| raise ValueError(f"control_mode must be 'fixed' or 'variable', got {control_mode}") |
|
|
| B = int(num_sequences) |
| T_codons = int(sequence_length if target_protein_length is None else target_protein_length) |
|
|
| |
| cond: Dict[str, Union[str, List[str], torch.Tensor]] = {"control_mode": control_mode} |
|
|
| |
| if species_tok_emb is not None: |
| if species_tok_emb.ndim != 3 or species_tok_emb.size(0) != B: |
| raise ValueError("species_tok_emb must be [B, Ls, Ds]") |
| st = species_tok_emb.to(self.device) |
| cond["species_tok_emb_src"] = st |
| cond["species_tok_emb_tgt"] = st |
| elif species_emb is not None: |
| if species_emb.ndim != 2 or species_emb.size(0) != B: |
| raise ValueError("species_emb must be [B, Ds]") |
| se = species_emb.to(self.device) |
| cond["species_emb_src"] = se |
| cond["species_emb_tgt"] = se |
| elif species is not None: |
| names = [species] * B if isinstance(species, str) else species |
| if len(names) != B: |
| raise ValueError("Length of species list must match num_sequences") |
|
|
| |
| if self.species_store is not None: |
| ids = [self.species_store.vocab.get(n, -1) for n in names] |
| known_mask = [i for i, sid in enumerate(ids) if sid >= 0] |
| unk_mask = [i for i, sid in enumerate(ids) if sid < 0] |
|
|
| |
| use_sequence = bool(getattr(self.species_store, "is_legacy", False)) |
| if not use_sequence: |
| |
| q_tok, q_len = self._qwen_embed_names(names, pooling="sequence") |
| cond["species_tok_emb_src"] = q_tok.to(self.device) |
| cond["species_tok_emb_tgt"] = q_tok.to(self.device) |
| else: |
| |
| seq_list: List[torch.Tensor] = [None] * B |
| D = int(getattr(self.species_store, "_ds", 1024)) |
| |
| if known_mask: |
| sub_ids = [ids[i] for i in known_mask] |
| result = self.species_store.batch_get(sub_ids) |
| assert isinstance(result, tuple) |
| sp_tok, _ = result |
| for j, i in enumerate(known_mask): |
| row = sp_tok[j] |
| nonzero = (row.abs().sum(dim=-1) > 0) |
| L = int(nonzero.sum().item()) if nonzero.any() else int(row.size(0)) |
| seq_list[i] = row[:L].to(self.device) |
| |
| if unk_mask: |
| unk_names = [names[i] for i in unk_mask] |
| q_tok, q_len = self._qwen_embed_names(unk_names, pooling="sequence") |
| for j, i in enumerate(unk_mask): |
| L = int(q_len[j].item()) |
| seq_list[i] = q_tok[j, :L, :].to(self.device) |
|
|
| |
| Lmax = max((t.size(0) for t in seq_list if t is not None), default=0) |
| if Lmax == 0: |
| raise RuntimeError("No species embeddings could be constructed.") |
| padded = torch.zeros(B, Lmax, D, device=self.device, dtype=seq_list[0].dtype) |
| for i, t in enumerate(seq_list): |
| if t is None: |
| continue |
| L = t.size(0) |
| padded[i, :L, :] = t |
| cond["species_tok_emb_src"] = padded |
| cond["species_tok_emb_tgt"] = padded |
| else: |
| |
| emb, lengths = self._qwen_embed_names(names, pooling="sequence") |
| st = emb.to(self.device, non_blocking=True) |
| cond["species_tok_emb_src"] = st |
| cond["species_tok_emb_tgt"] = st |
|
|
| |
| if protein_sequences is not None: |
| if isinstance(protein_sequences, list): |
| if len(protein_sequences) != B: |
| raise ValueError("Length of protein_sequences must match num_sequences") |
| cond["protein_seqs"] = protein_sequences |
| else: |
| cond["protein_seqs"] = [protein_sequences] * B |
|
|
| |
| input_ids = torch.empty((B, 0), dtype=torch.long, device=self.device) |
|
|
| |
| pref = None |
| try: |
| out0 = self.model(codon_ids=input_ids, cond=cond, return_dict=True, use_cache=True) |
| pref = out0.get("prefix_len") if isinstance(out0, dict) else None |
| if pref is not None: |
| max_pos = int(getattr(self.model, "max_position_embeddings", 1024)) |
| remaining0 = max_pos - (pref + 1) |
| need_cap = (remaining0 <= 0).any() |
| else: |
| need_cap = False |
| if need_cap: |
| prev_sp = int(getattr(self.model, "max_species_prefix", 0)) |
| prev_pp = int(getattr(self.model, "max_protein_prefix", 0)) |
| if prev_sp == 0 or prev_sp > 256: |
| setattr(self.model, "max_species_prefix", 256) |
| if prev_pp == 0 or prev_pp > 256: |
| setattr(self.model, "max_protein_prefix", 256) |
| out0b = self.model(codon_ids=input_ids, cond=cond, return_dict=True, use_cache=True) |
| pref = out0b.get("prefix_len") if isinstance(out0b, dict) else None |
| if pref is not None: |
| remaining0b = max_pos - (pref + 1) |
| if (remaining0b <= 0).all(): |
| setattr(self.model, "max_species_prefix", 128) |
| setattr(self.model, "max_protein_prefix", 128) |
| out0b = self.model(codon_ids=input_ids, cond=cond, return_dict=True, use_cache=True) |
| pref = out0b.get("prefix_len") if isinstance(out0b, dict) else pref |
| |
| out_prefill = out0 if pref is None else out0 |
| except Exception: |
| |
| out_prefill = self.model(codon_ids=input_ids, cond=cond, return_dict=True, use_cache=True) |
| pref = out_prefill.get("prefix_len") if isinstance(out_prefill, dict) else None |
|
|
| allowed = self._allowed_variable if control_mode == "variable" else self._allowed_fixed |
| finished = torch.zeros(B, dtype=torch.bool, device=self.device) |
| capacity_truncated = torch.zeros(B, dtype=torch.bool, device=self.device) |
|
|
| intermediate = [] if return_intermediate else None |
| aa2codons = self.tokenizer.aa2codons_char_map() |
|
|
| |
| try: |
| if pref is not None: |
| max_pos = int(getattr(self.model, "max_position_embeddings", 1024)) |
| remaining = (max_pos - (pref + 1)).clamp(min=0) |
| T_codons = int(min(T_codons, int(remaining.max().item()))) |
| except Exception: |
| pass |
|
|
| |
| kv = out_prefill.get("present_kv") if isinstance(out_prefill, dict) else None |
| logits = out_prefill.get("next_logits") if isinstance(out_prefill, dict) else None |
| if kv is None or logits is None: |
| |
| out_prefill = self.model(codon_ids=input_ids, cond=cond, return_dict=True, use_cache=True) |
| kv = out_prefill.get("present_kv") |
| logits = out_prefill.get("next_logits") |
| assert kv is not None and logits is not None |
| prefix_len = pref if pref is not None else torch.zeros(B, dtype=torch.long, device=self.device) |
| prefill_len = (prefix_len + 1) |
|
|
| rng = range(T_codons) |
| if progress_bar: |
| from tqdm import tqdm |
| rng = tqdm(rng, desc="GPT sampling", total=T_codons) |
|
|
| for step in rng: |
| |
| max_pos = int(getattr(self.model, "max_position_embeddings", 1024)) |
| remaining_now = (max_pos - prefill_len - input_ids.size(1)).clamp(max=10**9) |
| cant_extend = remaining_now <= 0 |
| newly_blocked = (~finished) & cant_extend |
| capacity_truncated = capacity_truncated | newly_blocked |
| finished = finished | cant_extend |
|
|
| |
| logits = logits.masked_fill(~allowed, float("-inf")) |
|
|
| |
| |
| if finished.any(): |
| logits[finished] = float("-inf") |
| logits[finished, self._pad_id] = 0.0 |
|
|
| |
| if enforce_translation and ("protein_seqs" in cond): |
| aas_now: List[Optional[str]] = [] |
| prot_list = cond["protein_seqs"] |
| assert isinstance(prot_list, list) |
| for i in range(B): |
| seq = prot_list[i] |
| aas_now.append(seq[step] if step < len(seq) else None) |
|
|
| mask = torch.zeros_like(logits, dtype=torch.bool) |
| for i, a in enumerate(aas_now): |
| if a is None: |
| mask[i, self._num_special:self.V] = True |
| else: |
| valid = aa2codons.get(a, []) |
| if len(valid) == 0: |
| mask[i, self._num_special:self.V] = True |
| else: |
| mask[i, valid] = True |
| logits = logits.masked_fill(~mask, float("-inf")) |
|
|
| |
| if temperature != 1.0: |
| logits = logits / float(temperature) |
| if top_k is not None: |
| logits = _top_k_filtering(logits, int(top_k)) |
| if top_p is not None: |
| logits = _top_p_filtering(logits, float(top_p)) |
|
|
| probs = F.softmax(logits, dim=-1) |
| next_tok = torch.multinomial(probs, num_samples=1) |
|
|
| if control_mode == "variable": |
| |
| eos_mask = (next_tok.squeeze(-1) == self._eos_id) |
| finished = finished | eos_mask |
|
|
| input_ids = torch.cat([input_ids, next_tok], dim=1) |
|
|
| if return_intermediate: |
| intermediate.append(input_ids.clone()) |
|
|
| |
| if finished.all(): |
| break |
|
|
| |
| pos_offset = int(prefill_len.max().item()) + input_ids.size(1) - 1 |
| out_inc = self.model( |
| codon_ids=next_tok, |
| cond=None, |
| return_dict=True, |
| use_cache=True, |
| past_kv=kv, |
| position_offset=pos_offset, |
| ) |
| kv = out_inc.get("present_kv") |
| logits = out_inc.get("next_logits") |
| assert kv is not None and logits is not None |
|
|
| |
| output_token_rows: List[List[int]] = [] |
| for row in input_ids.tolist(): |
| toks: List[int] = [] |
| for t in row: |
| if t == self._pad_id: |
| continue |
| if t == self._eos_id: |
| break |
| if t >= self._num_special and t < self.V: |
| toks.append(int(t)) |
| if control_mode == "fixed": |
| |
| toks = toks[:T_codons] |
| output_token_rows.append(toks) |
|
|
| sequences = [self.tokenizer.decode_codon_seq(row) for row in output_token_rows] |
|
|
| |
| |
| max_len = max((len(r) for r in output_token_rows), default=0) |
| if max_len > 0: |
| ids_padded = torch.full( |
| (len(output_token_rows), max_len), |
| self._pad_id, |
| device=self.device, |
| dtype=torch.long, |
| ) |
| for i, row in enumerate(output_token_rows): |
| if len(row) > 0: |
| ids_padded[i, : len(row)] = torch.tensor(row, device=self.device, dtype=torch.long) |
| else: |
| ids_padded = torch.empty((len(output_token_rows), 0), device=self.device, dtype=torch.long) |
|
|
| result: Dict[str, Union[List[str], torch.Tensor, List[bool]]] = { |
| "sequences": sequences, |
| "input_ids": ids_padded, |
| "capacity_truncated": capacity_truncated.detach().bool().tolist(), |
| } |
| if return_intermediate: |
| result["intermediate_states"] = intermediate |
| return result |
|
|
| |
| |
| |
| def _ensure_qwen_loaded(self): |
| if self._qwen_tokenizer is not None and self._qwen_model is not None: |
| return |
| from transformers import AutoTokenizer, AutoModel |
| self._qwen_tokenizer = AutoTokenizer.from_pretrained( |
| "Qwen/Qwen3-Embedding-0.6B", trust_remote_code=True, padding_side="left" |
| ) |
| dtype = torch.float16 if self.device.type == "cuda" else torch.float32 |
| self._qwen_model = AutoModel.from_pretrained( |
| "Qwen/Qwen3-Embedding-0.6B", torch_dtype=dtype, trust_remote_code=True |
| ).to(self.device).eval() |
|
|
| @staticmethod |
| def _last_token_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: |
| left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) |
| if left_padding: |
| return last_hidden_states[:, -1] |
| else: |
| sequence_lengths = attention_mask.sum(dim=1) - 1 |
| batch_size = last_hidden_states.shape[0] |
| return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths] |
|
|
| @staticmethod |
| def _format_instruct(task: str, query: str) -> str: |
| return f"Instruct: {task}\nQuery: {query}" |
|
|
| @torch.no_grad() |
| def _qwen_embed_names(self, names: List[str], pooling: str = "sequence") -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
| |
| taxonomy_db = None |
| if self.taxonomy_db_path: |
| try: |
| with open(self.taxonomy_db_path, "r") as f: |
| import json |
| taxonomy_db = json.load(f) |
| except Exception: |
| taxonomy_db = None |
|
|
| self._ensure_qwen_loaded() |
| tokenizer = self._qwen_tokenizer |
| model = self._qwen_model |
| assert tokenizer is not None and model is not None |
|
|
| task = ( |
| "Given a species taxonomy information, generate a biological embedding " |
| "representing its taxonomic and evolutionary characteristics" |
| ) |
| texts = [self._format_instruct(task, taxonomy_db.get(s, s) if taxonomy_db else s) for s in names] |
|
|
| BATCH = int(self.qwen_opts.get("batch_size", 16)) |
| max_len = int(self.qwen_opts.get("max_length", 512)) |
|
|
| |
| seqs: List[torch.Tensor] = [] |
| lens: List[int] = [] |
| for i in range(0, len(texts), BATCH): |
| chunk = texts[i : i + BATCH] |
| inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True, max_length=max_len).to(self.device) |
| out = model(**inputs) |
| h = torch.nn.functional.normalize(out.last_hidden_state, p=2, dim=-1) |
| attn = inputs["attention_mask"] |
| for j in range(h.size(0)): |
| L = int(attn[j].sum().item()) |
| seqs.append(h[j, :L, :].float().cpu()) |
| lens.append(L) |
| |
| Lmax = max(lens) if lens else 0 |
| D = seqs[0].size(1) if seqs else 0 |
| padded = torch.zeros(len(seqs), Lmax, D) |
| for i, t in enumerate(seqs): |
| padded[i, : t.size(0), :] = t |
| return padded, torch.tensor(lens, dtype=torch.long) |
|
|
| |
| |
| |
|
|
| |
|
|
|
|
| |
| |
| |
|
|
| def sample_sequences( |
| model_path: str, |
| num_sequences: int = 10, |
| sequence_length: int = 100, |
| species: Optional[Union[str, List[str]]] = None, |
| protein_sequence: Optional[Union[str, List[str]]] = None, |
| **kwargs |
| ) -> List[str]: |
| sampler = CodonSampler(model_path) |
| out = sampler.sample( |
| num_sequences=num_sequences, |
| sequence_length=sequence_length, |
| species=species, |
| protein_sequences=protein_sequence, |
| **kwargs |
| ) |
| return out["sequences"] |
|
|