g3eIL's picture
Upload 80 files
77320e4 verified
import asyncio
import uuid
import uvloop
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from sse_starlette.sse import EventSourceResponse, ServerSentEvent
from starlette.responses import JSONResponse, Response
from .activity_helpers import DONE
from .complete_chat import complete_chat_router
from .predict import predict_router
try:
import infiagent
from infiagent.schemas import FailedResponseBaseData
from infiagent.utils import get_logger, init_logging, log_id_var
except ImportError:
print("import infiagent failed, please install infiagent by 'pip install .' in the pipeline directory of ADA-Agent")
from ..schemas import FailedResponseBaseData
from ..utils import get_logger, init_logging, log_id_var
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
SSE_API_PATHS = ["/complete_sse"]
LOG_ID_HEADER_NAME = "X-Tt-Logid"
load_dotenv()
init_logging()
logger = get_logger()
app = FastAPI()
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(complete_chat_router)
app.include_router(predict_router)
@app.middleware("http")
async def log_id_middleware(request: Request, call_next):
# Get X-Tt-Logid from request headers
log_id = request.headers.get(LOG_ID_HEADER_NAME)
if not log_id:
# Generate a log_id if not present in headers
log_id = str(uuid.uuid4())
log_id_var.set(log_id)
response: Response = await call_next(request)
response.headers[LOG_ID_HEADER_NAME] = log_id_var.get()
return response
@app.exception_handler(Exception)
async def general_exception_handler(request, exc):
error_msg = "Failed to handle request. Internal Server error: {}".format(str(exc))
logger.error(error_msg, exc_info=True)
if request.url.path in SSE_API_PATHS:
return EventSourceResponse(ServerSentEvent(data=DONE))
else:
return JSONResponse(
status_code=500,
content={
"response": error_msg,
"ResponseBase": FailedResponseBaseData().dict()
}
)
@app.exception_handler(HTTPException)
async def http_exception_handler(request, exc):
error_msg = "Failed to handle request. Error: {}".format(exc.detail)
logger.error(error_msg, exc_info=True)
if request.url.path in SSE_API_PATHS:
return EventSourceResponse(ServerSentEvent(data=DONE))
else:
return JSONResponse(
status_code=exc.status_code,
content={
"response": error_msg,
"ResponseBase": FailedResponseBaseData().dict()
}
)