vikramvasudevan commited on
Commit
ea96b05
·
verified ·
1 Parent(s): 52308e4

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. graph.py +78 -17
  3. modules/db.py +2 -1
.gitignore CHANGED
@@ -14,3 +14,4 @@ output/
14
  data/jg_2/
15
  app.py
16
  data/vasudevan/
 
 
14
  data/jg_2/
15
  app.py
16
  data/vasudevan/
17
+ data/Srinivas/
graph.py CHANGED
@@ -309,6 +309,22 @@ async def fn_unit_normalizer_node(state: SheamiState):
309
  return state
310
 
311
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
  async def fn_trends_aggregator_node(state: SheamiState):
313
  logger.info("%s| Aggregating Trends : started", state["thread_id"])
314
  state["messages"].append("Aggregating Trends : started")
@@ -434,6 +450,8 @@ async def fn_interpreter_node(state: SheamiState):
434
  indent=1,
435
  )
436
 
 
 
437
  # 1. LLM narrative
438
  messages = [
439
  SystemMessage(
@@ -444,12 +462,63 @@ async def fn_interpreter_node(state: SheamiState):
444
  "2. The names of the reports used to summarize this information."
445
  "3. Patient summary (patient id, name, age, sex if available)"
446
  "4. Test window (mention the from and to dates)"
447
- "5. Trend summaries (tables with Test Name, Latest Value, Highest Value, Lowest Value, Unit, Reference Range, Trend Direction and Inference) "
448
- "6. Clinical insights. "
449
- "For inference column, use for normal, ▲ for high, and ▼ for low. "
450
- "For trend direction, use appropriate unicode icons like up arrow (improving trend) , down arrow (worsening trend) or checkmark if determined normal"
451
- "Format tables in proper <table> with <tr>, <th>, <td>. "
452
- "Do not include charts, they will be programmatically added."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
453
  )
454
  ),
455
  HumanMessage(content=llm_input),
@@ -544,16 +613,6 @@ async def fn_final_cleanup_node(state: SheamiState):
544
 
545
  await get_db().update_run_stats(run_id=state["run_id"], status="completed")
546
 
547
- ## add parsed reports
548
- report_id_list = await get_db().add_report_v2(
549
- patient_id=state["patient_id"], reports=state["standardized_reports"]
550
- )
551
- state["report_id_list"] = report_id_list
552
-
553
- logger.info("report_id_list = %s", report_id_list)
554
- for report_id in report_id_list.split(","):
555
- await get_db().aggregate_trends_from_report(state["patient_id"], report_id)
556
-
557
  # add final report
558
  # Save PDF along with metadata
559
  with open(pdf_path, "rb") as f:
@@ -642,6 +701,7 @@ def create_graph(user_email: str, patient_id: str, thread_id: str):
642
  workflow.add_node("increment_index", fn_increment_index_node)
643
  workflow.add_node("testname_standardizer", fn_testname_standardizer_node)
644
  workflow.add_node("unit_normalizer", fn_unit_normalizer_node)
 
645
  workflow.add_node("trends", fn_trends_aggregator_node)
646
  workflow.add_node("interpreter", fn_interpreter_node)
647
 
@@ -672,7 +732,8 @@ def create_graph(user_email: str, patient_id: str, thread_id: str):
672
  workflow.add_edge("testname_standardizer_notifier", "testname_standardizer")
673
  workflow.add_edge("testname_standardizer", "unit_normalizer_notifier")
674
  workflow.add_edge("unit_normalizer_notifier", "unit_normalizer")
675
- workflow.add_edge("unit_normalizer", "trends_notifier")
 
676
  workflow.add_edge("trends_notifier", "trends")
677
  workflow.add_edge("trends", "interpreter_notifier")
678
  workflow.add_edge("interpreter_notifier", "interpreter")
 
309
  return state
310
 
311
 
312
+ async def fn_db_update_node(state: SheamiState):
313
+ ## add parsed reports
314
+ report_id_list = await get_db().add_report_v2(
315
+ patient_id=state["patient_id"],
316
+ reports=state["standardized_reports"],
317
+ run_id=state["run_id"],
318
+ )
319
+ state["report_id_list"] = report_id_list
320
+
321
+ logger.info("report_id_list = %s", report_id_list)
322
+ for report_id in report_id_list.split(","):
323
+ await get_db().aggregate_trends_from_report(state["patient_id"], report_id)
324
+
325
+ return state
326
+
327
+
328
  async def fn_trends_aggregator_node(state: SheamiState):
329
  logger.info("%s| Aggregating Trends : started", state["thread_id"])
330
  state["messages"].append("Aggregating Trends : started")
 
450
  indent=1,
451
  )
452
 
453
+ # logger.info("llm_input = %s", llm_input)
454
+
455
  # 1. LLM narrative
456
  messages = [
457
  SystemMessage(
 
462
  "2. The names of the reports used to summarize this information."
463
  "3. Patient summary (patient id, name, age, sex if available)"
464
  "4. Test window (mention the from and to dates)"
465
+ """
466
+ 5. Trend summaries
467
+ Generate tables with the following columns:
468
+
469
+ - Test Name
470
+ - Latest Value 1, Latest Value 2, Latest Value 3 (use a hyphen "–" if a value is missing)
471
+ - Unit
472
+ - Reference Range
473
+ - Inference (latest value only): ✅ if within normal range, ▲ if above normal (high), ▼ if below normal (low)
474
+ - Trend Direction (across last 3 values): ⬆️ if values are rising, ⬇️ if values are falling, ➖ (or ✅) if stable/normal
475
+ """
476
+ "6. Clinical insights. \n"
477
+ "\nImportant Rules:\n"
478
+ "- Format tables in proper <table> with <tr>, <th>, <td>. "
479
+ "- Do not include charts, they will be programmatically added."
480
+ """
481
+ 5. Trend summaries
482
+ Generate HTML tables with the following structure and formatting rules:
483
+
484
+ Columns:
485
+ - Test Name
486
+ - Latest Value 1, Latest Value 2, Latest Value 3 (use a hyphen "–" if a value is missing)
487
+ - Unit
488
+ - Reference Range
489
+ - Inference (latest value only): ✅ if within normal range, ▲ if above normal (high), ▼ if below normal (low)
490
+ - Trend Direction (across last 3 values): ⬆️ if values are rising, ⬇️ if values are falling, ➖ (or ✅) if stable/normal
491
+
492
+ Formatting requirements:
493
+ - The HTML will be shown in a UI (`gr.HTML`) and also rendered to PDF via WeasyPrint.
494
+ - The table must ALWAYS fit within 100% of the container width. Do not allow horizontal scrolling, clipping, or overlapping columns.
495
+ - Use `table-layout: fixed;` and `<colgroup>` with percentage widths that sum to 100%.
496
+ - Allow text wrapping inside cells so narrow columns still display all content.
497
+ - Example CSS to embed at the top of the HTML:
498
+
499
+ <style>
500
+ table { width: 100%; border-collapse: collapse; table-layout: fixed; }
501
+ col { }
502
+ th, td {
503
+ font-size: 11px;
504
+ padding: 4px 6px;
505
+ white-space: normal;
506
+ word-break: break-word;
507
+ }
508
+ </style>
509
+
510
+ - Example `<colgroup>` (adjust if needed):
511
+ <colgroup>
512
+ <col style="width:20%"> <!-- Test Name -->
513
+ <col style="width:8%"> <!-- Latest Value 1 -->
514
+ <col style="width:8%"> <!-- Latest Value 2 -->
515
+ <col style="width:8%"> <!-- Latest Value 3 -->
516
+ <col style="width:8%"> <!-- Unit -->
517
+ <col style="width:16%"> <!-- Reference Range -->
518
+ <col style="width:16%"> <!-- Inference -->
519
+ <col style="width:16%"> <!-- Trend Direction -->
520
+ </colgroup>
521
+ """
522
  )
523
  ),
524
  HumanMessage(content=llm_input),
 
613
 
614
  await get_db().update_run_stats(run_id=state["run_id"], status="completed")
615
 
 
 
 
 
 
 
 
 
 
 
616
  # add final report
617
  # Save PDF along with metadata
618
  with open(pdf_path, "rb") as f:
 
701
  workflow.add_node("increment_index", fn_increment_index_node)
702
  workflow.add_node("testname_standardizer", fn_testname_standardizer_node)
703
  workflow.add_node("unit_normalizer", fn_unit_normalizer_node)
704
+ workflow.add_node("db_update_node", fn_db_update_node)
705
  workflow.add_node("trends", fn_trends_aggregator_node)
706
  workflow.add_node("interpreter", fn_interpreter_node)
707
 
 
732
  workflow.add_edge("testname_standardizer_notifier", "testname_standardizer")
733
  workflow.add_edge("testname_standardizer", "unit_normalizer_notifier")
734
  workflow.add_edge("unit_normalizer_notifier", "unit_normalizer")
735
+ workflow.add_edge("unit_normalizer", "db_update_node")
736
+ workflow.add_edge("db_update_node", "trends_notifier")
737
  workflow.add_edge("trends_notifier", "trends")
738
  workflow.add_edge("trends", "interpreter_notifier")
739
  workflow.add_edge("interpreter_notifier", "interpreter")
modules/db.py CHANGED
@@ -93,7 +93,7 @@ class SheamiDB:
93
  # REPORT FUNCTIONS
94
  # ---------------------------
95
  async def add_report_v2(
96
- self, patient_id: str, reports: list[StandardizedReport]
97
  ) -> str:
98
  inserted_ids: list[ObjectId] = []
99
  for parsed_data in reports:
@@ -102,6 +102,7 @@ class SheamiDB:
102
  "uploaded_at": datetime.now(timezone.utc),
103
  "file_name": parsed_data.original_report_file_name,
104
  "parsed_data_v2": parsed_data.model_dump(),
 
105
  }
106
  result = await self.reports.insert_one(report)
107
  inserted_ids.append(result.inserted_id)
 
93
  # REPORT FUNCTIONS
94
  # ---------------------------
95
  async def add_report_v2(
96
+ self, patient_id: str, reports: list[StandardizedReport], run_id: str
97
  ) -> str:
98
  inserted_ids: list[ObjectId] = []
99
  for parsed_data in reports:
 
102
  "uploaded_at": datetime.now(timezone.utc),
103
  "file_name": parsed_data.original_report_file_name,
104
  "parsed_data_v2": parsed_data.model_dump(),
105
+ "run_id" : ObjectId(run_id),
106
  }
107
  result = await self.reports.insert_one(report)
108
  inserted_ids.append(result.inserted_id)