Spaces:
Sleeping
Sleeping
Add entity handling to message saving and AI response logic
Browse files
trauma/api/message/ai/engine.py
CHANGED
|
@@ -31,16 +31,10 @@ from trauma.core.config import settings
|
|
| 31 |
async def search_entities(
|
| 32 |
user_message: str, messages: list[dict], chat: ChatModel
|
| 33 |
) -> CreateMessageResponse:
|
| 34 |
-
|
| 35 |
-
retrieve_semantic_answer(user_message),
|
| 36 |
update_entity_data_with_ai(chat.entityData, user_message, messages[-1]['content']),
|
| 37 |
check_is_valid_request(user_message, "\n".join([f"- [{i['role']}]: {i['content']}." for i in messages]))
|
| 38 |
)
|
| 39 |
-
if related_entity:
|
| 40 |
-
response = await generate_searched_entity_response(user_message, related_entity[0])
|
| 41 |
-
asyncio.create_task(save_assistant_user_message(user_message, response, chat.id))
|
| 42 |
-
return CreateMessageResponse(text=response, entities=related_entity)
|
| 43 |
-
|
| 44 |
final_entities = None
|
| 45 |
if not is_valid:
|
| 46 |
response = await generate_invalid_response(user_message, messages)
|
|
@@ -60,7 +54,7 @@ async def search_entities(
|
|
| 60 |
final_entities_str = prepare_final_entities_str(final_entities)
|
| 61 |
response = await generate_final_response(final_entities_str, user_message, messages)
|
| 62 |
|
| 63 |
-
asyncio.create_task(save_assistant_user_message(user_message, response, chat.id))
|
| 64 |
return CreateMessageResponse(text=response, entities=final_entities)
|
| 65 |
|
| 66 |
|
|
|
|
| 31 |
async def search_entities(
|
| 32 |
user_message: str, messages: list[dict], chat: ChatModel
|
| 33 |
) -> CreateMessageResponse:
|
| 34 |
+
entity_data, is_valid = await asyncio.gather(
|
|
|
|
| 35 |
update_entity_data_with_ai(chat.entityData, user_message, messages[-1]['content']),
|
| 36 |
check_is_valid_request(user_message, "\n".join([f"- [{i['role']}]: {i['content']}." for i in messages]))
|
| 37 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
final_entities = None
|
| 39 |
if not is_valid:
|
| 40 |
response = await generate_invalid_response(user_message, messages)
|
|
|
|
| 54 |
final_entities_str = prepare_final_entities_str(final_entities)
|
| 55 |
response = await generate_final_response(final_entities_str, user_message, messages)
|
| 56 |
|
| 57 |
+
asyncio.create_task(save_assistant_user_message(user_message, final_entities, response, chat.id))
|
| 58 |
return CreateMessageResponse(text=response, entities=final_entities)
|
| 59 |
|
| 60 |
|
trauma/api/message/db_requests.py
CHANGED
|
@@ -5,7 +5,7 @@ from fastapi import HTTPException
|
|
| 5 |
from trauma.api.account.dto import AccountType
|
| 6 |
from trauma.api.account.model import AccountModel
|
| 7 |
from trauma.api.chat.model import ChatModel
|
| 8 |
-
from trauma.api.data.model import EntityModel
|
| 9 |
from trauma.api.message.dto import Author
|
| 10 |
from trauma.api.message.model import MessageModel
|
| 11 |
from trauma.api.message.schemas import CreateMessageRequest
|
|
@@ -53,8 +53,10 @@ async def update_entity_data_obj(entity_data: dict, chat_id: str) -> None:
|
|
| 53 |
|
| 54 |
|
| 55 |
@background_task()
|
| 56 |
-
async def save_assistant_user_message(
|
| 57 |
-
|
|
|
|
|
|
|
| 58 |
assistant_message = MessageModel(chatId=chat_id, author=Author.Assistant, text=assistant_message)
|
| 59 |
await settings.DB_CLIENT.messages.insert_one(user_message.to_mongo())
|
| 60 |
await settings.DB_CLIENT.messages.insert_one(assistant_message.to_mongo())
|
|
|
|
| 5 |
from trauma.api.account.dto import AccountType
|
| 6 |
from trauma.api.account.model import AccountModel
|
| 7 |
from trauma.api.chat.model import ChatModel
|
| 8 |
+
from trauma.api.data.model import EntityModel, EntityModelExtended
|
| 9 |
from trauma.api.message.dto import Author
|
| 10 |
from trauma.api.message.model import MessageModel
|
| 11 |
from trauma.api.message.schemas import CreateMessageRequest
|
|
|
|
| 53 |
|
| 54 |
|
| 55 |
@background_task()
|
| 56 |
+
async def save_assistant_user_message(
|
| 57 |
+
user_message: str, final_entities: list[EntityModelExtended], assistant_message: str, chat_id: str
|
| 58 |
+
) -> None:
|
| 59 |
+
user_message = MessageModel(chatId=chat_id, author=Author.User, entities=final_entities, text=user_message)
|
| 60 |
assistant_message = MessageModel(chatId=chat_id, author=Author.Assistant, text=assistant_message)
|
| 61 |
await settings.DB_CLIENT.messages.insert_one(user_message.to_mongo())
|
| 62 |
await settings.DB_CLIENT.messages.insert_one(assistant_message.to_mongo())
|
trauma/api/message/model.py
CHANGED
|
@@ -2,6 +2,7 @@ from datetime import datetime
|
|
| 2 |
|
| 3 |
from pydantic import Field
|
| 4 |
|
|
|
|
| 5 |
from trauma.api.message.dto import Author
|
| 6 |
from trauma.core.database import MongoBaseModel
|
| 7 |
|
|
@@ -10,5 +11,6 @@ class MessageModel(MongoBaseModel):
|
|
| 10 |
chatId: str
|
| 11 |
author: Author
|
| 12 |
text: str
|
|
|
|
| 13 |
datetimeInserted: datetime = Field(default_factory=datetime.now)
|
| 14 |
datetimeUpdated: datetime = Field(default_factory=datetime.now)
|
|
|
|
| 2 |
|
| 3 |
from pydantic import Field
|
| 4 |
|
| 5 |
+
from trauma.api.data.model import EntityModelExtended
|
| 6 |
from trauma.api.message.dto import Author
|
| 7 |
from trauma.core.database import MongoBaseModel
|
| 8 |
|
|
|
|
| 11 |
chatId: str
|
| 12 |
author: Author
|
| 13 |
text: str
|
| 14 |
+
entities: list[EntityModelExtended] | None = None
|
| 15 |
datetimeInserted: datetime = Field(default_factory=datetime.now)
|
| 16 |
datetimeUpdated: datetime = Field(default_factory=datetime.now)
|