felixbet commited on
Commit
8a7967b
·
verified ·
1 Parent(s): a2d010b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -26
app.py CHANGED
@@ -4,34 +4,16 @@ from pydantic import BaseModel
4
  from typing import Any, List
5
  from transformers import BertTokenizer, BertConfig, TFBertModel
6
 
 
7
  MODEL_DIR = os.environ.get("MODEL_DIR", "/app/bert_tf")
8
- PORT = int(os.environ.get("PORT", "7860"))
 
 
 
 
 
 
9
 
10
  tok = BertTokenizer(vocab_file=f"{MODEL_DIR}/vocab.txt", do_lower_case=True)
11
  cfg = BertConfig.from_json_file(f"{MODEL_DIR}/config.json")
12
  model= TFBertModel.from_pretrained(MODEL_DIR, from_tf=True, config=cfg)
13
-
14
- def encode(texts: List[str]):
15
- ins = tok(texts, padding=True, truncation=True, return_tensors="tf", max_length=512)
16
- outs = model(ins)[0]
17
- mask = tf.cast(tf.expand_dims(ins["attention_mask"], -1), tf.float32)
18
- mean = tf.reduce_sum(outs*mask, axis=1) / tf.maximum(tf.reduce_sum(mask, axis=1), 1.0)
19
- return tf.linalg.l2_normalize(mean, axis=1).numpy().tolist()
20
-
21
- _ = encode(["warmup"])
22
-
23
- app = FastAPI()
24
-
25
- class EmbReq(BaseModel):
26
- input: Any
27
-
28
- @app.get("/health")
29
- def health():
30
- return {"ok": True}
31
-
32
- @app.post("/v1/embeddings")
33
- def embeddings(req: EmbReq):
34
- texts = req.input if isinstance(req.input, list) else [req.input]
35
- vecs = encode(texts)
36
- return {"object":"list","model":"biobert-tf-emb",
37
- "data":[{"object":"embedding","index":i,"embedding":v} for i,v in enumerate(vecs)]}
 
4
  from typing import Any, List
5
  from transformers import BertTokenizer, BertConfig, TFBertModel
6
 
7
+ # start.sh exports MODEL_DIR after normalization
8
  MODEL_DIR = os.environ.get("MODEL_DIR", "/app/bert_tf")
9
+
10
+ # Fallback: if still wrong, probe one level deeper
11
+ if not os.path.isfile(os.path.join(MODEL_DIR, "vocab.txt")):
12
+ for d in [MODEL_DIR] + [os.path.join(MODEL_DIR, x) for x in os.listdir(MODEL_DIR) if os.path.isdir(os.path.join(MODEL_DIR, x))]:
13
+ if os.path.isfile(os.path.join(d, "vocab.txt")) and os.path.isfile(os.path.join(d, "config.json")):
14
+ MODEL_DIR = d
15
+ break
16
 
17
  tok = BertTokenizer(vocab_file=f"{MODEL_DIR}/vocab.txt", do_lower_case=True)
18
  cfg = BertConfig.from_json_file(f"{MODEL_DIR}/config.json")
19
  model= TFBertModel.from_pretrained(MODEL_DIR, from_tf=True, config=cfg)