File size: 7,848 Bytes
9150f8e
 
e754e5a
 
 
9150f8e
e754e5a
40749a1
3ec35ef
 
3e8fd5d
 
 
 
 
 
7cfbfed
3ec35ef
211d915
030cdaf
198c726
 
9150f8e
7cfbfed
198c726
 
3ec35ef
5768ce5
7cfbfed
 
40749a1
 
e754e5a
9150f8e
 
 
198c726
9150f8e
40749a1
 
198c726
3a56394
40749a1
 
2849c14
7cfbfed
198c726
2849c14
198c726
40749a1
198c726
 
9150f8e
2849c14
 
4a1bb29
 
 
1003460
2849c14
40749a1
2849c14
211d915
2849c14
 
73f4a8a
 
2849c14
 
7cfbfed
 
40749a1
7cfbfed
 
 
40749a1
7cfbfed
9150f8e
40749a1
198c726
40749a1
 
198c726
e754e5a
 
 
 
 
fc55356
e754e5a
 
 
 
 
 
 
 
 
 
030cdaf
 
211d915
97743c1
 
e754e5a
 
211d915
97743c1
 
 
4a1bb29
 
97743c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc55356
97743c1
 
40749a1
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import asyncio

import numpy as np

from trauma.api.chat.dto import EntityData
from trauma.api.chat.model import ChatModel
from trauma.api.data.model import EntityModel, EntityModelExtended
from trauma.api.message.ai.openai_request import (get_sensitive_words, update_entity_data_with_ai,
                                                  generate_next_question,
                                                  generate_search_request,
                                                  generate_final_response,
                                                  convert_value_to_embeddings,
                                                  choose_closest_treatment_method,
                                                  choose_closest_treatment_area,
                                                  check_is_valid_request,
                                                  generate_invalid_response,
                                                  set_entity_score, generate_empty_final_response)
from trauma.api.message.db_requests import (save_assistant_user_message,
                                            filter_entities_by_age_location,
                                            update_entity_data_obj, get_entities_bulk)
from trauma.api.message.dto import Author
from trauma.api.message.model import MessageModel
from trauma.api.message.schemas import CreateMessageResponse
from trauma.api.message.utils import (decode_treatment_letters,
                                      prepare_message_history_str,
                                      retrieve_empty_field_from_entity_data,
                                      prepare_user_messages_str,
                                      prepare_final_entities_str,
                                      pick_empty_field_instructions,
                                      find_matching_age_group,
                                      search_changed_field_inst,
                                      encrypt_message)
from trauma.core.config import settings


async def search_entities(
        user_message: str, messages: list[MessageModel], chat: ChatModel
) -> CreateMessageResponse:
    user_message = decode_treatment_letters(user_message)
    message_history_str = prepare_message_history_str(messages, user_message)

    entity_data, is_valid = await asyncio.gather(
        update_entity_data_with_ai(chat.entityData, user_message, messages[-1].text),
        check_is_valid_request(user_message, message_history_str)
    )
    final_entities, fields_changed_inst = None, search_changed_field_inst(entity_data, chat.entityData)
    
    if not is_valid:
        empty_field = retrieve_empty_field_from_entity_data(chat.entityData.model_dump(mode='json'))
        response = await generate_invalid_response(user_message, message_history_str, empty_field)
        final_entities = messages[-1].entities if messages else None

    else:
        asyncio.create_task(update_entity_data_obj(entity_data, chat.id))
        empty_field = retrieve_empty_field_from_entity_data(entity_data)
        empty_field_instructions = pick_empty_field_instructions(empty_field)

        if empty_field == 'age':
            response = await generate_next_question(empty_field_instructions, message_history_str)
        else:
            user_messages_str = prepare_user_messages_str(user_message, messages)
            possible_entity_indexes, search_request = await asyncio.gather(
                filter_entities_by_age_location(entity_data),
                generate_search_request(user_messages_str, entity_data)
            )
            if not possible_entity_indexes and fields_changed_inst:
                fields_changed_inst = {k: v for k, v in fields_changed_inst.items() if k in {'age', 'location', 'postalCode'}}
            final_entities = await search_semantic_entities(search_request, entity_data, possible_entity_indexes)
            final_entities_str = prepare_final_entities_str(final_entities)
            if final_entities:
                response = await generate_final_response(
                    final_entities_str, user_message, message_history_str, empty_field_instructions
                )
            else:
                response = await generate_empty_final_response(
                    user_message, message_history_str, fields_changed_inst
                )

    user_message = MessageModel(chatId=chat.id, author=Author.User, text=user_message)
    assistant_message = MessageModel(chatId=chat.id, author=Author.Assistant, text=response, entities=final_entities)
    user_message_enc, assistant_message_enc = await encrypt_messages([user_message, assistant_message])
    asyncio.create_task(save_assistant_user_message(user_message_enc, assistant_message_enc))
    return assistant_message


async def search_semantic_entities(
        search_request: str, entity_data: EntityData, entities_indexes: list[int]
) -> list[EntityModelExtended]:

    embedding = await convert_value_to_embeddings(search_request)
    query_embedding = np.array([embedding], dtype=np.float32)
    distances, indices = settings.SEMANTIC_INDEX.search(query_embedding, k=settings.SEMANTIC_INDEX.ntotal)
    distances = distances[0]
    indices = indices[0]
    filtered_results = [
        {"index": int(idx), "distance": float(dist)}
        for idx, dist in zip(indices, distances)
        if idx in entities_indexes and dist <= 1.3
    ]
    filtered_results = sorted(filtered_results, key=lambda x: x["distance"])[:50]
    final_entities = await get_entities_bulk([i['index'] for i in filtered_results])
    final_entities_extended = await extend_entities_with_highlights(final_entities, entity_data)
    final_entities_scored = await set_entities_score(final_entities_extended, search_request)
    return final_entities_scored


async def extend_entities_with_highlights(entities: list[EntityModel], entity_data: dict) -> list[
    EntityModelExtended]:
    async def choose_closest(entity_: EntityModel) -> tuple:
        treatment_area, treatment_method = await asyncio.gather(
            choose_closest_treatment_area(entity_.treatmentAreas, entity_data.get('treatmentArea')),
            choose_closest_treatment_method(entity_.treatmentMethods, entity_data.get('treatmentMethod'))
        )
        return treatment_area, treatment_method

    results = await asyncio.gather(*[choose_closest(entity) for entity in entities])
    final_entities = []
    for treatment, entity in zip(results, entities):
        age_group = find_matching_age_group(entity, entity_data)
        final_entities.append(EntityModelExtended(
            **entity.to_mongo(),
            highlightedAgeGroup=age_group,
            highlightedTreatmentArea=treatment[0],
            highlightedTreatmentMethod=treatment[1]
        ))
    return final_entities


async def set_entities_score(entities: list[EntityModelExtended], search_request: str) -> list[EntityModelExtended]:
    scores = await asyncio.gather(*[set_entity_score(entity, search_request) for entity in entities])
    final_entities = []
    for score, entity in zip(scores, entities):
        if score > 0.9:
            entity.topMatch = True
        entity.score = score
        if score > 0.72:
            final_entities.append(entity)
    return sorted(final_entities, key=lambda x: x.score, reverse=True)


async def encrypt_messages(messages: list[MessageModel]) -> list[MessageModel]:
    encrypted_messages = []
    sensitive_words = await asyncio.gather(*[get_sensitive_words(message.text) for message in messages])
    for message, sensitive_word in zip(messages, sensitive_words):
        encrypted_message = MessageModel(**message.model_dump())
        encrypted_message.text = encrypt_message(message.text, sensitive_word)
        encrypted_messages.append(encrypted_message)
    return encrypted_messages