felixbet commited on
Commit
9b01576
·
verified ·
1 Parent(s): 9bd55b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -45
app.py CHANGED
@@ -1,72 +1,82 @@
1
- # app.py — self-bootstrapping TF BioBERT embeddings API (HF Spaces-friendly)
 
2
 
3
  import os, tarfile, glob, json, shutil, urllib.request
 
 
4
  from fastapi import FastAPI
 
5
  from pydantic import BaseModel
6
- from typing import List
 
 
7
  from transformers import BertTokenizer, BertConfig, TFBertModel
8
- import tensorflow as tf # noqa
9
 
10
- app = FastAPI()
11
 
12
- # --- Config
13
- MODEL_ROOT = os.environ.get("MODEL_ROOT", "/app/bert_tf")
14
  WEIGHTS_URL = os.environ.get("WEIGHTS_URL_TAR_GZ", "").strip() # direct .tar.gz link (Dropbox must end with dl=1)
15
  FALLBACK_VOCAB_URL = "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt"
 
16
 
17
  os.makedirs(MODEL_ROOT, exist_ok=True)
18
 
19
- def _extract_tar_gz(src: str, dest: str) -> None:
 
 
20
  with tarfile.open(src, "r:gz") as tar:
21
- def is_within(directory, target):
22
  abs_directory = os.path.abspath(directory)
23
  abs_target = os.path.abspath(target)
24
  return os.path.commonpath([abs_directory]) == os.path.commonpath([abs_directory, abs_target])
25
  for member in tar.getmembers():
26
  target_path = os.path.join(dest, member.name)
27
- if not is_within(dest, target_path):
28
  raise RuntimeError("Blocked path traversal in tar")
29
  tar.extractall(dest)
30
 
31
- def ensure_weights_and_get_model_dir() -> str:
32
- # If already prepared (vocab + any ckpt index) → reuse
33
- maybe_vocab = glob.glob(os.path.join(MODEL_ROOT, "**", "vocab.txt"), recursive=True)
34
- maybe_idx = glob.glob(os.path.join(MODEL_ROOT, "**", "model.ckpt-*.index"), recursive=True)
35
- if maybe_vocab and maybe_idx:
36
- # choose dir containing the first ckpt index
37
- return os.path.dirname(maybe_idx[0])
38
-
39
- # Otherwise download and extract the archive
40
- if not WEIGHTS_URL:
41
- print("[app] WEIGHTS_URL_TAR_GZ not set; will still try to run with fallback vocab if files exist.")
42
- else:
43
  print("[app] downloading weights:", WEIGHTS_URL)
44
  local_tar = "/tmp/model.tar.gz"
45
  urllib.request.urlretrieve(WEIGHTS_URL, local_tar)
46
  print("[app] extracting:", local_tar, "->", MODEL_ROOT)
47
- _extract_tar_gz(local_tar, MODEL_ROOT)
 
48
 
49
- # Pick the folder that has a ckpt index
50
- idx_files = glob.glob(os.path.join(MODEL_ROOT, "**", "model.ckpt-*.index"), recursive=True)
51
- if not idx_files:
52
- raise RuntimeError("No TensorFlow checkpoint index found under " + MODEL_ROOT)
53
- model_dir = os.path.dirname(idx_files[0])
54
 
55
- # Ensure checkpoint file points at the basename
56
- basename = os.path.basename(idx_files[0]).replace(".index", "")
 
 
 
 
 
 
57
  ckpt_meta = os.path.join(model_dir, "checkpoint")
58
  if not os.path.isfile(ckpt_meta):
59
  with open(ckpt_meta, "w") as f:
60
  f.write(f'model_checkpoint_path: "{basename}"\n')
61
 
62
- # Ensure config.json
63
- cfg = os.path.join(model_dir, "config.json")
64
  bcfg = os.path.join(model_dir, "bert_config.json")
65
- if not os.path.isfile(cfg):
66
  if os.path.isfile(bcfg):
67
- shutil.copy(bcfg, cfg)
68
  else:
69
- with open(cfg, "w") as f:
70
  json.dump({
71
  "hidden_size": 768,
72
  "num_attention_heads": 12,
@@ -86,26 +96,88 @@ def ensure_weights_and_get_model_dir() -> str:
86
  print("[app] vocab.txt missing; fetching BERT base uncased vocab…")
87
  urllib.request.urlretrieve(FALLBACK_VOCAB_URL, vocab)
88
 
89
- return model_dir
 
 
 
 
 
 
 
 
 
90
 
91
- # Prepare weights (download/extract if needed), then load model
92
- MODEL_DIR = ensure_weights_and_get_model_dir()
93
- print("[app] Using MODEL_DIR:", MODEL_DIR)
94
 
95
- tok = BertTokenizer(vocab_file=os.path.join(MODEL_DIR, "vocab.txt"), do_lower_case=True)
96
- cfg = BertConfig.from_json_file(os.path.join(MODEL_DIR, "config.json"))
97
- model= TFBertModel.from_pretrained(MODEL_DIR, from_tf=True, config=cfg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  class EmbReq(BaseModel):
100
  input: str
 
 
 
 
 
101
 
102
  @app.get("/health")
103
  def health():
104
- return {"ok": True}
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  @app.post("/v1/embeddings")
107
  def embeddings(req: EmbReq):
108
- enc = tok(req.input, return_tensors="tf", truncation=True, max_length=128)
109
- out = model(**enc)
110
- vec = out.pooler_output[0].numpy().tolist()
 
 
111
  return {"embedding": vec, "dim": len(vec)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py — FastAPI TF-BioBERT embeddings service (handles TF1 checkpoints)
2
+ # Requires: transformers==4.43.4, tensorflow-cpu==2.16.1, tf-keras, fastapi, uvicorn[standard]
3
 
4
  import os, tarfile, glob, json, shutil, urllib.request
5
+ from typing import List, Optional
6
+
7
  from fastapi import FastAPI
8
+ from fastapi.middleware.cors import CORSMiddleware
9
  from pydantic import BaseModel
10
+
11
+ # Import TensorFlow before Transformers TF models to avoid odd init order issues
12
+ import tensorflow as tf # noqa: F401
13
  from transformers import BertTokenizer, BertConfig, TFBertModel
 
14
 
15
+ # ---------------------------- Config ----------------------------
16
 
17
+ MODEL_ROOT = os.environ.get("MODEL_ROOT", "/app/bert_tf").rstrip("/")
 
18
  WEIGHTS_URL = os.environ.get("WEIGHTS_URL_TAR_GZ", "").strip() # direct .tar.gz link (Dropbox must end with dl=1)
19
  FALLBACK_VOCAB_URL = "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt"
20
+ MAX_LEN = int(os.environ.get("MAX_LEN", "128"))
21
 
22
  os.makedirs(MODEL_ROOT, exist_ok=True)
23
 
24
+ # ---------------------- Utils: safe extract ---------------------
25
+
26
+ def _safe_extract_tar_gz(src: str, dest: str) -> None:
27
  with tarfile.open(src, "r:gz") as tar:
28
+ def _is_within(directory, target):
29
  abs_directory = os.path.abspath(directory)
30
  abs_target = os.path.abspath(target)
31
  return os.path.commonpath([abs_directory]) == os.path.commonpath([abs_directory, abs_target])
32
  for member in tar.getmembers():
33
  target_path = os.path.join(dest, member.name)
34
+ if not _is_within(dest, target_path):
35
  raise RuntimeError("Blocked path traversal in tar")
36
  tar.extractall(dest)
37
 
38
+ # ---------------------- Bootstrap weights ----------------------
39
+
40
+ def ensure_weights_and_locate() -> (str, str):
41
+ """
42
+ Returns:
43
+ model_dir: directory containing vocab.txt/config.json/checkpoint + ckpt files
44
+ ckpt_prefix: full path WITHOUT extension, e.g. /app/bert_tf/bert_min/model.ckpt-150000
45
+ """
46
+ # Already present?
47
+ maybe_idx = glob.glob(os.path.join(MODEL_ROOT, "**", "model.ckpt-*.index"), recursive=True)
48
+ if not maybe_idx and WEIGHTS_URL:
 
49
  print("[app] downloading weights:", WEIGHTS_URL)
50
  local_tar = "/tmp/model.tar.gz"
51
  urllib.request.urlretrieve(WEIGHTS_URL, local_tar)
52
  print("[app] extracting:", local_tar, "->", MODEL_ROOT)
53
+ _safe_extract_tar_gz(local_tar, MODEL_ROOT)
54
+ maybe_idx = glob.glob(os.path.join(MODEL_ROOT, "**", "model.ckpt-*.index"), recursive=True)
55
 
56
+ if not maybe_idx:
57
+ raise RuntimeError(f"No TensorFlow checkpoint *.index found under {MODEL_ROOT}")
 
 
 
58
 
59
+ # Prefer shortest path depth (avoids weird nested dirs)
60
+ maybe_idx.sort(key=lambda p: len(os.path.relpath(p, MODEL_ROOT).split(os.sep)))
61
+ ckpt_index = maybe_idx[0]
62
+ model_dir = os.path.dirname(ckpt_index)
63
+ ckpt_prefix = ckpt_index.replace(".index", "")
64
+
65
+ # Ensure checkpoint meta file points to the basename
66
+ basename = os.path.basename(ckpt_prefix)
67
  ckpt_meta = os.path.join(model_dir, "checkpoint")
68
  if not os.path.isfile(ckpt_meta):
69
  with open(ckpt_meta, "w") as f:
70
  f.write(f'model_checkpoint_path: "{basename}"\n')
71
 
72
+ # Ensure config.json (copy from bert_config.json if present, else write default BERT base config)
73
+ cfg_json = os.path.join(model_dir, "config.json")
74
  bcfg = os.path.join(model_dir, "bert_config.json")
75
+ if not os.path.isfile(cfg_json):
76
  if os.path.isfile(bcfg):
77
+ shutil.copy(bcfg, cfg_json)
78
  else:
79
+ with open(cfg_json, "w") as f:
80
  json.dump({
81
  "hidden_size": 768,
82
  "num_attention_heads": 12,
 
96
  print("[app] vocab.txt missing; fetching BERT base uncased vocab…")
97
  urllib.request.urlretrieve(FALLBACK_VOCAB_URL, vocab)
98
 
99
+ # Sanity: ensure data shard exists
100
+ data_glob = glob.glob(os.path.join(model_dir, "model.ckpt-*.data-00000-of-00001"))
101
+ if not data_glob:
102
+ raise RuntimeError(f"Checkpoint data file missing in {model_dir} (model.ckpt-*.data-00000-of-00001)")
103
+
104
+ print("[app] Using MODEL_DIR:", model_dir)
105
+ print("[app] Using CKPT_PREFIX:", ckpt_prefix)
106
+ return model_dir, ckpt_prefix
107
+
108
+ MODEL_DIR, CKPT_PREFIX = ensure_weights_and_locate()
109
 
110
+ # ---------------------- Load tokenizer & model ------------------
 
 
111
 
112
+ tok = BertTokenizer(vocab_file=os.path.join(MODEL_DIR, "vocab.txt"), do_lower_case=True)
113
+ cfg = BertConfig.from_json_file(os.path.join(MODEL_DIR, "config.json"))
114
+
115
+ # IMPORTANT: load from TF1 checkpoint using the PREFIX (not folder)
116
+ model = TFBertModel.from_pretrained(
117
+ CKPT_PREFIX,
118
+ from_tf=True, # TF1 .ckpt import
119
+ from_pt=False,
120
+ config=cfg
121
+ )
122
+
123
+ # ---------------------------- API ------------------------------
124
+
125
+ app = FastAPI(title="BioBERT-TF Embeddings API", version="1.0")
126
+
127
+ # Optional: allow your website to call this API directly
128
+ app.add_middleware(
129
+ CORSMiddleware,
130
+ allow_origins=["*"], # tighten in production
131
+ allow_credentials=False,
132
+ allow_methods=["GET", "POST", "OPTIONS"],
133
+ allow_headers=["*"],
134
+ )
135
 
136
  class EmbReq(BaseModel):
137
  input: str
138
+ max_len: Optional[int] = None
139
+
140
+ class BatchEmbReq(BaseModel):
141
+ inputs: List[str]
142
+ max_len: Optional[int] = None
143
 
144
  @app.get("/health")
145
  def health():
146
+ return {"ok": True, "model_dir": MODEL_DIR, "ckpt_prefix": CKPT_PREFIX}
147
+
148
+ def _embed(texts: List[str], max_len: int) -> List[List[float]]:
149
+ enc = tok(texts, return_tensors="tf", truncation=True, padding=True, max_length=max_len)
150
+ out = model(**enc, training=False)
151
+ # Prefer pooled output if available; fallback to mean of last_hidden_state
152
+ if hasattr(out, "pooler_output") and out.pooler_output is not None:
153
+ vecs = out.pooler_output.numpy()
154
+ else:
155
+ last = out.last_hidden_state.numpy()
156
+ vecs = last.mean(axis=1)
157
+ return [v.tolist() for v in vecs]
158
 
159
  @app.post("/v1/embeddings")
160
  def embeddings(req: EmbReq):
161
+ text = req.input.strip()
162
+ if not text:
163
+ return {"embedding": [], "dim": 0}
164
+ L = int(req.max_len or MAX_LEN)
165
+ vec = _embed([text], L)[0]
166
  return {"embedding": vec, "dim": len(vec)}
167
+
168
+ @app.post("/v1/embeddings/batch")
169
+ def embeddings_batch(req: BatchEmbReq):
170
+ items = [t.strip() for t in req.inputs if str(t).strip()]
171
+ if not items:
172
+ return {"embeddings": [], "dim": 0}
173
+ L = int(req.max_len or MAX_LEN)
174
+ vecs = _embed(items, L)
175
+ return {"embeddings": vecs, "dim": len(vecs[0])}
176
+
177
+ @app.get("/")
178
+ def root():
179
+ return {
180
+ "name": "BioBERT-TF Embeddings",
181
+ "endpoints": ["/health", "/v1/embeddings", "/v1/embeddings/batch"],
182
+ "hint": "POST to /v1/embeddings with {'input': 'your text'}"
183
+ }