felixbet commited on
Commit
7bec0b4
·
verified ·
1 Parent(s): 9b01576

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -35
app.py CHANGED
@@ -1,5 +1,10 @@
1
- # app.py — FastAPI TF-BioBERT embeddings service (handles TF1 checkpoints)
2
- # Requires: transformers==4.43.4, tensorflow-cpu==2.16.1, tf-keras, fastapi, uvicorn[standard]
 
 
 
 
 
3
 
4
  import os, tarfile, glob, json, shutil, urllib.request
5
  from typing import List, Optional
@@ -8,42 +13,42 @@ from fastapi import FastAPI
8
  from fastapi.middleware.cors import CORSMiddleware
9
  from pydantic import BaseModel
10
 
11
- # Import TensorFlow before Transformers TF models to avoid odd init order issues
12
  import tensorflow as tf # noqa: F401
13
  from transformers import BertTokenizer, BertConfig, TFBertModel
14
 
15
- # ---------------------------- Config ----------------------------
 
 
 
 
 
16
 
17
  MODEL_ROOT = os.environ.get("MODEL_ROOT", "/app/bert_tf").rstrip("/")
18
- WEIGHTS_URL = os.environ.get("WEIGHTS_URL_TAR_GZ", "").strip() # direct .tar.gz link (Dropbox must end with dl=1)
19
  FALLBACK_VOCAB_URL = "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt"
20
  MAX_LEN = int(os.environ.get("MAX_LEN", "128"))
21
 
22
  os.makedirs(MODEL_ROOT, exist_ok=True)
23
 
24
- # ---------------------- Utils: safe extract ---------------------
25
-
26
  def _safe_extract_tar_gz(src: str, dest: str) -> None:
27
  with tarfile.open(src, "r:gz") as tar:
28
  def _is_within(directory, target):
29
  abs_directory = os.path.abspath(directory)
30
  abs_target = os.path.abspath(target)
31
  return os.path.commonpath([abs_directory]) == os.path.commonpath([abs_directory, abs_target])
32
- for member in tar.getmembers():
33
- target_path = os.path.join(dest, member.name)
34
- if not _is_within(dest, target_path):
35
  raise RuntimeError("Blocked path traversal in tar")
36
  tar.extractall(dest)
37
 
38
- # ---------------------- Bootstrap weights ----------------------
39
-
40
  def ensure_weights_and_locate() -> (str, str):
41
  """
42
  Returns:
43
- model_dir: directory containing vocab.txt/config.json/checkpoint + ckpt files
44
  ckpt_prefix: full path WITHOUT extension, e.g. /app/bert_tf/bert_min/model.ckpt-150000
45
  """
46
- # Already present?
47
  maybe_idx = glob.glob(os.path.join(MODEL_ROOT, "**", "model.ckpt-*.index"), recursive=True)
48
  if not maybe_idx and WEIGHTS_URL:
49
  print("[app] downloading weights:", WEIGHTS_URL)
@@ -56,20 +61,20 @@ def ensure_weights_and_locate() -> (str, str):
56
  if not maybe_idx:
57
  raise RuntimeError(f"No TensorFlow checkpoint *.index found under {MODEL_ROOT}")
58
 
59
- # Prefer shortest path depth (avoids weird nested dirs)
60
  maybe_idx.sort(key=lambda p: len(os.path.relpath(p, MODEL_ROOT).split(os.sep)))
61
  ckpt_index = maybe_idx[0]
62
  model_dir = os.path.dirname(ckpt_index)
63
  ckpt_prefix = ckpt_index.replace(".index", "")
64
 
65
- # Ensure checkpoint meta file points to the basename
66
  basename = os.path.basename(ckpt_prefix)
67
  ckpt_meta = os.path.join(model_dir, "checkpoint")
68
  if not os.path.isfile(ckpt_meta):
69
  with open(ckpt_meta, "w") as f:
70
  f.write(f'model_checkpoint_path: "{basename}"\n')
71
 
72
- # Ensure config.json (copy from bert_config.json if present, else write default BERT base config)
73
  cfg_json = os.path.join(model_dir, "config.json")
74
  bcfg = os.path.join(model_dir, "bert_config.json")
75
  if not os.path.isfile(cfg_json):
@@ -90,13 +95,13 @@ def ensure_weights_and_locate() -> (str, str):
90
  "vocab_size": 30522
91
  }, f)
92
 
93
- # Ensure vocab.txt (BioBERT uses BERT base uncased vocab)
94
  vocab = os.path.join(model_dir, "vocab.txt")
95
  if not os.path.isfile(vocab):
96
  print("[app] vocab.txt missing; fetching BERT base uncased vocab…")
97
  urllib.request.urlretrieve(FALLBACK_VOCAB_URL, vocab)
98
 
99
- # Sanity: ensure data shard exists
100
  data_glob = glob.glob(os.path.join(model_dir, "model.ckpt-*.data-00000-of-00001"))
101
  if not data_glob:
102
  raise RuntimeError(f"Checkpoint data file missing in {model_dir} (model.ckpt-*.data-00000-of-00001)")
@@ -107,27 +112,45 @@ def ensure_weights_and_locate() -> (str, str):
107
 
108
  MODEL_DIR, CKPT_PREFIX = ensure_weights_and_locate()
109
 
110
- # ---------------------- Load tokenizer & model ------------------
111
-
112
  tok = BertTokenizer(vocab_file=os.path.join(MODEL_DIR, "vocab.txt"), do_lower_case=True)
113
  cfg = BertConfig.from_json_file(os.path.join(MODEL_DIR, "config.json"))
114
 
115
- # IMPORTANT: load from TF1 checkpoint using the PREFIX (not folder)
116
- model = TFBertModel.from_pretrained(
117
- CKPT_PREFIX,
118
- from_tf=True, # TF1 .ckpt import
119
- from_pt=False,
120
- config=cfg
121
- )
122
-
123
- # ---------------------------- API ------------------------------
124
-
125
- app = FastAPI(title="BioBERT-TF Embeddings API", version="1.0")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
- # Optional: allow your website to call this API directly
128
  app.add_middleware(
129
  CORSMiddleware,
130
- allow_origins=["*"], # tighten in production
131
  allow_credentials=False,
132
  allow_methods=["GET", "POST", "OPTIONS"],
133
  allow_headers=["*"],
@@ -148,7 +171,6 @@ def health():
148
  def _embed(texts: List[str], max_len: int) -> List[List[float]]:
149
  enc = tok(texts, return_tensors="tf", truncation=True, padding=True, max_length=max_len)
150
  out = model(**enc, training=False)
151
- # Prefer pooled output if available; fallback to mean of last_hidden_state
152
  if hasattr(out, "pooler_output") and out.pooler_output is not None:
153
  vecs = out.pooler_output.numpy()
154
  else:
 
1
+ # app.py — FastAPI TF-BioBERT embeddings service (TF1 checkpoint loader)
2
+ # Pin these (requirements.txt):
3
+ # fastapi
4
+ # uvicorn[standard]
5
+ # transformers==4.43.4
6
+ # tensorflow-cpu==2.16.1
7
+ # tf-keras
8
 
9
  import os, tarfile, glob, json, shutil, urllib.request
10
  from typing import List, Optional
 
13
  from fastapi.middleware.cors import CORSMiddleware
14
  from pydantic import BaseModel
15
 
16
+ # Import TF first
17
  import tensorflow as tf # noqa: F401
18
  from transformers import BertTokenizer, BertConfig, TFBertModel
19
 
20
+ # For TF1 checkpoint loading
21
+ try:
22
+ # Present in transformers TF BERT module
23
+ from transformers.models.bert.modeling_tf_bert import load_tf_weights_in_bert as _hf_load_tf_ckpt
24
+ except Exception:
25
+ _hf_load_tf_ckpt = None
26
 
27
  MODEL_ROOT = os.environ.get("MODEL_ROOT", "/app/bert_tf").rstrip("/")
28
+ WEIGHTS_URL = os.environ.get("WEIGHTS_URL_TAR_GZ", "").strip() # direct .tar.gz (Dropbox must end with dl=1)
29
  FALLBACK_VOCAB_URL = "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt"
30
  MAX_LEN = int(os.environ.get("MAX_LEN", "128"))
31
 
32
  os.makedirs(MODEL_ROOT, exist_ok=True)
33
 
 
 
34
  def _safe_extract_tar_gz(src: str, dest: str) -> None:
35
  with tarfile.open(src, "r:gz") as tar:
36
  def _is_within(directory, target):
37
  abs_directory = os.path.abspath(directory)
38
  abs_target = os.path.abspath(target)
39
  return os.path.commonpath([abs_directory]) == os.path.commonpath([abs_directory, abs_target])
40
+ for m in tar.getmembers():
41
+ tp = os.path.join(dest, m.name)
42
+ if not _is_within(dest, tp):
43
  raise RuntimeError("Blocked path traversal in tar")
44
  tar.extractall(dest)
45
 
 
 
46
  def ensure_weights_and_locate() -> (str, str):
47
  """
48
  Returns:
49
+ model_dir: folder containing vocab/config/checkpoint + ckpt files
50
  ckpt_prefix: full path WITHOUT extension, e.g. /app/bert_tf/bert_min/model.ckpt-150000
51
  """
 
52
  maybe_idx = glob.glob(os.path.join(MODEL_ROOT, "**", "model.ckpt-*.index"), recursive=True)
53
  if not maybe_idx and WEIGHTS_URL:
54
  print("[app] downloading weights:", WEIGHTS_URL)
 
61
  if not maybe_idx:
62
  raise RuntimeError(f"No TensorFlow checkpoint *.index found under {MODEL_ROOT}")
63
 
64
+ # Prefer shallowest
65
  maybe_idx.sort(key=lambda p: len(os.path.relpath(p, MODEL_ROOT).split(os.sep)))
66
  ckpt_index = maybe_idx[0]
67
  model_dir = os.path.dirname(ckpt_index)
68
  ckpt_prefix = ckpt_index.replace(".index", "")
69
 
70
+ # checkpoint meta
71
  basename = os.path.basename(ckpt_prefix)
72
  ckpt_meta = os.path.join(model_dir, "checkpoint")
73
  if not os.path.isfile(ckpt_meta):
74
  with open(ckpt_meta, "w") as f:
75
  f.write(f'model_checkpoint_path: "{basename}"\n')
76
 
77
+ # config.json (copy bert_config.json if present)
78
  cfg_json = os.path.join(model_dir, "config.json")
79
  bcfg = os.path.join(model_dir, "bert_config.json")
80
  if not os.path.isfile(cfg_json):
 
95
  "vocab_size": 30522
96
  }, f)
97
 
98
+ # vocab.txt (BioBERT uses BERT base uncased vocab)
99
  vocab = os.path.join(model_dir, "vocab.txt")
100
  if not os.path.isfile(vocab):
101
  print("[app] vocab.txt missing; fetching BERT base uncased vocab…")
102
  urllib.request.urlretrieve(FALLBACK_VOCAB_URL, vocab)
103
 
104
+ # data shard sanity
105
  data_glob = glob.glob(os.path.join(model_dir, "model.ckpt-*.data-00000-of-00001"))
106
  if not data_glob:
107
  raise RuntimeError(f"Checkpoint data file missing in {model_dir} (model.ckpt-*.data-00000-of-00001)")
 
112
 
113
  MODEL_DIR, CKPT_PREFIX = ensure_weights_and_locate()
114
 
115
+ # Tokenizer + config
 
116
  tok = BertTokenizer(vocab_file=os.path.join(MODEL_DIR, "vocab.txt"), do_lower_case=True)
117
  cfg = BertConfig.from_json_file(os.path.join(MODEL_DIR, "config.json"))
118
 
119
+ # Build model skeleton
120
+ model = TFBertModel(cfg)
121
+
122
+ # Load TF1 checkpoint (no from_tf kwarg!)
123
+ loaded = False
124
+ err_stack = []
125
+
126
+ if _hf_load_tf_ckpt is not None:
127
+ try:
128
+ # Some transformer versions: (model, ckpt_prefix)
129
+ _hf_load_tf_ckpt(model, CKPT_PREFIX)
130
+ loaded = True
131
+ print("[app] Loaded TF1 checkpoint via load_tf_weights_in_bert(model, ckpt_prefix)")
132
+ except TypeError as e1:
133
+ err_stack.append(str(e1))
134
+ try:
135
+ # Other versions: (model, config, ckpt_prefix)
136
+ _hf_load_tf_ckpt(model, cfg, CKPT_PREFIX)
137
+ loaded = True
138
+ print("[app] Loaded TF1 checkpoint via load_tf_weights_in_bert(model, config, ckpt_prefix)")
139
+ except Exception as e2:
140
+ err_stack.append(str(e2))
141
+
142
+ if not loaded:
143
+ raise RuntimeError(
144
+ "Could not load TF1 checkpoint with transformers' loader. "
145
+ f"ckpt={CKPT_PREFIX}\nErrors: {err_stack or 'no loader available'}"
146
+ )
147
+
148
+ # ---------- API ----------
149
+ app = FastAPI(title="BioBERT-TF Embeddings API", version="1.1")
150
 
 
151
  app.add_middleware(
152
  CORSMiddleware,
153
+ allow_origins=["*"],
154
  allow_credentials=False,
155
  allow_methods=["GET", "POST", "OPTIONS"],
156
  allow_headers=["*"],
 
171
  def _embed(texts: List[str], max_len: int) -> List[List[float]]:
172
  enc = tok(texts, return_tensors="tf", truncation=True, padding=True, max_length=max_len)
173
  out = model(**enc, training=False)
 
174
  if hasattr(out, "pooler_output") and out.pooler_output is not None:
175
  vecs = out.pooler_output.numpy()
176
  else: