felixbet commited on
Commit
d213edf
·
verified ·
1 Parent(s): c7a1efc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -97
app.py CHANGED
@@ -1,111 +1,50 @@
1
- # app.py — FastAPI embeddings service using PyTorch BioBERT
2
- # Works on Hugging Face Spaces (CPU Basic, free)
3
-
4
- import os
5
- from typing import List, Optional
6
-
7
  from fastapi import FastAPI
8
- from fastapi.middleware.cors import CORSMiddleware
9
- from pydantic import BaseModel
10
-
11
- import torch
12
  from transformers import AutoTokenizer, AutoModel
 
13
 
14
- HF_MODEL_ID = os.environ.get("HF_MODEL_ID", "monologg/biobert_v1.1_pubmed").strip()
15
- MAX_LEN = int(os.environ.get("MAX_LEN", "128"))
16
- TORCH_THREADS = int(os.environ.get("TORCH_THREADS", "1"))
17
-
18
- torch.set_num_threads(TORCH_THREADS)
19
 
20
- # --------- Load model & tokenizer (PyTorch) ----------
21
- tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_ID)
22
- model = AutoModel.from_pretrained(HF_MODEL_ID)
23
- model.eval() # inference mode
24
- DEVICE = "cpu"
25
- model.to(DEVICE)
26
 
27
- # --------- FastAPI ----------
28
- app = FastAPI(title="BioBERT (PyTorch) Embeddings API", version="1.0")
 
 
 
 
29
 
30
- # CORS (relax; tighten in production)
31
- app.add_middleware(
32
- CORSMiddleware,
33
- allow_origins=["*"],
34
- allow_credentials=False,
35
- allow_methods=["GET", "POST", "OPTIONS"],
36
- allow_headers=["*"],
37
- )
38
 
39
- class EmbReq(BaseModel):
40
- input: str
41
- max_len: Optional[int] = None
42
- pooling: Optional[str] = "cls" # "cls" or "mean"
43
 
44
- class BatchEmbReq(BaseModel):
45
- inputs: List[str]
46
- max_len: Optional[int] = None
47
- pooling: Optional[str] = "cls" # "cls" or "mean"
48
 
49
- @app.get("/")
50
- def root():
51
- return {
52
- "name": "BioBERT Embeddings (PyTorch)",
53
- "model": HF_MODEL_ID,
54
- "device": DEVICE,
55
- "endpoints": ["/health", "/v1/embeddings", "/v1/embeddings/batch"],
56
- "hint": "POST to /v1/embeddings with {'input': 'your text'}",
57
- }
58
-
59
- @app.get("/health")
60
  def health():
61
- return {"ok": True, "model": HF_MODEL_ID, "device": DEVICE}
62
-
63
- def _pool(outputs, inputs, pooling: str):
64
- """
65
- pooling="cls": use CLS (pooler_output if present, else hidden_state[:,0])
66
- pooling="mean": mean of token embeddings (mask-aware)
67
- """
68
- if pooling == "mean":
69
- last = outputs.last_hidden_state # [B,T,H]
70
- mask = inputs["attention_mask"].unsqueeze(-1).type_as(last) # [B,T,1]
71
- summed = (last * mask).sum(dim=1)
72
- counts = mask.sum(dim=1).clamp(min=1e-9)
73
- return summed / counts
74
- # cls
75
- if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
76
- return outputs.pooler_output
77
- return outputs.last_hidden_state[:, 0, :] # CLS token
78
 
79
- def _embed(texts: List[str], max_len: int, pooling: str) -> List[List[float]]:
 
 
 
80
  enc = tokenizer(
81
- texts,
82
- return_tensors="pt",
83
- padding=True,
84
- truncation=True,
85
- max_length=max_len,
86
  )
87
- enc = {k: v.to(DEVICE) for k, v in enc.items()}
88
  with torch.no_grad():
89
- outputs = model(**enc)
90
- vecs = _pool(outputs, enc, pooling=pooling)
91
- return vecs.cpu().numpy().tolist()
92
-
93
- @app.post("/v1/embeddings")
94
- def embeddings(req: EmbReq):
95
- text = (req.input or "").strip()
96
- if not text:
97
- return {"embedding": [], "dim": 0}
98
- L = int(req.max_len or MAX_LEN)
99
- pooling = (req.pooling or "cls").lower()
100
- vec = _embed([text], L, pooling)[0]
101
- return {"embedding": vec, "dim": len(vec), "pooling": pooling}
102
-
103
- @app.post("/v1/embeddings/batch")
104
- def embeddings_batch(req: BatchEmbReq):
105
- items = [str(t).strip() for t in (req.inputs or []) if str(t).strip()]
106
- if not items:
107
- return {"embeddings": [], "dim": 0}
108
- L = int(req.max_len or MAX_LEN)
109
- pooling = (req.pooling or "cls").lower()
110
- vecs = _embed(items, L, pooling)
111
- return {"embeddings": vecs, "dim": len(vecs[0]), "pooling": pooling}
 
 
 
 
 
 
 
1
  from fastapi import FastAPI
2
+ from pydantic import BaseModel, Field
3
+ from typing import List
 
 
4
  from transformers import AutoTokenizer, AutoModel
5
+ import torch, os
6
 
7
+ MODEL_ID = "dmis-lab/biobert-base-cased-v1"
 
 
 
 
8
 
9
+ # Load once at startup
10
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
11
+ model = AutoModel.from_pretrained(MODEL_ID)
12
+ model.eval()
 
 
13
 
14
+ def mean_pooling(model_output, attention_mask):
15
+ token_embeddings = model_output[0] # [batch, seq, hidden]
16
+ mask = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
17
+ summed = (token_embeddings * mask).sum(1)
18
+ counts = mask.sum(1).clamp(min=1e-9)
19
+ return summed / counts
20
 
21
+ class EmbedRequest(BaseModel):
22
+ texts: List[str] = Field(default_factory=list)
23
+ max_length: int = 256
 
 
 
 
 
24
 
25
+ class EmbedResponse(BaseModel):
26
+ embeddings: List[List[float]]
 
 
27
 
28
+ app = FastAPI(title="BioBERT Embeddings", version="1.0")
 
 
 
29
 
30
+ @app.get("/healthz")
 
 
 
 
 
 
 
 
 
 
31
  def health():
32
+ return {"ok": True}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ @app.post("/embed", response_model=EmbedResponse)
35
+ def embed(req: EmbedRequest):
36
+ if not req.texts:
37
+ return {"embeddings": []}
38
  enc = tokenizer(
39
+ req.texts, padding=True, truncation=True,
40
+ max_length=req.max_length, return_tensors="pt"
 
 
 
41
  )
 
42
  with torch.no_grad():
43
+ out = model(**enc)
44
+ pooled = mean_pooling(out, enc["attention_mask"])
45
+ pooled = torch.nn.functional.normalize(pooled, p=2, dim=1)
46
+ return {"embeddings": pooled.cpu().tolist()}
47
+
48
+ if __name__ == "__main__":
49
+ import uvicorn
50
+ uvicorn.run("app:app", host="0.0.0.0", port=int(os.getenv("PORT", "7860")), workers=1)