from fastapi import FastAPI from pydantic import BaseModel, Field from typing import List from transformers import AutoTokenizer, AutoModel import torch, os MODEL_ID = os.getenv("MODEL_ID", "dmis-lab/biobert-base-cased-v1.2").strip() HF_TOKEN = (os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") or "").strip() or None def load_model(model_id: str): # Try public/anonymous first (works for public models) try: tok = AutoTokenizer.from_pretrained(model_id, token=None, trust_remote_code=False) mdl = AutoModel.from_pretrained(model_id, token=None, trust_remote_code=False) return tok, mdl except Exception: # If you actually use a private/gated model, fall back to an explicit token if HF_TOKEN: tok = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN, trust_remote_code=False) mdl = AutoModel.from_pretrained(model_id, token=HF_TOKEN, trust_remote_code=False) return tok, mdl raise # bubble up the original error tokenizer, model = load_model(MODEL_ID) model.eval() def mean_pooling(model_output, attention_mask): token_embeddings = model_output[0] # [batch, seq, hidden] mask = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() summed = (token_embeddings * mask).sum(1) counts = mask.sum(1).clamp(min=1e-9) return summed / counts class EmbedRequest(BaseModel): texts: List[str] = Field(default_factory=list) max_length: int = 256 class EmbedResponse(BaseModel): embeddings: List[List[float]] app = FastAPI(title="BioBERT Embeddings", version="1.0") @app.get("/healthz") def health(): return {"ok": True, "model_id": MODEL_ID} @app.post("/embed", response_model=EmbedResponse) def embed(req: EmbedRequest): if not req.texts: return {"embeddings": []} enc = tokenizer( req.texts, padding=True, truncation=True, max_length=req.max_length, return_tensors="pt" ) with torch.no_grad(): out = model(**enc) pooled = mean_pooling(out, enc["attention_mask"]) pooled = torch.nn.functional.normalize(pooled, p=2, dim=1) return {"embeddings": pooled.cpu().tolist()} if __name__ == "__main__": import uvicorn, os uvicorn.run("app:app", host="0.0.0.0", port=int(os.getenv("PORT", "7860")), workers=1)