File size: 2,782 Bytes
77320e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import asyncio
import uuid

import uvloop
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from sse_starlette.sse import EventSourceResponse, ServerSentEvent
from starlette.responses import JSONResponse, Response

from .activity_helpers import DONE
from .complete_chat import complete_chat_router
from .predict import predict_router

try:
    import infiagent
    from infiagent.schemas import FailedResponseBaseData
    from infiagent.utils import get_logger, init_logging, log_id_var
except ImportError:
    print("import infiagent failed, please install infiagent by 'pip install .' in the pipeline directory of ADA-Agent")
    from ..schemas import FailedResponseBaseData
    from ..utils import get_logger, init_logging, log_id_var

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

SSE_API_PATHS = ["/complete_sse"]
LOG_ID_HEADER_NAME = "X-Tt-Logid"


load_dotenv()
init_logging()
logger = get_logger()

app = FastAPI()
origins = ["*"]
app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)
app.include_router(complete_chat_router)
app.include_router(predict_router)


@app.middleware("http")
async def log_id_middleware(request: Request, call_next):
    # Get X-Tt-Logid from request headers
    log_id = request.headers.get(LOG_ID_HEADER_NAME)
    if not log_id:
        # Generate a log_id if not present in headers
        log_id = str(uuid.uuid4())

    log_id_var.set(log_id)

    response: Response = await call_next(request)
    response.headers[LOG_ID_HEADER_NAME] = log_id_var.get()
    return response


@app.exception_handler(Exception)
async def general_exception_handler(request, exc):
    error_msg = "Failed to handle request. Internal Server error: {}".format(str(exc))
    logger.error(error_msg, exc_info=True)

    if request.url.path in SSE_API_PATHS:
        return EventSourceResponse(ServerSentEvent(data=DONE))
    else:
        return JSONResponse(
            status_code=500,
            content={
                "response": error_msg,
                "ResponseBase": FailedResponseBaseData().dict()
            }
        )


@app.exception_handler(HTTPException)
async def http_exception_handler(request, exc):
    error_msg = "Failed to handle request. Error: {}".format(exc.detail)
    logger.error(error_msg, exc_info=True)

    if request.url.path in SSE_API_PATHS:
        return EventSourceResponse(ServerSentEvent(data=DONE))
    else:
        return JSONResponse(
            status_code=exc.status_code,
            content={
                "response": error_msg,
                "ResponseBase": FailedResponseBaseData().dict()
            }
        )