Spaces:
Sleeping
Sleeping
| import asyncio | |
| import json | |
| from trauma.api.chat.dto import EntityData | |
| from trauma.api.data.model import EntityModelExtended | |
| from trauma.api.message.ai.prompts import TraumaPrompts | |
| from trauma.core.config import settings | |
| from trauma.core.wrappers import openai_wrapper | |
| async def update_entity_data_with_ai(entity_data: EntityData, user_message: str, assistant_message: str): | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": TraumaPrompts.update_entity_data_with_ai | |
| .replace("{entity_data}", entity_data.model_dump_json(indent=2)) | |
| .replace("{assistant_message}", assistant_message) | |
| .replace("{user_message}", user_message) | |
| } | |
| ] | |
| return messages | |
| async def generate_next_question(instructions: str, message_history_str: str): | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": TraumaPrompts.generate_next_question | |
| .replace("{instructions}", instructions) | |
| .replace("{message_history}", message_history_str) | |
| } | |
| ] | |
| return messages | |
| async def generate_search_request(user_messages_str: str, entity_data: dict): | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": TraumaPrompts.generate_search_request | |
| .replace("{entity_data}", json.dumps(entity_data, indent=2)) | |
| .replace("{user_messages_str}", user_messages_str) | |
| } | |
| ] | |
| return messages | |
| async def generate_final_response( | |
| final_entities: str, user_message: str, message_history_str: str, empty_field_instructions: str | |
| ): | |
| if empty_field_instructions: | |
| prompt = (TraumaPrompts.generate_not_fully_recommendations | |
| .replace("{instructions}", empty_field_instructions)) | |
| else: | |
| prompt = (TraumaPrompts.generate_recommendation_decision | |
| .replace("{final_entities}", final_entities)) | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": prompt | |
| .replace("{message_history}", message_history_str) | |
| .replace("{user_message}", user_message) | |
| } | |
| ] | |
| return messages | |
| async def generate_empty_final_response( | |
| user_message: str, | |
| message_history_str: str, | |
| empty_field_instructions: dict | |
| ): | |
| field_changed = ", ".join(empty_field_instructions.keys()) | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": TraumaPrompts.generate_empty_recommendations | |
| .replace("{message_history}", message_history_str) | |
| .replace("{user_message}", user_message) | |
| .replace("{instructions}", json.dumps(empty_field_instructions, indent=2)) | |
| .replace("{field_changed}", field_changed) | |
| } | |
| ] | |
| return messages | |
| async def convert_value_to_embeddings(value: str, dimensions: int = 1536) -> list[float]: | |
| embeddings = await settings.OPENAI_CLIENT.embeddings.create( | |
| input=value, | |
| model='text-embedding-3-large', | |
| dimensions=dimensions, | |
| ) | |
| return embeddings.data[0].embedding | |
| async def choose_closest_treatment_area(treatment_areas: list[str], treatment_area: str | None): | |
| if not treatment_area: | |
| return None | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": TraumaPrompts.choose_closest_treatment_area | |
| .replace("{treatment_areas}", ", ".join(treatment_areas)) | |
| .replace("{treatment_area}", treatment_area) | |
| } | |
| ] | |
| return messages | |
| async def choose_closest_treatment_method(treatment_methods: list[str], treatment_method: str | None): | |
| if not treatment_method: | |
| return None | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": TraumaPrompts.choose_closest_treatment_method | |
| .replace("{treatment_methods}", ", ".join(treatment_methods)) | |
| .replace("{treatment_method}", treatment_method) | |
| } | |
| ] | |
| return messages | |
| async def check_is_valid_request(user_message: str, message_history: str): | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": TraumaPrompts.decide_is_valid_request | |
| .replace("{user_message}", user_message) | |
| .replace("{message_history}", message_history) | |
| } | |
| ] | |
| return messages | |
| async def generate_invalid_response(user_message: str, message_history_str: str, empty_field: str | None): | |
| from trauma.api.message.utils import pick_empty_field_instructions | |
| if empty_field: | |
| empty_field_instructions = pick_empty_field_instructions(empty_field) | |
| prompt = TraumaPrompts.generate_invalid_response.replace("{instructions}", empty_field_instructions) | |
| else: | |
| prompt = TraumaPrompts.generate_invalid_response_with_recs | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": prompt | |
| .replace("{message_history}", message_history_str) | |
| .replace("{user_message}", user_message) | |
| } | |
| ] | |
| return messages | |
| async def set_entity_score(entity: EntityModelExtended, search_request: str): | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": TraumaPrompts.set_entity_score | |
| .replace("{entity}", entity.model_dump_json(exclude={ | |
| "ageGroups", "treatmentAreas", "treatmentMethods", "contactDetails" | |
| })) | |
| .replace("{search_request}", search_request) | |
| } | |
| ] | |
| return messages | |
| async def retrieve_semantic_answer(user_query: str) -> list[EntityModelExtended] | None: | |
| embedding = await settings.OPENAI_CLIENT.embeddings.create(input=user_query, | |
| model='text-embedding-3-large', | |
| dimensions=384) | |
| response = await settings.DB_CLIENT.entities.aggregate([ | |
| {"$vectorSearch": { | |
| "index": f"entityVectors", | |
| "path": "embedding", | |
| "queryVector": embedding.data[0].embedding, | |
| "numCandidates": 20, | |
| "limit": 1 | |
| }}, | |
| {"$project": { | |
| "embedding": 0, | |
| "score": {"$meta": "vectorSearchScore"} | |
| }} | |
| ]).to_list(length=1) | |
| return [EntityModelExtended(**response[0])] if response[0]['score'] > 0.83 else None | |
| async def generate_searched_entity_response(user_query: str, facility: EntityModelExtended): | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": TraumaPrompts.generate_searched_entity | |
| .replace("{user_query}", user_query) | |
| .replace("{entity}", facility.model_dump_json(indent=2)) | |
| } | |
| ] | |
| return messages | |
| async def get_sensitive_words(text: str): | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": TraumaPrompts.get_sensitive_words | |
| .replace("{text}", text) | |
| } | |
| ] | |
| return messages | |