ml-server / src /main.py
Partha11's picture
added project files from private repo
ae2ef1b
import os
from fastapi import Body, FastAPI
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from typing_extensions import Annotated
from sentence_transformers import SentenceTransformer
from .processors.dialog.dialog_management import process_input_bn, reset_chatbot
from .processors.downloader.model_downloader import download
from .models.request import Request
from .utils.identifier import Identifier
from .server.routers import users
from .utils.constants import Constants
from definitions import ROOT_DIR
# @asynccontextmanager
# async def lifespan(app: FastAPI):
# try:
# model = SentenceTransformer(Constants.MODEL_PATH)
# except:
# download()
# model = SentenceTransformer(Constants.MODEL_PATH)
# finally:
# if model:
# yield
# reset_chatbot()
# else:
# # Something needs to be done here
# pass
# api = FastAPI(lifespan=lifespan)
# api.include_router(users)
api = FastAPI()
origins = [
"http://localhost.tiangolo.com",
"https://localhost.tiangolo.com",
"http://localhost",
"http://localhost:3000",
]
api.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
try:
model = SentenceTransformer(Constants.MODEL_PATH)
except:
download()
model = SentenceTransformer(Constants.MODEL_PATH)
finally:
reset_chatbot()
@api.get('/')
async def root():
return {'data': 'Hello World!'}
@api.post('/query')
async def query(data: Annotated[Request, Body(embed=True)]):
global model
query = data.query.lstrip()
identifier = Identifier(query)
if (identifier.is_bangla):
result = process_input_bn(query, model)
return {'data': result}
else:
return {'data': 'I can\'t understand you!'}