""" MiPLM inference API — runs on a HuggingFace Space (Docker SDK). Endpoints: GET / — list available models + device GET /health — liveness probe POST /predict — per-position 20-AA softmax over the sequence POST /embed — mean-pooled last-layer embedding POST /mutation — wildtype-marginal Δlog-likelihood matrix [L, 20] POST /downstream — all fine-tuned task heads (classification + regression) """ import os from typing import Dict, List, Tuple import torch import torch.nn.functional as F from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from transformers import EsmForMaskedLM, EsmForSequenceClassification, EsmTokenizer MODEL_REGISTRY: Dict[str, str] = { "miplm-ce": "HUBioDataLab/esm2-8m-ce", "miplm-blosum": "HUBioDataLab/esm2-8m-softce", "miplm-msa": "HUBioDataLab/esm2-8m-msa", } # Downstream heads — one HF repo per backbone, with 8 task subfolders each. # Map frontend backbone name → HF repo. DOWNSTREAM_REGISTRY: Dict[str, str] = { "miplm-ce": "HUBioDataLab/miplm-ce-tasks", "miplm-blosum": "HUBioDataLab/miplm-blosum-tasks", "miplm-msa": "HUBioDataLab/miplm-msa-tasks", } # Task metadata. `kind` drives post-processing: # single_label → softmax → top-K classes with probabilities # multi_label → sigmoid → top-K classes with independent probabilities # regression → raw scalar # `labels` is the human-readable class vocabulary (when known). When unset the # response falls back to numeric class indices. DEEPLOC10_LABELS = [ # Alphabetical order — matches the dataset's label-int assignment used during training. "Cell membrane", "Cytoplasm", "Endoplasmic reticulum", "Extracellular", "Golgi apparatus", "Lysosome/Vacuole", "Mitochondrion", "Nucleus", "Peroxisome", "Plastid", ] DEEPLOC2_LABELS = ["Membrane", "Soluble"] # alphabetical (matches training) METAL_LABELS = ["Non-binder", "Metal-ion binder"] # semantic order (matches training) DOWNSTREAM_TASKS: Dict[str, Dict] = { "deeploc-cls10": {"kind": "single_label", "labels": DEEPLOC10_LABELS, "title": "Subcellular localization (10-way)"}, "deeploc-cls2": {"kind": "single_label", "labels": DEEPLOC2_LABELS, "title": "Soluble vs membrane"}, "metalionbinding": {"kind": "single_label", "labels": METAL_LABELS, "title": "Metal-ion binding"}, "thermostability": {"kind": "regression", "labels": None, "title": "Thermostability", "unit": "normalised score"}, "ec": {"kind": "multi_label", "labels": None, "title": "EC enzyme class", "num_classes": 585}, "go-bp": {"kind": "multi_label", "labels": None, "title": "GO biological process", "num_classes": 1943}, "go-cc": {"kind": "multi_label", "labels": None, "title": "GO cellular component", "num_classes": 320}, "go-mf": {"kind": "multi_label", "labels": None, "title": "GO molecular function", "num_classes": 489}, } STANDARD_AAS = "ACDEFGHIKLMNPQRSTVWY" MAX_LEN = 1024 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" _models: Dict[str, Tuple[EsmForMaskedLM, EsmTokenizer, List[int]]] = {} _downstream: Dict[Tuple[str, str], EsmForSequenceClassification] = {} def get_model(name: str): if name not in MODEL_REGISTRY: raise HTTPException(400, f"unknown model {name!r}; available: {list(MODEL_REGISTRY)}") if name not in _models: repo = MODEL_REGISTRY[name] tokenizer = EsmTokenizer.from_pretrained(repo) model = EsmForMaskedLM.from_pretrained(repo).to(DEVICE).eval() aa_ids = [tokenizer.convert_tokens_to_ids(aa) for aa in STANDARD_AAS] _models[name] = (model, tokenizer, aa_ids) return _models[name] class _SeqRequest(BaseModel): sequence: str = Field(..., min_length=1, max_length=MAX_LEN, pattern=r"^[A-Za-z]+$") model: str = "miplm-blosum" class PredictResponse(BaseModel): model: str sequence: str aa_order: str probs: List[List[float]] class EmbedResponse(BaseModel): model: str dim: int embedding: List[float] class PerResidueEmbedResponse(BaseModel): model: str dim: int embeddings: List[List[float]] class MutationResponse(BaseModel): model: str sequence: str aa_order: str scores: List[List[float]] app = FastAPI(title="MiPLM inference", version="0.1.0") app.add_middleware( CORSMiddleware, allow_origins=os.environ.get("ALLOWED_ORIGINS", "*").split(","), allow_methods=["GET", "POST"], allow_headers=["*"], ) @app.on_event("startup") def _warmup(): for name in MODEL_REGISTRY: try: get_model(name) print(f"[warmup] loaded {name}") except Exception as e: print(f"[warmup] failed to load {name}: {e}") @app.get("/") def root(): return { "models": list(MODEL_REGISTRY.keys()), "downstream_tasks": [ {"id": t, "title": cfg["title"], "kind": cfg["kind"]} for t, cfg in DOWNSTREAM_TASKS.items() ], "device": DEVICE, "max_length": MAX_LEN, "aa_order": STANDARD_AAS, } @app.get("/health") def health(): return {"ok": True, "loaded": list(_models.keys())} @torch.inference_mode() def _forward(model, tokenizer, sequence: str): seq = sequence.upper() batch = tokenizer(seq, return_tensors="pt").to(DEVICE) out = model(**batch, output_hidden_states=True) logits = out.logits[0, 1:-1] # drop , -> [L, V] hidden = out.hidden_states[-1][0, 1:-1] # [L, H] return seq, logits, hidden @app.post("/predict", response_model=PredictResponse) def predict(req: _SeqRequest): model, tokenizer, aa_ids = get_model(req.model) seq, logits, _ = _forward(model, tokenizer, req.sequence) probs = F.softmax(logits[:, aa_ids], dim=-1).tolist() return PredictResponse(model=req.model, sequence=seq, aa_order=STANDARD_AAS, probs=probs) @app.post("/embed", response_model=EmbedResponse) def embed(req: _SeqRequest): model, tokenizer, _ = get_model(req.model) _, _, hidden = _forward(model, tokenizer, req.sequence) emb = hidden.mean(dim=0).cpu().tolist() return EmbedResponse(model=req.model, dim=len(emb), embedding=emb) @app.post("/embed_per_residue", response_model=PerResidueEmbedResponse) def embed_per_residue(req: _SeqRequest): model, tokenizer, _ = get_model(req.model) _, _, hidden = _forward(model, tokenizer, req.sequence) embs = hidden.cpu().tolist() return PerResidueEmbedResponse( model=req.model, dim=hidden.shape[-1], embeddings=embs ) @app.post("/mutation", response_model=MutationResponse) def mutation(req: _SeqRequest): model, tokenizer, aa_ids = get_model(req.model) seq, logits, _ = _forward(model, tokenizer, req.sequence) log_probs = F.log_softmax(logits[:, aa_ids], dim=-1) # [L, 20] aa_to_idx = {a: i for i, a in enumerate(STANDARD_AAS)} wt_idx = torch.tensor([aa_to_idx.get(c, 0) for c in seq], device=DEVICE) wt_logp = log_probs.gather(1, wt_idx[:, None]) # [L, 1] delta = (log_probs - wt_logp).tolist() return MutationResponse(model=req.model, sequence=seq, aa_order=STANDARD_AAS, scores=delta) # ─── Downstream task heads ────────────────────────────────────────────────── def get_downstream(backbone: str, task: str) -> EsmForSequenceClassification: """Lazy-load and cache fine-tuned classification/regression heads.""" if backbone not in DOWNSTREAM_REGISTRY: raise HTTPException(400, f"unknown backbone {backbone!r}; available: {list(DOWNSTREAM_REGISTRY)}") if task not in DOWNSTREAM_TASKS: raise HTTPException(400, f"unknown task {task!r}; available: {list(DOWNSTREAM_TASKS)}") key = (backbone, task) if key not in _downstream: repo = DOWNSTREAM_REGISTRY[backbone] print(f"[downstream] loading {repo} :: {task}") m = EsmForSequenceClassification.from_pretrained(repo, subfolder=task).to(DEVICE).eval() _downstream[key] = m return _downstream[key] class DownstreamRequest(BaseModel): sequence: str = Field(..., min_length=1, max_length=MAX_LEN, pattern=r"^[A-Za-z]+$") backbone: str = "miplm-blosum" top_k: int = 3 class TaskPrediction(BaseModel): task: str title: str kind: str # single_label | multi_label | regression value: float | None = None # set for regression unit: str | None = None top: List[Dict] | None = None # set for classification — [{label/index, prob}] num_classes: int | None = None class DownstreamResponse(BaseModel): backbone: str sequence: str predictions: List[TaskPrediction] @app.post("/downstream", response_model=DownstreamResponse) @torch.inference_mode() def downstream(req: DownstreamRequest): seq = req.sequence.upper() if req.backbone not in DOWNSTREAM_REGISTRY: raise HTTPException(400, f"unknown backbone {req.backbone!r}") # All downstream heads share the ESM-2 tokenizer with the base backbone. _, tokenizer, _ = get_model(req.backbone) batch = tokenizer(seq, return_tensors="pt").to(DEVICE) predictions: List[TaskPrediction] = [] for task, cfg in DOWNSTREAM_TASKS.items(): model = get_downstream(req.backbone, task) logits = model(**batch).logits[0] # [num_labels] for sequence-level head kind = cfg["kind"] if kind == "regression": predictions.append(TaskPrediction( task=task, title=cfg["title"], kind=kind, value=float(logits.item()), unit=cfg.get("unit"), )) elif kind == "single_label": probs = F.softmax(logits, dim=-1) top_idx = torch.topk(probs, k=min(req.top_k, probs.numel())).indices.tolist() top = [ {"label": cfg["labels"][i] if cfg["labels"] else f"class_{i}", "index": int(i), "prob": float(probs[i].item())} for i in top_idx ] predictions.append(TaskPrediction( task=task, title=cfg["title"], kind=kind, top=top, num_classes=int(probs.numel()), )) else: # multi_label probs = torch.sigmoid(logits) top_idx = torch.topk(probs, k=min(req.top_k, probs.numel())).indices.tolist() top = [ {"label": (cfg["labels"][i] if cfg["labels"] else f"class_{i}"), "index": int(i), "prob": float(probs[i].item())} for i in top_idx ] predictions.append(TaskPrediction( task=task, title=cfg["title"], kind=kind, top=top, num_classes=int(probs.numel()), )) return DownstreamResponse(backbone=req.backbone, sequence=seq, predictions=predictions)