Spaces:
Sleeping
Sleeping
File size: 7,848 Bytes
9150f8e e754e5a 9150f8e e754e5a 40749a1 3ec35ef 3e8fd5d 7cfbfed 3ec35ef 211d915 030cdaf 198c726 9150f8e 7cfbfed 198c726 3ec35ef 5768ce5 7cfbfed 40749a1 e754e5a 9150f8e 198c726 9150f8e 40749a1 198c726 3a56394 40749a1 2849c14 7cfbfed 198c726 2849c14 198c726 40749a1 198c726 9150f8e 2849c14 4a1bb29 1003460 2849c14 40749a1 2849c14 211d915 2849c14 73f4a8a 2849c14 7cfbfed 40749a1 7cfbfed 40749a1 7cfbfed 9150f8e 40749a1 198c726 40749a1 198c726 e754e5a fc55356 e754e5a 030cdaf 211d915 97743c1 e754e5a 211d915 97743c1 4a1bb29 97743c1 fc55356 97743c1 40749a1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import asyncio
import numpy as np
from trauma.api.chat.dto import EntityData
from trauma.api.chat.model import ChatModel
from trauma.api.data.model import EntityModel, EntityModelExtended
from trauma.api.message.ai.openai_request import (get_sensitive_words, update_entity_data_with_ai,
generate_next_question,
generate_search_request,
generate_final_response,
convert_value_to_embeddings,
choose_closest_treatment_method,
choose_closest_treatment_area,
check_is_valid_request,
generate_invalid_response,
set_entity_score, generate_empty_final_response)
from trauma.api.message.db_requests import (save_assistant_user_message,
filter_entities_by_age_location,
update_entity_data_obj, get_entities_bulk)
from trauma.api.message.dto import Author
from trauma.api.message.model import MessageModel
from trauma.api.message.schemas import CreateMessageResponse
from trauma.api.message.utils import (decode_treatment_letters,
prepare_message_history_str,
retrieve_empty_field_from_entity_data,
prepare_user_messages_str,
prepare_final_entities_str,
pick_empty_field_instructions,
find_matching_age_group,
search_changed_field_inst,
encrypt_message)
from trauma.core.config import settings
async def search_entities(
user_message: str, messages: list[MessageModel], chat: ChatModel
) -> CreateMessageResponse:
user_message = decode_treatment_letters(user_message)
message_history_str = prepare_message_history_str(messages, user_message)
entity_data, is_valid = await asyncio.gather(
update_entity_data_with_ai(chat.entityData, user_message, messages[-1].text),
check_is_valid_request(user_message, message_history_str)
)
final_entities, fields_changed_inst = None, search_changed_field_inst(entity_data, chat.entityData)
if not is_valid:
empty_field = retrieve_empty_field_from_entity_data(chat.entityData.model_dump(mode='json'))
response = await generate_invalid_response(user_message, message_history_str, empty_field)
final_entities = messages[-1].entities if messages else None
else:
asyncio.create_task(update_entity_data_obj(entity_data, chat.id))
empty_field = retrieve_empty_field_from_entity_data(entity_data)
empty_field_instructions = pick_empty_field_instructions(empty_field)
if empty_field == 'age':
response = await generate_next_question(empty_field_instructions, message_history_str)
else:
user_messages_str = prepare_user_messages_str(user_message, messages)
possible_entity_indexes, search_request = await asyncio.gather(
filter_entities_by_age_location(entity_data),
generate_search_request(user_messages_str, entity_data)
)
if not possible_entity_indexes and fields_changed_inst:
fields_changed_inst = {k: v for k, v in fields_changed_inst.items() if k in {'age', 'location', 'postalCode'}}
final_entities = await search_semantic_entities(search_request, entity_data, possible_entity_indexes)
final_entities_str = prepare_final_entities_str(final_entities)
if final_entities:
response = await generate_final_response(
final_entities_str, user_message, message_history_str, empty_field_instructions
)
else:
response = await generate_empty_final_response(
user_message, message_history_str, fields_changed_inst
)
user_message = MessageModel(chatId=chat.id, author=Author.User, text=user_message)
assistant_message = MessageModel(chatId=chat.id, author=Author.Assistant, text=response, entities=final_entities)
user_message_enc, assistant_message_enc = await encrypt_messages([user_message, assistant_message])
asyncio.create_task(save_assistant_user_message(user_message_enc, assistant_message_enc))
return assistant_message
async def search_semantic_entities(
search_request: str, entity_data: EntityData, entities_indexes: list[int]
) -> list[EntityModelExtended]:
embedding = await convert_value_to_embeddings(search_request)
query_embedding = np.array([embedding], dtype=np.float32)
distances, indices = settings.SEMANTIC_INDEX.search(query_embedding, k=settings.SEMANTIC_INDEX.ntotal)
distances = distances[0]
indices = indices[0]
filtered_results = [
{"index": int(idx), "distance": float(dist)}
for idx, dist in zip(indices, distances)
if idx in entities_indexes and dist <= 1.3
]
filtered_results = sorted(filtered_results, key=lambda x: x["distance"])[:50]
final_entities = await get_entities_bulk([i['index'] for i in filtered_results])
final_entities_extended = await extend_entities_with_highlights(final_entities, entity_data)
final_entities_scored = await set_entities_score(final_entities_extended, search_request)
return final_entities_scored
async def extend_entities_with_highlights(entities: list[EntityModel], entity_data: dict) -> list[
EntityModelExtended]:
async def choose_closest(entity_: EntityModel) -> tuple:
treatment_area, treatment_method = await asyncio.gather(
choose_closest_treatment_area(entity_.treatmentAreas, entity_data.get('treatmentArea')),
choose_closest_treatment_method(entity_.treatmentMethods, entity_data.get('treatmentMethod'))
)
return treatment_area, treatment_method
results = await asyncio.gather(*[choose_closest(entity) for entity in entities])
final_entities = []
for treatment, entity in zip(results, entities):
age_group = find_matching_age_group(entity, entity_data)
final_entities.append(EntityModelExtended(
**entity.to_mongo(),
highlightedAgeGroup=age_group,
highlightedTreatmentArea=treatment[0],
highlightedTreatmentMethod=treatment[1]
))
return final_entities
async def set_entities_score(entities: list[EntityModelExtended], search_request: str) -> list[EntityModelExtended]:
scores = await asyncio.gather(*[set_entity_score(entity, search_request) for entity in entities])
final_entities = []
for score, entity in zip(scores, entities):
if score > 0.9:
entity.topMatch = True
entity.score = score
if score > 0.72:
final_entities.append(entity)
return sorted(final_entities, key=lambda x: x.score, reverse=True)
async def encrypt_messages(messages: list[MessageModel]) -> list[MessageModel]:
encrypted_messages = []
sensitive_words = await asyncio.gather(*[get_sensitive_words(message.text) for message in messages])
for message, sensitive_word in zip(messages, sensitive_words):
encrypted_message = MessageModel(**message.model_dump())
encrypted_message.text = encrypt_message(message.text, sensitive_word)
encrypted_messages.append(encrypted_message)
return encrypted_messages
|