Upload folder using huggingface_hub
Browse files- graph.py +53 -30
- modules/db.py +95 -22
- modules/models.py +2 -1
- tests/test_trends.py +16 -11
- ui.py +1 -1
graph.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
from datetime import datetime
|
| 2 |
import threading
|
| 3 |
import time
|
|
|
|
| 4 |
import pandas as pd
|
| 5 |
from langchain_core.prompts import ChatPromptTemplate
|
| 6 |
import matplotlib.pyplot as plt
|
|
@@ -101,7 +102,7 @@ async def fn_init_node(state: SheamiState):
|
|
| 101 |
state["messages"].append(f"{idx+1}. {report.report_file_name}")
|
| 102 |
state["standardized_reports"] = []
|
| 103 |
state["trends_json"] = {}
|
| 104 |
-
state["
|
| 105 |
state["current_index"] = -1
|
| 106 |
state["units_processed"] = 0
|
| 107 |
state["units_total"] = 0
|
|
@@ -113,7 +114,9 @@ async def fn_init_node(state: SheamiState):
|
|
| 113 |
run_id = await get_db().start_run(
|
| 114 |
user_email=state["user_email"],
|
| 115 |
patient_id=state["patient_id"],
|
| 116 |
-
source_file_names=[
|
|
|
|
|
|
|
| 117 |
)
|
| 118 |
state["run_id"] = run_id
|
| 119 |
|
|
@@ -198,7 +201,7 @@ async def fn_standardize_current_report_node(state: SheamiState):
|
|
| 198 |
SheamiConfig.get_output_dir(state["thread_id"]), f"report_{idx}.json"
|
| 199 |
),
|
| 200 |
"w",
|
| 201 |
-
encoding="utf-8"
|
| 202 |
) as f:
|
| 203 |
f.write(result.model_dump_json(indent=2))
|
| 204 |
|
|
@@ -393,22 +396,17 @@ async def fn_trends_aggregator_node(state: SheamiState):
|
|
| 393 |
)
|
| 394 |
|
| 395 |
# Build trends JSON
|
| 396 |
-
state["trends_json"] =
|
| 397 |
-
"
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
"reference_range": ref_ranges.get(k),
|
| 402 |
-
}
|
| 403 |
-
for k, v in sorted(trends.items(), key=lambda kv: kv[0].lower())
|
| 404 |
-
]
|
| 405 |
-
}
|
| 406 |
|
| 407 |
# Persist
|
| 408 |
output_dir = SheamiConfig.get_output_dir(state["thread_id"])
|
| 409 |
os.makedirs(output_dir, exist_ok=True)
|
| 410 |
with open(os.path.join(output_dir, "trends.json"), "w", encoding="utf-8") as f:
|
| 411 |
-
json.dump(state["trends_json"], f, indent=
|
| 412 |
|
| 413 |
logger.info("%s| Aggregating Trends : finished", state["thread_id"])
|
| 414 |
state["messages"].append("Aggregating Trends : finished")
|
|
@@ -419,21 +417,42 @@ async def fn_interpreter_node(state: SheamiState):
|
|
| 419 |
logger.info("%s| Interpreting Trends : started", state["thread_id"])
|
| 420 |
state["messages"].append("Interpreting Trends : started")
|
| 421 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 422 |
# 1. LLM narrative
|
| 423 |
messages = [
|
| 424 |
SystemMessage(
|
| 425 |
content=(
|
| 426 |
"Interpret the following medical trends and produce a clean, structured **HTML** report without any markdown formatting like backquotes etc. "
|
| 427 |
"The report should have: "
|
| 428 |
-
"1.
|
| 429 |
-
"2.
|
| 430 |
-
"3.
|
| 431 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 432 |
"Format tables in proper <table> with <tr>, <th>, <td>. "
|
| 433 |
"Do not include charts, they will be programmatically added."
|
| 434 |
)
|
| 435 |
),
|
| 436 |
-
HumanMessage(content=
|
| 437 |
]
|
| 438 |
response = await llm.ainvoke(messages)
|
| 439 |
interpretation_html = response.content # ✅ already HTML now
|
|
@@ -443,11 +462,9 @@ async def fn_interpreter_node(state: SheamiState):
|
|
| 443 |
os.makedirs(plots_dir, exist_ok=True)
|
| 444 |
plot_files = []
|
| 445 |
|
| 446 |
-
for param in sorted(
|
| 447 |
-
state["trends_json"].get("parameter_trends", []), key=lambda x: x["test_name"]
|
| 448 |
-
):
|
| 449 |
test_name = param["test_name"]
|
| 450 |
-
values = param["
|
| 451 |
|
| 452 |
x = [parse_any_date(v["date"]) for v in values]
|
| 453 |
x = pd.to_datetime(x, errors="coerce")
|
|
@@ -463,7 +480,7 @@ async def fn_interpreter_node(state: SheamiState):
|
|
| 463 |
plt.figure(figsize=(6, 4))
|
| 464 |
plt.plot(x, y, marker="o", linestyle="-", label="Observed values")
|
| 465 |
|
| 466 |
-
ref = param.get("
|
| 467 |
if ref:
|
| 468 |
ymin, ymax = ref.get("min"), ref.get("max")
|
| 469 |
if ymin is not None and ymax is not None:
|
|
@@ -505,11 +522,17 @@ async def fn_interpreter_node(state: SheamiState):
|
|
| 505 |
)
|
| 506 |
|
| 507 |
# Save state
|
| 508 |
-
state["
|
| 509 |
-
|
| 510 |
logger.info("%s| Interpreting Trends : finished", state["thread_id"])
|
| 511 |
state["messages"].append("Interpreting Trends : finished")
|
| 512 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 513 |
state["milestones"][-1].status = "completed"
|
| 514 |
state["milestones"][-1].end_time = datetime.now()
|
| 515 |
await get_db().add_or_update_milestone(
|
|
@@ -537,14 +560,12 @@ async def fn_interpreter_node(state: SheamiState):
|
|
| 537 |
pdf_bytes = f.read()
|
| 538 |
final_report_id = await get_db().add_final_report_v2(
|
| 539 |
patient_id=state["patient_id"],
|
| 540 |
-
summary=interpretation_html,
|
| 541 |
pdf_bytes=pdf_bytes,
|
| 542 |
file_name=f"health_trends_report_{state["patient_id"]}.pdf",
|
| 543 |
)
|
| 544 |
logger.info("final_report_id = %s", final_report_id)
|
| 545 |
|
| 546 |
-
return state
|
| 547 |
-
|
| 548 |
|
| 549 |
def schedule_cleanup(file_path, delay=300): # 300 sec = 5 min
|
| 550 |
def cleanup():
|
|
@@ -631,6 +652,7 @@ def create_graph(user_email: str, patient_id: str, thread_id: str):
|
|
| 631 |
workflow.add_node("unit_normalizer_notifier", fn_unit_normalizer_node_notifier)
|
| 632 |
workflow.add_node("trends_notifier", fn_trends_aggregator_node_notifier)
|
| 633 |
workflow.add_node("interpreter_notifier", fn_interpreter_node_notifier)
|
|
|
|
| 634 |
|
| 635 |
workflow.add_edge(START, "init")
|
| 636 |
workflow.add_edge("init", "standardizer_notifier")
|
|
@@ -654,7 +676,8 @@ def create_graph(user_email: str, patient_id: str, thread_id: str):
|
|
| 654 |
workflow.add_edge("trends_notifier", "trends")
|
| 655 |
workflow.add_edge("trends", "interpreter_notifier")
|
| 656 |
workflow.add_edge("interpreter_notifier", "interpreter")
|
| 657 |
-
workflow.add_edge("interpreter",
|
|
|
|
| 658 |
|
| 659 |
logger.info("%s| Creating Graph : finished", thread_id)
|
| 660 |
return workflow.compile(checkpointer=memory)
|
|
|
|
| 1 |
from datetime import datetime
|
| 2 |
import threading
|
| 3 |
import time
|
| 4 |
+
from bson import ObjectId
|
| 5 |
import pandas as pd
|
| 6 |
from langchain_core.prompts import ChatPromptTemplate
|
| 7 |
import matplotlib.pyplot as plt
|
|
|
|
| 102 |
state["messages"].append(f"{idx+1}. {report.report_file_name}")
|
| 103 |
state["standardized_reports"] = []
|
| 104 |
state["trends_json"] = {}
|
| 105 |
+
state["pdf_path"] = ""
|
| 106 |
state["current_index"] = -1
|
| 107 |
state["units_processed"] = 0
|
| 108 |
state["units_total"] = 0
|
|
|
|
| 114 |
run_id = await get_db().start_run(
|
| 115 |
user_email=state["user_email"],
|
| 116 |
patient_id=state["patient_id"],
|
| 117 |
+
source_file_names=[
|
| 118 |
+
report.report_file_name for report in state["uploaded_reports"]
|
| 119 |
+
],
|
| 120 |
)
|
| 121 |
state["run_id"] = run_id
|
| 122 |
|
|
|
|
| 201 |
SheamiConfig.get_output_dir(state["thread_id"]), f"report_{idx}.json"
|
| 202 |
),
|
| 203 |
"w",
|
| 204 |
+
encoding="utf-8",
|
| 205 |
) as f:
|
| 206 |
f.write(result.model_dump_json(indent=2))
|
| 207 |
|
|
|
|
| 396 |
)
|
| 397 |
|
| 398 |
# Build trends JSON
|
| 399 |
+
state["trends_json"] = await get_db().get_trends_by_patient(
|
| 400 |
+
patient_id=state["patient_id"],
|
| 401 |
+
fields=["test_name", "trend_data"],
|
| 402 |
+
serializable=True,
|
| 403 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 404 |
|
| 405 |
# Persist
|
| 406 |
output_dir = SheamiConfig.get_output_dir(state["thread_id"])
|
| 407 |
os.makedirs(output_dir, exist_ok=True)
|
| 408 |
with open(os.path.join(output_dir, "trends.json"), "w", encoding="utf-8") as f:
|
| 409 |
+
json.dump(state["trends_json"], f, indent=1, ensure_ascii=False)
|
| 410 |
|
| 411 |
logger.info("%s| Aggregating Trends : finished", state["thread_id"])
|
| 412 |
state["messages"].append("Aggregating Trends : finished")
|
|
|
|
| 417 |
logger.info("%s| Interpreting Trends : started", state["thread_id"])
|
| 418 |
state["messages"].append("Interpreting Trends : started")
|
| 419 |
|
| 420 |
+
uploaded_reports = await get_db().get_reports_by_patient(
|
| 421 |
+
patient_id=state["patient_id"]
|
| 422 |
+
)
|
| 423 |
+
llm_input = json.dumps(
|
| 424 |
+
{
|
| 425 |
+
"patient_id": state["patient_id"],
|
| 426 |
+
"patient_info": await get_db().get_patient_by_id(
|
| 427 |
+
patient_id=state["patient_id"],
|
| 428 |
+
fields=["name", "dob", "gender"],
|
| 429 |
+
serializable=True,
|
| 430 |
+
),
|
| 431 |
+
"uploaded_reports": [report["file_name"] for report in uploaded_reports],
|
| 432 |
+
"trends_json": state["trends_json"],
|
| 433 |
+
},
|
| 434 |
+
indent=1,
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
# 1. LLM narrative
|
| 438 |
messages = [
|
| 439 |
SystemMessage(
|
| 440 |
content=(
|
| 441 |
"Interpret the following medical trends and produce a clean, structured **HTML** report without any markdown formatting like backquotes etc. "
|
| 442 |
"The report should have: "
|
| 443 |
+
"1. A header with the report generation date."
|
| 444 |
+
"2. The names of the reports used to summarize this information."
|
| 445 |
+
"3. Patient summary (patient id, name, age, sex if available)"
|
| 446 |
+
"4. Test window (mention the from and to dates)"
|
| 447 |
+
"5. Trend summaries (tables with Test Name, Latest Value, Highest Value, Lowest Value, Unit, Reference Range, Trend Direction and Inference) "
|
| 448 |
+
"6. Clinical insights. "
|
| 449 |
+
"For inference column, use ✅ for normal, ▲ for high, and ▼ for low. "
|
| 450 |
+
"For trend direction, use appropriate unicode icons like up arrow (improving trend) , down arrow (worsening trend) or checkmark if determined normal"
|
| 451 |
"Format tables in proper <table> with <tr>, <th>, <td>. "
|
| 452 |
"Do not include charts, they will be programmatically added."
|
| 453 |
)
|
| 454 |
),
|
| 455 |
+
HumanMessage(content=llm_input),
|
| 456 |
]
|
| 457 |
response = await llm.ainvoke(messages)
|
| 458 |
interpretation_html = response.content # ✅ already HTML now
|
|
|
|
| 462 |
os.makedirs(plots_dir, exist_ok=True)
|
| 463 |
plot_files = []
|
| 464 |
|
| 465 |
+
for param in sorted(state["trends_json"], key=lambda x: x["test_name"]):
|
|
|
|
|
|
|
| 466 |
test_name = param["test_name"]
|
| 467 |
+
values = param["trend_data"]
|
| 468 |
|
| 469 |
x = [parse_any_date(v["date"]) for v in values]
|
| 470 |
x = pd.to_datetime(x, errors="coerce")
|
|
|
|
| 480 |
plt.figure(figsize=(6, 4))
|
| 481 |
plt.plot(x, y, marker="o", linestyle="-", label="Observed values")
|
| 482 |
|
| 483 |
+
ref = param.get("test_reference_range")
|
| 484 |
if ref:
|
| 485 |
ymin, ymax = ref.get("min"), ref.get("max")
|
| 486 |
if ymin is not None and ymax is not None:
|
|
|
|
| 522 |
)
|
| 523 |
|
| 524 |
# Save state
|
| 525 |
+
state["pdf_path"] = pdf_path
|
| 526 |
+
state["interpretation_html"] = interpretation_html
|
| 527 |
logger.info("%s| Interpreting Trends : finished", state["thread_id"])
|
| 528 |
state["messages"].append("Interpreting Trends : finished")
|
| 529 |
|
| 530 |
+
return state
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
async def fn_final_cleanup_node(state: SheamiState):
|
| 534 |
+
pdf_path = state["pdf_path"]
|
| 535 |
+
schedule_cleanup(file_path=SheamiConfig.get_output_dir(state["thread_id"]))
|
| 536 |
state["milestones"][-1].status = "completed"
|
| 537 |
state["milestones"][-1].end_time = datetime.now()
|
| 538 |
await get_db().add_or_update_milestone(
|
|
|
|
| 560 |
pdf_bytes = f.read()
|
| 561 |
final_report_id = await get_db().add_final_report_v2(
|
| 562 |
patient_id=state["patient_id"],
|
| 563 |
+
summary=state["interpretation_html"],
|
| 564 |
pdf_bytes=pdf_bytes,
|
| 565 |
file_name=f"health_trends_report_{state["patient_id"]}.pdf",
|
| 566 |
)
|
| 567 |
logger.info("final_report_id = %s", final_report_id)
|
| 568 |
|
|
|
|
|
|
|
| 569 |
|
| 570 |
def schedule_cleanup(file_path, delay=300): # 300 sec = 5 min
|
| 571 |
def cleanup():
|
|
|
|
| 652 |
workflow.add_node("unit_normalizer_notifier", fn_unit_normalizer_node_notifier)
|
| 653 |
workflow.add_node("trends_notifier", fn_trends_aggregator_node_notifier)
|
| 654 |
workflow.add_node("interpreter_notifier", fn_interpreter_node_notifier)
|
| 655 |
+
workflow.add_node("final_cleanup_node", fn_final_cleanup_node)
|
| 656 |
|
| 657 |
workflow.add_edge(START, "init")
|
| 658 |
workflow.add_edge("init", "standardizer_notifier")
|
|
|
|
| 676 |
workflow.add_edge("trends_notifier", "trends")
|
| 677 |
workflow.add_edge("trends", "interpreter_notifier")
|
| 678 |
workflow.add_edge("interpreter_notifier", "interpreter")
|
| 679 |
+
workflow.add_edge("interpreter", "final_cleanup_node")
|
| 680 |
+
workflow.add_edge("final_cleanup_node", END)
|
| 681 |
|
| 682 |
logger.info("%s| Creating Graph : finished", thread_id)
|
| 683 |
return workflow.compile(checkpointer=memory)
|
modules/db.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import os
|
| 2 |
from typing import Any
|
| 3 |
from datetime import datetime, timezone
|
|
@@ -70,8 +71,14 @@ class SheamiDB:
|
|
| 70 |
result = await self.patients.insert_one(patient)
|
| 71 |
return str(result.inserted_id)
|
| 72 |
|
| 73 |
-
async def get_patient_by_id(
|
|
|
|
|
|
|
| 74 |
patient = await self.patients.find_one({"_id": ObjectId(patient_id)})
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
return patient
|
| 76 |
|
| 77 |
async def get_patients_by_user(self, user_id: str) -> list:
|
|
@@ -139,9 +146,19 @@ class SheamiDB:
|
|
| 139 |
upsert=True,
|
| 140 |
)
|
| 141 |
|
| 142 |
-
async def get_trends_by_patient(
|
|
|
|
|
|
|
| 143 |
cursor = self.trends.find({"patient_id": ObjectId(patient_id)})
|
| 144 |
trends = await cursor.to_list(length=None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
return trends
|
| 146 |
|
| 147 |
# ---------------------------
|
|
@@ -380,19 +397,16 @@ class SheamiDB:
|
|
| 380 |
|
| 381 |
updated = 0
|
| 382 |
|
| 383 |
-
async def
|
| 384 |
test_name = test.get("test_name")
|
| 385 |
value = test.get("result_value")
|
| 386 |
unit = test.get("test_unit")
|
| 387 |
-
test_date = (
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
or datetime.now(timezone.utc)
|
| 391 |
-
)
|
| 392 |
|
| 393 |
-
# Normalize
|
| 394 |
if isinstance(test_date, (int, float)):
|
| 395 |
-
# handle timestamp
|
| 396 |
test_date = datetime.fromtimestamp(test_date, tz=timezone.utc)
|
| 397 |
elif isinstance(test_date, str):
|
| 398 |
try:
|
|
@@ -407,20 +421,61 @@ class SheamiDB:
|
|
| 407 |
"report_id": ObjectId(report_id),
|
| 408 |
}
|
| 409 |
|
| 410 |
-
#
|
| 411 |
-
|
| 412 |
-
{"patient_id": ObjectId(patient_id), "test_name": test_name},
|
| 413 |
{
|
| 414 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 415 |
"patient_id": ObjectId(patient_id),
|
| 416 |
"test_name": test_name,
|
| 417 |
-
"
|
| 418 |
},
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 424 |
return result
|
| 425 |
|
| 426 |
for test in tests:
|
|
@@ -431,11 +486,11 @@ class SheamiDB:
|
|
| 431 |
continue
|
| 432 |
for sub_result in sub_results:
|
| 433 |
test_name = sub_result.get("test_name")
|
| 434 |
-
db_output = await
|
| 435 |
updated += db_output.modified_count
|
| 436 |
continue
|
| 437 |
else:
|
| 438 |
-
db_output = await
|
| 439 |
updated += db_output.modified_count
|
| 440 |
|
| 441 |
# print("updated/inserted", updated, "trends")
|
|
@@ -556,3 +611,21 @@ class SheamiDB:
|
|
| 556 |
await self.fs.delete(file_id)
|
| 557 |
deleted_count += 1
|
| 558 |
return deleted_count
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import timezone
|
| 2 |
import os
|
| 3 |
from typing import Any
|
| 4 |
from datetime import datetime, timezone
|
|
|
|
| 71 |
result = await self.patients.insert_one(patient)
|
| 72 |
return str(result.inserted_id)
|
| 73 |
|
| 74 |
+
async def get_patient_by_id(
|
| 75 |
+
self, patient_id: str, fields: list[str] = [], serializable: bool = False
|
| 76 |
+
) -> Any | None:
|
| 77 |
patient = await self.patients.find_one({"_id": ObjectId(patient_id)})
|
| 78 |
+
if fields:
|
| 79 |
+
patient = {key: patient[key] for key in fields if key in patient}
|
| 80 |
+
if serializable:
|
| 81 |
+
patient = self.convert_to_serializable_data(data=patient)
|
| 82 |
return patient
|
| 83 |
|
| 84 |
async def get_patients_by_user(self, user_id: str) -> list:
|
|
|
|
| 146 |
upsert=True,
|
| 147 |
)
|
| 148 |
|
| 149 |
+
async def get_trends_by_patient(
|
| 150 |
+
self, patient_id: str, fields: list[str] = None, serializable=False
|
| 151 |
+
) -> list:
|
| 152 |
cursor = self.trends.find({"patient_id": ObjectId(patient_id)})
|
| 153 |
trends = await cursor.to_list(length=None)
|
| 154 |
+
if fields:
|
| 155 |
+
trends = [
|
| 156 |
+
{field: trend[field] for field in fields if field in trend}
|
| 157 |
+
for trend in trends
|
| 158 |
+
]
|
| 159 |
+
if serializable:
|
| 160 |
+
trends = self.convert_to_serializable_data(data=trends)
|
| 161 |
+
|
| 162 |
return trends
|
| 163 |
|
| 164 |
# ---------------------------
|
|
|
|
| 397 |
|
| 398 |
updated = 0
|
| 399 |
|
| 400 |
+
async def add_or_update_trend_data_point(test):
|
| 401 |
test_name = test.get("test_name")
|
| 402 |
value = test.get("result_value")
|
| 403 |
unit = test.get("test_unit")
|
| 404 |
+
test_date = test.get("test_date") or datetime.now(timezone.utc)
|
| 405 |
+
test_reference_range = test.get("test_reference_range")
|
| 406 |
+
inferred_range = test.get("inferred_range")
|
|
|
|
|
|
|
| 407 |
|
| 408 |
+
# Normalize test_date (keep your existing normalization here)...
|
| 409 |
if isinstance(test_date, (int, float)):
|
|
|
|
| 410 |
test_date = datetime.fromtimestamp(test_date, tz=timezone.utc)
|
| 411 |
elif isinstance(test_date, str):
|
| 412 |
try:
|
|
|
|
| 421 |
"report_id": ObjectId(report_id),
|
| 422 |
}
|
| 423 |
|
| 424 |
+
# Step 1: Check if trend_data with same date exists
|
| 425 |
+
existing_doc = await self.trends.find_one(
|
|
|
|
| 426 |
{
|
| 427 |
+
"patient_id": ObjectId(patient_id),
|
| 428 |
+
"test_name": test_name,
|
| 429 |
+
"trend_data.date": test_date,
|
| 430 |
+
},
|
| 431 |
+
projection={"trend_data.$": 1}, # Project only matched array element
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
if existing_doc:
|
| 435 |
+
# Step 2: Update the existing trend_data array element with new data
|
| 436 |
+
result = await self.trends.update_one(
|
| 437 |
+
{
|
| 438 |
"patient_id": ObjectId(patient_id),
|
| 439 |
"test_name": test_name,
|
| 440 |
+
"trend_data.date": test_date,
|
| 441 |
},
|
| 442 |
+
{
|
| 443 |
+
"$set": {
|
| 444 |
+
"trend_data.$.value": value,
|
| 445 |
+
"trend_data.$.unit": unit,
|
| 446 |
+
"trend_data.$.report_id": ObjectId(report_id),
|
| 447 |
+
"last_updated": datetime.now(timezone.utc),
|
| 448 |
+
"test_reference_range": test_reference_range,
|
| 449 |
+
"inferred_range": inferred_range,
|
| 450 |
+
},
|
| 451 |
+
"$setOnInsert": {
|
| 452 |
+
"patient_id": ObjectId(patient_id),
|
| 453 |
+
"test_name": test_name,
|
| 454 |
+
"created_at": datetime.now(timezone.utc),
|
| 455 |
+
},
|
| 456 |
+
},
|
| 457 |
+
)
|
| 458 |
+
else:
|
| 459 |
+
# Step 3: Insert new point as it does not exist yet
|
| 460 |
+
result = await self.trends.update_one(
|
| 461 |
+
{"patient_id": ObjectId(patient_id), "test_name": test_name},
|
| 462 |
+
{
|
| 463 |
+
"$setOnInsert": {
|
| 464 |
+
"patient_id": ObjectId(patient_id),
|
| 465 |
+
"test_name": test_name,
|
| 466 |
+
"created_at": datetime.now(timezone.utc),
|
| 467 |
+
},
|
| 468 |
+
"$push": {"trend_data": point},
|
| 469 |
+
"$set": {
|
| 470 |
+
"last_updated": datetime.now(timezone.utc),
|
| 471 |
+
"test_reference_range": test_reference_range,
|
| 472 |
+
"inferred_range": inferred_range,
|
| 473 |
+
"test_reference_range": test_reference_range,
|
| 474 |
+
"inferred_range": inferred_range,
|
| 475 |
+
},
|
| 476 |
+
},
|
| 477 |
+
upsert=True,
|
| 478 |
+
)
|
| 479 |
return result
|
| 480 |
|
| 481 |
for test in tests:
|
|
|
|
| 486 |
continue
|
| 487 |
for sub_result in sub_results:
|
| 488 |
test_name = sub_result.get("test_name")
|
| 489 |
+
db_output = await add_or_update_trend_data_point(sub_result)
|
| 490 |
updated += db_output.modified_count
|
| 491 |
continue
|
| 492 |
else:
|
| 493 |
+
db_output = await add_or_update_trend_data_point(test)
|
| 494 |
updated += db_output.modified_count
|
| 495 |
|
| 496 |
# print("updated/inserted", updated, "trends")
|
|
|
|
| 611 |
await self.fs.delete(file_id)
|
| 612 |
deleted_count += 1
|
| 613 |
return deleted_count
|
| 614 |
+
|
| 615 |
+
def convert_to_serializable_data(self, data):
|
| 616 |
+
"""
|
| 617 |
+
Recursively converts MongoDB-specific types to JSON serializable formats.
|
| 618 |
+
- ObjectId to string
|
| 619 |
+
- datetime to ISO 8601 string
|
| 620 |
+
Handles dict, list, and basic types.
|
| 621 |
+
"""
|
| 622 |
+
if isinstance(data, dict):
|
| 623 |
+
return {k: self.convert_to_serializable_data(v) for k, v in data.items()}
|
| 624 |
+
elif isinstance(data, list):
|
| 625 |
+
return [self.convert_to_serializable_data(i) for i in data]
|
| 626 |
+
elif isinstance(data, ObjectId):
|
| 627 |
+
return str(data)
|
| 628 |
+
elif isinstance(data, datetime):
|
| 629 |
+
return data.isoformat()
|
| 630 |
+
else:
|
| 631 |
+
return data
|
modules/models.py
CHANGED
|
@@ -95,7 +95,7 @@ class SheamiState(TypedDict):
|
|
| 95 |
uploaded_reports: List[HealthReport]
|
| 96 |
standardized_reports: List[StandardizedReport]
|
| 97 |
trends_json: dict
|
| 98 |
-
|
| 99 |
current_index: int
|
| 100 |
process_desc: str
|
| 101 |
units_processed: int
|
|
@@ -103,3 +103,4 @@ class SheamiState(TypedDict):
|
|
| 103 |
overall_units_processed: int
|
| 104 |
overall_units_total: int
|
| 105 |
milestones: list[SheamiMilestone]
|
|
|
|
|
|
| 95 |
uploaded_reports: List[HealthReport]
|
| 96 |
standardized_reports: List[StandardizedReport]
|
| 97 |
trends_json: dict
|
| 98 |
+
pdf_path: str
|
| 99 |
current_index: int
|
| 100 |
process_desc: str
|
| 101 |
units_processed: int
|
|
|
|
| 103 |
overall_units_processed: int
|
| 104 |
overall_units_total: int
|
| 105 |
milestones: list[SheamiMilestone]
|
| 106 |
+
interpretation_html : str
|
tests/test_trends.py
CHANGED
|
@@ -1,14 +1,19 @@
|
|
| 1 |
from modules.db import SheamiDB
|
| 2 |
|
| 3 |
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from modules.db import SheamiDB
|
| 2 |
|
| 3 |
|
| 4 |
+
async def test():
|
| 5 |
+
db = SheamiDB()
|
| 6 |
+
patient_id = "68a67a92fa6a3a741b0c5c74"
|
| 7 |
+
reports = await db.get_reports_by_patient(patient_id=patient_id)
|
| 8 |
+
total_updated = 0
|
| 9 |
+
for report in reports:
|
| 10 |
+
# print(report)
|
| 11 |
+
num_updated = await db.aggregate_trends_from_report(
|
| 12 |
+
patient_id=patient_id, report_id=str(report["_id"])
|
| 13 |
+
)
|
| 14 |
+
total_updated += num_updated
|
| 15 |
+
print("total_updated = ", total_updated)
|
| 16 |
+
|
| 17 |
+
if __name__ == "__main__":
|
| 18 |
+
import asyncio
|
| 19 |
+
asyncio.run(test())
|
ui.py
CHANGED
|
@@ -117,7 +117,7 @@ async def process_reports(user_email: str, patient_id: str, files: list):
|
|
| 117 |
)
|
| 118 |
yield construct_process_message(
|
| 119 |
message=buffer,
|
| 120 |
-
final_output=gr.update(value=final_state["
|
| 121 |
milestones=final_state["milestones"],
|
| 122 |
reports_output=msg_packet["standardized_reports"],
|
| 123 |
trends_output=msg_packet["trends_json"],
|
|
|
|
| 117 |
)
|
| 118 |
yield construct_process_message(
|
| 119 |
message=buffer,
|
| 120 |
+
final_output=gr.update(value=final_state["pdf_path"], visible=True),
|
| 121 |
milestones=final_state["milestones"],
|
| 122 |
reports_output=msg_packet["standardized_reports"],
|
| 123 |
trends_output=msg_packet["trends_json"],
|