Spaces:
Running
Running
File size: 8,061 Bytes
0ed74db | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 | """PhenoLoRAModel β LoRA-fine-tuned ESM-2 + per-category mean-pool + multi-task heads.
Architecture:
For each genome, accept up to K proteins per category Γ 8 categories.
For each protein: tokenize β ESM-2(+LoRA) β masked mean-pool over residues β 1 vector.
For each category: mean-pool over its proteins β 1 category vector.
Concatenate the 8 category vectors β genome vector.
Predict the 4 phenotype targets via 4 small heads.
Trainable parameters:
- LoRA adapters on ESM-2 attention (~1.5M params for t30, r=8 on q+v).
- Per-protein projection (optional, default skipped β heads operate on raw 8ΓD vector).
- 4 prediction heads (~10-50K params combined).
Multi-task loss handles per-target missing labels via a binary mask.
"""
from __future__ import annotations
from dataclasses import dataclass
import torch
from torch import nn
CATEGORIES = ["temperature", "ph", "oxygen", "salt", "vitamin", "nitrogen", "carbon", "special"]
OXYGEN_CLASSES = ["aerobe", "anaerobe", "facultative_anaerobe", "microaerobe"]
N_OXYGEN_CLASSES = len(OXYGEN_CLASSES)
@dataclass
class LoraModelConfig:
esm_model_name: str = "facebook/esm2_t12_35M_UR50D"
lora_r: int = 8
lora_alpha: int = 16
lora_dropout: float = 0.05
lora_target: tuple[str, ...] = ("query", "value")
head_hidden_dim: int = 128
head_dropout: float = 0.1
max_seq_len: int = 512 # truncate long proteins to fit memory
max_proteins_per_cat: int = 6 # cap protein count per category at train time
gradient_checkpointing: bool = True # trade compute for memory
class PhenoLoRAModel(nn.Module):
def __init__(self, cfg: LoraModelConfig):
super().__init__()
from peft import LoraConfig, get_peft_model
from transformers import AutoModel, AutoTokenizer
self.cfg = cfg
self.tokenizer = AutoTokenizer.from_pretrained(cfg.esm_model_name)
base = AutoModel.from_pretrained(cfg.esm_model_name)
self.embed_dim = base.config.hidden_size
lora_cfg = LoraConfig(
r=cfg.lora_r,
lora_alpha=cfg.lora_alpha,
lora_dropout=cfg.lora_dropout,
target_modules=list(cfg.lora_target),
bias="none",
)
self.esm = get_peft_model(base, lora_cfg)
if cfg.gradient_checkpointing:
base.gradient_checkpointing_enable()
base.enable_input_require_grads()
# 8 Γ embed_dim β per-target heads.
genome_dim = self.embed_dim * len(CATEGORIES)
hidden = cfg.head_hidden_dim
def regression_head() -> nn.Module:
return nn.Sequential(
nn.Linear(genome_dim, hidden),
nn.GELU(),
nn.Dropout(cfg.head_dropout),
nn.Linear(hidden, 1),
)
self.heads = nn.ModuleDict({
"temp": regression_head(),
"ph": regression_head(),
"salt": regression_head(),
"oxy": nn.Sequential(
nn.Linear(genome_dim, hidden),
nn.GELU(),
nn.Dropout(cfg.head_dropout),
nn.Linear(hidden, N_OXYGEN_CLASSES),
),
})
def encode_proteins(self, proteins: list[str], device: torch.device) -> torch.Tensor:
"""Tokenize and ESM-encode a list of proteins β tensor of shape (n_proteins, embed_dim).
Returns a zero tensor of shape (0, embed_dim) for an empty list.
"""
if not proteins:
return torch.zeros((0, self.embed_dim), device=device)
enc = self.tokenizer(
proteins,
return_tensors="pt",
padding=True,
truncation=True,
max_length=self.cfg.max_seq_len,
)
enc = {k: v.to(device) for k, v in enc.items()}
outputs = self.esm(**enc)
last = outputs.last_hidden_state # (B, L, D)
mask = enc["attention_mask"].unsqueeze(-1).to(last.dtype)
pooled = (last * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1.0)
return pooled
def encode_genome(self, by_category: dict[str, list[str]], device: torch.device) -> torch.Tensor:
"""Build one genome vector of shape (8 Γ embed_dim,) by category-pooling proteins."""
cat_vectors: list[torch.Tensor] = []
cap = self.cfg.max_proteins_per_cat
for cat in CATEGORIES:
proteins = by_category.get(cat) or []
if proteins:
# Cap per-category protein count at train time to control memory.
# Already sorted shortest-first by the extractor.
proteins = proteins[:cap]
per_protein = self.encode_proteins(proteins, device) # (n, D)
cat_vec = per_protein.mean(dim=0)
else:
cat_vec = torch.zeros(self.embed_dim, device=device)
cat_vectors.append(cat_vec)
return torch.cat(cat_vectors, dim=0)
def forward(self, genomes: list[dict[str, list[str]]], device: torch.device) -> dict[str, torch.Tensor]:
"""Batched forward over a list of genomes (variable protein counts per genome).
Returns dict of per-target predictions:
temp, ph, salt: (B,) regression predictions
oxy: (B, N_OXYGEN_CLASSES) logits
"""
genome_vecs = torch.stack(
[self.encode_genome(g, device) for g in genomes],
dim=0,
) # (B, 8*D)
return {
"temp": self.heads["temp"](genome_vecs).squeeze(-1),
"ph": self.heads["ph"](genome_vecs).squeeze(-1),
"salt": self.heads["salt"](genome_vecs).squeeze(-1),
"oxy": self.heads["oxy"](genome_vecs),
}
def trainable_param_count(self) -> tuple[int, int]:
"""Return (trainable, total) parameter counts."""
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
total = sum(p.numel() for p in self.parameters())
return trainable, total
def masked_multitask_loss(
preds: dict[str, torch.Tensor],
labels: dict[str, torch.Tensor],
label_mask: dict[str, torch.Tensor],
target_weights: dict[str, float] | None = None,
oxy_class_weights: tuple[float, ...] | None = None,
) -> tuple[torch.Tensor, dict[str, float]]:
"""Sum of per-target losses, weighted by the binary label mask.
Regression targets: MSE.
Oxygen: cross-entropy.
Each per-target loss is averaged over rows with mask=1; if mask is all-zero for a
target in this batch, that target contributes 0 to the total.
"""
if target_weights is None:
target_weights = {"temp": 1.0, "ph": 1.0, "salt": 1.0, "oxy": 1.0}
per_target_loss: dict[str, float] = {}
# Graph-connected zero β keeps backward() safe when every mask is empty.
total = (preds["temp"] * 0.0).sum()
for tgt in ("temp", "ph", "salt"):
mask = label_mask[tgt].float()
sq = (preds[tgt] - labels[tgt]) ** 2
# clamp(min=1) β loss=0 (graph-connected via `sq`) when mask is all-zero.
loss = (sq * mask).sum() / mask.sum().clamp(min=1.0)
total = total + target_weights[tgt] * loss
per_target_loss[tgt] = float(loss.detach().cpu())
mask_oxy = label_mask["oxy"].float()
logits = preds["oxy"]
labels_oxy = labels["oxy"].long()
class_weight = None
if oxy_class_weights is not None:
if len(oxy_class_weights) != logits.shape[-1]:
raise ValueError("oxy_class_weights must match the number of oxygen classes")
class_weight = torch.tensor(oxy_class_weights, dtype=logits.dtype, device=logits.device)
per_row_loss = nn.functional.cross_entropy(
logits,
labels_oxy,
weight=class_weight,
reduction="none",
)
loss_oxy = (per_row_loss * mask_oxy).sum() / mask_oxy.sum().clamp(min=1.0)
total = total + target_weights["oxy"] * loss_oxy
per_target_loss["oxy"] = float(loss_oxy.detach().cpu())
return total, per_target_loss
|