|
|
import asyncio |
|
|
import uuid |
|
|
|
|
|
import uvloop |
|
|
from dotenv import load_dotenv |
|
|
from fastapi import FastAPI, HTTPException, Request |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from sse_starlette.sse import EventSourceResponse, ServerSentEvent |
|
|
from starlette.responses import JSONResponse, Response |
|
|
|
|
|
from .activity_helpers import DONE |
|
|
from .complete_chat import complete_chat_router |
|
|
from .predict import predict_router |
|
|
|
|
|
try: |
|
|
import infiagent |
|
|
from infiagent.schemas import FailedResponseBaseData |
|
|
from infiagent.utils import get_logger, init_logging, log_id_var |
|
|
except ImportError: |
|
|
print("import infiagent failed, please install infiagent by 'pip install .' in the pipeline directory of ADA-Agent") |
|
|
from ..schemas import FailedResponseBaseData |
|
|
from ..utils import get_logger, init_logging, log_id_var |
|
|
|
|
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) |
|
|
|
|
|
SSE_API_PATHS = ["/complete_sse"] |
|
|
LOG_ID_HEADER_NAME = "X-Tt-Logid" |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
init_logging() |
|
|
logger = get_logger() |
|
|
|
|
|
app = FastAPI() |
|
|
origins = ["*"] |
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=origins, |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
app.include_router(complete_chat_router) |
|
|
app.include_router(predict_router) |
|
|
|
|
|
|
|
|
@app.middleware("http") |
|
|
async def log_id_middleware(request: Request, call_next): |
|
|
|
|
|
log_id = request.headers.get(LOG_ID_HEADER_NAME) |
|
|
if not log_id: |
|
|
|
|
|
log_id = str(uuid.uuid4()) |
|
|
|
|
|
log_id_var.set(log_id) |
|
|
|
|
|
response: Response = await call_next(request) |
|
|
response.headers[LOG_ID_HEADER_NAME] = log_id_var.get() |
|
|
return response |
|
|
|
|
|
|
|
|
@app.exception_handler(Exception) |
|
|
async def general_exception_handler(request, exc): |
|
|
error_msg = "Failed to handle request. Internal Server error: {}".format(str(exc)) |
|
|
logger.error(error_msg, exc_info=True) |
|
|
|
|
|
if request.url.path in SSE_API_PATHS: |
|
|
return EventSourceResponse(ServerSentEvent(data=DONE)) |
|
|
else: |
|
|
return JSONResponse( |
|
|
status_code=500, |
|
|
content={ |
|
|
"response": error_msg, |
|
|
"ResponseBase": FailedResponseBaseData().dict() |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
@app.exception_handler(HTTPException) |
|
|
async def http_exception_handler(request, exc): |
|
|
error_msg = "Failed to handle request. Error: {}".format(exc.detail) |
|
|
logger.error(error_msg, exc_info=True) |
|
|
|
|
|
if request.url.path in SSE_API_PATHS: |
|
|
return EventSourceResponse(ServerSentEvent(data=DONE)) |
|
|
else: |
|
|
return JSONResponse( |
|
|
status_code=exc.status_code, |
|
|
content={ |
|
|
"response": error_msg, |
|
|
"ResponseBase": FailedResponseBaseData().dict() |
|
|
} |
|
|
) |
|
|
|