microbe-model / src /microbe_model /train /lora_model.py
Miyu Horiuchi
Deploy app from main@a3254bf (no paper/ binaries)
0ed74db
"""PhenoLoRAModel — LoRA-fine-tuned ESM-2 + per-category mean-pool + multi-task heads.
Architecture:
For each genome, accept up to K proteins per category × 8 categories.
For each protein: tokenize → ESM-2(+LoRA) → masked mean-pool over residues → 1 vector.
For each category: mean-pool over its proteins → 1 category vector.
Concatenate the 8 category vectors → genome vector.
Predict the 4 phenotype targets via 4 small heads.
Trainable parameters:
- LoRA adapters on ESM-2 attention (~1.5M params for t30, r=8 on q+v).
- Per-protein projection (optional, default skipped — heads operate on raw 8×D vector).
- 4 prediction heads (~10-50K params combined).
Multi-task loss handles per-target missing labels via a binary mask.
"""
from __future__ import annotations
from dataclasses import dataclass
import torch
from torch import nn
CATEGORIES = ["temperature", "ph", "oxygen", "salt", "vitamin", "nitrogen", "carbon", "special"]
OXYGEN_CLASSES = ["aerobe", "anaerobe", "facultative_anaerobe", "microaerobe"]
N_OXYGEN_CLASSES = len(OXYGEN_CLASSES)
@dataclass
class LoraModelConfig:
esm_model_name: str = "facebook/esm2_t12_35M_UR50D"
lora_r: int = 8
lora_alpha: int = 16
lora_dropout: float = 0.05
lora_target: tuple[str, ...] = ("query", "value")
head_hidden_dim: int = 128
head_dropout: float = 0.1
max_seq_len: int = 512 # truncate long proteins to fit memory
max_proteins_per_cat: int = 6 # cap protein count per category at train time
gradient_checkpointing: bool = True # trade compute for memory
class PhenoLoRAModel(nn.Module):
def __init__(self, cfg: LoraModelConfig):
super().__init__()
from peft import LoraConfig, get_peft_model
from transformers import AutoModel, AutoTokenizer
self.cfg = cfg
self.tokenizer = AutoTokenizer.from_pretrained(cfg.esm_model_name)
base = AutoModel.from_pretrained(cfg.esm_model_name)
self.embed_dim = base.config.hidden_size
lora_cfg = LoraConfig(
r=cfg.lora_r,
lora_alpha=cfg.lora_alpha,
lora_dropout=cfg.lora_dropout,
target_modules=list(cfg.lora_target),
bias="none",
)
self.esm = get_peft_model(base, lora_cfg)
if cfg.gradient_checkpointing:
base.gradient_checkpointing_enable()
base.enable_input_require_grads()
# 8 × embed_dim → per-target heads.
genome_dim = self.embed_dim * len(CATEGORIES)
hidden = cfg.head_hidden_dim
def regression_head() -> nn.Module:
return nn.Sequential(
nn.Linear(genome_dim, hidden),
nn.GELU(),
nn.Dropout(cfg.head_dropout),
nn.Linear(hidden, 1),
)
self.heads = nn.ModuleDict({
"temp": regression_head(),
"ph": regression_head(),
"salt": regression_head(),
"oxy": nn.Sequential(
nn.Linear(genome_dim, hidden),
nn.GELU(),
nn.Dropout(cfg.head_dropout),
nn.Linear(hidden, N_OXYGEN_CLASSES),
),
})
def encode_proteins(self, proteins: list[str], device: torch.device) -> torch.Tensor:
"""Tokenize and ESM-encode a list of proteins → tensor of shape (n_proteins, embed_dim).
Returns a zero tensor of shape (0, embed_dim) for an empty list.
"""
if not proteins:
return torch.zeros((0, self.embed_dim), device=device)
enc = self.tokenizer(
proteins,
return_tensors="pt",
padding=True,
truncation=True,
max_length=self.cfg.max_seq_len,
)
enc = {k: v.to(device) for k, v in enc.items()}
outputs = self.esm(**enc)
last = outputs.last_hidden_state # (B, L, D)
mask = enc["attention_mask"].unsqueeze(-1).to(last.dtype)
pooled = (last * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1.0)
return pooled
def encode_genome(self, by_category: dict[str, list[str]], device: torch.device) -> torch.Tensor:
"""Build one genome vector of shape (8 × embed_dim,) by category-pooling proteins."""
cat_vectors: list[torch.Tensor] = []
cap = self.cfg.max_proteins_per_cat
for cat in CATEGORIES:
proteins = by_category.get(cat) or []
if proteins:
# Cap per-category protein count at train time to control memory.
# Already sorted shortest-first by the extractor.
proteins = proteins[:cap]
per_protein = self.encode_proteins(proteins, device) # (n, D)
cat_vec = per_protein.mean(dim=0)
else:
cat_vec = torch.zeros(self.embed_dim, device=device)
cat_vectors.append(cat_vec)
return torch.cat(cat_vectors, dim=0)
def forward(self, genomes: list[dict[str, list[str]]], device: torch.device) -> dict[str, torch.Tensor]:
"""Batched forward over a list of genomes (variable protein counts per genome).
Returns dict of per-target predictions:
temp, ph, salt: (B,) regression predictions
oxy: (B, N_OXYGEN_CLASSES) logits
"""
genome_vecs = torch.stack(
[self.encode_genome(g, device) for g in genomes],
dim=0,
) # (B, 8*D)
return {
"temp": self.heads["temp"](genome_vecs).squeeze(-1),
"ph": self.heads["ph"](genome_vecs).squeeze(-1),
"salt": self.heads["salt"](genome_vecs).squeeze(-1),
"oxy": self.heads["oxy"](genome_vecs),
}
def trainable_param_count(self) -> tuple[int, int]:
"""Return (trainable, total) parameter counts."""
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
total = sum(p.numel() for p in self.parameters())
return trainable, total
def masked_multitask_loss(
preds: dict[str, torch.Tensor],
labels: dict[str, torch.Tensor],
label_mask: dict[str, torch.Tensor],
target_weights: dict[str, float] | None = None,
oxy_class_weights: tuple[float, ...] | None = None,
) -> tuple[torch.Tensor, dict[str, float]]:
"""Sum of per-target losses, weighted by the binary label mask.
Regression targets: MSE.
Oxygen: cross-entropy.
Each per-target loss is averaged over rows with mask=1; if mask is all-zero for a
target in this batch, that target contributes 0 to the total.
"""
if target_weights is None:
target_weights = {"temp": 1.0, "ph": 1.0, "salt": 1.0, "oxy": 1.0}
per_target_loss: dict[str, float] = {}
# Graph-connected zero — keeps backward() safe when every mask is empty.
total = (preds["temp"] * 0.0).sum()
for tgt in ("temp", "ph", "salt"):
mask = label_mask[tgt].float()
sq = (preds[tgt] - labels[tgt]) ** 2
# clamp(min=1) → loss=0 (graph-connected via `sq`) when mask is all-zero.
loss = (sq * mask).sum() / mask.sum().clamp(min=1.0)
total = total + target_weights[tgt] * loss
per_target_loss[tgt] = float(loss.detach().cpu())
mask_oxy = label_mask["oxy"].float()
logits = preds["oxy"]
labels_oxy = labels["oxy"].long()
class_weight = None
if oxy_class_weights is not None:
if len(oxy_class_weights) != logits.shape[-1]:
raise ValueError("oxy_class_weights must match the number of oxygen classes")
class_weight = torch.tensor(oxy_class_weights, dtype=logits.dtype, device=logits.device)
per_row_loss = nn.functional.cross_entropy(
logits,
labels_oxy,
weight=class_weight,
reduction="none",
)
loss_oxy = (per_row_loss * mask_oxy).sum() / mask_oxy.sum().clamp(min=1.0)
total = total + target_weights["oxy"] * loss_oxy
per_target_loss["oxy"] = float(loss_oxy.detach().cpu())
return total, per_target_loss