Spaces:
Sleeping
Sleeping
| import asyncio | |
| from fastapi import HTTPException | |
| from trauma.api.account.dto import AccountType | |
| from trauma.api.account.model import AccountModel | |
| from trauma.api.chat.model import ChatModel | |
| from trauma.api.data.model import EntityModel | |
| from trauma.api.message.dto import Author, Feedback | |
| from trauma.api.message.model import MessageModel | |
| from trauma.api.message.schemas import CreateMessageRequest | |
| from trauma.core.config import settings | |
| from trauma.core.wrappers import background_task | |
| async def create_message_obj( | |
| chat_id: str, message_data: CreateMessageRequest | |
| ) -> tuple[MessageModel, ChatModel]: | |
| chat = await settings.DB_CLIENT.chats.find_one({"id": chat_id}) | |
| if not chat: | |
| raise HTTPException(status_code=404, detail="Chat not found.") | |
| message = MessageModel(**message_data.model_dump(), chatId=chat_id, author=Author.User) | |
| await settings.DB_CLIENT.messages.insert_one(message.to_mongo()) | |
| return message, chat | |
| async def get_all_chat_messages_obj(chat_id: str, account: AccountModel) -> tuple[list[MessageModel], ChatModel]: | |
| messages, chat = await asyncio.gather( | |
| settings.DB_CLIENT.messages.find({"chatId": chat_id}).to_list(length=None), | |
| settings.DB_CLIENT.chats.find_one({"id": chat_id}) | |
| ) | |
| messages = [MessageModel.from_mongo(message) for message in messages] | |
| if not chat: | |
| raise HTTPException(status_code=404, detail="Chat not found") | |
| chat = ChatModel.from_mongo(chat) | |
| if chat.account.id != account.id and account.accountType != AccountType.Admin: | |
| raise HTTPException(status_code=404, detail="Not Authorized.") | |
| return messages, chat | |
| async def update_entity_data_obj(entity_data: dict, chat_id: str) -> None: | |
| await settings.DB_CLIENT.chats.update_one( | |
| {"id": chat_id}, | |
| {"$set": { | |
| "entityData": entity_data, | |
| }} | |
| ) | |
| async def save_assistant_user_message( | |
| user_message: MessageModel, assistant_message: MessageModel | |
| ) -> None: | |
| await settings.DB_CLIENT.messages.insert_one(user_message.to_mongo()) | |
| await settings.DB_CLIENT.messages.insert_one(assistant_message.to_mongo()) | |
| async def filter_entities_by_age_location(entity_data: dict) -> list[int]: | |
| query = { | |
| "ageGroups": { | |
| "$elemMatch": { | |
| "ageMin": {"$lte": entity_data['age']}, | |
| "ageMax": {"$gte": entity_data['age']} | |
| } | |
| }, | |
| } | |
| if entity_data.get('location'): | |
| query["contactDetails.address"] = { | |
| "$regex": f".*{entity_data['location']}.*", | |
| "$options": "i" | |
| } | |
| if entity_data.get('postalCode'): | |
| query["contactDetails.postalCode"] = { | |
| "$regex": f".*{entity_data['postalCode']}.*", | |
| "$options": "i" | |
| } | |
| entities = await settings.DB_CLIENT.entities.find(query, {"index": 1, "_id": 0}).to_list(length=None) | |
| return [entity['index'] for entity in entities] | |
| async def get_entity_by_index(index: int) -> EntityModel: | |
| entity = await settings.DB_CLIENT.entities.find_one({"index": index}) | |
| return EntityModel.from_mongo(entity) | |
| async def get_entities_bulk(indices: list[int]) -> list[EntityModel]: | |
| entities = await settings.DB_CLIENT.entities.find({"index": {"$in": indices}}, | |
| {"embedding": 0}).to_list(length=None) | |
| return [EntityModel.from_mongo(entity) for entity in entities] | |
| async def update_message_feedback_obj(message_id: str, feedback_data: Feedback) -> MessageModel: | |
| message = await settings.DB_CLIENT.messages.find_one({"id": message_id}) | |
| if not message: | |
| raise HTTPException(status_code=404, detail="Message not found") | |
| message = MessageModel.from_mongo(message) | |
| message.feedback = feedback_data | |
| await settings.DB_CLIENT.messages.update_one({"id": message_id}, | |
| {"$set": {"feedback": feedback_data.model_dump(mode='json')}}) | |
| return message | |