Upload folder using huggingface_hub
Browse files- .gitignore +1 -0
- graph.py +78 -17
- modules/db.py +2 -1
.gitignore
CHANGED
|
@@ -14,3 +14,4 @@ output/
|
|
| 14 |
data/jg_2/
|
| 15 |
app.py
|
| 16 |
data/vasudevan/
|
|
|
|
|
|
| 14 |
data/jg_2/
|
| 15 |
app.py
|
| 16 |
data/vasudevan/
|
| 17 |
+
data/Srinivas/
|
graph.py
CHANGED
|
@@ -309,6 +309,22 @@ async def fn_unit_normalizer_node(state: SheamiState):
|
|
| 309 |
return state
|
| 310 |
|
| 311 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
async def fn_trends_aggregator_node(state: SheamiState):
|
| 313 |
logger.info("%s| Aggregating Trends : started", state["thread_id"])
|
| 314 |
state["messages"].append("Aggregating Trends : started")
|
|
@@ -434,6 +450,8 @@ async def fn_interpreter_node(state: SheamiState):
|
|
| 434 |
indent=1,
|
| 435 |
)
|
| 436 |
|
|
|
|
|
|
|
| 437 |
# 1. LLM narrative
|
| 438 |
messages = [
|
| 439 |
SystemMessage(
|
|
@@ -444,12 +462,63 @@ async def fn_interpreter_node(state: SheamiState):
|
|
| 444 |
"2. The names of the reports used to summarize this information."
|
| 445 |
"3. Patient summary (patient id, name, age, sex if available)"
|
| 446 |
"4. Test window (mention the from and to dates)"
|
| 447 |
-
"
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 453 |
)
|
| 454 |
),
|
| 455 |
HumanMessage(content=llm_input),
|
|
@@ -544,16 +613,6 @@ async def fn_final_cleanup_node(state: SheamiState):
|
|
| 544 |
|
| 545 |
await get_db().update_run_stats(run_id=state["run_id"], status="completed")
|
| 546 |
|
| 547 |
-
## add parsed reports
|
| 548 |
-
report_id_list = await get_db().add_report_v2(
|
| 549 |
-
patient_id=state["patient_id"], reports=state["standardized_reports"]
|
| 550 |
-
)
|
| 551 |
-
state["report_id_list"] = report_id_list
|
| 552 |
-
|
| 553 |
-
logger.info("report_id_list = %s", report_id_list)
|
| 554 |
-
for report_id in report_id_list.split(","):
|
| 555 |
-
await get_db().aggregate_trends_from_report(state["patient_id"], report_id)
|
| 556 |
-
|
| 557 |
# add final report
|
| 558 |
# Save PDF along with metadata
|
| 559 |
with open(pdf_path, "rb") as f:
|
|
@@ -642,6 +701,7 @@ def create_graph(user_email: str, patient_id: str, thread_id: str):
|
|
| 642 |
workflow.add_node("increment_index", fn_increment_index_node)
|
| 643 |
workflow.add_node("testname_standardizer", fn_testname_standardizer_node)
|
| 644 |
workflow.add_node("unit_normalizer", fn_unit_normalizer_node)
|
|
|
|
| 645 |
workflow.add_node("trends", fn_trends_aggregator_node)
|
| 646 |
workflow.add_node("interpreter", fn_interpreter_node)
|
| 647 |
|
|
@@ -672,7 +732,8 @@ def create_graph(user_email: str, patient_id: str, thread_id: str):
|
|
| 672 |
workflow.add_edge("testname_standardizer_notifier", "testname_standardizer")
|
| 673 |
workflow.add_edge("testname_standardizer", "unit_normalizer_notifier")
|
| 674 |
workflow.add_edge("unit_normalizer_notifier", "unit_normalizer")
|
| 675 |
-
workflow.add_edge("unit_normalizer", "
|
|
|
|
| 676 |
workflow.add_edge("trends_notifier", "trends")
|
| 677 |
workflow.add_edge("trends", "interpreter_notifier")
|
| 678 |
workflow.add_edge("interpreter_notifier", "interpreter")
|
|
|
|
| 309 |
return state
|
| 310 |
|
| 311 |
|
| 312 |
+
async def fn_db_update_node(state: SheamiState):
|
| 313 |
+
## add parsed reports
|
| 314 |
+
report_id_list = await get_db().add_report_v2(
|
| 315 |
+
patient_id=state["patient_id"],
|
| 316 |
+
reports=state["standardized_reports"],
|
| 317 |
+
run_id=state["run_id"],
|
| 318 |
+
)
|
| 319 |
+
state["report_id_list"] = report_id_list
|
| 320 |
+
|
| 321 |
+
logger.info("report_id_list = %s", report_id_list)
|
| 322 |
+
for report_id in report_id_list.split(","):
|
| 323 |
+
await get_db().aggregate_trends_from_report(state["patient_id"], report_id)
|
| 324 |
+
|
| 325 |
+
return state
|
| 326 |
+
|
| 327 |
+
|
| 328 |
async def fn_trends_aggregator_node(state: SheamiState):
|
| 329 |
logger.info("%s| Aggregating Trends : started", state["thread_id"])
|
| 330 |
state["messages"].append("Aggregating Trends : started")
|
|
|
|
| 450 |
indent=1,
|
| 451 |
)
|
| 452 |
|
| 453 |
+
# logger.info("llm_input = %s", llm_input)
|
| 454 |
+
|
| 455 |
# 1. LLM narrative
|
| 456 |
messages = [
|
| 457 |
SystemMessage(
|
|
|
|
| 462 |
"2. The names of the reports used to summarize this information."
|
| 463 |
"3. Patient summary (patient id, name, age, sex if available)"
|
| 464 |
"4. Test window (mention the from and to dates)"
|
| 465 |
+
"""
|
| 466 |
+
5. Trend summaries
|
| 467 |
+
Generate tables with the following columns:
|
| 468 |
+
|
| 469 |
+
- Test Name
|
| 470 |
+
- Latest Value 1, Latest Value 2, Latest Value 3 (use a hyphen "–" if a value is missing)
|
| 471 |
+
- Unit
|
| 472 |
+
- Reference Range
|
| 473 |
+
- Inference (latest value only): ✅ if within normal range, ▲ if above normal (high), ▼ if below normal (low)
|
| 474 |
+
- Trend Direction (across last 3 values): ⬆️ if values are rising, ⬇️ if values are falling, ➖ (or ✅) if stable/normal
|
| 475 |
+
"""
|
| 476 |
+
"6. Clinical insights. \n"
|
| 477 |
+
"\nImportant Rules:\n"
|
| 478 |
+
"- Format tables in proper <table> with <tr>, <th>, <td>. "
|
| 479 |
+
"- Do not include charts, they will be programmatically added."
|
| 480 |
+
"""
|
| 481 |
+
5. Trend summaries
|
| 482 |
+
Generate HTML tables with the following structure and formatting rules:
|
| 483 |
+
|
| 484 |
+
Columns:
|
| 485 |
+
- Test Name
|
| 486 |
+
- Latest Value 1, Latest Value 2, Latest Value 3 (use a hyphen "–" if a value is missing)
|
| 487 |
+
- Unit
|
| 488 |
+
- Reference Range
|
| 489 |
+
- Inference (latest value only): ✅ if within normal range, ▲ if above normal (high), ▼ if below normal (low)
|
| 490 |
+
- Trend Direction (across last 3 values): ⬆️ if values are rising, ⬇️ if values are falling, ➖ (or ✅) if stable/normal
|
| 491 |
+
|
| 492 |
+
Formatting requirements:
|
| 493 |
+
- The HTML will be shown in a UI (`gr.HTML`) and also rendered to PDF via WeasyPrint.
|
| 494 |
+
- The table must ALWAYS fit within 100% of the container width. Do not allow horizontal scrolling, clipping, or overlapping columns.
|
| 495 |
+
- Use `table-layout: fixed;` and `<colgroup>` with percentage widths that sum to 100%.
|
| 496 |
+
- Allow text wrapping inside cells so narrow columns still display all content.
|
| 497 |
+
- Example CSS to embed at the top of the HTML:
|
| 498 |
+
|
| 499 |
+
<style>
|
| 500 |
+
table { width: 100%; border-collapse: collapse; table-layout: fixed; }
|
| 501 |
+
col { }
|
| 502 |
+
th, td {
|
| 503 |
+
font-size: 11px;
|
| 504 |
+
padding: 4px 6px;
|
| 505 |
+
white-space: normal;
|
| 506 |
+
word-break: break-word;
|
| 507 |
+
}
|
| 508 |
+
</style>
|
| 509 |
+
|
| 510 |
+
- Example `<colgroup>` (adjust if needed):
|
| 511 |
+
<colgroup>
|
| 512 |
+
<col style="width:20%"> <!-- Test Name -->
|
| 513 |
+
<col style="width:8%"> <!-- Latest Value 1 -->
|
| 514 |
+
<col style="width:8%"> <!-- Latest Value 2 -->
|
| 515 |
+
<col style="width:8%"> <!-- Latest Value 3 -->
|
| 516 |
+
<col style="width:8%"> <!-- Unit -->
|
| 517 |
+
<col style="width:16%"> <!-- Reference Range -->
|
| 518 |
+
<col style="width:16%"> <!-- Inference -->
|
| 519 |
+
<col style="width:16%"> <!-- Trend Direction -->
|
| 520 |
+
</colgroup>
|
| 521 |
+
"""
|
| 522 |
)
|
| 523 |
),
|
| 524 |
HumanMessage(content=llm_input),
|
|
|
|
| 613 |
|
| 614 |
await get_db().update_run_stats(run_id=state["run_id"], status="completed")
|
| 615 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 616 |
# add final report
|
| 617 |
# Save PDF along with metadata
|
| 618 |
with open(pdf_path, "rb") as f:
|
|
|
|
| 701 |
workflow.add_node("increment_index", fn_increment_index_node)
|
| 702 |
workflow.add_node("testname_standardizer", fn_testname_standardizer_node)
|
| 703 |
workflow.add_node("unit_normalizer", fn_unit_normalizer_node)
|
| 704 |
+
workflow.add_node("db_update_node", fn_db_update_node)
|
| 705 |
workflow.add_node("trends", fn_trends_aggregator_node)
|
| 706 |
workflow.add_node("interpreter", fn_interpreter_node)
|
| 707 |
|
|
|
|
| 732 |
workflow.add_edge("testname_standardizer_notifier", "testname_standardizer")
|
| 733 |
workflow.add_edge("testname_standardizer", "unit_normalizer_notifier")
|
| 734 |
workflow.add_edge("unit_normalizer_notifier", "unit_normalizer")
|
| 735 |
+
workflow.add_edge("unit_normalizer", "db_update_node")
|
| 736 |
+
workflow.add_edge("db_update_node", "trends_notifier")
|
| 737 |
workflow.add_edge("trends_notifier", "trends")
|
| 738 |
workflow.add_edge("trends", "interpreter_notifier")
|
| 739 |
workflow.add_edge("interpreter_notifier", "interpreter")
|
modules/db.py
CHANGED
|
@@ -93,7 +93,7 @@ class SheamiDB:
|
|
| 93 |
# REPORT FUNCTIONS
|
| 94 |
# ---------------------------
|
| 95 |
async def add_report_v2(
|
| 96 |
-
self, patient_id: str, reports: list[StandardizedReport]
|
| 97 |
) -> str:
|
| 98 |
inserted_ids: list[ObjectId] = []
|
| 99 |
for parsed_data in reports:
|
|
@@ -102,6 +102,7 @@ class SheamiDB:
|
|
| 102 |
"uploaded_at": datetime.now(timezone.utc),
|
| 103 |
"file_name": parsed_data.original_report_file_name,
|
| 104 |
"parsed_data_v2": parsed_data.model_dump(),
|
|
|
|
| 105 |
}
|
| 106 |
result = await self.reports.insert_one(report)
|
| 107 |
inserted_ids.append(result.inserted_id)
|
|
|
|
| 93 |
# REPORT FUNCTIONS
|
| 94 |
# ---------------------------
|
| 95 |
async def add_report_v2(
|
| 96 |
+
self, patient_id: str, reports: list[StandardizedReport], run_id: str
|
| 97 |
) -> str:
|
| 98 |
inserted_ids: list[ObjectId] = []
|
| 99 |
for parsed_data in reports:
|
|
|
|
| 102 |
"uploaded_at": datetime.now(timezone.utc),
|
| 103 |
"file_name": parsed_data.original_report_file_name,
|
| 104 |
"parsed_data_v2": parsed_data.model_dump(),
|
| 105 |
+
"run_id" : ObjectId(run_id),
|
| 106 |
}
|
| 107 |
result = await self.reports.insert_one(report)
|
| 108 |
inserted_ids.append(result.inserted_id)
|