import json from fastapi import FastAPI, HTTPException, Request from pydantic import ValidationError from sse_starlette import EventSourceResponse from .activity_helpers import ( async_sse_response_format, get_ignore_ping_comment, json_response_format, ) try: import infiagent from infiagent.schemas import ChatCompleteRequest from infiagent.services.complete_local_test import ( chat_local_event, chat_local_event_generator, ) from infiagent.utils import get_logger except ImportError: print("import infiagent failed, please install infiagent by 'pip install .' in the pipeline directory of ADA-Agent") from ..schemas import ChatCompleteRequest from ..services.complete_local_test import ( chat_local_event, chat_local_event_generator, ) from ..utils import get_logger logger = get_logger() local_app = FastAPI() @local_app.post("/local_sse_test") async def complete_sse(request: Request): body_str = await request.body() try: chat_request = ChatCompleteRequest.parse_raw(body_str) logger.info("Got chat request: {}".format(chat_request)) except ValidationError as e: error_msg = "Invalid input chat_request. Error: {}".format(str(e)) raise HTTPException(status_code=500, detail=error_msg) return EventSourceResponse(async_sse_response_format(chat_local_event_generator(chat_request)), ping_message_factory=get_ignore_ping_comment()) @local_app.post("/local_json_test") async def complete_json(request: Request): body_str = await request.body() try: chat_request = ChatCompleteRequest.parse_raw(body_str) logger.info("Got chat request: {}".format(chat_request)) except ValidationError as e: error_msg = "Invalid input chat_request. Error: {}".format(str(e)) raise HTTPException(status_code=500, detail=error_msg) response_items = await chat_local_event(chat_request) return json_response_format(response_items) @local_app.post("/exception_test") async def complete_json(request: Request): body_str = await request.body() try: chat_request = ChatCompleteRequest.parse_raw(body_str) logger.info("Got chat request: {}".format(chat_request)) except ValidationError as e: error_msg = "Invalid input chat_request. Error: {}".format(str(e)) raise HTTPException(status_code=500, detail=error_msg) return EventSourceResponse(async_sse_response_format(chat_local_event_generator(chat_request))) async def exception_test(request: Request): body_str = await request.body() json_val = json.loads(body_str) exception_type = json_val.get("exception", None) if exception_type: raise ValueError("Error triggerd!") else: yield iter(["Success"])