brestok's picture
Add encryption and decryption for sensitive message data
40749a1
import asyncio
import numpy as np
from trauma.api.chat.dto import EntityData
from trauma.api.chat.model import ChatModel
from trauma.api.data.model import EntityModel, EntityModelExtended
from trauma.api.message.ai.openai_request import (get_sensitive_words, update_entity_data_with_ai,
generate_next_question,
generate_search_request,
generate_final_response,
convert_value_to_embeddings,
choose_closest_treatment_method,
choose_closest_treatment_area,
check_is_valid_request,
generate_invalid_response,
set_entity_score, generate_empty_final_response)
from trauma.api.message.db_requests import (save_assistant_user_message,
filter_entities_by_age_location,
update_entity_data_obj, get_entities_bulk)
from trauma.api.message.dto import Author
from trauma.api.message.model import MessageModel
from trauma.api.message.schemas import CreateMessageResponse
from trauma.api.message.utils import (decode_treatment_letters,
prepare_message_history_str,
retrieve_empty_field_from_entity_data,
prepare_user_messages_str,
prepare_final_entities_str,
pick_empty_field_instructions,
find_matching_age_group,
search_changed_field_inst,
encrypt_message)
from trauma.core.config import settings
async def search_entities(
user_message: str, messages: list[MessageModel], chat: ChatModel
) -> CreateMessageResponse:
user_message = decode_treatment_letters(user_message)
message_history_str = prepare_message_history_str(messages, user_message)
entity_data, is_valid = await asyncio.gather(
update_entity_data_with_ai(chat.entityData, user_message, messages[-1].text),
check_is_valid_request(user_message, message_history_str)
)
final_entities, fields_changed_inst = None, search_changed_field_inst(entity_data, chat.entityData)
if not is_valid:
empty_field = retrieve_empty_field_from_entity_data(chat.entityData.model_dump(mode='json'))
response = await generate_invalid_response(user_message, message_history_str, empty_field)
final_entities = messages[-1].entities if messages else None
else:
asyncio.create_task(update_entity_data_obj(entity_data, chat.id))
empty_field = retrieve_empty_field_from_entity_data(entity_data)
empty_field_instructions = pick_empty_field_instructions(empty_field)
if empty_field == 'age':
response = await generate_next_question(empty_field_instructions, message_history_str)
else:
user_messages_str = prepare_user_messages_str(user_message, messages)
possible_entity_indexes, search_request = await asyncio.gather(
filter_entities_by_age_location(entity_data),
generate_search_request(user_messages_str, entity_data)
)
if not possible_entity_indexes and fields_changed_inst:
fields_changed_inst = {k: v for k, v in fields_changed_inst.items() if k in {'age', 'location', 'postalCode'}}
final_entities = await search_semantic_entities(search_request, entity_data, possible_entity_indexes)
final_entities_str = prepare_final_entities_str(final_entities)
if final_entities:
response = await generate_final_response(
final_entities_str, user_message, message_history_str, empty_field_instructions
)
else:
response = await generate_empty_final_response(
user_message, message_history_str, fields_changed_inst
)
user_message = MessageModel(chatId=chat.id, author=Author.User, text=user_message)
assistant_message = MessageModel(chatId=chat.id, author=Author.Assistant, text=response, entities=final_entities)
user_message_enc, assistant_message_enc = await encrypt_messages([user_message, assistant_message])
asyncio.create_task(save_assistant_user_message(user_message_enc, assistant_message_enc))
return assistant_message
async def search_semantic_entities(
search_request: str, entity_data: EntityData, entities_indexes: list[int]
) -> list[EntityModelExtended]:
embedding = await convert_value_to_embeddings(search_request)
query_embedding = np.array([embedding], dtype=np.float32)
distances, indices = settings.SEMANTIC_INDEX.search(query_embedding, k=settings.SEMANTIC_INDEX.ntotal)
distances = distances[0]
indices = indices[0]
filtered_results = [
{"index": int(idx), "distance": float(dist)}
for idx, dist in zip(indices, distances)
if idx in entities_indexes and dist <= 1.3
]
filtered_results = sorted(filtered_results, key=lambda x: x["distance"])[:50]
final_entities = await get_entities_bulk([i['index'] for i in filtered_results])
final_entities_extended = await extend_entities_with_highlights(final_entities, entity_data)
final_entities_scored = await set_entities_score(final_entities_extended, search_request)
return final_entities_scored
async def extend_entities_with_highlights(entities: list[EntityModel], entity_data: dict) -> list[
EntityModelExtended]:
async def choose_closest(entity_: EntityModel) -> tuple:
treatment_area, treatment_method = await asyncio.gather(
choose_closest_treatment_area(entity_.treatmentAreas, entity_data.get('treatmentArea')),
choose_closest_treatment_method(entity_.treatmentMethods, entity_data.get('treatmentMethod'))
)
return treatment_area, treatment_method
results = await asyncio.gather(*[choose_closest(entity) for entity in entities])
final_entities = []
for treatment, entity in zip(results, entities):
age_group = find_matching_age_group(entity, entity_data)
final_entities.append(EntityModelExtended(
**entity.to_mongo(),
highlightedAgeGroup=age_group,
highlightedTreatmentArea=treatment[0],
highlightedTreatmentMethod=treatment[1]
))
return final_entities
async def set_entities_score(entities: list[EntityModelExtended], search_request: str) -> list[EntityModelExtended]:
scores = await asyncio.gather(*[set_entity_score(entity, search_request) for entity in entities])
final_entities = []
for score, entity in zip(scores, entities):
if score > 0.9:
entity.topMatch = True
entity.score = score
if score > 0.72:
final_entities.append(entity)
return sorted(final_entities, key=lambda x: x.score, reverse=True)
async def encrypt_messages(messages: list[MessageModel]) -> list[MessageModel]:
encrypted_messages = []
sensitive_words = await asyncio.gather(*[get_sensitive_words(message.text) for message in messages])
for message, sensitive_word in zip(messages, sensitive_words):
encrypted_message = MessageModel(**message.model_dump())
encrypted_message.text = encrypt_message(message.text, sensitive_word)
encrypted_messages.append(encrypted_message)
return encrypted_messages