biobert-emb / app.py
felixbet's picture
Update app.py
5d7a5a6 verified
raw
history blame
2.34 kB
from fastapi import FastAPI
from pydantic import BaseModel, Field
from typing import List
from transformers import AutoTokenizer, AutoModel
import torch, os
MODEL_ID = os.getenv("MODEL_ID", "dmis-lab/biobert-base-cased-v1.2").strip()
HF_TOKEN = (os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") or "").strip() or None
def load_model(model_id: str):
# Try public/anonymous first (works for public models)
try:
tok = AutoTokenizer.from_pretrained(model_id, token=None, trust_remote_code=False)
mdl = AutoModel.from_pretrained(model_id, token=None, trust_remote_code=False)
return tok, mdl
except Exception:
# If you actually use a private/gated model, fall back to an explicit token
if HF_TOKEN:
tok = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN, trust_remote_code=False)
mdl = AutoModel.from_pretrained(model_id, token=HF_TOKEN, trust_remote_code=False)
return tok, mdl
raise # bubble up the original error
tokenizer, model = load_model(MODEL_ID)
model.eval()
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] # [batch, seq, hidden]
mask = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
summed = (token_embeddings * mask).sum(1)
counts = mask.sum(1).clamp(min=1e-9)
return summed / counts
class EmbedRequest(BaseModel):
texts: List[str] = Field(default_factory=list)
max_length: int = 256
class EmbedResponse(BaseModel):
embeddings: List[List[float]]
app = FastAPI(title="BioBERT Embeddings", version="1.0")
@app.get("/healthz")
def health():
return {"ok": True, "model_id": MODEL_ID}
@app.post("/embed", response_model=EmbedResponse)
def embed(req: EmbedRequest):
if not req.texts:
return {"embeddings": []}
enc = tokenizer(
req.texts, padding=True, truncation=True,
max_length=req.max_length, return_tensors="pt"
)
with torch.no_grad():
out = model(**enc)
pooled = mean_pooling(out, enc["attention_mask"])
pooled = torch.nn.functional.normalize(pooled, p=2, dim=1)
return {"embeddings": pooled.cpu().tolist()}
if __name__ == "__main__":
import uvicorn, os
uvicorn.run("app:app", host="0.0.0.0", port=int(os.getenv("PORT", "7860")), workers=1)