File size: 4,061 Bytes
0ed74db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""Per-genome ESM-2 embeddings.

Pipeline:
  1. Predict CDS via pyrodigal (reuses features.genome.predict_genes)
  2. For each protein: ESM-2 -> per-residue 320/640/1280-dim -> mean-pool over residues
  3. Mean-pool across all proteins in genome -> one fixed-dim vector per genome

Why ESM-2 specifically:
  - Reuses existing pyrodigal-predicted proteins (no DNA-window re-design)
  - Variants from 8M (laptop) to 3B params (cluster) -> easy to scale
  - Industry-standard for protein phenotype tasks
  - Mean-pool across residues + across proteome is the dumb-but-effective baseline

Model choices (set via env or argument):
  - facebook/esm2_t6_8M_UR50D    ->  320-dim, fast (laptop testing)
  - facebook/esm2_t12_35M_UR50D  ->  480-dim
  - facebook/esm2_t30_150M_UR50D ->  640-dim (recommended for GPU)
  - facebook/esm2_t33_650M_UR50D -> 1280-dim (best, needs GPU + 8GB+ VRAM)
"""
from __future__ import annotations

from typing import Any

import numpy as np
import torch
from transformers import AutoModel, AutoTokenizer

DEFAULT_MODEL = "facebook/esm2_t12_35M_UR50D"
ESM2_MAX_LEN = 1024  # ESM-2's positional embedding limit; longer proteins are truncated


def pick_device() -> torch.device:
    if torch.cuda.is_available():
        return torch.device("cuda")
    if torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")


def load_esm2(model_name: str = DEFAULT_MODEL, device: torch.device | None = None) -> tuple[Any, Any, torch.device]:
    """Load tokenizer + model on the best available device. Inference mode, fp16 on cuda."""
    device = device or pick_device()
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    dtype = torch.float16 if device.type == "cuda" else torch.float32
    model = AutoModel.from_pretrained(model_name, dtype=dtype)
    model.to(device)
    model.train(False)  # inference mode (equivalent to model.eval())
    return tokenizer, model, device


@torch.inference_mode()
def embed_proteins(
    proteins: list[str],
    tokenizer: Any,
    model: Any,
    device: torch.device,
    *,
    batch_size: int = 8,
    max_len: int = ESM2_MAX_LEN,
) -> np.ndarray:
    """Mean-pool the per-residue ESM-2 embeddings of each protein.

    Returns (n_proteins, embed_dim) float32 array.
    """
    if not proteins:
        return np.zeros((0, model.config.hidden_size), dtype=np.float32)

    out: list[np.ndarray] = []
    for i in range(0, len(proteins), batch_size):
        batch = proteins[i : i + batch_size]
        enc = tokenizer(
            batch,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_len,
        )
        enc = {k: v.to(device) for k, v in enc.items()}
        outputs = model(**enc)
        last_hidden = outputs.last_hidden_state  # (B, L, D)
        attention_mask = enc["attention_mask"].unsqueeze(-1).to(last_hidden.dtype)
        summed = (last_hidden * attention_mask).sum(dim=1)
        counts = attention_mask.sum(dim=1).clamp(min=1)
        pooled = summed / counts
        out.append(pooled.float().cpu().numpy())
    return np.concatenate(out, axis=0)


def embed_genome(
    proteins: list[str],
    tokenizer: Any,
    model: Any,
    device: torch.device,
    *,
    sample_n: int | None = None,
    batch_size: int = 8,
    rng: np.random.Generator | None = None,
) -> np.ndarray:
    """Return one fixed-dim vector summarizing the whole proteome.

    If ``sample_n`` is set, only that many proteins are embedded (uniformly sampled
    without replacement) to bound runtime. None = embed every protein.
    """
    if not proteins:
        return np.zeros(model.config.hidden_size, dtype=np.float32)

    if sample_n is not None and sample_n < len(proteins):
        rng = rng or np.random.default_rng(0)
        idx = rng.choice(len(proteins), size=sample_n, replace=False)
        proteins = [proteins[i] for i in idx]

    matrix = embed_proteins(proteins, tokenizer, model, device, batch_size=batch_size)
    return matrix.mean(axis=0).astype(np.float32)