Miyu Horiuchi
Deploy app from main@a3254bf (no paper/ binaries)
0ed74db
"""Per-genome ESM-2 embeddings.
Pipeline:
1. Predict CDS via pyrodigal (reuses features.genome.predict_genes)
2. For each protein: ESM-2 -> per-residue 320/640/1280-dim -> mean-pool over residues
3. Mean-pool across all proteins in genome -> one fixed-dim vector per genome
Why ESM-2 specifically:
- Reuses existing pyrodigal-predicted proteins (no DNA-window re-design)
- Variants from 8M (laptop) to 3B params (cluster) -> easy to scale
- Industry-standard for protein phenotype tasks
- Mean-pool across residues + across proteome is the dumb-but-effective baseline
Model choices (set via env or argument):
- facebook/esm2_t6_8M_UR50D -> 320-dim, fast (laptop testing)
- facebook/esm2_t12_35M_UR50D -> 480-dim
- facebook/esm2_t30_150M_UR50D -> 640-dim (recommended for GPU)
- facebook/esm2_t33_650M_UR50D -> 1280-dim (best, needs GPU + 8GB+ VRAM)
"""
from __future__ import annotations
from typing import Any
import numpy as np
import torch
from transformers import AutoModel, AutoTokenizer
DEFAULT_MODEL = "facebook/esm2_t12_35M_UR50D"
ESM2_MAX_LEN = 1024 # ESM-2's positional embedding limit; longer proteins are truncated
def pick_device() -> torch.device:
if torch.cuda.is_available():
return torch.device("cuda")
if torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
def load_esm2(model_name: str = DEFAULT_MODEL, device: torch.device | None = None) -> tuple[Any, Any, torch.device]:
"""Load tokenizer + model on the best available device. Inference mode, fp16 on cuda."""
device = device or pick_device()
tokenizer = AutoTokenizer.from_pretrained(model_name)
dtype = torch.float16 if device.type == "cuda" else torch.float32
model = AutoModel.from_pretrained(model_name, dtype=dtype)
model.to(device)
model.train(False) # inference mode (equivalent to model.eval())
return tokenizer, model, device
@torch.inference_mode()
def embed_proteins(
proteins: list[str],
tokenizer: Any,
model: Any,
device: torch.device,
*,
batch_size: int = 8,
max_len: int = ESM2_MAX_LEN,
) -> np.ndarray:
"""Mean-pool the per-residue ESM-2 embeddings of each protein.
Returns (n_proteins, embed_dim) float32 array.
"""
if not proteins:
return np.zeros((0, model.config.hidden_size), dtype=np.float32)
out: list[np.ndarray] = []
for i in range(0, len(proteins), batch_size):
batch = proteins[i : i + batch_size]
enc = tokenizer(
batch,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_len,
)
enc = {k: v.to(device) for k, v in enc.items()}
outputs = model(**enc)
last_hidden = outputs.last_hidden_state # (B, L, D)
attention_mask = enc["attention_mask"].unsqueeze(-1).to(last_hidden.dtype)
summed = (last_hidden * attention_mask).sum(dim=1)
counts = attention_mask.sum(dim=1).clamp(min=1)
pooled = summed / counts
out.append(pooled.float().cpu().numpy())
return np.concatenate(out, axis=0)
def embed_genome(
proteins: list[str],
tokenizer: Any,
model: Any,
device: torch.device,
*,
sample_n: int | None = None,
batch_size: int = 8,
rng: np.random.Generator | None = None,
) -> np.ndarray:
"""Return one fixed-dim vector summarizing the whole proteome.
If ``sample_n`` is set, only that many proteins are embedded (uniformly sampled
without replacement) to bound runtime. None = embed every protein.
"""
if not proteins:
return np.zeros(model.config.hidden_size, dtype=np.float32)
if sample_n is not None and sample_n < len(proteins):
rng = rng or np.random.default_rng(0)
idx = rng.choice(len(proteins), size=sample_n, replace=False)
proteins = [proteins[i] for i in idx]
matrix = embed_proteins(proteins, tokenizer, model, device, batch_size=batch_size)
return matrix.mean(axis=0).astype(np.float32)