Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI | |
| from pydantic import BaseModel, Field | |
| from typing import List | |
| from transformers import AutoTokenizer, AutoModel | |
| import torch, os | |
| MODEL_ID = os.getenv("MODEL_ID", "dmis-lab/biobert-base-cased-v1.2").strip() | |
| HF_TOKEN = (os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") or "").strip() or None | |
| def load_model(model_id: str): | |
| # Try public/anonymous first (works for public models) | |
| try: | |
| tok = AutoTokenizer.from_pretrained(model_id, token=None, trust_remote_code=False) | |
| mdl = AutoModel.from_pretrained(model_id, token=None, trust_remote_code=False) | |
| return tok, mdl | |
| except Exception: | |
| # If you actually use a private/gated model, fall back to an explicit token | |
| if HF_TOKEN: | |
| tok = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN, trust_remote_code=False) | |
| mdl = AutoModel.from_pretrained(model_id, token=HF_TOKEN, trust_remote_code=False) | |
| return tok, mdl | |
| raise # bubble up the original error | |
| tokenizer, model = load_model(MODEL_ID) | |
| model.eval() | |
| def mean_pooling(model_output, attention_mask): | |
| token_embeddings = model_output[0] # [batch, seq, hidden] | |
| mask = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
| summed = (token_embeddings * mask).sum(1) | |
| counts = mask.sum(1).clamp(min=1e-9) | |
| return summed / counts | |
| class EmbedRequest(BaseModel): | |
| texts: List[str] = Field(default_factory=list) | |
| max_length: int = 256 | |
| class EmbedResponse(BaseModel): | |
| embeddings: List[List[float]] | |
| app = FastAPI(title="BioBERT Embeddings", version="1.0") | |
| def health(): | |
| return {"ok": True, "model_id": MODEL_ID} | |
| def embed(req: EmbedRequest): | |
| if not req.texts: | |
| return {"embeddings": []} | |
| enc = tokenizer( | |
| req.texts, padding=True, truncation=True, | |
| max_length=req.max_length, return_tensors="pt" | |
| ) | |
| with torch.no_grad(): | |
| out = model(**enc) | |
| pooled = mean_pooling(out, enc["attention_mask"]) | |
| pooled = torch.nn.functional.normalize(pooled, p=2, dim=1) | |
| return {"embeddings": pooled.cpu().tolist()} | |
| if __name__ == "__main__": | |
| import uvicorn, os | |
| uvicorn.run("app:app", host="0.0.0.0", port=int(os.getenv("PORT", "7860")), workers=1) | |