felixbet commited on
Commit
4be4ef1
·
verified ·
1 Parent(s): 0835264

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -11
app.py CHANGED
@@ -5,22 +5,20 @@ from typing import Any, List
5
  from transformers import BertTokenizer, BertConfig, TFBertModel
6
 
7
  MODEL_DIR = os.environ.get("MODEL_DIR", "/app/bert_tf")
8
- PORT = int(os.environ.get("PORT", "7860")) # HF Spaces sets PORT=7860
9
 
10
- # --- Load BioBERT (TF checkpoint converted by HF loader) ---
11
  tok = BertTokenizer(vocab_file=f"{MODEL_DIR}/vocab.txt", do_lower_case=True)
12
  cfg = BertConfig.from_json_file(f"{MODEL_DIR}/config.json")
13
  model= TFBertModel.from_pretrained(MODEL_DIR, from_tf=True, config=cfg)
14
 
15
  def encode(texts: List[str]):
16
  ins = tok(texts, padding=True, truncation=True, return_tensors="tf", max_length=512)
17
- outs = model(ins)[0] # [batch, seq, hidden]
18
  mask = tf.cast(tf.expand_dims(ins["attention_mask"], -1), tf.float32)
19
- mean = tf.reduce_sum(outs * mask, axis=1) / tf.maximum(tf.reduce_sum(mask, axis=1), 1.0)
20
  return tf.linalg.l2_normalize(mean, axis=1).numpy().tolist()
21
 
22
- # Warmup to reduce first request latency (still a cold-start, but faster)
23
- _ = encode(["warmup biobert embeddings"])
24
 
25
  app = FastAPI()
26
 
@@ -35,8 +33,5 @@ def health():
35
  def embeddings(req: EmbReq):
36
  texts = req.input if isinstance(req.input, list) else [req.input]
37
  vecs = encode(texts)
38
- return {
39
- "object": "list",
40
- "model" : "biobert-tf-emb",
41
- "data" : [{"object":"embedding","index":i,"embedding":v} for i, v in enumerate(vecs)]
42
- }
 
5
  from transformers import BertTokenizer, BertConfig, TFBertModel
6
 
7
  MODEL_DIR = os.environ.get("MODEL_DIR", "/app/bert_tf")
8
+ PORT = int(os.environ.get("PORT", "7860"))
9
 
 
10
  tok = BertTokenizer(vocab_file=f"{MODEL_DIR}/vocab.txt", do_lower_case=True)
11
  cfg = BertConfig.from_json_file(f"{MODEL_DIR}/config.json")
12
  model= TFBertModel.from_pretrained(MODEL_DIR, from_tf=True, config=cfg)
13
 
14
  def encode(texts: List[str]):
15
  ins = tok(texts, padding=True, truncation=True, return_tensors="tf", max_length=512)
16
+ outs = model(ins)[0]
17
  mask = tf.cast(tf.expand_dims(ins["attention_mask"], -1), tf.float32)
18
+ mean = tf.reduce_sum(outs*mask, axis=1) / tf.maximum(tf.reduce_sum(mask, axis=1), 1.0)
19
  return tf.linalg.l2_normalize(mean, axis=1).numpy().tolist()
20
 
21
+ _ = encode(["warmup"])
 
22
 
23
  app = FastAPI()
24
 
 
33
  def embeddings(req: EmbReq):
34
  texts = req.input if isinstance(req.input, list) else [req.input]
35
  vecs = encode(texts)
36
+ return {"object":"list","model":"biobert-tf-emb",
37
+ "data":[{"object":"embedding","index":i,"embedding":v} for i,v in enumerate(vecs)]}