felixbet commited on
Commit
512acbc
·
verified ·
1 Parent(s): 7bec0b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -130
app.py CHANGED
@@ -1,10 +1,5 @@
1
- # app.py — FastAPI TF-BioBERT embeddings service (TF1 checkpoint loader)
2
- # Pin these (requirements.txt):
3
- # fastapi
4
- # uvicorn[standard]
5
- # transformers==4.43.4
6
- # tensorflow-cpu==2.16.1
7
- # tf-keras
8
 
9
  import os, tarfile, glob, json, shutil, urllib.request
10
  from typing import List, Optional
@@ -13,140 +8,112 @@ from fastapi import FastAPI
13
  from fastapi.middleware.cors import CORSMiddleware
14
  from pydantic import BaseModel
15
 
16
- # Import TF first
17
  import tensorflow as tf # noqa: F401
18
- from transformers import BertTokenizer, BertConfig, TFBertModel
19
-
20
- # For TF1 checkpoint loading
21
- try:
22
- # Present in transformers TF BERT module
23
- from transformers.models.bert.modeling_tf_bert import load_tf_weights_in_bert as _hf_load_tf_ckpt
24
- except Exception:
25
- _hf_load_tf_ckpt = None
26
-
27
- MODEL_ROOT = os.environ.get("MODEL_ROOT", "/app/bert_tf").rstrip("/")
28
- WEIGHTS_URL = os.environ.get("WEIGHTS_URL_TAR_GZ", "").strip() # direct .tar.gz (Dropbox must end with dl=1)
29
- FALLBACK_VOCAB_URL = "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt"
30
- MAX_LEN = int(os.environ.get("MAX_LEN", "128"))
31
 
 
 
 
 
 
32
  os.makedirs(MODEL_ROOT, exist_ok=True)
33
 
 
34
  def _safe_extract_tar_gz(src: str, dest: str) -> None:
35
  with tarfile.open(src, "r:gz") as tar:
36
  def _is_within(directory, target):
37
- abs_directory = os.path.abspath(directory)
38
- abs_target = os.path.abspath(target)
39
- return os.path.commonpath([abs_directory]) == os.path.commonpath([abs_directory, abs_target])
40
  for m in tar.getmembers():
41
  tp = os.path.join(dest, m.name)
42
  if not _is_within(dest, tp):
43
  raise RuntimeError("Blocked path traversal in tar")
44
  tar.extractall(dest)
45
 
46
- def ensure_weights_and_locate() -> (str, str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  """
48
- Returns:
49
- model_dir: folder containing vocab/config/checkpoint + ckpt files
50
- ckpt_prefix: full path WITHOUT extension, e.g. /app/bert_tf/bert_min/model.ckpt-150000
51
  """
52
- maybe_idx = glob.glob(os.path.join(MODEL_ROOT, "**", "model.ckpt-*.index"), recursive=True)
53
- if not maybe_idx and WEIGHTS_URL:
54
- print("[app] downloading weights:", WEIGHTS_URL)
55
- local_tar = "/tmp/model.tar.gz"
56
- urllib.request.urlretrieve(WEIGHTS_URL, local_tar)
57
- print("[app] extracting:", local_tar, "->", MODEL_ROOT)
58
- _safe_extract_tar_gz(local_tar, MODEL_ROOT)
59
- maybe_idx = glob.glob(os.path.join(MODEL_ROOT, "**", "model.ckpt-*.index"), recursive=True)
60
-
61
- if not maybe_idx:
62
- raise RuntimeError(f"No TensorFlow checkpoint *.index found under {MODEL_ROOT}")
63
-
64
- # Prefer shallowest
65
- maybe_idx.sort(key=lambda p: len(os.path.relpath(p, MODEL_ROOT).split(os.sep)))
66
- ckpt_index = maybe_idx[0]
67
- model_dir = os.path.dirname(ckpt_index)
68
- ckpt_prefix = ckpt_index.replace(".index", "")
69
-
70
- # checkpoint meta
71
- basename = os.path.basename(ckpt_prefix)
72
- ckpt_meta = os.path.join(model_dir, "checkpoint")
73
- if not os.path.isfile(ckpt_meta):
74
- with open(ckpt_meta, "w") as f:
75
- f.write(f'model_checkpoint_path: "{basename}"\n')
76
-
77
- # config.json (copy bert_config.json if present)
78
- cfg_json = os.path.join(model_dir, "config.json")
79
- bcfg = os.path.join(model_dir, "bert_config.json")
80
- if not os.path.isfile(cfg_json):
81
- if os.path.isfile(bcfg):
82
- shutil.copy(bcfg, cfg_json)
83
- else:
84
- with open(cfg_json, "w") as f:
85
- json.dump({
86
- "hidden_size": 768,
87
- "num_attention_heads": 12,
88
- "num_hidden_layers": 12,
89
- "intermediate_size": 3072,
90
- "hidden_act": "gelu",
91
- "hidden_dropout_prob": 0.1,
92
- "attention_probs_dropout_prob": 0.1,
93
- "max_position_embeddings": 512,
94
- "type_vocab_size": 2,
95
- "vocab_size": 30522
96
- }, f)
97
-
98
- # vocab.txt (BioBERT uses BERT base uncased vocab)
99
- vocab = os.path.join(model_dir, "vocab.txt")
100
- if not os.path.isfile(vocab):
101
- print("[app] vocab.txt missing; fetching BERT base uncased vocab…")
102
- urllib.request.urlretrieve(FALLBACK_VOCAB_URL, vocab)
103
-
104
- # data shard sanity
105
- data_glob = glob.glob(os.path.join(model_dir, "model.ckpt-*.data-00000-of-00001"))
106
- if not data_glob:
107
- raise RuntimeError(f"Checkpoint data file missing in {model_dir} (model.ckpt-*.data-00000-of-00001)")
108
-
109
- print("[app] Using MODEL_DIR:", model_dir)
110
- print("[app] Using CKPT_PREFIX:", ckpt_prefix)
111
- return model_dir, ckpt_prefix
112
-
113
- MODEL_DIR, CKPT_PREFIX = ensure_weights_and_locate()
114
-
115
- # Tokenizer + config
116
- tok = BertTokenizer(vocab_file=os.path.join(MODEL_DIR, "vocab.txt"), do_lower_case=True)
117
- cfg = BertConfig.from_json_file(os.path.join(MODEL_DIR, "config.json"))
118
-
119
- # Build model skeleton
120
- model = TFBertModel(cfg)
121
-
122
- # Load TF1 checkpoint (no from_tf kwarg!)
123
- loaded = False
124
- err_stack = []
125
-
126
- if _hf_load_tf_ckpt is not None:
127
- try:
128
- # Some transformer versions: (model, ckpt_prefix)
129
- _hf_load_tf_ckpt(model, CKPT_PREFIX)
130
- loaded = True
131
- print("[app] Loaded TF1 checkpoint via load_tf_weights_in_bert(model, ckpt_prefix)")
132
- except TypeError as e1:
133
- err_stack.append(str(e1))
134
- try:
135
- # Other versions: (model, config, ckpt_prefix)
136
- _hf_load_tf_ckpt(model, cfg, CKPT_PREFIX)
137
- loaded = True
138
- print("[app] Loaded TF1 checkpoint via load_tf_weights_in_bert(model, config, ckpt_prefix)")
139
- except Exception as e2:
140
- err_stack.append(str(e2))
141
-
142
- if not loaded:
143
- raise RuntimeError(
144
- "Could not load TF1 checkpoint with transformers' loader. "
145
- f"ckpt={CKPT_PREFIX}\nErrors: {err_stack or 'no loader available'}"
146
- )
147
-
148
- # ---------- API ----------
149
- app = FastAPI(title="BioBERT-TF Embeddings API", version="1.1")
150
 
151
  app.add_middleware(
152
  CORSMiddleware,
@@ -166,10 +133,10 @@ class BatchEmbReq(BaseModel):
166
 
167
  @app.get("/health")
168
  def health():
169
- return {"ok": True, "model_dir": MODEL_DIR, "ckpt_prefix": CKPT_PREFIX}
170
 
171
  def _embed(texts: List[str], max_len: int) -> List[List[float]]:
172
- enc = tok(texts, return_tensors="tf", truncation=True, padding=True, max_length=max_len)
173
  out = model(**enc, training=False)
174
  if hasattr(out, "pooler_output") and out.pooler_output is not None:
175
  vecs = out.pooler_output.numpy()
@@ -199,7 +166,8 @@ def embeddings_batch(req: BatchEmbReq):
199
  @app.get("/")
200
  def root():
201
  return {
202
- "name": "BioBERT-TF Embeddings",
203
  "endpoints": ["/health", "/v1/embeddings", "/v1/embeddings/batch"],
204
- "hint": "POST to /v1/embeddings with {'input': 'your text'}"
 
205
  }
 
1
+ # FastAPI BioBERT embeddings (Hub-first, no TF1 ckpt dependency)
2
+ # Works free on Hugging Face Spaces (CPU). Auto-converts PyTorch -> TF.
 
 
 
 
 
3
 
4
  import os, tarfile, glob, json, shutil, urllib.request
5
  from typing import List, Optional
 
8
  from fastapi.middleware.cors import CORSMiddleware
9
  from pydantic import BaseModel
10
 
11
+ # Load TF before transformers' TF models
12
  import tensorflow as tf # noqa: F401
13
+ from transformers import AutoTokenizer, TFAutoModel, BertConfig
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ # ------------------- Config -------------------
16
+ HF_MODEL_ID = os.environ.get("HF_MODEL_ID", "monologg/biobert_v1.1_pubmed").strip()
17
+ MODEL_ROOT = os.environ.get("MODEL_ROOT", "/app/bert_tf").rstrip("/")
18
+ WEIGHTS_URL = os.environ.get("WEIGHTS_URL_TAR_GZ", "").strip() # optional direct .tar.gz (Dropbox must end with dl=1)
19
+ MAX_LEN = int(os.environ.get("MAX_LEN", "128"))
20
  os.makedirs(MODEL_ROOT, exist_ok=True)
21
 
22
+ # ------------------- Utils --------------------
23
  def _safe_extract_tar_gz(src: str, dest: str) -> None:
24
  with tarfile.open(src, "r:gz") as tar:
25
  def _is_within(directory, target):
26
+ ad = os.path.abspath(directory); at = os.path.abspath(target)
27
+ return os.path.commonpath([ad]) == os.path.commonpath([ad, at])
 
28
  for m in tar.getmembers():
29
  tp = os.path.join(dest, m.name)
30
  if not _is_within(dest, tp):
31
  raise RuntimeError("Blocked path traversal in tar")
32
  tar.extractall(dest)
33
 
34
+ def _maybe_download_tar_into_model_root() -> Optional[str]:
35
+ """If WEIGHTS_URL is set, download + extract it into MODEL_ROOT. Return extracted dir if any."""
36
+ if not WEIGHTS_URL:
37
+ return None
38
+ print("[app] downloading weights:", WEIGHTS_URL)
39
+ local_tar = "/tmp/model.tar.gz"
40
+ urllib.request.urlretrieve(WEIGHTS_URL, local_tar)
41
+ print("[app] extracting:", local_tar, "->", MODEL_ROOT)
42
+ _safe_extract_tar_gz(local_tar, MODEL_ROOT)
43
+ # return shallowest dir inside MODEL_ROOT
44
+ candidates = [d for d in glob.glob(os.path.join(MODEL_ROOT, "*")) if os.path.isdir(d)]
45
+ if not candidates:
46
+ return MODEL_ROOT
47
+ candidates.sort(key=lambda p: len(os.path.relpath(p, MODEL_ROOT).split(os.sep)))
48
+ return candidates[0]
49
+
50
+ def _detect_local_hf_dir(root: str) -> Optional[str]:
51
  """
52
+ Return a directory under root that looks like a modern HF model folder:
53
+ - pytorch_model.bin / model.safetensors (for from_pt=True)
54
+ - OR tf_model.h5 (native TF)
55
  """
56
+ # search at depth 0/1/2
57
+ for depth in range(3):
58
+ pattern = os.path.join(root, *(["**"] if depth else []))
59
+ # prefer TF weights first if present
60
+ tf_h5 = glob.glob(os.path.join(pattern, "tf_model.h5"), recursive=True)
61
+ if tf_h5:
62
+ tf_h5.sort(key=lambda p: len(os.path.relpath(p, root).split(os.sep)))
63
+ return os.path.dirname(tf_h5[0])
64
+
65
+ # else look for PT/safetensors
66
+ pt = glob.glob(os.path.join(pattern, "pytorch_model.bin"), recursive=True)
67
+ st = glob.glob(os.path.join(pattern, "model.safetensors"), recursive=True)
68
+ have = (pt or st)
69
+ if have:
70
+ have.sort(key=lambda p: len(os.path.relpath(p, root).split(os.sep)))
71
+ return os.path.dirname(have[0])
72
+ return None
73
+
74
+ def _looks_like_tf1_ckpt_dir(path: str) -> bool:
75
+ return bool(glob.glob(os.path.join(path, "model.ckpt-*.index")))
76
+
77
+ # ------------------- Load strategy -------------------
78
+ # 1) If a tar URL is provided, unpack it (optional convenience)
79
+ extracted = _maybe_download_tar_into_model_root()
80
+
81
+ # 2) If after extraction we have a local HF-style folder, use it
82
+ LOCAL_DIR = _detect_local_hf_dir(MODEL_ROOT)
83
+
84
+ # 3) If only TF1 ckpt found, refuse with a clear message (no fragile loaders)
85
+ if not LOCAL_DIR:
86
+ # If there is any directory in MODEL_ROOT with TF1 ckpts, warn
87
+ for d in [MODEL_ROOT] + [p for p in glob.glob(os.path.join(MODEL_ROOT, "*")) if os.path.isdir(p)]:
88
+ if _looks_like_tf1_ckpt_dir(d):
89
+ raise RuntimeError(
90
+ "Found TF-1 checkpoint files (model.ckpt-*) but this app purposely avoids "
91
+ "runtime TF-1 → TF-2 weight mapping. Either:\n"
92
+ " Set HF_MODEL_ID to a BioBERT model on the Hub (recommended), e.g. 'monologg/biobert_v1.1_pubmed'\n"
93
+ " Or package modern HF weights (pytorch_model.bin/model.safetensors or tf_model.h5) in your tar."
94
+ )
95
+
96
+ # 4) Tokenizer+Model
97
+ if LOCAL_DIR:
98
+ print(f"[app] Using LOCAL_DIR: {LOCAL_DIR}")
99
+ # Prefer native TF if available, else convert from PT
100
+ if os.path.isfile(os.path.join(LOCAL_DIR, "tf_model.h5")):
101
+ tokenizer = AutoTokenizer.from_pretrained(LOCAL_DIR)
102
+ model = TFAutoModel.from_pretrained(LOCAL_DIR)
103
+ USED = {"source": "local", "format": "tf_h5", "path": LOCAL_DIR}
104
+ else:
105
+ tokenizer = AutoTokenizer.from_pretrained(LOCAL_DIR)
106
+ model = TFAutoModel.from_pretrained(LOCAL_DIR, from_pt=True)
107
+ USED = {"source": "local", "format": "pt/safetensors->tf", "path": LOCAL_DIR}
108
+ else:
109
+ print(f"[app] Using HF_MODEL_ID: {HF_MODEL_ID}")
110
+ tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_ID)
111
+ # Most BioBERT repos are PyTorch; allow auto-conversion
112
+ model = TFAutoModel.from_pretrained(HF_MODEL_ID, from_pt=True)
113
+ USED = {"source": "hub", "model_id": HF_MODEL_ID, "format": "pt->tf"}
114
+
115
+ # ------------------- API -------------------
116
+ app = FastAPI(title="BioBERT Embeddings API (Hub-first)", version="2.0")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  app.add_middleware(
119
  CORSMiddleware,
 
133
 
134
  @app.get("/health")
135
  def health():
136
+ return {"ok": True, "strategy": USED, "max_len_default": MAX_LEN}
137
 
138
  def _embed(texts: List[str], max_len: int) -> List[List[float]]:
139
+ enc = tokenizer(texts, return_tensors="tf", truncation=True, padding=True, max_length=max_len)
140
  out = model(**enc, training=False)
141
  if hasattr(out, "pooler_output") and out.pooler_output is not None:
142
  vecs = out.pooler_output.numpy()
 
166
  @app.get("/")
167
  def root():
168
  return {
169
+ "name": "BioBERT Embeddings (Hub-first)",
170
  "endpoints": ["/health", "/v1/embeddings", "/v1/embeddings/batch"],
171
+ "hint": "POST /v1/embeddings with {'input': 'your text'}",
172
+ "strategy": USED
173
  }