biobert-emb / app.py
felixbet's picture
Update app.py
5d7a5a6 verified
from fastapi import FastAPI
from pydantic import BaseModel, Field
from typing import List
from transformers import AutoTokenizer, AutoModel
import torch, os
MODEL_ID = os.getenv("MODEL_ID", "dmis-lab/biobert-base-cased-v1.2").strip()
HF_TOKEN = (os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") or "").strip() or None
def load_model(model_id: str):
# Try public/anonymous first (works for public models)
try:
tok = AutoTokenizer.from_pretrained(model_id, token=None, trust_remote_code=False)
mdl = AutoModel.from_pretrained(model_id, token=None, trust_remote_code=False)
return tok, mdl
except Exception:
# If you actually use a private/gated model, fall back to an explicit token
if HF_TOKEN:
tok = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN, trust_remote_code=False)
mdl = AutoModel.from_pretrained(model_id, token=HF_TOKEN, trust_remote_code=False)
return tok, mdl
raise # bubble up the original error
tokenizer, model = load_model(MODEL_ID)
model.eval()
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] # [batch, seq, hidden]
mask = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
summed = (token_embeddings * mask).sum(1)
counts = mask.sum(1).clamp(min=1e-9)
return summed / counts
class EmbedRequest(BaseModel):
texts: List[str] = Field(default_factory=list)
max_length: int = 256
class EmbedResponse(BaseModel):
embeddings: List[List[float]]
app = FastAPI(title="BioBERT Embeddings", version="1.0")
@app.get("/healthz")
def health():
return {"ok": True, "model_id": MODEL_ID}
@app.post("/embed", response_model=EmbedResponse)
def embed(req: EmbedRequest):
if not req.texts:
return {"embeddings": []}
enc = tokenizer(
req.texts, padding=True, truncation=True,
max_length=req.max_length, return_tensors="pt"
)
with torch.no_grad():
out = model(**enc)
pooled = mean_pooling(out, enc["attention_mask"])
pooled = torch.nn.functional.normalize(pooled, p=2, dim=1)
return {"embeddings": pooled.cpu().tolist()}
if __name__ == "__main__":
import uvicorn, os
uvicorn.run("app:app", host="0.0.0.0", port=int(os.getenv("PORT", "7860")), workers=1)