embed_api / app.py
sam2ai's picture
Synced repo using 'sync_with_huggingface' Github Action
0b80183
raw
history blame
4.28 kB
import argparse
import asyncio
import functools
import json
import os
from io import BytesIO
import uvicorn
from fastapi import FastAPI, Body, Request
# from fastapi.responses import StreamingResponse
# from starlette.staticfiles import StaticFiles
# from starlette.templating import Jinja2Templates
from utils.utils import add_arguments, print_arguments
from sentence_transformers import SentenceTransformer, models
# def print_arguments(args):
# print("----------- Configuration Arguments -----------")
# for arg, value in vars(args).items():
# print("%s: %s" % (arg, value))
# print("------------------------------------------------")
# def strtobool(val):
# val = val.lower()
# if val in ('y', 'yes', 't', 'true', 'on', '1'):
# return True
# elif val in ('n', 'no', 'f', 'false', 'off', '0'):
# return False
# else:
# raise ValueError("invalid truth value %r" % (val,))
# def str_none(val):
# if val == 'None':
# return None
# else:
# return val
# def add_arguments(argname, type, default, help, argparser, **kwargs):
# type = strtobool if type == bool else type
# type = str_none if type == str else type
# argparser.add_argument(
# "--" + argname,
# default=default,
# type=type,
# help=help + ' Default: %(default)s.',
# **kwargs
# )
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg("host", type=str, default="0.0.0.0", help="")
add_arg("port", type=int, default=5000, help="")
add_arg("model_path", type=str, default="BAAI/bge-small-en-v1.5", help="")
add_arg("use_gpu", type=bool, default=False, help="")
# add_arg("use_int8", type=bool, default=True, help="")
add_arg("beam_size", type=int, default=10, help="")
add_arg("num_workers", type=int, default=2, help="")
add_arg("vad_filter", type=bool, default=True, help="")
add_arg("local_files_only", type=bool, default=True, help="")
args = parser.parse_args()
print_arguments(args)
if args.use_gpu:
bge_model = SentenceTransformer(args.model_path, device="cuda", compute_type="float16", cache_folder=".")
else:
bge_model = SentenceTransformer(args.model_path, device='cpu', cache_folder=".")
if args.use_gpu:
model_name = 'sam2ai/sbert-tsdae'
word_embedding_model = models.Transformer(model_name)
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), 'cls')
tsdae_model = SentenceTransformer(
modules=[word_embedding_model, pooling_model],
device="cuda",
compute_type="float16",
cache_folder="."
)
else:
model_name = 'sam2ai/sbert-tsdae'
word_embedding_model = models.Transformer(model_name)
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), 'cls')
tsdae_model = SentenceTransformer(
modules=[word_embedding_model, pooling_model],
device='cpu',
cache_folder="."
)
app = FastAPI(title="embedding Inference")
def similarity_score(model, textA, textB):
em_test = model.encode(
[textA, textB],
normalize_embeddings=True
)
return em_test[0] @ em_test[1].T
@app.get("/")
async def index(request: Request):
return {"detail": "API is Active !!"}
@app.post("/bge_embed")
async def api_bge_embed(
text1: str = Body("text1", description="", embed=True),
text2: str = Body("text2", description="", embed=True),
):
scores = similarity_score(bge_model, text1, text2)
print(scores)
scores = scores.tolist()
ret = {"similarity score": scores, "status_code": 200}
return ret
@app.post("/tsdae_embed")
async def api_tsdae_embed(
text1: str = Body("text1", description="", embed=True),
text2: str = Body("text2", description="", embed=True),
):
scores = similarity_score(tsdae_model, text1, text2)
print(scores)
scores = scores.tolist()
ret = {"similarity score": scores, "status_code": 200}
return ret
if __name__ == '__main__':
uvicorn.run(app, host=args.host, port=args.port)