biobert-emb / app.py
felixbet's picture
Update app.py
4be4ef1 verified
raw
history blame
1.31 kB
import os, tensorflow as tf
from fastapi import FastAPI
from pydantic import BaseModel
from typing import Any, List
from transformers import BertTokenizer, BertConfig, TFBertModel
MODEL_DIR = os.environ.get("MODEL_DIR", "/app/bert_tf")
PORT = int(os.environ.get("PORT", "7860"))
tok = BertTokenizer(vocab_file=f"{MODEL_DIR}/vocab.txt", do_lower_case=True)
cfg = BertConfig.from_json_file(f"{MODEL_DIR}/config.json")
model= TFBertModel.from_pretrained(MODEL_DIR, from_tf=True, config=cfg)
def encode(texts: List[str]):
ins = tok(texts, padding=True, truncation=True, return_tensors="tf", max_length=512)
outs = model(ins)[0]
mask = tf.cast(tf.expand_dims(ins["attention_mask"], -1), tf.float32)
mean = tf.reduce_sum(outs*mask, axis=1) / tf.maximum(tf.reduce_sum(mask, axis=1), 1.0)
return tf.linalg.l2_normalize(mean, axis=1).numpy().tolist()
_ = encode(["warmup"])
app = FastAPI()
class EmbReq(BaseModel):
input: Any
@app.get("/health")
def health():
return {"ok": True}
@app.post("/v1/embeddings")
def embeddings(req: EmbReq):
texts = req.input if isinstance(req.input, list) else [req.input]
vecs = encode(texts)
return {"object":"list","model":"biobert-tf-emb",
"data":[{"object":"embedding","index":i,"embedding":v} for i,v in enumerate(vecs)]}