Spaces:
Sleeping
Sleeping
| import asyncio | |
| import numpy as np | |
| from trauma.api.chat.dto import EntityData | |
| from trauma.api.chat.model import ChatModel | |
| from trauma.api.data.model import EntityModel, EntityModelExtended | |
| from trauma.api.message.ai.openai_request import (get_sensitive_words, update_entity_data_with_ai, | |
| generate_next_question, | |
| generate_search_request, | |
| generate_final_response, | |
| convert_value_to_embeddings, | |
| choose_closest_treatment_method, | |
| choose_closest_treatment_area, | |
| check_is_valid_request, | |
| generate_invalid_response, | |
| set_entity_score, generate_empty_final_response) | |
| from trauma.api.message.db_requests import (save_assistant_user_message, | |
| filter_entities_by_age_location, | |
| update_entity_data_obj, get_entities_bulk) | |
| from trauma.api.message.dto import Author | |
| from trauma.api.message.model import MessageModel | |
| from trauma.api.message.schemas import CreateMessageResponse | |
| from trauma.api.message.utils import (decode_treatment_letters, | |
| prepare_message_history_str, | |
| retrieve_empty_field_from_entity_data, | |
| prepare_user_messages_str, | |
| prepare_final_entities_str, | |
| pick_empty_field_instructions, | |
| find_matching_age_group, | |
| search_changed_field_inst, | |
| encrypt_message) | |
| from trauma.core.config import settings | |
| async def search_entities( | |
| user_message: str, messages: list[MessageModel], chat: ChatModel | |
| ) -> CreateMessageResponse: | |
| user_message = decode_treatment_letters(user_message) | |
| message_history_str = prepare_message_history_str(messages, user_message) | |
| entity_data, is_valid = await asyncio.gather( | |
| update_entity_data_with_ai(chat.entityData, user_message, messages[-1].text), | |
| check_is_valid_request(user_message, message_history_str) | |
| ) | |
| final_entities, fields_changed_inst = None, search_changed_field_inst(entity_data, chat.entityData) | |
| if not is_valid: | |
| empty_field = retrieve_empty_field_from_entity_data(chat.entityData.model_dump(mode='json')) | |
| response = await generate_invalid_response(user_message, message_history_str, empty_field) | |
| final_entities = messages[-1].entities if messages else None | |
| else: | |
| asyncio.create_task(update_entity_data_obj(entity_data, chat.id)) | |
| empty_field = retrieve_empty_field_from_entity_data(entity_data) | |
| empty_field_instructions = pick_empty_field_instructions(empty_field) | |
| if empty_field == 'age': | |
| response = await generate_next_question(empty_field_instructions, message_history_str) | |
| else: | |
| user_messages_str = prepare_user_messages_str(user_message, messages) | |
| possible_entity_indexes, search_request = await asyncio.gather( | |
| filter_entities_by_age_location(entity_data), | |
| generate_search_request(user_messages_str, entity_data) | |
| ) | |
| if not possible_entity_indexes and fields_changed_inst: | |
| fields_changed_inst = {k: v for k, v in fields_changed_inst.items() if k in {'age', 'location', 'postalCode'}} | |
| final_entities = await search_semantic_entities(search_request, entity_data, possible_entity_indexes) | |
| final_entities_str = prepare_final_entities_str(final_entities) | |
| if final_entities: | |
| response = await generate_final_response( | |
| final_entities_str, user_message, message_history_str, empty_field_instructions | |
| ) | |
| else: | |
| response = await generate_empty_final_response( | |
| user_message, message_history_str, fields_changed_inst | |
| ) | |
| user_message = MessageModel(chatId=chat.id, author=Author.User, text=user_message) | |
| assistant_message = MessageModel(chatId=chat.id, author=Author.Assistant, text=response, entities=final_entities) | |
| user_message_enc, assistant_message_enc = await encrypt_messages([user_message, assistant_message]) | |
| asyncio.create_task(save_assistant_user_message(user_message_enc, assistant_message_enc)) | |
| return assistant_message | |
| async def search_semantic_entities( | |
| search_request: str, entity_data: EntityData, entities_indexes: list[int] | |
| ) -> list[EntityModelExtended]: | |
| embedding = await convert_value_to_embeddings(search_request) | |
| query_embedding = np.array([embedding], dtype=np.float32) | |
| distances, indices = settings.SEMANTIC_INDEX.search(query_embedding, k=settings.SEMANTIC_INDEX.ntotal) | |
| distances = distances[0] | |
| indices = indices[0] | |
| filtered_results = [ | |
| {"index": int(idx), "distance": float(dist)} | |
| for idx, dist in zip(indices, distances) | |
| if idx in entities_indexes and dist <= 1.3 | |
| ] | |
| filtered_results = sorted(filtered_results, key=lambda x: x["distance"])[:50] | |
| final_entities = await get_entities_bulk([i['index'] for i in filtered_results]) | |
| final_entities_extended = await extend_entities_with_highlights(final_entities, entity_data) | |
| final_entities_scored = await set_entities_score(final_entities_extended, search_request) | |
| return final_entities_scored | |
| async def extend_entities_with_highlights(entities: list[EntityModel], entity_data: dict) -> list[ | |
| EntityModelExtended]: | |
| async def choose_closest(entity_: EntityModel) -> tuple: | |
| treatment_area, treatment_method = await asyncio.gather( | |
| choose_closest_treatment_area(entity_.treatmentAreas, entity_data.get('treatmentArea')), | |
| choose_closest_treatment_method(entity_.treatmentMethods, entity_data.get('treatmentMethod')) | |
| ) | |
| return treatment_area, treatment_method | |
| results = await asyncio.gather(*[choose_closest(entity) for entity in entities]) | |
| final_entities = [] | |
| for treatment, entity in zip(results, entities): | |
| age_group = find_matching_age_group(entity, entity_data) | |
| final_entities.append(EntityModelExtended( | |
| **entity.to_mongo(), | |
| highlightedAgeGroup=age_group, | |
| highlightedTreatmentArea=treatment[0], | |
| highlightedTreatmentMethod=treatment[1] | |
| )) | |
| return final_entities | |
| async def set_entities_score(entities: list[EntityModelExtended], search_request: str) -> list[EntityModelExtended]: | |
| scores = await asyncio.gather(*[set_entity_score(entity, search_request) for entity in entities]) | |
| final_entities = [] | |
| for score, entity in zip(scores, entities): | |
| if score > 0.9: | |
| entity.topMatch = True | |
| entity.score = score | |
| if score > 0.72: | |
| final_entities.append(entity) | |
| return sorted(final_entities, key=lambda x: x.score, reverse=True) | |
| async def encrypt_messages(messages: list[MessageModel]) -> list[MessageModel]: | |
| encrypted_messages = [] | |
| sensitive_words = await asyncio.gather(*[get_sensitive_words(message.text) for message in messages]) | |
| for message, sensitive_word in zip(messages, sensitive_words): | |
| encrypted_message = MessageModel(**message.model_dump()) | |
| encrypted_message.text = encrypt_message(message.text, sensitive_word) | |
| encrypted_messages.append(encrypted_message) | |
| return encrypted_messages | |