|
|
import asyncio |
|
|
|
|
|
from ocr.api.message import message_router |
|
|
from ocr.api.message.db_requests import get_all_chat_messages_obj, save_assistant_user_message |
|
|
from ocr.api.message.models import MessageModel |
|
|
from ocr.api.message.schemas import AllMessageWrapper, AllMessageResponse, CreateMessageRequest |
|
|
from ocr.api.openai_requests import generate_agent_response |
|
|
from ocr.api.report.db_requests import get_report_obj_by_id |
|
|
from ocr.api.report.dto import Paging |
|
|
from ocr.api.utils import transform_messages_to_openai |
|
|
from ocr.core.wrappers import OcrResponseWrapper |
|
|
|
|
|
|
|
|
@message_router.get('/{reportId}/all') |
|
|
async def get_all_chat_messages( |
|
|
reportId: str |
|
|
) -> AllMessageWrapper: |
|
|
messages = await get_all_chat_messages_obj(reportId) |
|
|
response = AllMessageResponse( |
|
|
paging=Paging(pageSize=len(messages), pageIndex=0, totalCount=len(messages)), |
|
|
data=messages |
|
|
) |
|
|
return AllMessageWrapper(data=response) |
|
|
|
|
|
|
|
|
@message_router.post('/{reportId}') |
|
|
async def create_message( |
|
|
reportId: str, |
|
|
message_data: CreateMessageRequest, |
|
|
) -> OcrResponseWrapper[MessageModel]: |
|
|
messages, report = await asyncio.gather( |
|
|
get_all_chat_messages_obj(reportId), |
|
|
get_report_obj_by_id(reportId) |
|
|
) |
|
|
message_history = transform_messages_to_openai(messages, message_data.text) |
|
|
response = await generate_agent_response(message_history, report) |
|
|
response = await save_assistant_user_message(message_data.text, response, reportId) |
|
|
return OcrResponseWrapper(data=response) |
|
|
|