biobert-emb / app.py
felixbet's picture
Update app.py
1c0323e verified
raw
history blame
3.59 kB
# app.py β€” FastAPI embeddings service using PyTorch BioBERT
# Works on Hugging Face Spaces (CPU Basic, free)
import os
from typing import List, Optional
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
from transformers import AutoTokenizer, AutoModel
HF_MODEL_ID = os.environ.get("HF_MODEL_ID", "monologg/biobert_v1.1_pubmed").strip()
MAX_LEN = int(os.environ.get("MAX_LEN", "128"))
TORCH_THREADS = int(os.environ.get("TORCH_THREADS", "1"))
torch.set_num_threads(TORCH_THREADS)
# --------- Load model & tokenizer (PyTorch) ----------
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_ID)
model = AutoModel.from_pretrained(HF_MODEL_ID)
model.eval() # inference mode
DEVICE = "cpu"
model.to(DEVICE)
# --------- FastAPI ----------
app = FastAPI(title="BioBERT (PyTorch) Embeddings API", version="1.0")
# CORS (relax; tighten in production)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=False,
allow_methods=["GET", "POST", "OPTIONS"],
allow_headers=["*"],
)
class EmbReq(BaseModel):
input: str
max_len: Optional[int] = None
pooling: Optional[str] = "cls" # "cls" or "mean"
class BatchEmbReq(BaseModel):
inputs: List[str]
max_len: Optional[int] = None
pooling: Optional[str] = "cls" # "cls" or "mean"
@app.get("/")
def root():
return {
"name": "BioBERT Embeddings (PyTorch)",
"model": HF_MODEL_ID,
"device": DEVICE,
"endpoints": ["/health", "/v1/embeddings", "/v1/embeddings/batch"],
"hint": "POST to /v1/embeddings with {'input': 'your text'}",
}
@app.get("/health")
def health():
return {"ok": True, "model": HF_MODEL_ID, "device": DEVICE}
def _pool(outputs, inputs, pooling: str):
"""
pooling="cls": use CLS (pooler_output if present, else hidden_state[:,0])
pooling="mean": mean of token embeddings (mask-aware)
"""
if pooling == "mean":
last = outputs.last_hidden_state # [B,T,H]
mask = inputs["attention_mask"].unsqueeze(-1).type_as(last) # [B,T,1]
summed = (last * mask).sum(dim=1)
counts = mask.sum(dim=1).clamp(min=1e-9)
return summed / counts
# cls
if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
return outputs.pooler_output
return outputs.last_hidden_state[:, 0, :] # CLS token
def _embed(texts: List[str], max_len: int, pooling: str) -> List[List[float]]:
enc = tokenizer(
texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_len,
)
enc = {k: v.to(DEVICE) for k, v in enc.items()}
with torch.no_grad():
outputs = model(**enc)
vecs = _pool(outputs, enc, pooling=pooling)
return vecs.cpu().numpy().tolist()
@app.post("/v1/embeddings")
def embeddings(req: EmbReq):
text = (req.input or "").strip()
if not text:
return {"embedding": [], "dim": 0}
L = int(req.max_len or MAX_LEN)
pooling = (req.pooling or "cls").lower()
vec = _embed([text], L, pooling)[0]
return {"embedding": vec, "dim": len(vec), "pooling": pooling}
@app.post("/v1/embeddings/batch")
def embeddings_batch(req: BatchEmbReq):
items = [str(t).strip() for t in (req.inputs or []) if str(t).strip()]
if not items:
return {"embeddings": [], "dim": 0}
L = int(req.max_len or MAX_LEN)
pooling = (req.pooling or "cls").lower()
vecs = _embed(items, L, pooling)
return {"embeddings": vecs, "dim": len(vecs[0]), "pooling": pooling}