Upload folder using huggingface_hub
Browse files
graph.py
CHANGED
|
@@ -94,14 +94,23 @@ testname_standardizer_chain = testname_standardizer_prompt | llm
|
|
| 94 |
# -----------------------------
|
| 95 |
|
| 96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
async def fn_init_node(state: SheamiState):
|
| 98 |
os.makedirs(SheamiConfig.get_output_dir(state["thread_id"]), exist_ok=True)
|
| 99 |
if "messages" not in state:
|
| 100 |
state["messages"] = []
|
| 101 |
-
state
|
| 102 |
-
state
|
| 103 |
for idx, report in enumerate(state["uploaded_reports"]):
|
| 104 |
-
|
| 105 |
state["standardized_reports"] = []
|
| 106 |
state["trends_json"] = {}
|
| 107 |
state["pdf_path"] = ""
|
|
@@ -124,7 +133,7 @@ async def fn_init_node(state: SheamiState):
|
|
| 124 |
],
|
| 125 |
)
|
| 126 |
state["run_id"] = run_id
|
| 127 |
-
|
| 128 |
return state
|
| 129 |
|
| 130 |
|
|
@@ -215,12 +224,14 @@ async def fn_standardize_current_report_node(state: SheamiState):
|
|
| 215 |
logger.info(
|
| 216 |
"%s| Standardizing report %s", state["thread_id"], report.report_file_name
|
| 217 |
)
|
| 218 |
-
|
| 219 |
|
| 220 |
result = await call_llm(report=report, ocr=False)
|
| 221 |
if not result.lab_results:
|
| 222 |
-
|
| 223 |
-
|
|
|
|
|
|
|
| 224 |
)
|
| 225 |
report.report_contents = pdf_to_text_ocr(
|
| 226 |
pdf_path=report.report_file_name_with_path
|
|
@@ -236,13 +247,23 @@ async def fn_standardize_current_report_node(state: SheamiState):
|
|
| 236 |
)
|
| 237 |
result = await call_llm(report=report, ocr=True)
|
| 238 |
if not result.lab_results:
|
| 239 |
-
|
| 240 |
-
|
|
|
|
|
|
|
| 241 |
)
|
| 242 |
else:
|
| 243 |
-
|
| 244 |
-
|
|
|
|
|
|
|
| 245 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
|
| 247 |
state["standardized_reports"].append(result)
|
| 248 |
|
|
@@ -263,12 +284,13 @@ async def fn_standardize_current_report_node(state: SheamiState):
|
|
| 263 |
def fn_is_report_available_to_process(state: SheamiState) -> str:
|
| 264 |
if state["current_index"] < len(state["uploaded_reports"]):
|
| 265 |
report = state["uploaded_reports"][state["current_index"]]
|
| 266 |
-
|
| 267 |
-
|
|
|
|
| 268 |
)
|
| 269 |
return "continue"
|
| 270 |
else:
|
| 271 |
-
state
|
| 272 |
return "done"
|
| 273 |
|
| 274 |
|
|
@@ -289,7 +311,7 @@ def get_unique_test_names(state: SheamiState):
|
|
| 289 |
|
| 290 |
async def fn_testname_standardizer_node(state: SheamiState):
|
| 291 |
logger.info("%s| Standardizing Test Names: started", state["thread_id"])
|
| 292 |
-
state
|
| 293 |
|
| 294 |
# collect unique names
|
| 295 |
unique_names = get_unique_test_names(state)
|
|
@@ -321,14 +343,16 @@ async def fn_testname_standardizer_node(state: SheamiState):
|
|
| 321 |
)
|
| 322 |
|
| 323 |
logger.info("%s| Standardizing Test Names: finished", state["thread_id"])
|
| 324 |
-
|
| 325 |
-
|
|
|
|
|
|
|
| 326 |
return state
|
| 327 |
|
| 328 |
|
| 329 |
async def fn_unit_normalizer_node(state: SheamiState):
|
| 330 |
logger.info("%s| Standardizing Units : started", state["thread_id"])
|
| 331 |
-
state
|
| 332 |
"""
|
| 333 |
Normalize units for lab test values across all standardized reports.
|
| 334 |
Example: 'gms/dL', 'gm%', 'G/DL' → 'g/dL'
|
|
@@ -355,7 +379,7 @@ async def fn_unit_normalizer_node(state: SheamiState):
|
|
| 355 |
sub.test_unit = unit_map.get(normalized, sub.test_unit)
|
| 356 |
|
| 357 |
logger.info("%s| Standardizing Units : finished", state["thread_id"])
|
| 358 |
-
state
|
| 359 |
return state
|
| 360 |
|
| 361 |
|
|
@@ -377,7 +401,7 @@ async def fn_db_update_node(state: SheamiState):
|
|
| 377 |
|
| 378 |
async def fn_trends_aggregator_node(state: SheamiState):
|
| 379 |
logger.info("%s| Aggregating Trends : started", state["thread_id"])
|
| 380 |
-
state
|
| 381 |
|
| 382 |
import re
|
| 383 |
import os
|
|
@@ -431,9 +455,14 @@ async def fn_trends_aggregator_node(state: SheamiState):
|
|
| 431 |
if rr and key not in ref_ranges:
|
| 432 |
ref_ranges[key] = {"min": rr.min, "max": rr.max}
|
| 433 |
|
|
|
|
| 434 |
for idx, report in enumerate(state["standardized_reports"]):
|
| 435 |
logger.info("%s| Aggregating Trends for report-%d", state["thread_id"], idx)
|
| 436 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 437 |
|
| 438 |
for item in report.lab_results:
|
| 439 |
# Case A: CompositeLabResult (e.g., CUE, LFT, etc.)
|
|
@@ -475,13 +504,13 @@ async def fn_trends_aggregator_node(state: SheamiState):
|
|
| 475 |
json.dump(state["trends_json"], f, indent=1, ensure_ascii=False)
|
| 476 |
|
| 477 |
logger.info("%s| Aggregating Trends : finished", state["thread_id"])
|
| 478 |
-
state
|
| 479 |
return state
|
| 480 |
|
| 481 |
|
| 482 |
async def fn_interpreter_node(state: SheamiState):
|
| 483 |
logger.info("%s| Interpreting Trends : started", state["thread_id"])
|
| 484 |
-
state
|
| 485 |
|
| 486 |
uploaded_reports = await get_db().get_reports_by_patient(
|
| 487 |
patient_id=state["patient_id"]
|
|
@@ -645,7 +674,7 @@ Formatting requirements:
|
|
| 645 |
state["pdf_path"] = pdf_path
|
| 646 |
state["interpretation_html"] = interpretation_html
|
| 647 |
logger.info("%s| Interpreting Trends : finished", state["thread_id"])
|
| 648 |
-
state
|
| 649 |
|
| 650 |
return state
|
| 651 |
|
|
@@ -703,8 +732,8 @@ def schedule_cleanup(file_path, delay=300): # 300 sec = 5 min
|
|
| 703 |
async def fn_standardizer_node_notifier(state: SheamiState):
|
| 704 |
state = await reset_process_desc(state, process_desc="Standardizing reports ...")
|
| 705 |
state["units_total"] = len(state["uploaded_reports"])
|
| 706 |
-
|
| 707 |
-
"Standardizing reports now ... this might take a while ..."
|
| 708 |
)
|
| 709 |
state["overall_units_processed"] += 1
|
| 710 |
return state
|
|
@@ -712,28 +741,28 @@ async def fn_standardizer_node_notifier(state: SheamiState):
|
|
| 712 |
|
| 713 |
async def fn_testname_standardizer_node_notifier(state: SheamiState):
|
| 714 |
state = await reset_process_desc(state, process_desc="Standardizing test names ...")
|
| 715 |
-
state
|
| 716 |
state["overall_units_processed"] += 1
|
| 717 |
return state
|
| 718 |
|
| 719 |
|
| 720 |
async def fn_unit_normalizer_node_notifier(state: SheamiState):
|
| 721 |
state = await reset_process_desc(state, process_desc="Standardizing units ...")
|
| 722 |
-
state
|
| 723 |
state["overall_units_processed"] += 1
|
| 724 |
return state
|
| 725 |
|
| 726 |
|
| 727 |
async def fn_trends_aggregator_node_notifier(state: SheamiState):
|
| 728 |
state = await reset_process_desc(state, process_desc="Aggregating trends ...")
|
| 729 |
-
state
|
| 730 |
state["overall_units_processed"] += 1
|
| 731 |
return state
|
| 732 |
|
| 733 |
|
| 734 |
async def fn_interpreter_node_notifier(state: SheamiState):
|
| 735 |
state = await reset_process_desc(state, process_desc="Plotting trends ...")
|
| 736 |
-
state
|
| 737 |
state["overall_units_processed"] += 1
|
| 738 |
return state
|
| 739 |
|
|
|
|
| 94 |
# -----------------------------
|
| 95 |
|
| 96 |
|
| 97 |
+
def send_message(state: SheamiState, msg: str, append: bool = True):
|
| 98 |
+
if append:
|
| 99 |
+
# append message
|
| 100 |
+
state["messages"].append(msg)
|
| 101 |
+
else:
|
| 102 |
+
# replace last message
|
| 103 |
+
state["messages"][-1] = msg
|
| 104 |
+
|
| 105 |
+
|
| 106 |
async def fn_init_node(state: SheamiState):
|
| 107 |
os.makedirs(SheamiConfig.get_output_dir(state["thread_id"]), exist_ok=True)
|
| 108 |
if "messages" not in state:
|
| 109 |
state["messages"] = []
|
| 110 |
+
send_message(state=state, msg="Initializing ...")
|
| 111 |
+
send_message(state=state, msg="Files received for processing ...", append=False)
|
| 112 |
for idx, report in enumerate(state["uploaded_reports"]):
|
| 113 |
+
send_message(state=state, msg=f"{idx+1}. <span class='highlighted-text'>{report.report_file_name}</span>")
|
| 114 |
state["standardized_reports"] = []
|
| 115 |
state["trends_json"] = {}
|
| 116 |
state["pdf_path"] = ""
|
|
|
|
| 133 |
],
|
| 134 |
)
|
| 135 |
state["run_id"] = run_id
|
| 136 |
+
send_message(state=state, msg=f"Initialized run [{run_id}]")
|
| 137 |
return state
|
| 138 |
|
| 139 |
|
|
|
|
| 224 |
logger.info(
|
| 225 |
"%s| Standardizing report %s", state["thread_id"], report.report_file_name
|
| 226 |
)
|
| 227 |
+
send_message(state=state, msg=f"Standardizing report: {report.report_file_name}", append=False)
|
| 228 |
|
| 229 |
result = await call_llm(report=report, ocr=False)
|
| 230 |
if not result.lab_results:
|
| 231 |
+
send_message(
|
| 232 |
+
state=state,
|
| 233 |
+
msg=f"⛔ Could not extract any data from PDF : {report.report_file_name}. Trying OCR ... might take a while",
|
| 234 |
+
append=False,
|
| 235 |
)
|
| 236 |
report.report_contents = pdf_to_text_ocr(
|
| 237 |
pdf_path=report.report_file_name_with_path
|
|
|
|
| 247 |
)
|
| 248 |
result = await call_llm(report=report, ocr=True)
|
| 249 |
if not result.lab_results:
|
| 250 |
+
send_message(
|
| 251 |
+
state=state,
|
| 252 |
+
msg=f"⛔ OCR couldn't extract : {report.report_file_name}.",
|
| 253 |
+
append=False,
|
| 254 |
)
|
| 255 |
else:
|
| 256 |
+
send_message(
|
| 257 |
+
state=state,
|
| 258 |
+
msg=f"✅ Extracted <span class='highlighted-text'>{len(result.lab_results)}</span> lab results using OCR for report : <span class='highlighted-text'>{report.report_file_name}</span>.",
|
| 259 |
+
append=False,
|
| 260 |
)
|
| 261 |
+
else:
|
| 262 |
+
send_message(
|
| 263 |
+
state=state,
|
| 264 |
+
msg=f"✅ Extracted <span class='highlighted-text'>{len(result.lab_results)}</span> lab results from : <span class='highlighted-text'>{report.report_file_name}</span>.",
|
| 265 |
+
append=False,
|
| 266 |
+
)
|
| 267 |
|
| 268 |
state["standardized_reports"].append(result)
|
| 269 |
|
|
|
|
| 284 |
def fn_is_report_available_to_process(state: SheamiState) -> str:
|
| 285 |
if state["current_index"] < len(state["uploaded_reports"]):
|
| 286 |
report = state["uploaded_reports"][state["current_index"]]
|
| 287 |
+
send_message(
|
| 288 |
+
state=state,
|
| 289 |
+
msg=f"Initiating report standardization for: <span class='highlighted-text'>{report.report_file_name}</span>",
|
| 290 |
)
|
| 291 |
return "continue"
|
| 292 |
else:
|
| 293 |
+
send_message(state=state, msg="Standardizing reports: finished")
|
| 294 |
return "done"
|
| 295 |
|
| 296 |
|
|
|
|
| 311 |
|
| 312 |
async def fn_testname_standardizer_node(state: SheamiState):
|
| 313 |
logger.info("%s| Standardizing Test Names: started", state["thread_id"])
|
| 314 |
+
send_message(state=state, msg="Standardizing Test Names: started", append=False)
|
| 315 |
|
| 316 |
# collect unique names
|
| 317 |
unique_names = get_unique_test_names(state)
|
|
|
|
| 343 |
)
|
| 344 |
|
| 345 |
logger.info("%s| Standardizing Test Names: finished", state["thread_id"])
|
| 346 |
+
send_message(
|
| 347 |
+
state=state, msg=f"Identified <span class='highlighted-text'>{len(unique_names)}</span> unique tests", append=False
|
| 348 |
+
)
|
| 349 |
+
# send_message(state=state, msg="Standardizing Test Names: finished")
|
| 350 |
return state
|
| 351 |
|
| 352 |
|
| 353 |
async def fn_unit_normalizer_node(state: SheamiState):
|
| 354 |
logger.info("%s| Standardizing Units : started", state["thread_id"])
|
| 355 |
+
send_message(state=state, msg="Standardizing Units: started", append=False)
|
| 356 |
"""
|
| 357 |
Normalize units for lab test values across all standardized reports.
|
| 358 |
Example: 'gms/dL', 'gm%', 'G/DL' → 'g/dL'
|
|
|
|
| 379 |
sub.test_unit = unit_map.get(normalized, sub.test_unit)
|
| 380 |
|
| 381 |
logger.info("%s| Standardizing Units : finished", state["thread_id"])
|
| 382 |
+
send_message(state=state, msg="Standardizing Units: finished", append=False)
|
| 383 |
return state
|
| 384 |
|
| 385 |
|
|
|
|
| 401 |
|
| 402 |
async def fn_trends_aggregator_node(state: SheamiState):
|
| 403 |
logger.info("%s| Aggregating Trends : started", state["thread_id"])
|
| 404 |
+
send_message(state=state, msg="Aggregating Trends : started", append=False)
|
| 405 |
|
| 406 |
import re
|
| 407 |
import os
|
|
|
|
| 455 |
if rr and key not in ref_ranges:
|
| 456 |
ref_ranges[key] = {"min": rr.min, "max": rr.max}
|
| 457 |
|
| 458 |
+
total_reports = len(state["standardized_reports"])
|
| 459 |
for idx, report in enumerate(state["standardized_reports"]):
|
| 460 |
logger.info("%s| Aggregating Trends for report-%d", state["thread_id"], idx)
|
| 461 |
+
send_message(
|
| 462 |
+
state=state,
|
| 463 |
+
msg=f"Aggregating {idx+1}/{total_reports} trends : report-{idx+1}...",
|
| 464 |
+
append=False,
|
| 465 |
+
)
|
| 466 |
|
| 467 |
for item in report.lab_results:
|
| 468 |
# Case A: CompositeLabResult (e.g., CUE, LFT, etc.)
|
|
|
|
| 504 |
json.dump(state["trends_json"], f, indent=1, ensure_ascii=False)
|
| 505 |
|
| 506 |
logger.info("%s| Aggregating Trends : finished", state["thread_id"])
|
| 507 |
+
send_message(state=state, msg="Aggregating Trends : finished", append=False)
|
| 508 |
return state
|
| 509 |
|
| 510 |
|
| 511 |
async def fn_interpreter_node(state: SheamiState):
|
| 512 |
logger.info("%s| Interpreting Trends : started", state["thread_id"])
|
| 513 |
+
send_message(state=state, msg="Interpreting Trends : started", append=False)
|
| 514 |
|
| 515 |
uploaded_reports = await get_db().get_reports_by_patient(
|
| 516 |
patient_id=state["patient_id"]
|
|
|
|
| 674 |
state["pdf_path"] = pdf_path
|
| 675 |
state["interpretation_html"] = interpretation_html
|
| 676 |
logger.info("%s| Interpreting Trends : finished", state["thread_id"])
|
| 677 |
+
send_message(state=state, msg="Interpreting Trends : finished", append=False)
|
| 678 |
|
| 679 |
return state
|
| 680 |
|
|
|
|
| 732 |
async def fn_standardizer_node_notifier(state: SheamiState):
|
| 733 |
state = await reset_process_desc(state, process_desc="Standardizing reports ...")
|
| 734 |
state["units_total"] = len(state["uploaded_reports"])
|
| 735 |
+
send_message(
|
| 736 |
+
state=state, msg="Standardizing reports now ... this might take a while ..."
|
| 737 |
)
|
| 738 |
state["overall_units_processed"] += 1
|
| 739 |
return state
|
|
|
|
| 741 |
|
| 742 |
async def fn_testname_standardizer_node_notifier(state: SheamiState):
|
| 743 |
state = await reset_process_desc(state, process_desc="Standardizing test names ...")
|
| 744 |
+
send_message(state=state, msg="Standardizing test names now ...")
|
| 745 |
state["overall_units_processed"] += 1
|
| 746 |
return state
|
| 747 |
|
| 748 |
|
| 749 |
async def fn_unit_normalizer_node_notifier(state: SheamiState):
|
| 750 |
state = await reset_process_desc(state, process_desc="Standardizing units ...")
|
| 751 |
+
send_message(state=state, msg="Standardizing measurement units now ...")
|
| 752 |
state["overall_units_processed"] += 1
|
| 753 |
return state
|
| 754 |
|
| 755 |
|
| 756 |
async def fn_trends_aggregator_node_notifier(state: SheamiState):
|
| 757 |
state = await reset_process_desc(state, process_desc="Aggregating trends ...")
|
| 758 |
+
send_message(state=state, msg="Aggregating trends now ...")
|
| 759 |
state["overall_units_processed"] += 1
|
| 760 |
return state
|
| 761 |
|
| 762 |
|
| 763 |
async def fn_interpreter_node_notifier(state: SheamiState):
|
| 764 |
state = await reset_process_desc(state, process_desc="Plotting trends ...")
|
| 765 |
+
send_message(state=state, msg="Interpreting and plotting trends now ...")
|
| 766 |
state["overall_units_processed"] += 1
|
| 767 |
return state
|
| 768 |
|
ui.py
CHANGED
|
@@ -121,7 +121,7 @@ async def process_reports(user_email: str, patient_id: str, files: list):
|
|
| 121 |
|
| 122 |
buffer += (
|
| 123 |
"\n\n"
|
| 124 |
-
f"✅ Processed {len(files)} reports.\n"
|
| 125 |
"Please download the output file from below within 5 min."
|
| 126 |
)
|
| 127 |
except Exception as e:
|
|
@@ -361,6 +361,10 @@ def handle_file_input_change(files):
|
|
| 361 |
|
| 362 |
def get_css():
|
| 363 |
return """
|
|
|
|
|
|
|
|
|
|
|
|
|
| 364 |
#patient-card{
|
| 365 |
border: 1px solid rgba(0,0,0,0.06);
|
| 366 |
background: #fafafa;
|
|
|
|
| 121 |
|
| 122 |
buffer += (
|
| 123 |
"\n\n"
|
| 124 |
+
f"✅ Processed <span class='highlighted-text'>{len(files)}</span> reports.\n"
|
| 125 |
"Please download the output file from below within 5 min."
|
| 126 |
)
|
| 127 |
except Exception as e:
|
|
|
|
| 361 |
|
| 362 |
def get_css():
|
| 363 |
return """
|
| 364 |
+
.highlighted-text {
|
| 365 |
+
color : lightgray;
|
| 366 |
+
font-style: italics;
|
| 367 |
+
}
|
| 368 |
#patient-card{
|
| 369 |
border: 1px solid rgba(0,0,0,0.06);
|
| 370 |
background: #fafafa;
|