Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,5 +1,10 @@
|
|
| 1 |
-
# app.py — FastAPI TF-BioBERT embeddings service (
|
| 2 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
import os, tarfile, glob, json, shutil, urllib.request
|
| 5 |
from typing import List, Optional
|
|
@@ -8,42 +13,42 @@ from fastapi import FastAPI
|
|
| 8 |
from fastapi.middleware.cors import CORSMiddleware
|
| 9 |
from pydantic import BaseModel
|
| 10 |
|
| 11 |
-
# Import
|
| 12 |
import tensorflow as tf # noqa: F401
|
| 13 |
from transformers import BertTokenizer, BertConfig, TFBertModel
|
| 14 |
|
| 15 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
MODEL_ROOT = os.environ.get("MODEL_ROOT", "/app/bert_tf").rstrip("/")
|
| 18 |
-
WEIGHTS_URL = os.environ.get("WEIGHTS_URL_TAR_GZ", "").strip() # direct .tar.gz
|
| 19 |
FALLBACK_VOCAB_URL = "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt"
|
| 20 |
MAX_LEN = int(os.environ.get("MAX_LEN", "128"))
|
| 21 |
|
| 22 |
os.makedirs(MODEL_ROOT, exist_ok=True)
|
| 23 |
|
| 24 |
-
# ---------------------- Utils: safe extract ---------------------
|
| 25 |
-
|
| 26 |
def _safe_extract_tar_gz(src: str, dest: str) -> None:
|
| 27 |
with tarfile.open(src, "r:gz") as tar:
|
| 28 |
def _is_within(directory, target):
|
| 29 |
abs_directory = os.path.abspath(directory)
|
| 30 |
abs_target = os.path.abspath(target)
|
| 31 |
return os.path.commonpath([abs_directory]) == os.path.commonpath([abs_directory, abs_target])
|
| 32 |
-
for
|
| 33 |
-
|
| 34 |
-
if not _is_within(dest,
|
| 35 |
raise RuntimeError("Blocked path traversal in tar")
|
| 36 |
tar.extractall(dest)
|
| 37 |
|
| 38 |
-
# ---------------------- Bootstrap weights ----------------------
|
| 39 |
-
|
| 40 |
def ensure_weights_and_locate() -> (str, str):
|
| 41 |
"""
|
| 42 |
Returns:
|
| 43 |
-
model_dir:
|
| 44 |
ckpt_prefix: full path WITHOUT extension, e.g. /app/bert_tf/bert_min/model.ckpt-150000
|
| 45 |
"""
|
| 46 |
-
# Already present?
|
| 47 |
maybe_idx = glob.glob(os.path.join(MODEL_ROOT, "**", "model.ckpt-*.index"), recursive=True)
|
| 48 |
if not maybe_idx and WEIGHTS_URL:
|
| 49 |
print("[app] downloading weights:", WEIGHTS_URL)
|
|
@@ -56,20 +61,20 @@ def ensure_weights_and_locate() -> (str, str):
|
|
| 56 |
if not maybe_idx:
|
| 57 |
raise RuntimeError(f"No TensorFlow checkpoint *.index found under {MODEL_ROOT}")
|
| 58 |
|
| 59 |
-
# Prefer
|
| 60 |
maybe_idx.sort(key=lambda p: len(os.path.relpath(p, MODEL_ROOT).split(os.sep)))
|
| 61 |
ckpt_index = maybe_idx[0]
|
| 62 |
model_dir = os.path.dirname(ckpt_index)
|
| 63 |
ckpt_prefix = ckpt_index.replace(".index", "")
|
| 64 |
|
| 65 |
-
#
|
| 66 |
basename = os.path.basename(ckpt_prefix)
|
| 67 |
ckpt_meta = os.path.join(model_dir, "checkpoint")
|
| 68 |
if not os.path.isfile(ckpt_meta):
|
| 69 |
with open(ckpt_meta, "w") as f:
|
| 70 |
f.write(f'model_checkpoint_path: "{basename}"\n')
|
| 71 |
|
| 72 |
-
#
|
| 73 |
cfg_json = os.path.join(model_dir, "config.json")
|
| 74 |
bcfg = os.path.join(model_dir, "bert_config.json")
|
| 75 |
if not os.path.isfile(cfg_json):
|
|
@@ -90,13 +95,13 @@ def ensure_weights_and_locate() -> (str, str):
|
|
| 90 |
"vocab_size": 30522
|
| 91 |
}, f)
|
| 92 |
|
| 93 |
-
#
|
| 94 |
vocab = os.path.join(model_dir, "vocab.txt")
|
| 95 |
if not os.path.isfile(vocab):
|
| 96 |
print("[app] vocab.txt missing; fetching BERT base uncased vocab…")
|
| 97 |
urllib.request.urlretrieve(FALLBACK_VOCAB_URL, vocab)
|
| 98 |
|
| 99 |
-
#
|
| 100 |
data_glob = glob.glob(os.path.join(model_dir, "model.ckpt-*.data-00000-of-00001"))
|
| 101 |
if not data_glob:
|
| 102 |
raise RuntimeError(f"Checkpoint data file missing in {model_dir} (model.ckpt-*.data-00000-of-00001)")
|
|
@@ -107,27 +112,45 @@ def ensure_weights_and_locate() -> (str, str):
|
|
| 107 |
|
| 108 |
MODEL_DIR, CKPT_PREFIX = ensure_weights_and_locate()
|
| 109 |
|
| 110 |
-
#
|
| 111 |
-
|
| 112 |
tok = BertTokenizer(vocab_file=os.path.join(MODEL_DIR, "vocab.txt"), do_lower_case=True)
|
| 113 |
cfg = BertConfig.from_json_file(os.path.join(MODEL_DIR, "config.json"))
|
| 114 |
|
| 115 |
-
#
|
| 116 |
-
model = TFBertModel
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
-
# Optional: allow your website to call this API directly
|
| 128 |
app.add_middleware(
|
| 129 |
CORSMiddleware,
|
| 130 |
-
allow_origins=["*"],
|
| 131 |
allow_credentials=False,
|
| 132 |
allow_methods=["GET", "POST", "OPTIONS"],
|
| 133 |
allow_headers=["*"],
|
|
@@ -148,7 +171,6 @@ def health():
|
|
| 148 |
def _embed(texts: List[str], max_len: int) -> List[List[float]]:
|
| 149 |
enc = tok(texts, return_tensors="tf", truncation=True, padding=True, max_length=max_len)
|
| 150 |
out = model(**enc, training=False)
|
| 151 |
-
# Prefer pooled output if available; fallback to mean of last_hidden_state
|
| 152 |
if hasattr(out, "pooler_output") and out.pooler_output is not None:
|
| 153 |
vecs = out.pooler_output.numpy()
|
| 154 |
else:
|
|
|
|
| 1 |
+
# app.py — FastAPI TF-BioBERT embeddings service (TF1 checkpoint loader)
|
| 2 |
+
# Pin these (requirements.txt):
|
| 3 |
+
# fastapi
|
| 4 |
+
# uvicorn[standard]
|
| 5 |
+
# transformers==4.43.4
|
| 6 |
+
# tensorflow-cpu==2.16.1
|
| 7 |
+
# tf-keras
|
| 8 |
|
| 9 |
import os, tarfile, glob, json, shutil, urllib.request
|
| 10 |
from typing import List, Optional
|
|
|
|
| 13 |
from fastapi.middleware.cors import CORSMiddleware
|
| 14 |
from pydantic import BaseModel
|
| 15 |
|
| 16 |
+
# Import TF first
|
| 17 |
import tensorflow as tf # noqa: F401
|
| 18 |
from transformers import BertTokenizer, BertConfig, TFBertModel
|
| 19 |
|
| 20 |
+
# For TF1 checkpoint loading
|
| 21 |
+
try:
|
| 22 |
+
# Present in transformers TF BERT module
|
| 23 |
+
from transformers.models.bert.modeling_tf_bert import load_tf_weights_in_bert as _hf_load_tf_ckpt
|
| 24 |
+
except Exception:
|
| 25 |
+
_hf_load_tf_ckpt = None
|
| 26 |
|
| 27 |
MODEL_ROOT = os.environ.get("MODEL_ROOT", "/app/bert_tf").rstrip("/")
|
| 28 |
+
WEIGHTS_URL = os.environ.get("WEIGHTS_URL_TAR_GZ", "").strip() # direct .tar.gz (Dropbox must end with dl=1)
|
| 29 |
FALLBACK_VOCAB_URL = "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt"
|
| 30 |
MAX_LEN = int(os.environ.get("MAX_LEN", "128"))
|
| 31 |
|
| 32 |
os.makedirs(MODEL_ROOT, exist_ok=True)
|
| 33 |
|
|
|
|
|
|
|
| 34 |
def _safe_extract_tar_gz(src: str, dest: str) -> None:
|
| 35 |
with tarfile.open(src, "r:gz") as tar:
|
| 36 |
def _is_within(directory, target):
|
| 37 |
abs_directory = os.path.abspath(directory)
|
| 38 |
abs_target = os.path.abspath(target)
|
| 39 |
return os.path.commonpath([abs_directory]) == os.path.commonpath([abs_directory, abs_target])
|
| 40 |
+
for m in tar.getmembers():
|
| 41 |
+
tp = os.path.join(dest, m.name)
|
| 42 |
+
if not _is_within(dest, tp):
|
| 43 |
raise RuntimeError("Blocked path traversal in tar")
|
| 44 |
tar.extractall(dest)
|
| 45 |
|
|
|
|
|
|
|
| 46 |
def ensure_weights_and_locate() -> (str, str):
|
| 47 |
"""
|
| 48 |
Returns:
|
| 49 |
+
model_dir: folder containing vocab/config/checkpoint + ckpt files
|
| 50 |
ckpt_prefix: full path WITHOUT extension, e.g. /app/bert_tf/bert_min/model.ckpt-150000
|
| 51 |
"""
|
|
|
|
| 52 |
maybe_idx = glob.glob(os.path.join(MODEL_ROOT, "**", "model.ckpt-*.index"), recursive=True)
|
| 53 |
if not maybe_idx and WEIGHTS_URL:
|
| 54 |
print("[app] downloading weights:", WEIGHTS_URL)
|
|
|
|
| 61 |
if not maybe_idx:
|
| 62 |
raise RuntimeError(f"No TensorFlow checkpoint *.index found under {MODEL_ROOT}")
|
| 63 |
|
| 64 |
+
# Prefer shallowest
|
| 65 |
maybe_idx.sort(key=lambda p: len(os.path.relpath(p, MODEL_ROOT).split(os.sep)))
|
| 66 |
ckpt_index = maybe_idx[0]
|
| 67 |
model_dir = os.path.dirname(ckpt_index)
|
| 68 |
ckpt_prefix = ckpt_index.replace(".index", "")
|
| 69 |
|
| 70 |
+
# checkpoint meta
|
| 71 |
basename = os.path.basename(ckpt_prefix)
|
| 72 |
ckpt_meta = os.path.join(model_dir, "checkpoint")
|
| 73 |
if not os.path.isfile(ckpt_meta):
|
| 74 |
with open(ckpt_meta, "w") as f:
|
| 75 |
f.write(f'model_checkpoint_path: "{basename}"\n')
|
| 76 |
|
| 77 |
+
# config.json (copy bert_config.json if present)
|
| 78 |
cfg_json = os.path.join(model_dir, "config.json")
|
| 79 |
bcfg = os.path.join(model_dir, "bert_config.json")
|
| 80 |
if not os.path.isfile(cfg_json):
|
|
|
|
| 95 |
"vocab_size": 30522
|
| 96 |
}, f)
|
| 97 |
|
| 98 |
+
# vocab.txt (BioBERT uses BERT base uncased vocab)
|
| 99 |
vocab = os.path.join(model_dir, "vocab.txt")
|
| 100 |
if not os.path.isfile(vocab):
|
| 101 |
print("[app] vocab.txt missing; fetching BERT base uncased vocab…")
|
| 102 |
urllib.request.urlretrieve(FALLBACK_VOCAB_URL, vocab)
|
| 103 |
|
| 104 |
+
# data shard sanity
|
| 105 |
data_glob = glob.glob(os.path.join(model_dir, "model.ckpt-*.data-00000-of-00001"))
|
| 106 |
if not data_glob:
|
| 107 |
raise RuntimeError(f"Checkpoint data file missing in {model_dir} (model.ckpt-*.data-00000-of-00001)")
|
|
|
|
| 112 |
|
| 113 |
MODEL_DIR, CKPT_PREFIX = ensure_weights_and_locate()
|
| 114 |
|
| 115 |
+
# Tokenizer + config
|
|
|
|
| 116 |
tok = BertTokenizer(vocab_file=os.path.join(MODEL_DIR, "vocab.txt"), do_lower_case=True)
|
| 117 |
cfg = BertConfig.from_json_file(os.path.join(MODEL_DIR, "config.json"))
|
| 118 |
|
| 119 |
+
# Build model skeleton
|
| 120 |
+
model = TFBertModel(cfg)
|
| 121 |
+
|
| 122 |
+
# Load TF1 checkpoint (no from_tf kwarg!)
|
| 123 |
+
loaded = False
|
| 124 |
+
err_stack = []
|
| 125 |
+
|
| 126 |
+
if _hf_load_tf_ckpt is not None:
|
| 127 |
+
try:
|
| 128 |
+
# Some transformer versions: (model, ckpt_prefix)
|
| 129 |
+
_hf_load_tf_ckpt(model, CKPT_PREFIX)
|
| 130 |
+
loaded = True
|
| 131 |
+
print("[app] Loaded TF1 checkpoint via load_tf_weights_in_bert(model, ckpt_prefix)")
|
| 132 |
+
except TypeError as e1:
|
| 133 |
+
err_stack.append(str(e1))
|
| 134 |
+
try:
|
| 135 |
+
# Other versions: (model, config, ckpt_prefix)
|
| 136 |
+
_hf_load_tf_ckpt(model, cfg, CKPT_PREFIX)
|
| 137 |
+
loaded = True
|
| 138 |
+
print("[app] Loaded TF1 checkpoint via load_tf_weights_in_bert(model, config, ckpt_prefix)")
|
| 139 |
+
except Exception as e2:
|
| 140 |
+
err_stack.append(str(e2))
|
| 141 |
+
|
| 142 |
+
if not loaded:
|
| 143 |
+
raise RuntimeError(
|
| 144 |
+
"Could not load TF1 checkpoint with transformers' loader. "
|
| 145 |
+
f"ckpt={CKPT_PREFIX}\nErrors: {err_stack or 'no loader available'}"
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
# ---------- API ----------
|
| 149 |
+
app = FastAPI(title="BioBERT-TF Embeddings API", version="1.1")
|
| 150 |
|
|
|
|
| 151 |
app.add_middleware(
|
| 152 |
CORSMiddleware,
|
| 153 |
+
allow_origins=["*"],
|
| 154 |
allow_credentials=False,
|
| 155 |
allow_methods=["GET", "POST", "OPTIONS"],
|
| 156 |
allow_headers=["*"],
|
|
|
|
| 171 |
def _embed(texts: List[str], max_len: int) -> List[List[float]]:
|
| 172 |
enc = tok(texts, return_tensors="tf", truncation=True, padding=True, max_length=max_len)
|
| 173 |
out = model(**enc, training=False)
|
|
|
|
| 174 |
if hasattr(out, "pooler_output") and out.pooler_output is not None:
|
| 175 |
vecs = out.pooler_output.numpy()
|
| 176 |
else:
|