File size: 1,896 Bytes
ae2ef1b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
from fastapi import Body, FastAPI
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from typing_extensions import Annotated
from sentence_transformers import SentenceTransformer
from .processors.dialog.dialog_management import process_input_bn, reset_chatbot
from .processors.downloader.model_downloader import download
from .models.request import Request
from .utils.identifier import Identifier
from .server.routers import users
from .utils.constants import Constants
from definitions import ROOT_DIR

# @asynccontextmanager
# async def lifespan(app: FastAPI):
#     try:
#         model = SentenceTransformer(Constants.MODEL_PATH)
#     except:
#         download()
#         model = SentenceTransformer(Constants.MODEL_PATH)
#     finally:
#         if model:
#             yield
#             reset_chatbot()
#         else:
#             # Something needs to be done here
#             pass

# api = FastAPI(lifespan=lifespan)
# api.include_router(users)

api = FastAPI()

origins = [
    "http://localhost.tiangolo.com",
    "https://localhost.tiangolo.com",
    "http://localhost",
    "http://localhost:3000",
]

api.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

try:
    model = SentenceTransformer(Constants.MODEL_PATH)
except:
    download()
    model = SentenceTransformer(Constants.MODEL_PATH)
finally:
    reset_chatbot()

@api.get('/')
async def root():
    return {'data': 'Hello World!'}

@api.post('/query')
async def query(data: Annotated[Request, Body(embed=True)]):
    global model
    query = data.query.lstrip()
    identifier = Identifier(query)
    if (identifier.is_bangla):
        result = process_input_bn(query, model)
        return {'data': result}
    else:
        return {'data': 'I can\'t understand you!'}