import asyncio import numpy as np from trauma.api.chat.dto import EntityData from trauma.api.chat.model import ChatModel from trauma.api.data.model import EntityModel, EntityModelExtended from trauma.api.message.ai.openai_request import (get_sensitive_words, update_entity_data_with_ai, generate_next_question, generate_search_request, generate_final_response, convert_value_to_embeddings, choose_closest_treatment_method, choose_closest_treatment_area, check_is_valid_request, generate_invalid_response, set_entity_score, generate_empty_final_response) from trauma.api.message.db_requests import (save_assistant_user_message, filter_entities_by_age_location, update_entity_data_obj, get_entities_bulk) from trauma.api.message.dto import Author from trauma.api.message.model import MessageModel from trauma.api.message.schemas import CreateMessageResponse from trauma.api.message.utils import (decode_treatment_letters, prepare_message_history_str, retrieve_empty_field_from_entity_data, prepare_user_messages_str, prepare_final_entities_str, pick_empty_field_instructions, find_matching_age_group, search_changed_field_inst, encrypt_message) from trauma.core.config import settings async def search_entities( user_message: str, messages: list[MessageModel], chat: ChatModel ) -> CreateMessageResponse: user_message = decode_treatment_letters(user_message) message_history_str = prepare_message_history_str(messages, user_message) entity_data, is_valid = await asyncio.gather( update_entity_data_with_ai(chat.entityData, user_message, messages[-1].text), check_is_valid_request(user_message, message_history_str) ) final_entities, fields_changed_inst = None, search_changed_field_inst(entity_data, chat.entityData) if not is_valid: empty_field = retrieve_empty_field_from_entity_data(chat.entityData.model_dump(mode='json')) response = await generate_invalid_response(user_message, message_history_str, empty_field) final_entities = messages[-1].entities if messages else None else: asyncio.create_task(update_entity_data_obj(entity_data, chat.id)) empty_field = retrieve_empty_field_from_entity_data(entity_data) empty_field_instructions = pick_empty_field_instructions(empty_field) if empty_field == 'age': response = await generate_next_question(empty_field_instructions, message_history_str) else: user_messages_str = prepare_user_messages_str(user_message, messages) possible_entity_indexes, search_request = await asyncio.gather( filter_entities_by_age_location(entity_data), generate_search_request(user_messages_str, entity_data) ) if not possible_entity_indexes and fields_changed_inst: fields_changed_inst = {k: v for k, v in fields_changed_inst.items() if k in {'age', 'location', 'postalCode'}} final_entities = await search_semantic_entities(search_request, entity_data, possible_entity_indexes) final_entities_str = prepare_final_entities_str(final_entities) if final_entities: response = await generate_final_response( final_entities_str, user_message, message_history_str, empty_field_instructions ) else: response = await generate_empty_final_response( user_message, message_history_str, fields_changed_inst ) user_message = MessageModel(chatId=chat.id, author=Author.User, text=user_message) assistant_message = MessageModel(chatId=chat.id, author=Author.Assistant, text=response, entities=final_entities) user_message_enc, assistant_message_enc = await encrypt_messages([user_message, assistant_message]) asyncio.create_task(save_assistant_user_message(user_message_enc, assistant_message_enc)) return assistant_message async def search_semantic_entities( search_request: str, entity_data: EntityData, entities_indexes: list[int] ) -> list[EntityModelExtended]: embedding = await convert_value_to_embeddings(search_request) query_embedding = np.array([embedding], dtype=np.float32) distances, indices = settings.SEMANTIC_INDEX.search(query_embedding, k=settings.SEMANTIC_INDEX.ntotal) distances = distances[0] indices = indices[0] filtered_results = [ {"index": int(idx), "distance": float(dist)} for idx, dist in zip(indices, distances) if idx in entities_indexes and dist <= 1.3 ] filtered_results = sorted(filtered_results, key=lambda x: x["distance"])[:50] final_entities = await get_entities_bulk([i['index'] for i in filtered_results]) final_entities_extended = await extend_entities_with_highlights(final_entities, entity_data) final_entities_scored = await set_entities_score(final_entities_extended, search_request) return final_entities_scored async def extend_entities_with_highlights(entities: list[EntityModel], entity_data: dict) -> list[ EntityModelExtended]: async def choose_closest(entity_: EntityModel) -> tuple: treatment_area, treatment_method = await asyncio.gather( choose_closest_treatment_area(entity_.treatmentAreas, entity_data.get('treatmentArea')), choose_closest_treatment_method(entity_.treatmentMethods, entity_data.get('treatmentMethod')) ) return treatment_area, treatment_method results = await asyncio.gather(*[choose_closest(entity) for entity in entities]) final_entities = [] for treatment, entity in zip(results, entities): age_group = find_matching_age_group(entity, entity_data) final_entities.append(EntityModelExtended( **entity.to_mongo(), highlightedAgeGroup=age_group, highlightedTreatmentArea=treatment[0], highlightedTreatmentMethod=treatment[1] )) return final_entities async def set_entities_score(entities: list[EntityModelExtended], search_request: str) -> list[EntityModelExtended]: scores = await asyncio.gather(*[set_entity_score(entity, search_request) for entity in entities]) final_entities = [] for score, entity in zip(scores, entities): if score > 0.9: entity.topMatch = True entity.score = score if score > 0.72: final_entities.append(entity) return sorted(final_entities, key=lambda x: x.score, reverse=True) async def encrypt_messages(messages: list[MessageModel]) -> list[MessageModel]: encrypted_messages = [] sensitive_words = await asyncio.gather(*[get_sensitive_words(message.text) for message in messages]) for message, sensitive_word in zip(messages, sensitive_words): encrypted_message = MessageModel(**message.model_dump()) encrypted_message.text = encrypt_message(message.text, sensitive_word) encrypted_messages.append(encrypted_message) return encrypted_messages