import os from fastapi import Body, FastAPI from fastapi.middleware.cors import CORSMiddleware from contextlib import asynccontextmanager from typing_extensions import Annotated from sentence_transformers import SentenceTransformer from .processors.dialog.dialog_management import process_input_bn, reset_chatbot from .processors.downloader.model_downloader import download from .models.request import Request from .utils.identifier import Identifier from .server.routers import users from .utils.constants import Constants from definitions import ROOT_DIR # @asynccontextmanager # async def lifespan(app: FastAPI): # try: # model = SentenceTransformer(Constants.MODEL_PATH) # except: # download() # model = SentenceTransformer(Constants.MODEL_PATH) # finally: # if model: # yield # reset_chatbot() # else: # # Something needs to be done here # pass # api = FastAPI(lifespan=lifespan) # api.include_router(users) api = FastAPI() origins = [ "http://localhost.tiangolo.com", "https://localhost.tiangolo.com", "http://localhost", "http://localhost:3000", ] api.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) try: model = SentenceTransformer(Constants.MODEL_PATH) except: download() model = SentenceTransformer(Constants.MODEL_PATH) finally: reset_chatbot() @api.get('/') async def root(): return {'data': 'Hello World!'} @api.post('/query') async def query(data: Annotated[Request, Body(embed=True)]): global model query = data.query.lstrip() identifier = Identifier(query) if (identifier.is_bangla): result = process_input_bn(query, model) return {'data': result} else: return {'data': 'I can\'t understand you!'}