miplm-inference / app.py
mgorkemuz's picture
Fix DeepLoc cls10/cls2 label ordering (alphabetical matches training dataset)
0cbb26b verified
Raw
History Blame Contribute Delete
11.3 kB
"""
MiPLM inference API β€” runs on a HuggingFace Space (Docker SDK).
Endpoints:
GET / β€” list available models + device
GET /health β€” liveness probe
POST /predict β€” per-position 20-AA softmax over the sequence
POST /embed β€” mean-pooled last-layer embedding
POST /mutation β€” wildtype-marginal Ξ”log-likelihood matrix [L, 20]
POST /downstream β€” all fine-tuned task heads (classification + regression)
"""
import os
from typing import Dict, List, Tuple
import torch
import torch.nn.functional as F
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from transformers import EsmForMaskedLM, EsmForSequenceClassification, EsmTokenizer
MODEL_REGISTRY: Dict[str, str] = {
"miplm-ce": "HUBioDataLab/esm2-8m-ce",
"miplm-blosum": "HUBioDataLab/esm2-8m-softce",
"miplm-msa": "HUBioDataLab/esm2-8m-msa",
}
# Downstream heads β€” one HF repo per backbone, with 8 task subfolders each.
# Map frontend backbone name β†’ HF repo.
DOWNSTREAM_REGISTRY: Dict[str, str] = {
"miplm-ce": "HUBioDataLab/miplm-ce-tasks",
"miplm-blosum": "HUBioDataLab/miplm-blosum-tasks",
"miplm-msa": "HUBioDataLab/miplm-msa-tasks",
}
# Task metadata. `kind` drives post-processing:
# single_label β†’ softmax β†’ top-K classes with probabilities
# multi_label β†’ sigmoid β†’ top-K classes with independent probabilities
# regression β†’ raw scalar
# `labels` is the human-readable class vocabulary (when known). When unset the
# response falls back to numeric class indices.
DEEPLOC10_LABELS = [
# Alphabetical order β€” matches the dataset's label-int assignment used during training.
"Cell membrane", "Cytoplasm", "Endoplasmic reticulum", "Extracellular",
"Golgi apparatus", "Lysosome/Vacuole", "Mitochondrion", "Nucleus",
"Peroxisome", "Plastid",
]
DEEPLOC2_LABELS = ["Membrane", "Soluble"] # alphabetical (matches training)
METAL_LABELS = ["Non-binder", "Metal-ion binder"] # semantic order (matches training)
DOWNSTREAM_TASKS: Dict[str, Dict] = {
"deeploc-cls10": {"kind": "single_label", "labels": DEEPLOC10_LABELS, "title": "Subcellular localization (10-way)"},
"deeploc-cls2": {"kind": "single_label", "labels": DEEPLOC2_LABELS, "title": "Soluble vs membrane"},
"metalionbinding": {"kind": "single_label", "labels": METAL_LABELS, "title": "Metal-ion binding"},
"thermostability": {"kind": "regression", "labels": None, "title": "Thermostability", "unit": "normalised score"},
"ec": {"kind": "multi_label", "labels": None, "title": "EC enzyme class", "num_classes": 585},
"go-bp": {"kind": "multi_label", "labels": None, "title": "GO biological process", "num_classes": 1943},
"go-cc": {"kind": "multi_label", "labels": None, "title": "GO cellular component", "num_classes": 320},
"go-mf": {"kind": "multi_label", "labels": None, "title": "GO molecular function", "num_classes": 489},
}
STANDARD_AAS = "ACDEFGHIKLMNPQRSTVWY"
MAX_LEN = 1024
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
_models: Dict[str, Tuple[EsmForMaskedLM, EsmTokenizer, List[int]]] = {}
_downstream: Dict[Tuple[str, str], EsmForSequenceClassification] = {}
def get_model(name: str):
if name not in MODEL_REGISTRY:
raise HTTPException(400, f"unknown model {name!r}; available: {list(MODEL_REGISTRY)}")
if name not in _models:
repo = MODEL_REGISTRY[name]
tokenizer = EsmTokenizer.from_pretrained(repo)
model = EsmForMaskedLM.from_pretrained(repo).to(DEVICE).eval()
aa_ids = [tokenizer.convert_tokens_to_ids(aa) for aa in STANDARD_AAS]
_models[name] = (model, tokenizer, aa_ids)
return _models[name]
class _SeqRequest(BaseModel):
sequence: str = Field(..., min_length=1, max_length=MAX_LEN, pattern=r"^[A-Za-z]+$")
model: str = "miplm-blosum"
class PredictResponse(BaseModel):
model: str
sequence: str
aa_order: str
probs: List[List[float]]
class EmbedResponse(BaseModel):
model: str
dim: int
embedding: List[float]
class PerResidueEmbedResponse(BaseModel):
model: str
dim: int
embeddings: List[List[float]]
class MutationResponse(BaseModel):
model: str
sequence: str
aa_order: str
scores: List[List[float]]
app = FastAPI(title="MiPLM inference", version="0.1.0")
app.add_middleware(
CORSMiddleware,
allow_origins=os.environ.get("ALLOWED_ORIGINS", "*").split(","),
allow_methods=["GET", "POST"],
allow_headers=["*"],
)
@app.on_event("startup")
def _warmup():
for name in MODEL_REGISTRY:
try:
get_model(name)
print(f"[warmup] loaded {name}")
except Exception as e:
print(f"[warmup] failed to load {name}: {e}")
@app.get("/")
def root():
return {
"models": list(MODEL_REGISTRY.keys()),
"downstream_tasks": [
{"id": t, "title": cfg["title"], "kind": cfg["kind"]}
for t, cfg in DOWNSTREAM_TASKS.items()
],
"device": DEVICE,
"max_length": MAX_LEN,
"aa_order": STANDARD_AAS,
}
@app.get("/health")
def health():
return {"ok": True, "loaded": list(_models.keys())}
@torch.inference_mode()
def _forward(model, tokenizer, sequence: str):
seq = sequence.upper()
batch = tokenizer(seq, return_tensors="pt").to(DEVICE)
out = model(**batch, output_hidden_states=True)
logits = out.logits[0, 1:-1] # drop <cls>, <eos> -> [L, V]
hidden = out.hidden_states[-1][0, 1:-1] # [L, H]
return seq, logits, hidden
@app.post("/predict", response_model=PredictResponse)
def predict(req: _SeqRequest):
model, tokenizer, aa_ids = get_model(req.model)
seq, logits, _ = _forward(model, tokenizer, req.sequence)
probs = F.softmax(logits[:, aa_ids], dim=-1).tolist()
return PredictResponse(model=req.model, sequence=seq, aa_order=STANDARD_AAS, probs=probs)
@app.post("/embed", response_model=EmbedResponse)
def embed(req: _SeqRequest):
model, tokenizer, _ = get_model(req.model)
_, _, hidden = _forward(model, tokenizer, req.sequence)
emb = hidden.mean(dim=0).cpu().tolist()
return EmbedResponse(model=req.model, dim=len(emb), embedding=emb)
@app.post("/embed_per_residue", response_model=PerResidueEmbedResponse)
def embed_per_residue(req: _SeqRequest):
model, tokenizer, _ = get_model(req.model)
_, _, hidden = _forward(model, tokenizer, req.sequence)
embs = hidden.cpu().tolist()
return PerResidueEmbedResponse(
model=req.model, dim=hidden.shape[-1], embeddings=embs
)
@app.post("/mutation", response_model=MutationResponse)
def mutation(req: _SeqRequest):
model, tokenizer, aa_ids = get_model(req.model)
seq, logits, _ = _forward(model, tokenizer, req.sequence)
log_probs = F.log_softmax(logits[:, aa_ids], dim=-1) # [L, 20]
aa_to_idx = {a: i for i, a in enumerate(STANDARD_AAS)}
wt_idx = torch.tensor([aa_to_idx.get(c, 0) for c in seq], device=DEVICE)
wt_logp = log_probs.gather(1, wt_idx[:, None]) # [L, 1]
delta = (log_probs - wt_logp).tolist()
return MutationResponse(model=req.model, sequence=seq, aa_order=STANDARD_AAS, scores=delta)
# ─── Downstream task heads ──────────────────────────────────────────────────
def get_downstream(backbone: str, task: str) -> EsmForSequenceClassification:
"""Lazy-load and cache fine-tuned classification/regression heads."""
if backbone not in DOWNSTREAM_REGISTRY:
raise HTTPException(400, f"unknown backbone {backbone!r}; available: {list(DOWNSTREAM_REGISTRY)}")
if task not in DOWNSTREAM_TASKS:
raise HTTPException(400, f"unknown task {task!r}; available: {list(DOWNSTREAM_TASKS)}")
key = (backbone, task)
if key not in _downstream:
repo = DOWNSTREAM_REGISTRY[backbone]
print(f"[downstream] loading {repo} :: {task}")
m = EsmForSequenceClassification.from_pretrained(repo, subfolder=task).to(DEVICE).eval()
_downstream[key] = m
return _downstream[key]
class DownstreamRequest(BaseModel):
sequence: str = Field(..., min_length=1, max_length=MAX_LEN, pattern=r"^[A-Za-z]+$")
backbone: str = "miplm-blosum"
top_k: int = 3
class TaskPrediction(BaseModel):
task: str
title: str
kind: str # single_label | multi_label | regression
value: float | None = None # set for regression
unit: str | None = None
top: List[Dict] | None = None # set for classification β€” [{label/index, prob}]
num_classes: int | None = None
class DownstreamResponse(BaseModel):
backbone: str
sequence: str
predictions: List[TaskPrediction]
@app.post("/downstream", response_model=DownstreamResponse)
@torch.inference_mode()
def downstream(req: DownstreamRequest):
seq = req.sequence.upper()
if req.backbone not in DOWNSTREAM_REGISTRY:
raise HTTPException(400, f"unknown backbone {req.backbone!r}")
# All downstream heads share the ESM-2 tokenizer with the base backbone.
_, tokenizer, _ = get_model(req.backbone)
batch = tokenizer(seq, return_tensors="pt").to(DEVICE)
predictions: List[TaskPrediction] = []
for task, cfg in DOWNSTREAM_TASKS.items():
model = get_downstream(req.backbone, task)
logits = model(**batch).logits[0] # [num_labels] for sequence-level head
kind = cfg["kind"]
if kind == "regression":
predictions.append(TaskPrediction(
task=task, title=cfg["title"], kind=kind,
value=float(logits.item()),
unit=cfg.get("unit"),
))
elif kind == "single_label":
probs = F.softmax(logits, dim=-1)
top_idx = torch.topk(probs, k=min(req.top_k, probs.numel())).indices.tolist()
top = [
{"label": cfg["labels"][i] if cfg["labels"] else f"class_{i}",
"index": int(i),
"prob": float(probs[i].item())}
for i in top_idx
]
predictions.append(TaskPrediction(
task=task, title=cfg["title"], kind=kind,
top=top, num_classes=int(probs.numel()),
))
else: # multi_label
probs = torch.sigmoid(logits)
top_idx = torch.topk(probs, k=min(req.top_k, probs.numel())).indices.tolist()
top = [
{"label": (cfg["labels"][i] if cfg["labels"] else f"class_{i}"),
"index": int(i),
"prob": float(probs[i].item())}
for i in top_idx
]
predictions.append(TaskPrediction(
task=task, title=cfg["title"], kind=kind,
top=top, num_classes=int(probs.numel()),
))
return DownstreamResponse(backbone=req.backbone, sequence=seq, predictions=predictions)