felixbet commited on
Commit
1c0323e
·
verified ·
1 Parent(s): 512acbc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -133
app.py CHANGED
@@ -1,120 +1,33 @@
1
- # FastAPI BioBERT embeddings (Hub-first, no TF1 ckpt dependency)
2
- # Works free on Hugging Face Spaces (CPU). Auto-converts PyTorch -> TF.
3
 
4
- import os, tarfile, glob, json, shutil, urllib.request
5
  from typing import List, Optional
6
 
7
  from fastapi import FastAPI
8
  from fastapi.middleware.cors import CORSMiddleware
9
  from pydantic import BaseModel
10
 
11
- # Load TF before transformers' TF models
12
- import tensorflow as tf # noqa: F401
13
- from transformers import AutoTokenizer, TFAutoModel, BertConfig
14
 
15
- # ------------------- Config -------------------
16
  HF_MODEL_ID = os.environ.get("HF_MODEL_ID", "monologg/biobert_v1.1_pubmed").strip()
17
- MODEL_ROOT = os.environ.get("MODEL_ROOT", "/app/bert_tf").rstrip("/")
18
- WEIGHTS_URL = os.environ.get("WEIGHTS_URL_TAR_GZ", "").strip() # optional direct .tar.gz (Dropbox must end with dl=1)
19
- MAX_LEN = int(os.environ.get("MAX_LEN", "128"))
20
- os.makedirs(MODEL_ROOT, exist_ok=True)
21
-
22
- # ------------------- Utils --------------------
23
- def _safe_extract_tar_gz(src: str, dest: str) -> None:
24
- with tarfile.open(src, "r:gz") as tar:
25
- def _is_within(directory, target):
26
- ad = os.path.abspath(directory); at = os.path.abspath(target)
27
- return os.path.commonpath([ad]) == os.path.commonpath([ad, at])
28
- for m in tar.getmembers():
29
- tp = os.path.join(dest, m.name)
30
- if not _is_within(dest, tp):
31
- raise RuntimeError("Blocked path traversal in tar")
32
- tar.extractall(dest)
33
-
34
- def _maybe_download_tar_into_model_root() -> Optional[str]:
35
- """If WEIGHTS_URL is set, download + extract it into MODEL_ROOT. Return extracted dir if any."""
36
- if not WEIGHTS_URL:
37
- return None
38
- print("[app] downloading weights:", WEIGHTS_URL)
39
- local_tar = "/tmp/model.tar.gz"
40
- urllib.request.urlretrieve(WEIGHTS_URL, local_tar)
41
- print("[app] extracting:", local_tar, "->", MODEL_ROOT)
42
- _safe_extract_tar_gz(local_tar, MODEL_ROOT)
43
- # return shallowest dir inside MODEL_ROOT
44
- candidates = [d for d in glob.glob(os.path.join(MODEL_ROOT, "*")) if os.path.isdir(d)]
45
- if not candidates:
46
- return MODEL_ROOT
47
- candidates.sort(key=lambda p: len(os.path.relpath(p, MODEL_ROOT).split(os.sep)))
48
- return candidates[0]
49
-
50
- def _detect_local_hf_dir(root: str) -> Optional[str]:
51
- """
52
- Return a directory under root that looks like a modern HF model folder:
53
- - pytorch_model.bin / model.safetensors (for from_pt=True)
54
- - OR tf_model.h5 (native TF)
55
- """
56
- # search at depth 0/1/2
57
- for depth in range(3):
58
- pattern = os.path.join(root, *(["**"] if depth else []))
59
- # prefer TF weights first if present
60
- tf_h5 = glob.glob(os.path.join(pattern, "tf_model.h5"), recursive=True)
61
- if tf_h5:
62
- tf_h5.sort(key=lambda p: len(os.path.relpath(p, root).split(os.sep)))
63
- return os.path.dirname(tf_h5[0])
64
-
65
- # else look for PT/safetensors
66
- pt = glob.glob(os.path.join(pattern, "pytorch_model.bin"), recursive=True)
67
- st = glob.glob(os.path.join(pattern, "model.safetensors"), recursive=True)
68
- have = (pt or st)
69
- if have:
70
- have.sort(key=lambda p: len(os.path.relpath(p, root).split(os.sep)))
71
- return os.path.dirname(have[0])
72
- return None
73
-
74
- def _looks_like_tf1_ckpt_dir(path: str) -> bool:
75
- return bool(glob.glob(os.path.join(path, "model.ckpt-*.index")))
76
-
77
- # ------------------- Load strategy -------------------
78
- # 1) If a tar URL is provided, unpack it (optional convenience)
79
- extracted = _maybe_download_tar_into_model_root()
80
-
81
- # 2) If after extraction we have a local HF-style folder, use it
82
- LOCAL_DIR = _detect_local_hf_dir(MODEL_ROOT)
83
-
84
- # 3) If only TF1 ckpt found, refuse with a clear message (no fragile loaders)
85
- if not LOCAL_DIR:
86
- # If there is any directory in MODEL_ROOT with TF1 ckpts, warn
87
- for d in [MODEL_ROOT] + [p for p in glob.glob(os.path.join(MODEL_ROOT, "*")) if os.path.isdir(p)]:
88
- if _looks_like_tf1_ckpt_dir(d):
89
- raise RuntimeError(
90
- "Found TF-1 checkpoint files (model.ckpt-*) but this app purposely avoids "
91
- "runtime TF-1 → TF-2 weight mapping. Either:\n"
92
- " • Set HF_MODEL_ID to a BioBERT model on the Hub (recommended), e.g. 'monologg/biobert_v1.1_pubmed'\n"
93
- " • Or package modern HF weights (pytorch_model.bin/model.safetensors or tf_model.h5) in your tar."
94
- )
95
-
96
- # 4) Tokenizer+Model
97
- if LOCAL_DIR:
98
- print(f"[app] Using LOCAL_DIR: {LOCAL_DIR}")
99
- # Prefer native TF if available, else convert from PT
100
- if os.path.isfile(os.path.join(LOCAL_DIR, "tf_model.h5")):
101
- tokenizer = AutoTokenizer.from_pretrained(LOCAL_DIR)
102
- model = TFAutoModel.from_pretrained(LOCAL_DIR)
103
- USED = {"source": "local", "format": "tf_h5", "path": LOCAL_DIR}
104
- else:
105
- tokenizer = AutoTokenizer.from_pretrained(LOCAL_DIR)
106
- model = TFAutoModel.from_pretrained(LOCAL_DIR, from_pt=True)
107
- USED = {"source": "local", "format": "pt/safetensors->tf", "path": LOCAL_DIR}
108
- else:
109
- print(f"[app] Using HF_MODEL_ID: {HF_MODEL_ID}")
110
- tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_ID)
111
- # Most BioBERT repos are PyTorch; allow auto-conversion
112
- model = TFAutoModel.from_pretrained(HF_MODEL_ID, from_pt=True)
113
- USED = {"source": "hub", "model_id": HF_MODEL_ID, "format": "pt->tf"}
114
-
115
- # ------------------- API -------------------
116
- app = FastAPI(title="BioBERT Embeddings API (Hub-first)", version="2.0")
117
 
 
118
  app.add_middleware(
119
  CORSMiddleware,
120
  allow_origins=["*"],
@@ -126,48 +39,73 @@ app.add_middleware(
126
  class EmbReq(BaseModel):
127
  input: str
128
  max_len: Optional[int] = None
 
129
 
130
  class BatchEmbReq(BaseModel):
131
  inputs: List[str]
132
  max_len: Optional[int] = None
 
 
 
 
 
 
 
 
 
 
 
133
 
134
  @app.get("/health")
135
  def health():
136
- return {"ok": True, "strategy": USED, "max_len_default": MAX_LEN}
137
-
138
- def _embed(texts: List[str], max_len: int) -> List[List[float]]:
139
- enc = tokenizer(texts, return_tensors="tf", truncation=True, padding=True, max_length=max_len)
140
- out = model(**enc, training=False)
141
- if hasattr(out, "pooler_output") and out.pooler_output is not None:
142
- vecs = out.pooler_output.numpy()
143
- else:
144
- last = out.last_hidden_state.numpy()
145
- vecs = last.mean(axis=1)
146
- return [v.tolist() for v in vecs]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
  @app.post("/v1/embeddings")
149
  def embeddings(req: EmbReq):
150
- text = req.input.strip()
151
  if not text:
152
  return {"embedding": [], "dim": 0}
153
  L = int(req.max_len or MAX_LEN)
154
- vec = _embed([text], L)[0]
155
- return {"embedding": vec, "dim": len(vec)}
 
156
 
157
  @app.post("/v1/embeddings/batch")
158
  def embeddings_batch(req: BatchEmbReq):
159
- items = [t.strip() for t in req.inputs if str(t).strip()]
160
  if not items:
161
  return {"embeddings": [], "dim": 0}
162
  L = int(req.max_len or MAX_LEN)
163
- vecs = _embed(items, L)
164
- return {"embeddings": vecs, "dim": len(vecs[0])}
165
-
166
- @app.get("/")
167
- def root():
168
- return {
169
- "name": "BioBERT Embeddings (Hub-first)",
170
- "endpoints": ["/health", "/v1/embeddings", "/v1/embeddings/batch"],
171
- "hint": "POST /v1/embeddings with {'input': 'your text'}",
172
- "strategy": USED
173
- }
 
1
+ # app.py — FastAPI embeddings service using PyTorch BioBERT
2
+ # Works on Hugging Face Spaces (CPU Basic, free)
3
 
4
+ import os
5
  from typing import List, Optional
6
 
7
  from fastapi import FastAPI
8
  from fastapi.middleware.cors import CORSMiddleware
9
  from pydantic import BaseModel
10
 
11
+ import torch
12
+ from transformers import AutoTokenizer, AutoModel
 
13
 
 
14
  HF_MODEL_ID = os.environ.get("HF_MODEL_ID", "monologg/biobert_v1.1_pubmed").strip()
15
+ MAX_LEN = int(os.environ.get("MAX_LEN", "128"))
16
+ TORCH_THREADS = int(os.environ.get("TORCH_THREADS", "1"))
17
+
18
+ torch.set_num_threads(TORCH_THREADS)
19
+
20
+ # --------- Load model & tokenizer (PyTorch) ----------
21
+ tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_ID)
22
+ model = AutoModel.from_pretrained(HF_MODEL_ID)
23
+ model.eval() # inference mode
24
+ DEVICE = "cpu"
25
+ model.to(DEVICE)
26
+
27
+ # --------- FastAPI ----------
28
+ app = FastAPI(title="BioBERT (PyTorch) Embeddings API", version="1.0")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
+ # CORS (relax; tighten in production)
31
  app.add_middleware(
32
  CORSMiddleware,
33
  allow_origins=["*"],
 
39
  class EmbReq(BaseModel):
40
  input: str
41
  max_len: Optional[int] = None
42
+ pooling: Optional[str] = "cls" # "cls" or "mean"
43
 
44
  class BatchEmbReq(BaseModel):
45
  inputs: List[str]
46
  max_len: Optional[int] = None
47
+ pooling: Optional[str] = "cls" # "cls" or "mean"
48
+
49
+ @app.get("/")
50
+ def root():
51
+ return {
52
+ "name": "BioBERT Embeddings (PyTorch)",
53
+ "model": HF_MODEL_ID,
54
+ "device": DEVICE,
55
+ "endpoints": ["/health", "/v1/embeddings", "/v1/embeddings/batch"],
56
+ "hint": "POST to /v1/embeddings with {'input': 'your text'}",
57
+ }
58
 
59
  @app.get("/health")
60
  def health():
61
+ return {"ok": True, "model": HF_MODEL_ID, "device": DEVICE}
62
+
63
+ def _pool(outputs, inputs, pooling: str):
64
+ """
65
+ pooling="cls": use CLS (pooler_output if present, else hidden_state[:,0])
66
+ pooling="mean": mean of token embeddings (mask-aware)
67
+ """
68
+ if pooling == "mean":
69
+ last = outputs.last_hidden_state # [B,T,H]
70
+ mask = inputs["attention_mask"].unsqueeze(-1).type_as(last) # [B,T,1]
71
+ summed = (last * mask).sum(dim=1)
72
+ counts = mask.sum(dim=1).clamp(min=1e-9)
73
+ return summed / counts
74
+ # cls
75
+ if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
76
+ return outputs.pooler_output
77
+ return outputs.last_hidden_state[:, 0, :] # CLS token
78
+
79
+ def _embed(texts: List[str], max_len: int, pooling: str) -> List[List[float]]:
80
+ enc = tokenizer(
81
+ texts,
82
+ return_tensors="pt",
83
+ padding=True,
84
+ truncation=True,
85
+ max_length=max_len,
86
+ )
87
+ enc = {k: v.to(DEVICE) for k, v in enc.items()}
88
+ with torch.no_grad():
89
+ outputs = model(**enc)
90
+ vecs = _pool(outputs, enc, pooling=pooling)
91
+ return vecs.cpu().numpy().tolist()
92
 
93
  @app.post("/v1/embeddings")
94
  def embeddings(req: EmbReq):
95
+ text = (req.input or "").strip()
96
  if not text:
97
  return {"embedding": [], "dim": 0}
98
  L = int(req.max_len or MAX_LEN)
99
+ pooling = (req.pooling or "cls").lower()
100
+ vec = _embed([text], L, pooling)[0]
101
+ return {"embedding": vec, "dim": len(vec), "pooling": pooling}
102
 
103
  @app.post("/v1/embeddings/batch")
104
  def embeddings_batch(req: BatchEmbReq):
105
+ items = [str(t).strip() for t in (req.inputs or []) if str(t).strip()]
106
  if not items:
107
  return {"embeddings": [], "dim": 0}
108
  L = int(req.max_len or MAX_LEN)
109
+ pooling = (req.pooling or "cls").lower()
110
+ vecs = _embed(items, L, pooling)
111
+ return {"embeddings": vecs, "dim": len(vecs[0]), "pooling": pooling}