Spaces:
Running
Running
| """Per-genome ESM-2 embeddings. | |
| Pipeline: | |
| 1. Predict CDS via pyrodigal (reuses features.genome.predict_genes) | |
| 2. For each protein: ESM-2 -> per-residue 320/640/1280-dim -> mean-pool over residues | |
| 3. Mean-pool across all proteins in genome -> one fixed-dim vector per genome | |
| Why ESM-2 specifically: | |
| - Reuses existing pyrodigal-predicted proteins (no DNA-window re-design) | |
| - Variants from 8M (laptop) to 3B params (cluster) -> easy to scale | |
| - Industry-standard for protein phenotype tasks | |
| - Mean-pool across residues + across proteome is the dumb-but-effective baseline | |
| Model choices (set via env or argument): | |
| - facebook/esm2_t6_8M_UR50D -> 320-dim, fast (laptop testing) | |
| - facebook/esm2_t12_35M_UR50D -> 480-dim | |
| - facebook/esm2_t30_150M_UR50D -> 640-dim (recommended for GPU) | |
| - facebook/esm2_t33_650M_UR50D -> 1280-dim (best, needs GPU + 8GB+ VRAM) | |
| """ | |
| from __future__ import annotations | |
| from typing import Any | |
| import numpy as np | |
| import torch | |
| from transformers import AutoModel, AutoTokenizer | |
| DEFAULT_MODEL = "facebook/esm2_t12_35M_UR50D" | |
| ESM2_MAX_LEN = 1024 # ESM-2's positional embedding limit; longer proteins are truncated | |
| def pick_device() -> torch.device: | |
| if torch.cuda.is_available(): | |
| return torch.device("cuda") | |
| if torch.backends.mps.is_available(): | |
| return torch.device("mps") | |
| return torch.device("cpu") | |
| def load_esm2(model_name: str = DEFAULT_MODEL, device: torch.device | None = None) -> tuple[Any, Any, torch.device]: | |
| """Load tokenizer + model on the best available device. Inference mode, fp16 on cuda.""" | |
| device = device or pick_device() | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| dtype = torch.float16 if device.type == "cuda" else torch.float32 | |
| model = AutoModel.from_pretrained(model_name, dtype=dtype) | |
| model.to(device) | |
| model.train(False) # inference mode (equivalent to model.eval()) | |
| return tokenizer, model, device | |
| def embed_proteins( | |
| proteins: list[str], | |
| tokenizer: Any, | |
| model: Any, | |
| device: torch.device, | |
| *, | |
| batch_size: int = 8, | |
| max_len: int = ESM2_MAX_LEN, | |
| ) -> np.ndarray: | |
| """Mean-pool the per-residue ESM-2 embeddings of each protein. | |
| Returns (n_proteins, embed_dim) float32 array. | |
| """ | |
| if not proteins: | |
| return np.zeros((0, model.config.hidden_size), dtype=np.float32) | |
| out: list[np.ndarray] = [] | |
| for i in range(0, len(proteins), batch_size): | |
| batch = proteins[i : i + batch_size] | |
| enc = tokenizer( | |
| batch, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=max_len, | |
| ) | |
| enc = {k: v.to(device) for k, v in enc.items()} | |
| outputs = model(**enc) | |
| last_hidden = outputs.last_hidden_state # (B, L, D) | |
| attention_mask = enc["attention_mask"].unsqueeze(-1).to(last_hidden.dtype) | |
| summed = (last_hidden * attention_mask).sum(dim=1) | |
| counts = attention_mask.sum(dim=1).clamp(min=1) | |
| pooled = summed / counts | |
| out.append(pooled.float().cpu().numpy()) | |
| return np.concatenate(out, axis=0) | |
| def embed_genome( | |
| proteins: list[str], | |
| tokenizer: Any, | |
| model: Any, | |
| device: torch.device, | |
| *, | |
| sample_n: int | None = None, | |
| batch_size: int = 8, | |
| rng: np.random.Generator | None = None, | |
| ) -> np.ndarray: | |
| """Return one fixed-dim vector summarizing the whole proteome. | |
| If ``sample_n`` is set, only that many proteins are embedded (uniformly sampled | |
| without replacement) to bound runtime. None = embed every protein. | |
| """ | |
| if not proteins: | |
| return np.zeros(model.config.hidden_size, dtype=np.float32) | |
| if sample_n is not None and sample_n < len(proteins): | |
| rng = rng or np.random.default_rng(0) | |
| idx = rng.choice(len(proteins), size=sample_n, replace=False) | |
| proteins = [proteins[i] for i in idx] | |
| matrix = embed_proteins(proteins, tokenizer, model, device, batch_size=batch_size) | |
| return matrix.mean(axis=0).astype(np.float32) | |