Spaces:
Sleeping
Sleeping
| """ | |
| MiPLM inference API β runs on a HuggingFace Space (Docker SDK). | |
| Endpoints: | |
| GET / β list available models + device | |
| GET /health β liveness probe | |
| POST /predict β per-position 20-AA softmax over the sequence | |
| POST /embed β mean-pooled last-layer embedding | |
| POST /mutation β wildtype-marginal Ξlog-likelihood matrix [L, 20] | |
| POST /downstream β all fine-tuned task heads (classification + regression) | |
| """ | |
| import os | |
| from typing import Dict, List, Tuple | |
| import torch | |
| import torch.nn.functional as F | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| from transformers import EsmForMaskedLM, EsmForSequenceClassification, EsmTokenizer | |
| MODEL_REGISTRY: Dict[str, str] = { | |
| "miplm-ce": "HUBioDataLab/esm2-8m-ce", | |
| "miplm-blosum": "HUBioDataLab/esm2-8m-softce", | |
| "miplm-msa": "HUBioDataLab/esm2-8m-msa", | |
| } | |
| # Downstream heads β one HF repo per backbone, with 8 task subfolders each. | |
| # Map frontend backbone name β HF repo. | |
| DOWNSTREAM_REGISTRY: Dict[str, str] = { | |
| "miplm-ce": "HUBioDataLab/miplm-ce-tasks", | |
| "miplm-blosum": "HUBioDataLab/miplm-blosum-tasks", | |
| "miplm-msa": "HUBioDataLab/miplm-msa-tasks", | |
| } | |
| # Task metadata. `kind` drives post-processing: | |
| # single_label β softmax β top-K classes with probabilities | |
| # multi_label β sigmoid β top-K classes with independent probabilities | |
| # regression β raw scalar | |
| # `labels` is the human-readable class vocabulary (when known). When unset the | |
| # response falls back to numeric class indices. | |
| DEEPLOC10_LABELS = [ | |
| # Alphabetical order β matches the dataset's label-int assignment used during training. | |
| "Cell membrane", "Cytoplasm", "Endoplasmic reticulum", "Extracellular", | |
| "Golgi apparatus", "Lysosome/Vacuole", "Mitochondrion", "Nucleus", | |
| "Peroxisome", "Plastid", | |
| ] | |
| DEEPLOC2_LABELS = ["Membrane", "Soluble"] # alphabetical (matches training) | |
| METAL_LABELS = ["Non-binder", "Metal-ion binder"] # semantic order (matches training) | |
| DOWNSTREAM_TASKS: Dict[str, Dict] = { | |
| "deeploc-cls10": {"kind": "single_label", "labels": DEEPLOC10_LABELS, "title": "Subcellular localization (10-way)"}, | |
| "deeploc-cls2": {"kind": "single_label", "labels": DEEPLOC2_LABELS, "title": "Soluble vs membrane"}, | |
| "metalionbinding": {"kind": "single_label", "labels": METAL_LABELS, "title": "Metal-ion binding"}, | |
| "thermostability": {"kind": "regression", "labels": None, "title": "Thermostability", "unit": "normalised score"}, | |
| "ec": {"kind": "multi_label", "labels": None, "title": "EC enzyme class", "num_classes": 585}, | |
| "go-bp": {"kind": "multi_label", "labels": None, "title": "GO biological process", "num_classes": 1943}, | |
| "go-cc": {"kind": "multi_label", "labels": None, "title": "GO cellular component", "num_classes": 320}, | |
| "go-mf": {"kind": "multi_label", "labels": None, "title": "GO molecular function", "num_classes": 489}, | |
| } | |
| STANDARD_AAS = "ACDEFGHIKLMNPQRSTVWY" | |
| MAX_LEN = 1024 | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| _models: Dict[str, Tuple[EsmForMaskedLM, EsmTokenizer, List[int]]] = {} | |
| _downstream: Dict[Tuple[str, str], EsmForSequenceClassification] = {} | |
| def get_model(name: str): | |
| if name not in MODEL_REGISTRY: | |
| raise HTTPException(400, f"unknown model {name!r}; available: {list(MODEL_REGISTRY)}") | |
| if name not in _models: | |
| repo = MODEL_REGISTRY[name] | |
| tokenizer = EsmTokenizer.from_pretrained(repo) | |
| model = EsmForMaskedLM.from_pretrained(repo).to(DEVICE).eval() | |
| aa_ids = [tokenizer.convert_tokens_to_ids(aa) for aa in STANDARD_AAS] | |
| _models[name] = (model, tokenizer, aa_ids) | |
| return _models[name] | |
| class _SeqRequest(BaseModel): | |
| sequence: str = Field(..., min_length=1, max_length=MAX_LEN, pattern=r"^[A-Za-z]+$") | |
| model: str = "miplm-blosum" | |
| class PredictResponse(BaseModel): | |
| model: str | |
| sequence: str | |
| aa_order: str | |
| probs: List[List[float]] | |
| class EmbedResponse(BaseModel): | |
| model: str | |
| dim: int | |
| embedding: List[float] | |
| class PerResidueEmbedResponse(BaseModel): | |
| model: str | |
| dim: int | |
| embeddings: List[List[float]] | |
| class MutationResponse(BaseModel): | |
| model: str | |
| sequence: str | |
| aa_order: str | |
| scores: List[List[float]] | |
| app = FastAPI(title="MiPLM inference", version="0.1.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=os.environ.get("ALLOWED_ORIGINS", "*").split(","), | |
| allow_methods=["GET", "POST"], | |
| allow_headers=["*"], | |
| ) | |
| def _warmup(): | |
| for name in MODEL_REGISTRY: | |
| try: | |
| get_model(name) | |
| print(f"[warmup] loaded {name}") | |
| except Exception as e: | |
| print(f"[warmup] failed to load {name}: {e}") | |
| def root(): | |
| return { | |
| "models": list(MODEL_REGISTRY.keys()), | |
| "downstream_tasks": [ | |
| {"id": t, "title": cfg["title"], "kind": cfg["kind"]} | |
| for t, cfg in DOWNSTREAM_TASKS.items() | |
| ], | |
| "device": DEVICE, | |
| "max_length": MAX_LEN, | |
| "aa_order": STANDARD_AAS, | |
| } | |
| def health(): | |
| return {"ok": True, "loaded": list(_models.keys())} | |
| def _forward(model, tokenizer, sequence: str): | |
| seq = sequence.upper() | |
| batch = tokenizer(seq, return_tensors="pt").to(DEVICE) | |
| out = model(**batch, output_hidden_states=True) | |
| logits = out.logits[0, 1:-1] # drop <cls>, <eos> -> [L, V] | |
| hidden = out.hidden_states[-1][0, 1:-1] # [L, H] | |
| return seq, logits, hidden | |
| def predict(req: _SeqRequest): | |
| model, tokenizer, aa_ids = get_model(req.model) | |
| seq, logits, _ = _forward(model, tokenizer, req.sequence) | |
| probs = F.softmax(logits[:, aa_ids], dim=-1).tolist() | |
| return PredictResponse(model=req.model, sequence=seq, aa_order=STANDARD_AAS, probs=probs) | |
| def embed(req: _SeqRequest): | |
| model, tokenizer, _ = get_model(req.model) | |
| _, _, hidden = _forward(model, tokenizer, req.sequence) | |
| emb = hidden.mean(dim=0).cpu().tolist() | |
| return EmbedResponse(model=req.model, dim=len(emb), embedding=emb) | |
| def embed_per_residue(req: _SeqRequest): | |
| model, tokenizer, _ = get_model(req.model) | |
| _, _, hidden = _forward(model, tokenizer, req.sequence) | |
| embs = hidden.cpu().tolist() | |
| return PerResidueEmbedResponse( | |
| model=req.model, dim=hidden.shape[-1], embeddings=embs | |
| ) | |
| def mutation(req: _SeqRequest): | |
| model, tokenizer, aa_ids = get_model(req.model) | |
| seq, logits, _ = _forward(model, tokenizer, req.sequence) | |
| log_probs = F.log_softmax(logits[:, aa_ids], dim=-1) # [L, 20] | |
| aa_to_idx = {a: i for i, a in enumerate(STANDARD_AAS)} | |
| wt_idx = torch.tensor([aa_to_idx.get(c, 0) for c in seq], device=DEVICE) | |
| wt_logp = log_probs.gather(1, wt_idx[:, None]) # [L, 1] | |
| delta = (log_probs - wt_logp).tolist() | |
| return MutationResponse(model=req.model, sequence=seq, aa_order=STANDARD_AAS, scores=delta) | |
| # βββ Downstream task heads ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_downstream(backbone: str, task: str) -> EsmForSequenceClassification: | |
| """Lazy-load and cache fine-tuned classification/regression heads.""" | |
| if backbone not in DOWNSTREAM_REGISTRY: | |
| raise HTTPException(400, f"unknown backbone {backbone!r}; available: {list(DOWNSTREAM_REGISTRY)}") | |
| if task not in DOWNSTREAM_TASKS: | |
| raise HTTPException(400, f"unknown task {task!r}; available: {list(DOWNSTREAM_TASKS)}") | |
| key = (backbone, task) | |
| if key not in _downstream: | |
| repo = DOWNSTREAM_REGISTRY[backbone] | |
| print(f"[downstream] loading {repo} :: {task}") | |
| m = EsmForSequenceClassification.from_pretrained(repo, subfolder=task).to(DEVICE).eval() | |
| _downstream[key] = m | |
| return _downstream[key] | |
| class DownstreamRequest(BaseModel): | |
| sequence: str = Field(..., min_length=1, max_length=MAX_LEN, pattern=r"^[A-Za-z]+$") | |
| backbone: str = "miplm-blosum" | |
| top_k: int = 3 | |
| class TaskPrediction(BaseModel): | |
| task: str | |
| title: str | |
| kind: str # single_label | multi_label | regression | |
| value: float | None = None # set for regression | |
| unit: str | None = None | |
| top: List[Dict] | None = None # set for classification β [{label/index, prob}] | |
| num_classes: int | None = None | |
| class DownstreamResponse(BaseModel): | |
| backbone: str | |
| sequence: str | |
| predictions: List[TaskPrediction] | |
| def downstream(req: DownstreamRequest): | |
| seq = req.sequence.upper() | |
| if req.backbone not in DOWNSTREAM_REGISTRY: | |
| raise HTTPException(400, f"unknown backbone {req.backbone!r}") | |
| # All downstream heads share the ESM-2 tokenizer with the base backbone. | |
| _, tokenizer, _ = get_model(req.backbone) | |
| batch = tokenizer(seq, return_tensors="pt").to(DEVICE) | |
| predictions: List[TaskPrediction] = [] | |
| for task, cfg in DOWNSTREAM_TASKS.items(): | |
| model = get_downstream(req.backbone, task) | |
| logits = model(**batch).logits[0] # [num_labels] for sequence-level head | |
| kind = cfg["kind"] | |
| if kind == "regression": | |
| predictions.append(TaskPrediction( | |
| task=task, title=cfg["title"], kind=kind, | |
| value=float(logits.item()), | |
| unit=cfg.get("unit"), | |
| )) | |
| elif kind == "single_label": | |
| probs = F.softmax(logits, dim=-1) | |
| top_idx = torch.topk(probs, k=min(req.top_k, probs.numel())).indices.tolist() | |
| top = [ | |
| {"label": cfg["labels"][i] if cfg["labels"] else f"class_{i}", | |
| "index": int(i), | |
| "prob": float(probs[i].item())} | |
| for i in top_idx | |
| ] | |
| predictions.append(TaskPrediction( | |
| task=task, title=cfg["title"], kind=kind, | |
| top=top, num_classes=int(probs.numel()), | |
| )) | |
| else: # multi_label | |
| probs = torch.sigmoid(logits) | |
| top_idx = torch.topk(probs, k=min(req.top_k, probs.numel())).indices.tolist() | |
| top = [ | |
| {"label": (cfg["labels"][i] if cfg["labels"] else f"class_{i}"), | |
| "index": int(i), | |
| "prob": float(probs[i].item())} | |
| for i in top_idx | |
| ] | |
| predictions.append(TaskPrediction( | |
| task=task, title=cfg["title"], kind=kind, | |
| top=top, num_classes=int(probs.numel()), | |
| )) | |
| return DownstreamResponse(backbone=req.backbone, sequence=seq, predictions=predictions) | |