vikramvasudevan commited on
Commit
094e0f4
·
verified ·
1 Parent(s): af357c2

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. graph.py +59 -30
  2. ui.py +5 -1
graph.py CHANGED
@@ -94,14 +94,23 @@ testname_standardizer_chain = testname_standardizer_prompt | llm
94
  # -----------------------------
95
 
96
 
 
 
 
 
 
 
 
 
 
97
  async def fn_init_node(state: SheamiState):
98
  os.makedirs(SheamiConfig.get_output_dir(state["thread_id"]), exist_ok=True)
99
  if "messages" not in state:
100
  state["messages"] = []
101
- state["messages"].append("Initializing ...")
102
- state["messages"].append("Files received for processing ...")
103
  for idx, report in enumerate(state["uploaded_reports"]):
104
- state["messages"].append(f"{idx+1}. {report.report_file_name}")
105
  state["standardized_reports"] = []
106
  state["trends_json"] = {}
107
  state["pdf_path"] = ""
@@ -124,7 +133,7 @@ async def fn_init_node(state: SheamiState):
124
  ],
125
  )
126
  state["run_id"] = run_id
127
-
128
  return state
129
 
130
 
@@ -215,12 +224,14 @@ async def fn_standardize_current_report_node(state: SheamiState):
215
  logger.info(
216
  "%s| Standardizing report %s", state["thread_id"], report.report_file_name
217
  )
218
- state["messages"].append(f"Standardizing report: {report.report_file_name}")
219
 
220
  result = await call_llm(report=report, ocr=False)
221
  if not result.lab_results:
222
- state["messages"].append(
223
- f"⛔ Could not extract any data from PDF : {report.report_file_name}. Trying OCR ... might take a while"
 
 
224
  )
225
  report.report_contents = pdf_to_text_ocr(
226
  pdf_path=report.report_file_name_with_path
@@ -236,13 +247,23 @@ async def fn_standardize_current_report_node(state: SheamiState):
236
  )
237
  result = await call_llm(report=report, ocr=True)
238
  if not result.lab_results:
239
- state["messages"].append(
240
- f"⛔ OCR couldn't extract : {report.report_file_name}."
 
 
241
  )
242
  else:
243
- state["messages"].append(
244
- f"✅ Extracted data for report : {report.report_file_name}."
 
 
245
  )
 
 
 
 
 
 
246
 
247
  state["standardized_reports"].append(result)
248
 
@@ -263,12 +284,13 @@ async def fn_standardize_current_report_node(state: SheamiState):
263
  def fn_is_report_available_to_process(state: SheamiState) -> str:
264
  if state["current_index"] < len(state["uploaded_reports"]):
265
  report = state["uploaded_reports"][state["current_index"]]
266
- state["messages"].append(
267
- f"Initiating report standardization for: {report.report_file_name}"
 
268
  )
269
  return "continue"
270
  else:
271
- state["messages"].append("Standardizing reports: finished")
272
  return "done"
273
 
274
 
@@ -289,7 +311,7 @@ def get_unique_test_names(state: SheamiState):
289
 
290
  async def fn_testname_standardizer_node(state: SheamiState):
291
  logger.info("%s| Standardizing Test Names: started", state["thread_id"])
292
- state["messages"].append("Standardizing Test Names: started")
293
 
294
  # collect unique names
295
  unique_names = get_unique_test_names(state)
@@ -321,14 +343,16 @@ async def fn_testname_standardizer_node(state: SheamiState):
321
  )
322
 
323
  logger.info("%s| Standardizing Test Names: finished", state["thread_id"])
324
- state["messages"].append(f"Processed {len(unique_names)} tests")
325
- state["messages"].append("Standardizing Test Names: finished")
 
 
326
  return state
327
 
328
 
329
  async def fn_unit_normalizer_node(state: SheamiState):
330
  logger.info("%s| Standardizing Units : started", state["thread_id"])
331
- state["messages"].append("Standardizing Units: started")
332
  """
333
  Normalize units for lab test values across all standardized reports.
334
  Example: 'gms/dL', 'gm%', 'G/DL' → 'g/dL'
@@ -355,7 +379,7 @@ async def fn_unit_normalizer_node(state: SheamiState):
355
  sub.test_unit = unit_map.get(normalized, sub.test_unit)
356
 
357
  logger.info("%s| Standardizing Units : finished", state["thread_id"])
358
- state["messages"].append("Standardizing Units: finished")
359
  return state
360
 
361
 
@@ -377,7 +401,7 @@ async def fn_db_update_node(state: SheamiState):
377
 
378
  async def fn_trends_aggregator_node(state: SheamiState):
379
  logger.info("%s| Aggregating Trends : started", state["thread_id"])
380
- state["messages"].append("Aggregating Trends : started")
381
 
382
  import re
383
  import os
@@ -431,9 +455,14 @@ async def fn_trends_aggregator_node(state: SheamiState):
431
  if rr and key not in ref_ranges:
432
  ref_ranges[key] = {"min": rr.min, "max": rr.max}
433
 
 
434
  for idx, report in enumerate(state["standardized_reports"]):
435
  logger.info("%s| Aggregating Trends for report-%d", state["thread_id"], idx)
436
- state["messages"].append(f"Aggregating Trends for report-{idx+1}...")
 
 
 
 
437
 
438
  for item in report.lab_results:
439
  # Case A: CompositeLabResult (e.g., CUE, LFT, etc.)
@@ -475,13 +504,13 @@ async def fn_trends_aggregator_node(state: SheamiState):
475
  json.dump(state["trends_json"], f, indent=1, ensure_ascii=False)
476
 
477
  logger.info("%s| Aggregating Trends : finished", state["thread_id"])
478
- state["messages"].append("Aggregating Trends : finished")
479
  return state
480
 
481
 
482
  async def fn_interpreter_node(state: SheamiState):
483
  logger.info("%s| Interpreting Trends : started", state["thread_id"])
484
- state["messages"].append("Interpreting Trends : started")
485
 
486
  uploaded_reports = await get_db().get_reports_by_patient(
487
  patient_id=state["patient_id"]
@@ -645,7 +674,7 @@ Formatting requirements:
645
  state["pdf_path"] = pdf_path
646
  state["interpretation_html"] = interpretation_html
647
  logger.info("%s| Interpreting Trends : finished", state["thread_id"])
648
- state["messages"].append("Interpreting Trends : finished")
649
 
650
  return state
651
 
@@ -703,8 +732,8 @@ def schedule_cleanup(file_path, delay=300): # 300 sec = 5 min
703
  async def fn_standardizer_node_notifier(state: SheamiState):
704
  state = await reset_process_desc(state, process_desc="Standardizing reports ...")
705
  state["units_total"] = len(state["uploaded_reports"])
706
- state["messages"].append(
707
- "Standardizing reports now ... this might take a while ..."
708
  )
709
  state["overall_units_processed"] += 1
710
  return state
@@ -712,28 +741,28 @@ async def fn_standardizer_node_notifier(state: SheamiState):
712
 
713
  async def fn_testname_standardizer_node_notifier(state: SheamiState):
714
  state = await reset_process_desc(state, process_desc="Standardizing test names ...")
715
- state["messages"].append("Standardizing test names now ...")
716
  state["overall_units_processed"] += 1
717
  return state
718
 
719
 
720
  async def fn_unit_normalizer_node_notifier(state: SheamiState):
721
  state = await reset_process_desc(state, process_desc="Standardizing units ...")
722
- state["messages"].append("Standardizing measurement units now ...")
723
  state["overall_units_processed"] += 1
724
  return state
725
 
726
 
727
  async def fn_trends_aggregator_node_notifier(state: SheamiState):
728
  state = await reset_process_desc(state, process_desc="Aggregating trends ...")
729
- state["messages"].append("Aggregating trends now ...")
730
  state["overall_units_processed"] += 1
731
  return state
732
 
733
 
734
  async def fn_interpreter_node_notifier(state: SheamiState):
735
  state = await reset_process_desc(state, process_desc="Plotting trends ...")
736
- state["messages"].append("Interpreting and plotting trends now ...")
737
  state["overall_units_processed"] += 1
738
  return state
739
 
 
94
  # -----------------------------
95
 
96
 
97
+ def send_message(state: SheamiState, msg: str, append: bool = True):
98
+ if append:
99
+ # append message
100
+ state["messages"].append(msg)
101
+ else:
102
+ # replace last message
103
+ state["messages"][-1] = msg
104
+
105
+
106
  async def fn_init_node(state: SheamiState):
107
  os.makedirs(SheamiConfig.get_output_dir(state["thread_id"]), exist_ok=True)
108
  if "messages" not in state:
109
  state["messages"] = []
110
+ send_message(state=state, msg="Initializing ...")
111
+ send_message(state=state, msg="Files received for processing ...", append=False)
112
  for idx, report in enumerate(state["uploaded_reports"]):
113
+ send_message(state=state, msg=f"{idx+1}. <span class='highlighted-text'>{report.report_file_name}</span>")
114
  state["standardized_reports"] = []
115
  state["trends_json"] = {}
116
  state["pdf_path"] = ""
 
133
  ],
134
  )
135
  state["run_id"] = run_id
136
+ send_message(state=state, msg=f"Initialized run [{run_id}]")
137
  return state
138
 
139
 
 
224
  logger.info(
225
  "%s| Standardizing report %s", state["thread_id"], report.report_file_name
226
  )
227
+ send_message(state=state, msg=f"Standardizing report: {report.report_file_name}", append=False)
228
 
229
  result = await call_llm(report=report, ocr=False)
230
  if not result.lab_results:
231
+ send_message(
232
+ state=state,
233
+ msg=f"⛔ Could not extract any data from PDF : {report.report_file_name}. Trying OCR ... might take a while",
234
+ append=False,
235
  )
236
  report.report_contents = pdf_to_text_ocr(
237
  pdf_path=report.report_file_name_with_path
 
247
  )
248
  result = await call_llm(report=report, ocr=True)
249
  if not result.lab_results:
250
+ send_message(
251
+ state=state,
252
+ msg=f"⛔ OCR couldn't extract : {report.report_file_name}.",
253
+ append=False,
254
  )
255
  else:
256
+ send_message(
257
+ state=state,
258
+ msg=f"✅ Extracted <span class='highlighted-text'>{len(result.lab_results)}</span> lab results using OCR for report : <span class='highlighted-text'>{report.report_file_name}</span>.",
259
+ append=False,
260
  )
261
+ else:
262
+ send_message(
263
+ state=state,
264
+ msg=f"✅ Extracted <span class='highlighted-text'>{len(result.lab_results)}</span> lab results from : <span class='highlighted-text'>{report.report_file_name}</span>.",
265
+ append=False,
266
+ )
267
 
268
  state["standardized_reports"].append(result)
269
 
 
284
  def fn_is_report_available_to_process(state: SheamiState) -> str:
285
  if state["current_index"] < len(state["uploaded_reports"]):
286
  report = state["uploaded_reports"][state["current_index"]]
287
+ send_message(
288
+ state=state,
289
+ msg=f"Initiating report standardization for: <span class='highlighted-text'>{report.report_file_name}</span>",
290
  )
291
  return "continue"
292
  else:
293
+ send_message(state=state, msg="Standardizing reports: finished")
294
  return "done"
295
 
296
 
 
311
 
312
  async def fn_testname_standardizer_node(state: SheamiState):
313
  logger.info("%s| Standardizing Test Names: started", state["thread_id"])
314
+ send_message(state=state, msg="Standardizing Test Names: started", append=False)
315
 
316
  # collect unique names
317
  unique_names = get_unique_test_names(state)
 
343
  )
344
 
345
  logger.info("%s| Standardizing Test Names: finished", state["thread_id"])
346
+ send_message(
347
+ state=state, msg=f"Identified <span class='highlighted-text'>{len(unique_names)}</span> unique tests", append=False
348
+ )
349
+ # send_message(state=state, msg="Standardizing Test Names: finished")
350
  return state
351
 
352
 
353
  async def fn_unit_normalizer_node(state: SheamiState):
354
  logger.info("%s| Standardizing Units : started", state["thread_id"])
355
+ send_message(state=state, msg="Standardizing Units: started", append=False)
356
  """
357
  Normalize units for lab test values across all standardized reports.
358
  Example: 'gms/dL', 'gm%', 'G/DL' → 'g/dL'
 
379
  sub.test_unit = unit_map.get(normalized, sub.test_unit)
380
 
381
  logger.info("%s| Standardizing Units : finished", state["thread_id"])
382
+ send_message(state=state, msg="Standardizing Units: finished", append=False)
383
  return state
384
 
385
 
 
401
 
402
  async def fn_trends_aggregator_node(state: SheamiState):
403
  logger.info("%s| Aggregating Trends : started", state["thread_id"])
404
+ send_message(state=state, msg="Aggregating Trends : started", append=False)
405
 
406
  import re
407
  import os
 
455
  if rr and key not in ref_ranges:
456
  ref_ranges[key] = {"min": rr.min, "max": rr.max}
457
 
458
+ total_reports = len(state["standardized_reports"])
459
  for idx, report in enumerate(state["standardized_reports"]):
460
  logger.info("%s| Aggregating Trends for report-%d", state["thread_id"], idx)
461
+ send_message(
462
+ state=state,
463
+ msg=f"Aggregating {idx+1}/{total_reports} trends : report-{idx+1}...",
464
+ append=False,
465
+ )
466
 
467
  for item in report.lab_results:
468
  # Case A: CompositeLabResult (e.g., CUE, LFT, etc.)
 
504
  json.dump(state["trends_json"], f, indent=1, ensure_ascii=False)
505
 
506
  logger.info("%s| Aggregating Trends : finished", state["thread_id"])
507
+ send_message(state=state, msg="Aggregating Trends : finished", append=False)
508
  return state
509
 
510
 
511
  async def fn_interpreter_node(state: SheamiState):
512
  logger.info("%s| Interpreting Trends : started", state["thread_id"])
513
+ send_message(state=state, msg="Interpreting Trends : started", append=False)
514
 
515
  uploaded_reports = await get_db().get_reports_by_patient(
516
  patient_id=state["patient_id"]
 
674
  state["pdf_path"] = pdf_path
675
  state["interpretation_html"] = interpretation_html
676
  logger.info("%s| Interpreting Trends : finished", state["thread_id"])
677
+ send_message(state=state, msg="Interpreting Trends : finished", append=False)
678
 
679
  return state
680
 
 
732
  async def fn_standardizer_node_notifier(state: SheamiState):
733
  state = await reset_process_desc(state, process_desc="Standardizing reports ...")
734
  state["units_total"] = len(state["uploaded_reports"])
735
+ send_message(
736
+ state=state, msg="Standardizing reports now ... this might take a while ..."
737
  )
738
  state["overall_units_processed"] += 1
739
  return state
 
741
 
742
  async def fn_testname_standardizer_node_notifier(state: SheamiState):
743
  state = await reset_process_desc(state, process_desc="Standardizing test names ...")
744
+ send_message(state=state, msg="Standardizing test names now ...")
745
  state["overall_units_processed"] += 1
746
  return state
747
 
748
 
749
  async def fn_unit_normalizer_node_notifier(state: SheamiState):
750
  state = await reset_process_desc(state, process_desc="Standardizing units ...")
751
+ send_message(state=state, msg="Standardizing measurement units now ...")
752
  state["overall_units_processed"] += 1
753
  return state
754
 
755
 
756
  async def fn_trends_aggregator_node_notifier(state: SheamiState):
757
  state = await reset_process_desc(state, process_desc="Aggregating trends ...")
758
+ send_message(state=state, msg="Aggregating trends now ...")
759
  state["overall_units_processed"] += 1
760
  return state
761
 
762
 
763
  async def fn_interpreter_node_notifier(state: SheamiState):
764
  state = await reset_process_desc(state, process_desc="Plotting trends ...")
765
+ send_message(state=state, msg="Interpreting and plotting trends now ...")
766
  state["overall_units_processed"] += 1
767
  return state
768
 
ui.py CHANGED
@@ -121,7 +121,7 @@ async def process_reports(user_email: str, patient_id: str, files: list):
121
 
122
  buffer += (
123
  "\n\n"
124
- f"✅ Processed {len(files)} reports.\n"
125
  "Please download the output file from below within 5 min."
126
  )
127
  except Exception as e:
@@ -361,6 +361,10 @@ def handle_file_input_change(files):
361
 
362
  def get_css():
363
  return """
 
 
 
 
364
  #patient-card{
365
  border: 1px solid rgba(0,0,0,0.06);
366
  background: #fafafa;
 
121
 
122
  buffer += (
123
  "\n\n"
124
+ f"✅ Processed <span class='highlighted-text'>{len(files)}</span> reports.\n"
125
  "Please download the output file from below within 5 min."
126
  )
127
  except Exception as e:
 
361
 
362
  def get_css():
363
  return """
364
+ .highlighted-text {
365
+ color : lightgray;
366
+ font-style: italics;
367
+ }
368
  #patient-card{
369
  border: 1px solid rgba(0,0,0,0.06);
370
  background: #fafafa;