Spaces:
Running
Running
| # app.py β FastAPI embeddings service using PyTorch BioBERT | |
| # Works on Hugging Face Spaces (CPU Basic, free) | |
| import os | |
| from typing import List, Optional | |
| from fastapi import FastAPI | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import torch | |
| from transformers import AutoTokenizer, AutoModel | |
| HF_MODEL_ID = os.environ.get("HF_MODEL_ID", "monologg/biobert_v1.1_pubmed").strip() | |
| MAX_LEN = int(os.environ.get("MAX_LEN", "128")) | |
| TORCH_THREADS = int(os.environ.get("TORCH_THREADS", "1")) | |
| torch.set_num_threads(TORCH_THREADS) | |
| # --------- Load model & tokenizer (PyTorch) ---------- | |
| tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_ID) | |
| model = AutoModel.from_pretrained(HF_MODEL_ID) | |
| model.eval() # inference mode | |
| DEVICE = "cpu" | |
| model.to(DEVICE) | |
| # --------- FastAPI ---------- | |
| app = FastAPI(title="BioBERT (PyTorch) Embeddings API", version="1.0") | |
| # CORS (relax; tighten in production) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=False, | |
| allow_methods=["GET", "POST", "OPTIONS"], | |
| allow_headers=["*"], | |
| ) | |
| class EmbReq(BaseModel): | |
| input: str | |
| max_len: Optional[int] = None | |
| pooling: Optional[str] = "cls" # "cls" or "mean" | |
| class BatchEmbReq(BaseModel): | |
| inputs: List[str] | |
| max_len: Optional[int] = None | |
| pooling: Optional[str] = "cls" # "cls" or "mean" | |
| def root(): | |
| return { | |
| "name": "BioBERT Embeddings (PyTorch)", | |
| "model": HF_MODEL_ID, | |
| "device": DEVICE, | |
| "endpoints": ["/health", "/v1/embeddings", "/v1/embeddings/batch"], | |
| "hint": "POST to /v1/embeddings with {'input': 'your text'}", | |
| } | |
| def health(): | |
| return {"ok": True, "model": HF_MODEL_ID, "device": DEVICE} | |
| def _pool(outputs, inputs, pooling: str): | |
| """ | |
| pooling="cls": use CLS (pooler_output if present, else hidden_state[:,0]) | |
| pooling="mean": mean of token embeddings (mask-aware) | |
| """ | |
| if pooling == "mean": | |
| last = outputs.last_hidden_state # [B,T,H] | |
| mask = inputs["attention_mask"].unsqueeze(-1).type_as(last) # [B,T,1] | |
| summed = (last * mask).sum(dim=1) | |
| counts = mask.sum(dim=1).clamp(min=1e-9) | |
| return summed / counts | |
| # cls | |
| if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None: | |
| return outputs.pooler_output | |
| return outputs.last_hidden_state[:, 0, :] # CLS token | |
| def _embed(texts: List[str], max_len: int, pooling: str) -> List[List[float]]: | |
| enc = tokenizer( | |
| texts, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=max_len, | |
| ) | |
| enc = {k: v.to(DEVICE) for k, v in enc.items()} | |
| with torch.no_grad(): | |
| outputs = model(**enc) | |
| vecs = _pool(outputs, enc, pooling=pooling) | |
| return vecs.cpu().numpy().tolist() | |
| def embeddings(req: EmbReq): | |
| text = (req.input or "").strip() | |
| if not text: | |
| return {"embedding": [], "dim": 0} | |
| L = int(req.max_len or MAX_LEN) | |
| pooling = (req.pooling or "cls").lower() | |
| vec = _embed([text], L, pooling)[0] | |
| return {"embedding": vec, "dim": len(vec), "pooling": pooling} | |
| def embeddings_batch(req: BatchEmbReq): | |
| items = [str(t).strip() for t in (req.inputs or []) if str(t).strip()] | |
| if not items: | |
| return {"embeddings": [], "dim": 0} | |
| L = int(req.max_len or MAX_LEN) | |
| pooling = (req.pooling or "cls").lower() | |
| vecs = _embed(items, L, pooling) | |
| return {"embeddings": vecs, "dim": len(vecs[0]), "pooling": pooling} | |