File size: 2,847 Bytes
77320e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import json

from fastapi import FastAPI, HTTPException, Request
from pydantic import ValidationError
from sse_starlette import EventSourceResponse

from .activity_helpers import (
    async_sse_response_format,
    get_ignore_ping_comment,
    json_response_format,
)


try:
    import infiagent
    from infiagent.schemas import ChatCompleteRequest
    from infiagent.services.complete_local_test import (
        chat_local_event,
        chat_local_event_generator,
    )
    from infiagent.utils import get_logger
except ImportError:
    print("import infiagent failed, please install infiagent by 'pip install .' in the pipeline directory of ADA-Agent")
    from ..schemas import ChatCompleteRequest
    from ..services.complete_local_test import (
        chat_local_event,
        chat_local_event_generator,
    )
    from ..utils import get_logger

logger = get_logger()
local_app = FastAPI()


@local_app.post("/local_sse_test")
async def complete_sse(request: Request):
    body_str = await request.body()

    try:
        chat_request = ChatCompleteRequest.parse_raw(body_str)
        logger.info("Got chat request: {}".format(chat_request))
    except ValidationError as e:
        error_msg = "Invalid input chat_request. Error: {}".format(str(e))
        raise HTTPException(status_code=500, detail=error_msg)

    return EventSourceResponse(async_sse_response_format(chat_local_event_generator(chat_request)),
                               ping_message_factory=get_ignore_ping_comment())


@local_app.post("/local_json_test")
async def complete_json(request: Request):

    body_str = await request.body()

    try:
        chat_request = ChatCompleteRequest.parse_raw(body_str)
        logger.info("Got chat request: {}".format(chat_request))
    except ValidationError as e:
        error_msg = "Invalid input chat_request. Error: {}".format(str(e))
        raise HTTPException(status_code=500, detail=error_msg)

    response_items = await chat_local_event(chat_request)
    return json_response_format(response_items)


@local_app.post("/exception_test")
async def complete_json(request: Request):
    body_str = await request.body()

    try:
        chat_request = ChatCompleteRequest.parse_raw(body_str)
        logger.info("Got chat request: {}".format(chat_request))
    except ValidationError as e:
        error_msg = "Invalid input chat_request. Error: {}".format(str(e))
        raise HTTPException(status_code=500, detail=error_msg)
    return EventSourceResponse(async_sse_response_format(chat_local_event_generator(chat_request)))


async def exception_test(request: Request):
    body_str = await request.body()
    json_val = json.loads(body_str)
    exception_type = json_val.get("exception", None)

    if exception_type:
        raise ValueError("Error triggerd!")
    else:
        yield iter(["Success"])