Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -4,11 +4,24 @@ from typing import List
|
|
| 4 |
from transformers import AutoTokenizer, AutoModel
|
| 5 |
import torch, os
|
| 6 |
|
| 7 |
-
MODEL_ID = "dmis-lab/biobert-base-cased-v1"
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
model.eval()
|
| 13 |
|
| 14 |
def mean_pooling(model_output, attention_mask):
|
|
@@ -29,7 +42,7 @@ app = FastAPI(title="BioBERT Embeddings", version="1.0")
|
|
| 29 |
|
| 30 |
@app.get("/healthz")
|
| 31 |
def health():
|
| 32 |
-
return {"ok": True}
|
| 33 |
|
| 34 |
@app.post("/embed", response_model=EmbedResponse)
|
| 35 |
def embed(req: EmbedRequest):
|
|
@@ -46,5 +59,5 @@ def embed(req: EmbedRequest):
|
|
| 46 |
return {"embeddings": pooled.cpu().tolist()}
|
| 47 |
|
| 48 |
if __name__ == "__main__":
|
| 49 |
-
import uvicorn
|
| 50 |
uvicorn.run("app:app", host="0.0.0.0", port=int(os.getenv("PORT", "7860")), workers=1)
|
|
|
|
| 4 |
from transformers import AutoTokenizer, AutoModel
|
| 5 |
import torch, os
|
| 6 |
|
| 7 |
+
MODEL_ID = os.getenv("MODEL_ID", "dmis-lab/biobert-base-cased-v1.2").strip()
|
| 8 |
+
HF_TOKEN = (os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") or "").strip() or None
|
| 9 |
+
|
| 10 |
+
def load_model(model_id: str):
|
| 11 |
+
# Try public/anonymous first (works for public models)
|
| 12 |
+
try:
|
| 13 |
+
tok = AutoTokenizer.from_pretrained(model_id, token=None, trust_remote_code=False)
|
| 14 |
+
mdl = AutoModel.from_pretrained(model_id, token=None, trust_remote_code=False)
|
| 15 |
+
return tok, mdl
|
| 16 |
+
except Exception:
|
| 17 |
+
# If you actually use a private/gated model, fall back to an explicit token
|
| 18 |
+
if HF_TOKEN:
|
| 19 |
+
tok = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN, trust_remote_code=False)
|
| 20 |
+
mdl = AutoModel.from_pretrained(model_id, token=HF_TOKEN, trust_remote_code=False)
|
| 21 |
+
return tok, mdl
|
| 22 |
+
raise # bubble up the original error
|
| 23 |
+
|
| 24 |
+
tokenizer, model = load_model(MODEL_ID)
|
| 25 |
model.eval()
|
| 26 |
|
| 27 |
def mean_pooling(model_output, attention_mask):
|
|
|
|
| 42 |
|
| 43 |
@app.get("/healthz")
|
| 44 |
def health():
|
| 45 |
+
return {"ok": True, "model_id": MODEL_ID}
|
| 46 |
|
| 47 |
@app.post("/embed", response_model=EmbedResponse)
|
| 48 |
def embed(req: EmbedRequest):
|
|
|
|
| 59 |
return {"embeddings": pooled.cpu().tolist()}
|
| 60 |
|
| 61 |
if __name__ == "__main__":
|
| 62 |
+
import uvicorn, os
|
| 63 |
uvicorn.run("app:app", host="0.0.0.0", port=int(os.getenv("PORT", "7860")), workers=1)
|