File size: 2,847 Bytes
77320e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import json
from fastapi import FastAPI, HTTPException, Request
from pydantic import ValidationError
from sse_starlette import EventSourceResponse
from .activity_helpers import (
async_sse_response_format,
get_ignore_ping_comment,
json_response_format,
)
try:
import infiagent
from infiagent.schemas import ChatCompleteRequest
from infiagent.services.complete_local_test import (
chat_local_event,
chat_local_event_generator,
)
from infiagent.utils import get_logger
except ImportError:
print("import infiagent failed, please install infiagent by 'pip install .' in the pipeline directory of ADA-Agent")
from ..schemas import ChatCompleteRequest
from ..services.complete_local_test import (
chat_local_event,
chat_local_event_generator,
)
from ..utils import get_logger
logger = get_logger()
local_app = FastAPI()
@local_app.post("/local_sse_test")
async def complete_sse(request: Request):
body_str = await request.body()
try:
chat_request = ChatCompleteRequest.parse_raw(body_str)
logger.info("Got chat request: {}".format(chat_request))
except ValidationError as e:
error_msg = "Invalid input chat_request. Error: {}".format(str(e))
raise HTTPException(status_code=500, detail=error_msg)
return EventSourceResponse(async_sse_response_format(chat_local_event_generator(chat_request)),
ping_message_factory=get_ignore_ping_comment())
@local_app.post("/local_json_test")
async def complete_json(request: Request):
body_str = await request.body()
try:
chat_request = ChatCompleteRequest.parse_raw(body_str)
logger.info("Got chat request: {}".format(chat_request))
except ValidationError as e:
error_msg = "Invalid input chat_request. Error: {}".format(str(e))
raise HTTPException(status_code=500, detail=error_msg)
response_items = await chat_local_event(chat_request)
return json_response_format(response_items)
@local_app.post("/exception_test")
async def complete_json(request: Request):
body_str = await request.body()
try:
chat_request = ChatCompleteRequest.parse_raw(body_str)
logger.info("Got chat request: {}".format(chat_request))
except ValidationError as e:
error_msg = "Invalid input chat_request. Error: {}".format(str(e))
raise HTTPException(status_code=500, detail=error_msg)
return EventSourceResponse(async_sse_response_format(chat_local_event_generator(chat_request)))
async def exception_test(request: Request):
body_str = await request.body()
json_val = json.loads(body_str)
exception_type = json_val.get("exception", None)
if exception_type:
raise ValueError("Error triggerd!")
else:
yield iter(["Success"])
|