File size: 30,701 Bytes
b7bc475 0cfb077 52308e4 0cfb077 93bdf8b 0cfb077 8675fd6 0cfb077 b1447d4 0cfb077 d76f9b8 2cade03 d76f9b8 2cade03 93bdf8b cc61cf6 0cfb077 cc61cf6 0cfb077 094e0f4 469de1d 0cfb077 275f4fa 094e0f4 7987601 78d30a7 0cfb077 52308e4 fc44b9f 7987601 7d0bec7 b7bc475 7987601 469de1d b1447d4 52308e4 7f0cd9b b1447d4 c4b6dca fc44b9f 469de1d cc61cf6 b7bc475 cc61cf6 469de1d b1447d4 cc61cf6 7987601 469de1d b1447d4 7987601 7d0bec7 fc44b9f 7d0bec7 0cfb077 7f0cd9b 0cfb077 7f0cd9b 1f69723 7f0cd9b fc44b9f 7f0cd9b b1447d4 7f0cd9b b1447d4 fc44b9f 2cade03 7f0cd9b 2cade03 78d30a7 2cade03 7f0cd9b 094e0f4 7f0cd9b 56629cd 7f0cd9b 9e2bd6f 7f0cd9b 094e0f4 7f0cd9b 2cade03 094e0f4 7f0cd9b 094e0f4 93bdf8b fc44b9f 7987601 52308e4 fc44b9f 7987601 0cfb077 7987601 fc44b9f 7987601 fc44b9f 094e0f4 78d30a7 7987601 fc44b9f 094e0f4 fc44b9f 0cfb077 7987601 7d0bec7 93bdf8b 0cfb077 7d0bec7 094e0f4 93bdf8b 7d0bec7 0cfb077 469de1d 0cfb077 93bdf8b 0cfb077 93bdf8b f2bf8a5 93bdf8b 0cfb077 094e0f4 78d30a7 094e0f4 0cfb077 275f4fa 0cfb077 094e0f4 0cfb077 93bdf8b 0cfb077 094e0f4 0cfb077 ea96b05 275f4fa 0cfb077 094e0f4 275f4fa 0cfb077 93bdf8b 0cfb077 93bdf8b b1447d4 93bdf8b b1447d4 93bdf8b b1447d4 93bdf8b b1447d4 93bdf8b 0cfb077 094e0f4 0cfb077 094e0f4 275f4fa 93bdf8b b1447d4 93bdf8b b1447d4 93bdf8b 0cfb077 93bdf8b 52308e4 2cade03 52308e4 0cfb077 93bdf8b 52308e4 93bdf8b 0cfb077 094e0f4 0cfb077 275f4fa 0cfb077 094e0f4 cc61cf6 52308e4 ea96b05 4187049 ea96b05 0cfb077 93bdf8b 987e752 93bdf8b 4187049 52308e4 ea96b05 7f0cd9b ea96b05 93bdf8b 0cfb077 52308e4 0cfb077 469de1d 93bdf8b 0cfb077 52308e4 0cfb077 52308e4 0cfb077 cc61cf6 0cfb077 cc61cf6 0cfb077 52308e4 0cfb077 3b181ff 0cfb077 cc61cf6 0cfb077 7987601 cc61cf6 8675fd6 93bdf8b 8675fd6 93bdf8b cc61cf6 52308e4 0cfb077 094e0f4 b7bc475 52308e4 cc61cf6 469de1d b1447d4 0cfb077 469de1d b1447d4 d76f9b8 469de1d d76f9b8 52308e4 d76f9b8 93bdf8b 0cfb077 7987601 0cfb077 7987601 0cfb077 275f4fa 7987601 0cfb077 7987601 469de1d 7987601 094e0f4 7987601 275f4fa 7987601 469de1d 094e0f4 7987601 275f4fa 7987601 469de1d 094e0f4 7987601 275f4fa 7987601 469de1d 094e0f4 7987601 275f4fa 7987601 469de1d 094e0f4 7987601 275f4fa 0cfb077 7987601 b1447d4 0cfb077 7987601 fc44b9f 0cfb077 ea96b05 0cfb077 275f4fa 7987601 275f4fa 52308e4 275f4fa 0cfb077 275f4fa 7987601 fc44b9f 7987601 fc44b9f 7987601 fc44b9f 7987601 fc44b9f 7987601 fc44b9f 7987601 275f4fa ea96b05 275f4fa 52308e4 0cfb077 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 | from datetime import datetime
import threading
import time
from bson import ObjectId
import pandas as pd
from langchain_core.prompts import ChatPromptTemplate
import matplotlib.pyplot as plt
from dataclasses import dataclass
from typing import Dict, List, Literal, Optional, TypedDict, Union
import os, json
from pydantic import BaseModel
from langchain_core.messages import HumanMessage, SystemMessage
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph.message import StateGraph
from langgraph.graph.state import START, END
from langchain_openai import ChatOpenAI
from dotenv import load_dotenv
from common import get_db
from config import SheamiConfig
import logging
from modules.models import (
HealthReport,
SheamiMilestone,
SheamiState,
StandardizedReport,
TestResultReferenceRange,
)
from pdf_reader import pdf_bytes_to_text_ocr, pdf_to_text_ocr
from pdf_helper import generate_pdf
logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
load_dotenv(override=True)
llm = ChatOpenAI(model=os.getenv("MODEL"), temperature=0.3)
# -----------------------------
# SCHEMA DEFINITIONS
# -----------------------------
from typing import Optional, List
from pydantic import BaseModel, Field
import re
def safe_filename(name: str) -> str:
# Replace spaces with underscores
name = name.replace(" ", "_")
# Replace any non-alphanumeric / dash / underscore with "_"
name = re.sub(r"[^A-Za-z0-9_\-]", "_", name)
# Collapse multiple underscores
name = re.sub(r"_+", "_", name)
return name.strip("_")
import dateutil.parser
def parse_any_date(date_str):
if not date_str or pd.isna(date_str):
return pd.NaT
try:
return dateutil.parser.parse(str(date_str), dayfirst=False, fuzzy=True)
except Exception:
return pd.NaT
# prompt template
testname_standardizer_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are a medical assistant. Normalize lab test names."
"All outputs must use **title case** (e.g., 'Hemoglobin', 'Blood Glucose')."
"Return ONLY valid JSON where keys are original names and values are standardized names. DO NOT return markdown formatting like backquotes etc.",
),
(
"human",
"""Normalize the following lab test names to their standard medical equivalents.
Test names: {test_names}
""",
),
]
)
# chain = prompt → LLM → string
testname_standardizer_chain = testname_standardizer_prompt | llm
# -----------------------------
# GRAPH NODES
# -----------------------------
def send_message(state: SheamiState, msg: str, append: bool = True):
if append:
# append message
state["messages"].append(msg)
else:
# replace last message
state["messages"][-1] = msg
async def fn_init_node(state: SheamiState):
os.makedirs(SheamiConfig.get_output_dir(state["thread_id"]), exist_ok=True)
if "messages" not in state:
state["messages"] = []
send_message(state=state, msg="Initializing ...")
send_message(state=state, msg="Files received for processing ...", append=False)
for idx, report in enumerate(state["uploaded_reports"]):
send_message(
state=state,
msg=f"{idx+1}. <span class='highlighted-text'>{report.report_file_name}</span>",
)
state["standardized_reports"] = []
state["trends_json"] = {}
state["pdf_path"] = ""
state["current_index"] = -1
state["units_processed"] = 0
state["units_total"] = 0
state["process_desc"] = ""
state["overall_units_processed"] = 0
state["overall_units_total"] = 6 # 6 steps totally
state["milestones"] = []
run_id = await get_db().start_run(
user_email=state["user_email"],
patient_id=state["patient_id"],
source_file_names=[
report.report_file_name for report in state["uploaded_reports"]
],
source_file_contents=[
report.report_contents for report in state["uploaded_reports"]
],
)
state["run_id"] = run_id
send_message(state=state, msg=f"Initialized run [<span class='highlighted-text'>{run_id}</span>]")
return state
async def reset_process_desc(state: SheamiState, process_desc: str):
# close previous milestone
if len(state["milestones"]) > 0:
state["milestones"][-1].status = "completed"
state["milestones"][-1].end_time = datetime.now()
await get_db().add_or_update_milestone(
run_id=state["run_id"],
milestone=state["milestones"][-1].step_name,
status="completed",
end=True,
)
state["process_desc"] = process_desc
state["milestones"].append(
SheamiMilestone(
step_name=state["process_desc"], status="started", start_time=datetime.now()
)
)
state["units_processed"] = 0
state["units_total"] = 0
await get_db().add_or_update_milestone(
run_id=state["run_id"], milestone=state["process_desc"]
)
return state
async def fn_increment_index_node(state: SheamiState):
state["current_index"] += 1
total_reports = len(state["uploaded_reports"])
try:
report_file_name = state["uploaded_reports"][
state["current_index"]
].report_file_name
state["process_desc"] = (
f"Standardizing {state["current_index"]+1} of {total_reports} reports - {report_file_name} ..."
)
except:
pass
return state
async def call_llm(report: HealthReport, ocr: bool):
llm_structured = llm.with_structured_output(StandardizedReport)
ocr_instructions = """
The input is pre-parsed structured text from an OCR engine (output.STRING).
- Each line corresponds to one recognized piece of text.
- Do NOT merge unrelated lines together.
- Use each line to reconstruct tests faithfully without skipping.
- Do not hallucinate results or ranges; only use what is explicitly present.
"""
system_msg = f"""
You are a medical report parser.
Your job is to convert the raw lab report text into the given schema.
Important:
- Do not omit any test mentioned in the report.
- Every test name in the input must appear in the output schema exactly once.
- If a test panel has multiple sub-tests, ensure ALL are included.
- If unsure about a value, still include the test with result = null.
{ocr_instructions if ocr else ""}
- If the report contains a test panel (e.g., 'CUE - COMPLETE URINE ANALYSIS'),
break it down into its component sub-tests (e.g., pH, Specific Gravity, Protein, Glucose, Ketones, etc).
- Each sub-test must appear as an individual entry in the schema, with its own name, result, unit, and reference range.
- Do not summarize a panel as just 'positive/negative'. Capture all sub-results explicitly.
- Preserve the hierarchy but ensure sub-tests are separate objects.
"""
messages = [
SystemMessage(content=system_msg),
HumanMessage(
content=f"""Original report file name: {report.report_file_name}
--- BEGIN REPORT ---
{report.report_contents}
--- END REPORT ---"""
),
]
result: StandardizedReport = await llm_structured.ainvoke(messages)
return result
async def fn_standardize_current_report_node(state: SheamiState):
idx = state["current_index"]
report = state["uploaded_reports"][idx]
logger.info(
"%s| Standardizing report %s", state["thread_id"], report.report_file_name
)
send_message(
state=state,
msg=f"Standardizing report: {report.report_file_name}",
append=False,
)
result = await call_llm(report=report, ocr=False)
if not result.lab_results:
send_message(
state=state,
msg=f"⛔ Could not extract any data from PDF : {report.report_file_name}. Trying OCR ... might take a while",
append=False,
)
report.report_contents = pdf_to_text_ocr(
pdf_path=report.report_file_name_with_path
)
# logger.info("Parsed text using OCR: %s", report.report_contents)
run_stats_details = await get_db().get_run_stats_by_id(id=state["run_id"])
run_stats_details["source_file_contents"][state["current_index"]] = (
report.report_contents.replace("\\n", "\n")
)
await get_db().update_run_stats(
run_id=state["run_id"],
source_file_contents=run_stats_details["source_file_contents"],
)
result = await call_llm(report=report, ocr=True)
if not result.lab_results:
send_message(
state=state,
msg=f"⛔ OCR couldn't extract : {report.report_file_name}.",
append=False,
)
else:
send_message(
state=state,
msg=f"✅ Extracted <span class='highlighted-text'>{len(result.lab_results)}</span> lab results using OCR for report : <span class='highlighted-text'>{report.report_file_name}</span>.",
append=False,
)
else:
send_message(
state=state,
msg=f"✅ Extracted <span class='highlighted-text'>{len(result.lab_results)}</span> lab results from : <span class='highlighted-text'>{report.report_file_name}</span>.",
append=False,
)
state["standardized_reports"].append(result)
with open(
os.path.join(
SheamiConfig.get_output_dir(state["thread_id"]), f"report_{idx}.json"
),
"w",
encoding="utf-8",
) as f:
f.write(result.model_dump_json(indent=2))
state["units_processed"] = idx + 1
return state
# edge
def fn_is_report_available_to_process(state: SheamiState) -> str:
if state["current_index"] < len(state["uploaded_reports"]):
report = state["uploaded_reports"][state["current_index"]]
send_message(
state=state,
msg=f"⏳ Initiating report standardization for: <span class='highlighted-text'>{report.report_file_name}</span>",
append=state["current_index"] > 0,
)
return "continue"
else:
send_message(state=state, msg="Standardizing reports: finished")
return "done"
def get_unique_test_names(state: SheamiState):
test_names = set()
for report in state["standardized_reports"]:
for result in report.lab_results:
if hasattr(result, "test_name"): # Normal LabResult
test_names.add(result.test_name)
elif hasattr(result, "sub_results"): # CompositeLabResult
for sub in result.sub_results:
if hasattr(sub, "test_name"):
test_names.add(sub.test_name)
return list(test_names)
async def fn_testname_standardizer_node(state: SheamiState):
logger.info("%s| Standardizing Test Names: started", state["thread_id"])
send_message(state=state, msg="Standardizing Test Names: started", append=False)
# collect unique names
unique_names = get_unique_test_names(state)
# run through LLM
response = await testname_standardizer_chain.ainvoke({"test_names": unique_names})
raw_text = response.content
try:
normalization_map: Dict[str, str] = json.loads(raw_text)
except Exception as e:
print("Exception in normalization: ", e)
normalization_map = {name: name for name in unique_names} # fallback
# apply mapping back
for report in state["standardized_reports"]:
for comp_result in report.lab_results:
# normalize composite-level name if present
if getattr(comp_result, "test_name", None):
comp_result.test_name = normalization_map.get(
comp_result.test_name, comp_result.test_name
)
# normalize sub_results
if getattr(comp_result, "sub_results", None):
for sub in comp_result.sub_results:
if getattr(sub, "test_name", None):
sub.test_name = normalization_map.get(
sub.test_name, sub.test_name
)
logger.info("%s| Standardizing Test Names: finished", state["thread_id"])
send_message(
state=state,
msg=f"Identified <span class='highlighted-text'>{len(unique_names)}</span> unique tests",
append=False,
)
# send_message(state=state, msg="Standardizing Test Names: finished")
return state
async def fn_unit_normalizer_node(state: SheamiState):
logger.info("%s| Standardizing Units : started", state["thread_id"])
send_message(state=state, msg="Standardizing Units: started", append=False)
"""
Normalize units for lab test values across all standardized reports.
Example: 'gms/dL', 'gm%', 'G/DL' → 'g/dL'
"""
unit_map = {
"g/dl": "g/dL",
"gms/dl": "g/dL",
"gm%": "g/dL",
"g/dl.": "g/dL",
}
for report in state["standardized_reports"]:
for lr in report.lab_results:
# case 1: simple result
if hasattr(lr, "test_unit") and lr.test_unit:
normalized = lr.test_unit.lower().replace(" ", "")
lr.test_unit = unit_map.get(normalized, lr.test_unit)
# case 2: composite result with sub_results
if hasattr(lr, "sub_results") and lr.sub_results:
for sub in lr.sub_results:
if sub.test_unit:
normalized = sub.test_unit.lower().replace(" ", "")
sub.test_unit = unit_map.get(normalized, sub.test_unit)
logger.info("%s| Standardizing Units : finished", state["thread_id"])
send_message(state=state, msg="Standardizing Units: finished", append=False)
return state
async def fn_db_update_node(state: SheamiState):
## add parsed reports
report_id_list = await get_db().add_report_v2(
patient_id=state["patient_id"],
reports=state["standardized_reports"],
run_id=state["run_id"],
)
state["report_id_list"] = report_id_list
logger.info("report_id_list = %s", report_id_list)
for report_id in report_id_list.split(","):
await get_db().aggregate_trends_from_report(state["patient_id"], report_id)
return state
async def fn_trends_aggregator_node(state: SheamiState):
logger.info("%s| Aggregating Trends : started", state["thread_id"])
send_message(state=state, msg="Aggregating Trends : started", append=False)
import re
import os
import json
# Aggregation buckets
trends: dict[str, list[dict]] = {}
ref_ranges: dict[str, dict] = {}
def try_parse_numeric(value) -> float | None:
"""
Return a float only for clean numeric strings like '75', '75.2', or '12%'.
Avoids picking '0' out of '0-2 /hpf' etc.
"""
if value is None:
return None
s = str(value).strip()
# pure number
if re.fullmatch(r"[-+]?\d+(?:\.\d+)?", s):
try:
return float(s)
except ValueError:
return None
# percent like "12%"
m = re.fullmatch(r"([-+]?\d+(?:\.\d+)?)\s*%", s)
if m:
try:
return float(m.group(1))
except ValueError:
return None
return None
def add_point(
key: str,
date: str | None,
value: str,
unit: str | None,
rr: TestResultReferenceRange | None,
original_report_file_name: str,
):
num = try_parse_numeric(value)
trends.setdefault(key, []).append(
{
"date": date or "unknown",
"value": num if num is not None else value,
"is_numeric": num is not None,
"unit": unit or "",
"orig_report": original_report_file_name,
}
)
if rr and key not in ref_ranges:
ref_ranges[key] = {"min": rr.min, "max": rr.max}
total_reports = len(state["standardized_reports"])
for idx, report in enumerate(state["standardized_reports"]):
logger.info("%s| Aggregating Trends for report-%d", state["thread_id"], idx)
send_message(
state=state,
msg=f"Aggregating {idx+1}/{total_reports} trends : report-{idx+1}...",
append=False,
)
for item in report.lab_results:
# Case A: CompositeLabResult (e.g., CUE, LFT, etc.)
if hasattr(item, "sub_results") and item.sub_results:
panel = getattr(item, "section_name", "Panel")
for sub in item.sub_results:
key = f"{panel} · {sub.test_name}"
add_point(
key=key,
date=sub.test_date,
value=sub.result_value,
unit=sub.test_unit,
rr=sub.test_reference_range,
original_report_file_name=report.original_report_file_name,
)
# Case B: Simple LabResult
else:
key = item.test_name
add_point(
key=key,
date=item.test_date,
value=item.result_value,
unit=item.test_unit,
rr=item.test_reference_range,
original_report_file_name=report.original_report_file_name,
)
# Build trends JSON
state["trends_json"] = await get_db().get_trends_by_patient(
patient_id=state["patient_id"],
fields=["test_name", "trend_data", "test_reference_range", "inferred_range"],
serializable=True,
)
# Persist
output_dir = SheamiConfig.get_output_dir(state["thread_id"])
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, "trends.json"), "w", encoding="utf-8") as f:
json.dump(state["trends_json"], f, indent=1, ensure_ascii=False)
logger.info("%s| Aggregating Trends : finished", state["thread_id"])
send_message(state=state, msg="Aggregating Trends : finished", append=False)
return state
async def fn_interpreter_node(state: SheamiState):
logger.info("%s| Interpreting Trends : started", state["thread_id"])
send_message(state=state, msg="Interpreting Trends : started", append=False)
uploaded_reports = await get_db().get_reports_by_patient(
patient_id=state["patient_id"]
)
llm_input = json.dumps(
{
"patient_id": state["patient_id"],
"patient_info": await get_db().get_patient_by_id(
patient_id=state["patient_id"],
fields=["name", "dob", "gender"],
serializable=True,
),
"uploaded_reports": [report["file_name"] for report in uploaded_reports],
"trends_json": state["trends_json"],
},
indent=1,
)
# logger.info("llm_input = %s", llm_input)
report_date = datetime.now().strftime("%d %B %Y") # e.g., "22 August 2025"
# 1. LLM narrative
messages = [
SystemMessage(
content=(
"Interpret the following medical trends and produce a clean, structured **HTML** report without any markdown formatting like backquotes etc. "
"The report should have: "
f"1. A header that says report generated on : {report_date}."
"2. The names of the reports used to summarize this information."
"3. Patient summary (patient id, name, age, sex if available)"
"4. Test window (mention the from and to dates)"
"""
5. Trend summaries
Generate tables with the following columns:
- Test Name
- Most Recent Value, Previous Value, Older Value (use a hyphen "–" if a value is missing). Use these exact column names (do not call them latest value 1,2,or 3)
- Unit
- Reference Range
- Inference (latest value only): ✅ if within normal range, ▲ if above normal (high), ▼ if below normal (low)
- Trend Direction (across last 3 values): ⬆️ if values are rising, ⬇️ if values are falling, ➖ (or ✅) if stable/normal
"""
"6. Clinical insights. \n"
"\nImportant Rules:\n"
"- Format tables in proper <table> with <tr>, <th>, <td>. "
"- Do not include charts, they will be programmatically added."
"""
5. Trend summaries
Generate HTML tables with the following structure and formatting rules:
Columns:
- Test Name
- Latest Value 1, Latest Value 2, Latest Value 3 (use a hyphen "–" if a value is missing)
- Unit
- Reference Range
- Inference (latest value only): ✅ if within normal range, ▲ if above normal (high), ▼ if below normal (low)
- Trend Direction (across last 3 values): ⬆️ if values are rising, ⬇️ if values are falling, ➖ (or ✅) if stable/normal
Formatting requirements:
- The HTML will be shown in a UI (`gr.HTML`) and also rendered to PDF via WeasyPrint.
- The table must ALWAYS fit within 100% of the container width. Do not allow horizontal scrolling, clipping, or overlapping columns.
- Use `table-layout: fixed;` and `<colgroup>` with percentage widths that sum to 100%.
- Allow text wrapping inside cells so narrow columns still display all content.
- Example CSS to embed at the top of the HTML:
<style>
table { width: 100%; border-collapse: collapse; table-layout: fixed; }
col { }
th, td {
font-size: 11px;
padding: 4px 6px;
white-space: normal;
word-break: break-word;
}
</style>
- Example `<colgroup>` (adjust if needed):
<colgroup>
<col style="width:20%"> <!-- Test Name -->
<col style="width:8%"> <!-- Latest Value 1 -->
<col style="width:8%"> <!-- Latest Value 2 -->
<col style="width:8%"> <!-- Latest Value 3 -->
<col style="width:8%"> <!-- Unit -->
<col style="width:16%"> <!-- Reference Range -->
<col style="width:16%"> <!-- Inference -->
<col style="width:16%"> <!-- Trend Direction -->
</colgroup>
"""
)
),
HumanMessage(content=llm_input),
]
response = await llm.ainvoke(messages)
interpretation_html = response.content # ✅ already HTML now
# 2. Generate plots for each parameter
plots_dir = os.path.join(SheamiConfig.get_output_dir(state["thread_id"]), "plots")
os.makedirs(plots_dir, exist_ok=True)
plot_files = []
for param in sorted(state["trends_json"], key=lambda x: x["test_name"]):
test_name = param["test_name"]
values = param["trend_data"]
x = [parse_any_date(v["date"]) for v in values]
x = pd.to_datetime(x, errors="coerce")
try:
y = [float(v["value"]) for v in values]
except ValueError:
continue # skip non-numeric
df_plot = pd.DataFrame({"x": x, "y": y}).dropna(subset=["x"]).sort_values("x")
x, y = df_plot["x"].to_numpy(), df_plot["y"].to_numpy()
plt.figure(figsize=(6, 4))
plt.plot(x, y, marker="o", linestyle="-", label="Observed values")
ref = param.get("test_reference_range")
if ref:
ymin, ymax = ref.get("min"), ref.get("max")
if ymin is not None and ymax is not None:
plt.axhspan(
ymin, ymax, color="green", alpha=0.2, label="Reference range"
)
elif ymax is not None:
plt.axhline(
y=ymax, color="red", linestyle="--", label="Upper threshold"
)
elif ymin is not None:
plt.axhline(
y=ymin, color="blue", linestyle="--", label="Lower threshold"
)
# plt.title(f"{test_name} Trend")
plt.xlabel("Date")
plt.ylabel(values[0].get("unit", "") if values else "")
plt.grid(True)
plt.xticks(rotation=45)
plt.legend()
plt.tight_layout()
filename = f"{safe_filename(test_name).replace(' ', '_')}_trend.png"
filepath = os.path.join(plots_dir, filename)
plt.savefig(filepath)
plt.close()
plot_files.append((test_name, filepath))
# 3. Build PDF
pdf_path = os.path.join(
SheamiConfig.get_output_dir(state["thread_id"]), "final_report.pdf"
)
generate_pdf(
pdf_path=pdf_path,
interpretation_html=interpretation_html, # ✅ HTML input
plot_files=plot_files,
)
# Save state
state["pdf_path"] = pdf_path
state["interpretation_html"] = interpretation_html
logger.info("%s| Interpreting Trends : finished", state["thread_id"])
send_message(state=state, msg="Interpreting Trends : finished", append=False)
return state
async def fn_final_cleanup_node(state: SheamiState):
pdf_path = state["pdf_path"]
schedule_cleanup(file_path=SheamiConfig.get_output_dir(state["thread_id"]))
state["milestones"][-1].status = "completed"
state["milestones"][-1].end_time = datetime.now()
await get_db().add_or_update_milestone(
run_id=state["run_id"],
milestone=state["process_desc"],
status="completed",
end=True,
)
await get_db().update_run_stats(run_id=state["run_id"], status="completed")
# add final report
# Save PDF along with metadata
with open(pdf_path, "rb") as f:
pdf_bytes = f.read()
final_report_id = await get_db().add_final_report_v2(
patient_id=state["patient_id"],
summary=state["interpretation_html"],
pdf_bytes=pdf_bytes,
file_name=f"health_trends_report_{state["patient_id"]}.pdf",
)
logger.info("final_report_id = %s", final_report_id)
def schedule_cleanup(file_path, delay=300): # 300 sec = 5 min
def cleanup():
time.sleep(delay)
if os.path.exists(file_path):
try:
if os.path.isdir(file_path):
import shutil
shutil.rmtree(file_path)
else:
os.remove(file_path)
print(f"Cleaned up: {file_path}")
except Exception as e:
print(f"Cleanup failed for {file_path}: {e}")
threading.Thread(target=cleanup, daemon=True).start()
# -----------------------------
# GRAPH CREATION
# -----------------------------
async def fn_standardizer_node_notifier(state: SheamiState):
state = await reset_process_desc(state, process_desc="Standardizing reports ...")
state["units_total"] = len(state["uploaded_reports"])
send_message(
state=state, msg="Standardizing reports now ... this might take a while ..."
)
state["overall_units_processed"] += 1
return state
async def fn_testname_standardizer_node_notifier(state: SheamiState):
state = await reset_process_desc(state, process_desc="Standardizing test names ...")
send_message(state=state, msg="Standardizing test names now ...")
state["overall_units_processed"] += 1
return state
async def fn_unit_normalizer_node_notifier(state: SheamiState):
state = await reset_process_desc(state, process_desc="Standardizing units ...")
send_message(state=state, msg="Standardizing measurement units now ...")
state["overall_units_processed"] += 1
return state
async def fn_trends_aggregator_node_notifier(state: SheamiState):
state = await reset_process_desc(state, process_desc="Aggregating trends ...")
send_message(state=state, msg="Aggregating trends now ...")
state["overall_units_processed"] += 1
return state
async def fn_interpreter_node_notifier(state: SheamiState):
state = await reset_process_desc(state, process_desc="Plotting trends ...")
send_message(state=state, msg="Interpreting and plotting trends now ...")
state["overall_units_processed"] += 1
return state
def create_graph(user_email: str, patient_id: str, thread_id: str):
logger.info(
"%s| Creating Graph : started for user:%s | patient:%s",
thread_id,
user_email,
patient_id,
)
memory = InMemorySaver()
workflow = StateGraph(SheamiState)
workflow.add_node("init", fn_init_node)
workflow.add_node("standardize_current_report", fn_standardize_current_report_node)
workflow.add_node("increment_index", fn_increment_index_node)
workflow.add_node("testname_standardizer", fn_testname_standardizer_node)
workflow.add_node("unit_normalizer", fn_unit_normalizer_node)
workflow.add_node("db_update_node", fn_db_update_node)
workflow.add_node("trends", fn_trends_aggregator_node)
workflow.add_node("interpreter", fn_interpreter_node)
workflow.add_node("standardizer_notifier", fn_standardizer_node_notifier)
workflow.add_node(
"testname_standardizer_notifier", fn_testname_standardizer_node_notifier
)
workflow.add_node("unit_normalizer_notifier", fn_unit_normalizer_node_notifier)
workflow.add_node("trends_notifier", fn_trends_aggregator_node_notifier)
workflow.add_node("interpreter_notifier", fn_interpreter_node_notifier)
workflow.add_node("final_cleanup_node", fn_final_cleanup_node)
workflow.add_edge(START, "init")
workflow.add_edge("init", "standardizer_notifier")
workflow.add_edge("standardizer_notifier", "increment_index")
# loop back if continue
workflow.add_conditional_edges(
"increment_index",
fn_is_report_available_to_process,
{
"continue": "standardize_current_report",
"done": "testname_standardizer_notifier",
},
)
workflow.add_edge("standardize_current_report", "increment_index")
workflow.add_edge("testname_standardizer_notifier", "testname_standardizer")
workflow.add_edge("testname_standardizer", "unit_normalizer_notifier")
workflow.add_edge("unit_normalizer_notifier", "unit_normalizer")
workflow.add_edge("unit_normalizer", "db_update_node")
workflow.add_edge("db_update_node", "trends_notifier")
workflow.add_edge("trends_notifier", "trends")
workflow.add_edge("trends", "interpreter_notifier")
workflow.add_edge("interpreter_notifier", "interpreter")
workflow.add_edge("interpreter", "final_cleanup_node")
workflow.add_edge("final_cleanup_node", END)
logger.info("%s| Creating Graph : finished", thread_id)
return workflow.compile(checkpointer=memory)
|