Spaces:
Sleeping
Sleeping
add location postal code searching
Browse files
trauma/api/account/db_requests.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import asyncio
|
| 2 |
|
|
|
|
| 3 |
from trauma.api.account.model import AccountModel
|
| 4 |
from trauma.api.account.schemas import CreateAccountRequest
|
| 5 |
from trauma.api.common.db_requests import check_unique_fields_existence
|
|
|
|
| 1 |
import asyncio
|
| 2 |
|
| 3 |
+
from trauma.api.account.dto import AccountType
|
| 4 |
from trauma.api.account.model import AccountModel
|
| 5 |
from trauma.api.account.schemas import CreateAccountRequest
|
| 6 |
from trauma.api.common.db_requests import check_unique_fields_existence
|
trauma/api/chat/dto.py
CHANGED
|
@@ -13,3 +13,5 @@ class EntityData(BaseModel):
|
|
| 13 |
age: int | None = None
|
| 14 |
treatmentArea: str | None = None
|
| 15 |
treatmentMethod: str | None = None
|
|
|
|
|
|
|
|
|
| 13 |
age: int | None = None
|
| 14 |
treatmentArea: str | None = None
|
| 15 |
treatmentMethod: str | None = None
|
| 16 |
+
location: str | None = None
|
| 17 |
+
postalCode: str | None = None
|
trauma/api/data/__init__.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
from fastapi.routing import APIRouter
|
| 2 |
|
| 3 |
facility_router = APIRouter(
|
| 4 |
-
prefix="/api/facility", tags=["
|
| 5 |
)
|
| 6 |
|
| 7 |
from . import views
|
|
|
|
| 1 |
from fastapi.routing import APIRouter
|
| 2 |
|
| 3 |
facility_router = APIRouter(
|
| 4 |
+
prefix="/api/facility", tags=["facility"]
|
| 5 |
)
|
| 6 |
|
| 7 |
from . import views
|
trauma/api/data/model.py
CHANGED
|
@@ -1,39 +1,7 @@
|
|
| 1 |
-
import datetime
|
| 2 |
-
|
| 3 |
from trauma.api.data.dto import AgeGroup, ContactDetails
|
| 4 |
from trauma.core.database import MongoBaseModel
|
| 5 |
|
| 6 |
|
| 7 |
-
class DataModel(MongoBaseModel):
|
| 8 |
-
startDate: datetime.datetime
|
| 9 |
-
endDate: datetime.datetime
|
| 10 |
-
email: str | None = None
|
| 11 |
-
name: str | None = None
|
| 12 |
-
datetimeUpdated: datetime.datetime | None = None
|
| 13 |
-
organizationName: str | None = None
|
| 14 |
-
organizationLocation: str
|
| 15 |
-
youthCareRegion: str
|
| 16 |
-
postalCode: str
|
| 17 |
-
hasOrganizationMultipleLocations: bool
|
| 18 |
-
publicEmail: str
|
| 19 |
-
website: str
|
| 20 |
-
isTraumaTreatmentVisible: bool
|
| 21 |
-
injuryTreatmentForms: list[str]
|
| 22 |
-
offers: list[str]
|
| 23 |
-
intensities: list[str]
|
| 24 |
-
injuryTreatmentDuration: list[str]
|
| 25 |
-
ageGroups: list[AgeGroup]
|
| 26 |
-
traumaIndication: list[str]
|
| 27 |
-
supplier: str
|
| 28 |
-
treatmentFinancing: list[str]
|
| 29 |
-
isContractedCare: list[bool]
|
| 30 |
-
traumaTeamCharacteristics: list[str]
|
| 31 |
-
traumaDisciplines: list[str]
|
| 32 |
-
practitionersQualifications: list[str]
|
| 33 |
-
AQAs: list[str]
|
| 34 |
-
additions: list[str]
|
| 35 |
-
|
| 36 |
-
|
| 37 |
class EntityModel(MongoBaseModel):
|
| 38 |
id: str
|
| 39 |
name: str
|
|
|
|
|
|
|
|
|
|
| 1 |
from trauma.api.data.dto import AgeGroup, ContactDetails
|
| 2 |
from trauma.core.database import MongoBaseModel
|
| 3 |
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
class EntityModel(MongoBaseModel):
|
| 6 |
id: str
|
| 7 |
name: str
|
trauma/api/message/ai/engine.py
CHANGED
|
@@ -12,7 +12,7 @@ from trauma.api.message.ai.openai_request import (update_entity_data_with_ai,
|
|
| 12 |
choose_closest_treatment_method, choose_closest_treatment_area,
|
| 13 |
check_is_valid_request, generate_invalid_response, set_entity_score)
|
| 14 |
from trauma.api.message.db_requests import (save_assistant_user_message,
|
| 15 |
-
|
| 16 |
update_entity_data_obj, get_entity_by_index)
|
| 17 |
from trauma.api.message.schemas import CreateMessageResponse
|
| 18 |
from trauma.api.message.utils import (retrieve_empty_field_from_entity_data,
|
|
@@ -41,7 +41,7 @@ async def search_entities(
|
|
| 41 |
else:
|
| 42 |
user_messages_str = prepare_user_messages_str(user_message, messages)
|
| 43 |
possible_entity_indexes, search_request = await asyncio.gather(
|
| 44 |
-
|
| 45 |
generate_search_request(user_messages_str, entity_data)
|
| 46 |
)
|
| 47 |
final_entities = await search_semantic_entities(search_request, entity_data, possible_entity_indexes)
|
|
@@ -67,12 +67,12 @@ async def search_semantic_entities(
|
|
| 67 |
]
|
| 68 |
filtered_results = sorted(filtered_results, key=lambda x: x["distance"])[:5]
|
| 69 |
final_entities = await asyncio.gather(*[get_entity_by_index(i['index']) for i in filtered_results])
|
| 70 |
-
final_entities_extended = await
|
| 71 |
final_entities_scored = await set_entities_score(final_entities_extended, search_request)
|
| 72 |
return final_entities_scored
|
| 73 |
|
| 74 |
|
| 75 |
-
async def
|
| 76 |
EntityModelExtended]:
|
| 77 |
async def choose_closest(entity_: EntityModel) -> tuple:
|
| 78 |
treatment_area, treatment_method = await asyncio.gather(
|
|
|
|
| 12 |
choose_closest_treatment_method, choose_closest_treatment_area,
|
| 13 |
check_is_valid_request, generate_invalid_response, set_entity_score)
|
| 14 |
from trauma.api.message.db_requests import (save_assistant_user_message,
|
| 15 |
+
filter_entities_by_age_location,
|
| 16 |
update_entity_data_obj, get_entity_by_index)
|
| 17 |
from trauma.api.message.schemas import CreateMessageResponse
|
| 18 |
from trauma.api.message.utils import (retrieve_empty_field_from_entity_data,
|
|
|
|
| 41 |
else:
|
| 42 |
user_messages_str = prepare_user_messages_str(user_message, messages)
|
| 43 |
possible_entity_indexes, search_request = await asyncio.gather(
|
| 44 |
+
filter_entities_by_age_location(entity_data),
|
| 45 |
generate_search_request(user_messages_str, entity_data)
|
| 46 |
)
|
| 47 |
final_entities = await search_semantic_entities(search_request, entity_data, possible_entity_indexes)
|
|
|
|
| 67 |
]
|
| 68 |
filtered_results = sorted(filtered_results, key=lambda x: x["distance"])[:5]
|
| 69 |
final_entities = await asyncio.gather(*[get_entity_by_index(i['index']) for i in filtered_results])
|
| 70 |
+
final_entities_extended = await extend_entities_with_highlights(final_entities, entity_data)
|
| 71 |
final_entities_scored = await set_entities_score(final_entities_extended, search_request)
|
| 72 |
return final_entities_scored
|
| 73 |
|
| 74 |
|
| 75 |
+
async def extend_entities_with_highlights(entities: list[EntityModel], entity_data: dict) -> list[
|
| 76 |
EntityModelExtended]:
|
| 77 |
async def choose_closest(entity_: EntityModel) -> tuple:
|
| 78 |
treatment_area, treatment_method = await asyncio.gather(
|
trauma/api/message/ai/prompts.py
CHANGED
|
@@ -30,13 +30,17 @@ Je verzamelt informatie over een patiënt, hun ziekte en de behandelmethode zoda
|
|
| 30 |
{
|
| 31 |
“age”: integer,
|
| 32 |
“treatmentArea”: “string”,
|
| 33 |
-
“treatmentMethod”: “string”
|
|
|
|
|
|
|
| 34 |
}
|
| 35 |
```
|
| 36 |
|
| 37 |
-
- **
|
| 38 |
-
- **
|
| 39 |
-
- **
|
|
|
|
|
|
|
| 40 |
|
| 41 |
## Regels voor het bijwerken van Entity Data
|
| 42 |
|
|
@@ -139,6 +143,7 @@ The field is considered valid (`is_valid = true`) if:
|
|
| 139 |
- The user describes the patient, their data, illness, treatment method, etc.
|
| 140 |
- The user's request relates to a medical topic.
|
| 141 |
- The user's request is a valid response to the assistant's question.
|
|
|
|
| 142 |
|
| 143 |
[/INST]"""
|
| 144 |
generate_invalid_response = """## Taak
|
|
|
|
| 30 |
{
|
| 31 |
“age”: integer,
|
| 32 |
“treatmentArea”: “string”,
|
| 33 |
+
“treatmentMethod”: “string”,
|
| 34 |
+
"location": "string",
|
| 35 |
+
"postalCode": "string",
|
| 36 |
}
|
| 37 |
```
|
| 38 |
|
| 39 |
+
- **age**: leeftijd van de patiënt.
|
| 40 |
+
- **treatmentArea**: Het type mentale of fysieke ziekte/stoornis.
|
| 41 |
+
- **treatmentMethod**: Een methode voor het behandelen van de ziekte of stoornis.
|
| 42 |
+
- **location**: Stad of adres waar de facility zich bevindt
|
| 43 |
+
- **postalCode**: Postcode van de facility..
|
| 44 |
|
| 45 |
## Regels voor het bijwerken van Entity Data
|
| 46 |
|
|
|
|
| 143 |
- The user describes the patient, their data, illness, treatment method, etc.
|
| 144 |
- The user's request relates to a medical topic.
|
| 145 |
- The user's request is a valid response to the assistant's question.
|
| 146 |
+
- The user's request describes desired facility.
|
| 147 |
|
| 148 |
[/INST]"""
|
| 149 |
generate_invalid_response = """## Taak
|
trauma/api/message/db_requests.py
CHANGED
|
@@ -60,15 +60,26 @@ async def save_assistant_user_message(user_message: str, assistant_message: str,
|
|
| 60 |
await settings.DB_CLIENT.messages.insert_one(assistant_message.to_mongo())
|
| 61 |
|
| 62 |
|
| 63 |
-
async def
|
| 64 |
query = {
|
| 65 |
"ageGroups": {
|
| 66 |
"$elemMatch": {
|
| 67 |
"ageMin": {"$lte": entity_data['age']},
|
| 68 |
"ageMax": {"$gte": entity_data['age']}
|
| 69 |
}
|
| 70 |
-
}
|
| 71 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
entities = await settings.DB_CLIENT.entities.find(query).to_list(length=None)
|
| 73 |
return [entity['index'] for entity in entities]
|
| 74 |
|
|
|
|
| 60 |
await settings.DB_CLIENT.messages.insert_one(assistant_message.to_mongo())
|
| 61 |
|
| 62 |
|
| 63 |
+
async def filter_entities_by_age_location(entity_data: dict) -> list[int]:
|
| 64 |
query = {
|
| 65 |
"ageGroups": {
|
| 66 |
"$elemMatch": {
|
| 67 |
"ageMin": {"$lte": entity_data['age']},
|
| 68 |
"ageMax": {"$gte": entity_data['age']}
|
| 69 |
}
|
| 70 |
+
},
|
| 71 |
}
|
| 72 |
+
if entity_data.get('location'):
|
| 73 |
+
query["contactDetails.address"] = {
|
| 74 |
+
"$regex": f".*{entity_data['location']}.*",
|
| 75 |
+
"$options": "i"
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
if entity_data.get('postalCode'):
|
| 79 |
+
query["contactDetails.postalCode"] = {
|
| 80 |
+
"$regex": f".*{entity_data['postalCode']}.*",
|
| 81 |
+
"$options": "i"
|
| 82 |
+
}
|
| 83 |
entities = await settings.DB_CLIENT.entities.find(query).to_list(length=None)
|
| 84 |
return [entity['index'] for entity in entities]
|
| 85 |
|
trauma/api/message/utils.py
CHANGED
|
@@ -20,7 +20,7 @@ def transform_messages_to_openai(messages: list[MessageModel]) -> list[dict]:
|
|
| 20 |
|
| 21 |
def retrieve_empty_field_from_entity_data(entity_data: dict) -> str | None:
|
| 22 |
for k, v in entity_data.items():
|
| 23 |
-
if not v:
|
| 24 |
return k
|
| 25 |
return None
|
| 26 |
|
|
|
|
| 20 |
|
| 21 |
def retrieve_empty_field_from_entity_data(entity_data: dict) -> str | None:
|
| 22 |
for k, v in entity_data.items():
|
| 23 |
+
if k not in ('location', 'postalCode') and not v:
|
| 24 |
return k
|
| 25 |
return None
|
| 26 |
|