felixbet commited on
Commit
5d7a5a6
·
verified ·
1 Parent(s): d213edf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -7
app.py CHANGED
@@ -4,11 +4,24 @@ from typing import List
4
  from transformers import AutoTokenizer, AutoModel
5
  import torch, os
6
 
7
- MODEL_ID = "dmis-lab/biobert-base-cased-v1"
8
-
9
- # Load once at startup
10
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
11
- model = AutoModel.from_pretrained(MODEL_ID)
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  model.eval()
13
 
14
  def mean_pooling(model_output, attention_mask):
@@ -29,7 +42,7 @@ app = FastAPI(title="BioBERT Embeddings", version="1.0")
29
 
30
  @app.get("/healthz")
31
  def health():
32
- return {"ok": True}
33
 
34
  @app.post("/embed", response_model=EmbedResponse)
35
  def embed(req: EmbedRequest):
@@ -46,5 +59,5 @@ def embed(req: EmbedRequest):
46
  return {"embeddings": pooled.cpu().tolist()}
47
 
48
  if __name__ == "__main__":
49
- import uvicorn
50
  uvicorn.run("app:app", host="0.0.0.0", port=int(os.getenv("PORT", "7860")), workers=1)
 
4
  from transformers import AutoTokenizer, AutoModel
5
  import torch, os
6
 
7
+ MODEL_ID = os.getenv("MODEL_ID", "dmis-lab/biobert-base-cased-v1.2").strip()
8
+ HF_TOKEN = (os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") or "").strip() or None
9
+
10
+ def load_model(model_id: str):
11
+ # Try public/anonymous first (works for public models)
12
+ try:
13
+ tok = AutoTokenizer.from_pretrained(model_id, token=None, trust_remote_code=False)
14
+ mdl = AutoModel.from_pretrained(model_id, token=None, trust_remote_code=False)
15
+ return tok, mdl
16
+ except Exception:
17
+ # If you actually use a private/gated model, fall back to an explicit token
18
+ if HF_TOKEN:
19
+ tok = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN, trust_remote_code=False)
20
+ mdl = AutoModel.from_pretrained(model_id, token=HF_TOKEN, trust_remote_code=False)
21
+ return tok, mdl
22
+ raise # bubble up the original error
23
+
24
+ tokenizer, model = load_model(MODEL_ID)
25
  model.eval()
26
 
27
  def mean_pooling(model_output, attention_mask):
 
42
 
43
  @app.get("/healthz")
44
  def health():
45
+ return {"ok": True, "model_id": MODEL_ID}
46
 
47
  @app.post("/embed", response_model=EmbedResponse)
48
  def embed(req: EmbedRequest):
 
59
  return {"embeddings": pooled.cpu().tolist()}
60
 
61
  if __name__ == "__main__":
62
+ import uvicorn, os
63
  uvicorn.run("app:app", host="0.0.0.0", port=int(os.getenv("PORT", "7860")), workers=1)