from datetime import datetime import threading import time from bson import ObjectId import pandas as pd from langchain_core.prompts import ChatPromptTemplate import matplotlib.pyplot as plt from dataclasses import dataclass from typing import Dict, List, Literal, Optional, TypedDict, Union import os, json from pydantic import BaseModel from langchain_core.messages import HumanMessage, SystemMessage from langgraph.checkpoint.memory import InMemorySaver from langgraph.graph.message import StateGraph from langgraph.graph.state import START, END from langchain_openai import ChatOpenAI from dotenv import load_dotenv from common import get_db from config import SheamiConfig import logging from modules.models import ( HealthReport, SheamiMilestone, SheamiState, StandardizedReport, TestResultReferenceRange, ) from pdf_reader import pdf_bytes_to_text_ocr, pdf_to_text_ocr from pdf_helper import generate_pdf logging.basicConfig() logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) load_dotenv(override=True) llm = ChatOpenAI(model=os.getenv("MODEL"), temperature=0.3) # ----------------------------- # SCHEMA DEFINITIONS # ----------------------------- from typing import Optional, List from pydantic import BaseModel, Field import re def safe_filename(name: str) -> str: # Replace spaces with underscores name = name.replace(" ", "_") # Replace any non-alphanumeric / dash / underscore with "_" name = re.sub(r"[^A-Za-z0-9_\-]", "_", name) # Collapse multiple underscores name = re.sub(r"_+", "_", name) return name.strip("_") import dateutil.parser def parse_any_date(date_str): if not date_str or pd.isna(date_str): return pd.NaT try: return dateutil.parser.parse(str(date_str), dayfirst=False, fuzzy=True) except Exception: return pd.NaT # prompt template testname_standardizer_prompt = ChatPromptTemplate.from_messages( [ ( "system", "You are a medical assistant. Normalize lab test names." "All outputs must use **title case** (e.g., 'Hemoglobin', 'Blood Glucose')." "Return ONLY valid JSON where keys are original names and values are standardized names. DO NOT return markdown formatting like backquotes etc.", ), ( "human", """Normalize the following lab test names to their standard medical equivalents. Test names: {test_names} """, ), ] ) # chain = prompt → LLM → string testname_standardizer_chain = testname_standardizer_prompt | llm # ----------------------------- # GRAPH NODES # ----------------------------- def send_message(state: SheamiState, msg: str, append: bool = True): if append: # append message state["messages"].append(msg) else: # replace last message state["messages"][-1] = msg async def fn_init_node(state: SheamiState): os.makedirs(SheamiConfig.get_output_dir(state["thread_id"]), exist_ok=True) if "messages" not in state: state["messages"] = [] send_message(state=state, msg="Initializing ...") send_message(state=state, msg="Files received for processing ...", append=False) for idx, report in enumerate(state["uploaded_reports"]): send_message( state=state, msg=f"{idx+1}. {report.report_file_name}", ) state["standardized_reports"] = [] state["trends_json"] = {} state["pdf_path"] = "" state["current_index"] = -1 state["units_processed"] = 0 state["units_total"] = 0 state["process_desc"] = "" state["overall_units_processed"] = 0 state["overall_units_total"] = 6 # 6 steps totally state["milestones"] = [] run_id = await get_db().start_run( user_email=state["user_email"], patient_id=state["patient_id"], source_file_names=[ report.report_file_name for report in state["uploaded_reports"] ], source_file_contents=[ report.report_contents for report in state["uploaded_reports"] ], ) state["run_id"] = run_id send_message(state=state, msg=f"Initialized run [{run_id}]") return state async def reset_process_desc(state: SheamiState, process_desc: str): # close previous milestone if len(state["milestones"]) > 0: state["milestones"][-1].status = "completed" state["milestones"][-1].end_time = datetime.now() await get_db().add_or_update_milestone( run_id=state["run_id"], milestone=state["milestones"][-1].step_name, status="completed", end=True, ) state["process_desc"] = process_desc state["milestones"].append( SheamiMilestone( step_name=state["process_desc"], status="started", start_time=datetime.now() ) ) state["units_processed"] = 0 state["units_total"] = 0 await get_db().add_or_update_milestone( run_id=state["run_id"], milestone=state["process_desc"] ) return state async def fn_increment_index_node(state: SheamiState): state["current_index"] += 1 total_reports = len(state["uploaded_reports"]) try: report_file_name = state["uploaded_reports"][ state["current_index"] ].report_file_name state["process_desc"] = ( f"Standardizing {state["current_index"]+1} of {total_reports} reports - {report_file_name} ..." ) except: pass return state async def call_llm(report: HealthReport, ocr: bool): llm_structured = llm.with_structured_output(StandardizedReport) ocr_instructions = """ The input is pre-parsed structured text from an OCR engine (output.STRING). - Each line corresponds to one recognized piece of text. - Do NOT merge unrelated lines together. - Use each line to reconstruct tests faithfully without skipping. - Do not hallucinate results or ranges; only use what is explicitly present. """ system_msg = f""" You are a medical report parser. Your job is to convert the raw lab report text into the given schema. Important: - Do not omit any test mentioned in the report. - Every test name in the input must appear in the output schema exactly once. - If a test panel has multiple sub-tests, ensure ALL are included. - If unsure about a value, still include the test with result = null. {ocr_instructions if ocr else ""} - If the report contains a test panel (e.g., 'CUE - COMPLETE URINE ANALYSIS'), break it down into its component sub-tests (e.g., pH, Specific Gravity, Protein, Glucose, Ketones, etc). - Each sub-test must appear as an individual entry in the schema, with its own name, result, unit, and reference range. - Do not summarize a panel as just 'positive/negative'. Capture all sub-results explicitly. - Preserve the hierarchy but ensure sub-tests are separate objects. """ messages = [ SystemMessage(content=system_msg), HumanMessage( content=f"""Original report file name: {report.report_file_name} --- BEGIN REPORT --- {report.report_contents} --- END REPORT ---""" ), ] result: StandardizedReport = await llm_structured.ainvoke(messages) return result async def fn_standardize_current_report_node(state: SheamiState): idx = state["current_index"] report = state["uploaded_reports"][idx] logger.info( "%s| Standardizing report %s", state["thread_id"], report.report_file_name ) send_message( state=state, msg=f"Standardizing report: {report.report_file_name}", append=False, ) result = await call_llm(report=report, ocr=False) if not result.lab_results: send_message( state=state, msg=f"⛔ Could not extract any data from PDF : {report.report_file_name}. Trying OCR ... might take a while", append=False, ) report.report_contents = pdf_to_text_ocr( pdf_path=report.report_file_name_with_path ) # logger.info("Parsed text using OCR: %s", report.report_contents) run_stats_details = await get_db().get_run_stats_by_id(id=state["run_id"]) run_stats_details["source_file_contents"][state["current_index"]] = ( report.report_contents.replace("\\n", "\n") ) await get_db().update_run_stats( run_id=state["run_id"], source_file_contents=run_stats_details["source_file_contents"], ) result = await call_llm(report=report, ocr=True) if not result.lab_results: send_message( state=state, msg=f"⛔ OCR couldn't extract : {report.report_file_name}.", append=False, ) else: send_message( state=state, msg=f"✅ Extracted {len(result.lab_results)} lab results using OCR for report : {report.report_file_name}.", append=False, ) else: send_message( state=state, msg=f"✅ Extracted {len(result.lab_results)} lab results from : {report.report_file_name}.", append=False, ) state["standardized_reports"].append(result) with open( os.path.join( SheamiConfig.get_output_dir(state["thread_id"]), f"report_{idx}.json" ), "w", encoding="utf-8", ) as f: f.write(result.model_dump_json(indent=2)) state["units_processed"] = idx + 1 return state # edge def fn_is_report_available_to_process(state: SheamiState) -> str: if state["current_index"] < len(state["uploaded_reports"]): report = state["uploaded_reports"][state["current_index"]] send_message( state=state, msg=f"⏳ Initiating report standardization for: {report.report_file_name}", append=state["current_index"] > 0, ) return "continue" else: send_message(state=state, msg="Standardizing reports: finished") return "done" def get_unique_test_names(state: SheamiState): test_names = set() for report in state["standardized_reports"]: for result in report.lab_results: if hasattr(result, "test_name"): # Normal LabResult test_names.add(result.test_name) elif hasattr(result, "sub_results"): # CompositeLabResult for sub in result.sub_results: if hasattr(sub, "test_name"): test_names.add(sub.test_name) return list(test_names) async def fn_testname_standardizer_node(state: SheamiState): logger.info("%s| Standardizing Test Names: started", state["thread_id"]) send_message(state=state, msg="Standardizing Test Names: started", append=False) # collect unique names unique_names = get_unique_test_names(state) # run through LLM response = await testname_standardizer_chain.ainvoke({"test_names": unique_names}) raw_text = response.content try: normalization_map: Dict[str, str] = json.loads(raw_text) except Exception as e: print("Exception in normalization: ", e) normalization_map = {name: name for name in unique_names} # fallback # apply mapping back for report in state["standardized_reports"]: for comp_result in report.lab_results: # normalize composite-level name if present if getattr(comp_result, "test_name", None): comp_result.test_name = normalization_map.get( comp_result.test_name, comp_result.test_name ) # normalize sub_results if getattr(comp_result, "sub_results", None): for sub in comp_result.sub_results: if getattr(sub, "test_name", None): sub.test_name = normalization_map.get( sub.test_name, sub.test_name ) logger.info("%s| Standardizing Test Names: finished", state["thread_id"]) send_message( state=state, msg=f"Identified {len(unique_names)} unique tests", append=False, ) # send_message(state=state, msg="Standardizing Test Names: finished") return state async def fn_unit_normalizer_node(state: SheamiState): logger.info("%s| Standardizing Units : started", state["thread_id"]) send_message(state=state, msg="Standardizing Units: started", append=False) """ Normalize units for lab test values across all standardized reports. Example: 'gms/dL', 'gm%', 'G/DL' → 'g/dL' """ unit_map = { "g/dl": "g/dL", "gms/dl": "g/dL", "gm%": "g/dL", "g/dl.": "g/dL", } for report in state["standardized_reports"]: for lr in report.lab_results: # case 1: simple result if hasattr(lr, "test_unit") and lr.test_unit: normalized = lr.test_unit.lower().replace(" ", "") lr.test_unit = unit_map.get(normalized, lr.test_unit) # case 2: composite result with sub_results if hasattr(lr, "sub_results") and lr.sub_results: for sub in lr.sub_results: if sub.test_unit: normalized = sub.test_unit.lower().replace(" ", "") sub.test_unit = unit_map.get(normalized, sub.test_unit) logger.info("%s| Standardizing Units : finished", state["thread_id"]) send_message(state=state, msg="Standardizing Units: finished", append=False) return state async def fn_db_update_node(state: SheamiState): ## add parsed reports report_id_list = await get_db().add_report_v2( patient_id=state["patient_id"], reports=state["standardized_reports"], run_id=state["run_id"], ) state["report_id_list"] = report_id_list logger.info("report_id_list = %s", report_id_list) for report_id in report_id_list.split(","): await get_db().aggregate_trends_from_report(state["patient_id"], report_id) return state async def fn_trends_aggregator_node(state: SheamiState): logger.info("%s| Aggregating Trends : started", state["thread_id"]) send_message(state=state, msg="Aggregating Trends : started", append=False) import re import os import json # Aggregation buckets trends: dict[str, list[dict]] = {} ref_ranges: dict[str, dict] = {} def try_parse_numeric(value) -> float | None: """ Return a float only for clean numeric strings like '75', '75.2', or '12%'. Avoids picking '0' out of '0-2 /hpf' etc. """ if value is None: return None s = str(value).strip() # pure number if re.fullmatch(r"[-+]?\d+(?:\.\d+)?", s): try: return float(s) except ValueError: return None # percent like "12%" m = re.fullmatch(r"([-+]?\d+(?:\.\d+)?)\s*%", s) if m: try: return float(m.group(1)) except ValueError: return None return None def add_point( key: str, date: str | None, value: str, unit: str | None, rr: TestResultReferenceRange | None, original_report_file_name: str, ): num = try_parse_numeric(value) trends.setdefault(key, []).append( { "date": date or "unknown", "value": num if num is not None else value, "is_numeric": num is not None, "unit": unit or "", "orig_report": original_report_file_name, } ) if rr and key not in ref_ranges: ref_ranges[key] = {"min": rr.min, "max": rr.max} total_reports = len(state["standardized_reports"]) for idx, report in enumerate(state["standardized_reports"]): logger.info("%s| Aggregating Trends for report-%d", state["thread_id"], idx) send_message( state=state, msg=f"Aggregating {idx+1}/{total_reports} trends : report-{idx+1}...", append=False, ) for item in report.lab_results: # Case A: CompositeLabResult (e.g., CUE, LFT, etc.) if hasattr(item, "sub_results") and item.sub_results: panel = getattr(item, "section_name", "Panel") for sub in item.sub_results: key = f"{panel} · {sub.test_name}" add_point( key=key, date=sub.test_date, value=sub.result_value, unit=sub.test_unit, rr=sub.test_reference_range, original_report_file_name=report.original_report_file_name, ) # Case B: Simple LabResult else: key = item.test_name add_point( key=key, date=item.test_date, value=item.result_value, unit=item.test_unit, rr=item.test_reference_range, original_report_file_name=report.original_report_file_name, ) # Build trends JSON state["trends_json"] = await get_db().get_trends_by_patient( patient_id=state["patient_id"], fields=["test_name", "trend_data", "test_reference_range", "inferred_range"], serializable=True, ) # Persist output_dir = SheamiConfig.get_output_dir(state["thread_id"]) os.makedirs(output_dir, exist_ok=True) with open(os.path.join(output_dir, "trends.json"), "w", encoding="utf-8") as f: json.dump(state["trends_json"], f, indent=1, ensure_ascii=False) logger.info("%s| Aggregating Trends : finished", state["thread_id"]) send_message(state=state, msg="Aggregating Trends : finished", append=False) return state async def fn_interpreter_node(state: SheamiState): logger.info("%s| Interpreting Trends : started", state["thread_id"]) send_message(state=state, msg="Interpreting Trends : started", append=False) uploaded_reports = await get_db().get_reports_by_patient( patient_id=state["patient_id"] ) llm_input = json.dumps( { "patient_id": state["patient_id"], "patient_info": await get_db().get_patient_by_id( patient_id=state["patient_id"], fields=["name", "dob", "gender"], serializable=True, ), "uploaded_reports": [report["file_name"] for report in uploaded_reports], "trends_json": state["trends_json"], }, indent=1, ) # logger.info("llm_input = %s", llm_input) report_date = datetime.now().strftime("%d %B %Y") # e.g., "22 August 2025" # 1. LLM narrative messages = [ SystemMessage( content=( "Interpret the following medical trends and produce a clean, structured **HTML** report without any markdown formatting like backquotes etc. " "The report should have: " f"1. A header that says report generated on : {report_date}." "2. The names of the reports used to summarize this information." "3. Patient summary (patient id, name, age, sex if available)" "4. Test window (mention the from and to dates)" """ 5. Trend summaries Generate tables with the following columns: - Test Name - Most Recent Value, Previous Value, Older Value (use a hyphen "–" if a value is missing). Use these exact column names (do not call them latest value 1,2,or 3) - Unit - Reference Range - Inference (latest value only): ✅ if within normal range, ▲ if above normal (high), ▼ if below normal (low) - Trend Direction (across last 3 values): ⬆️ if values are rising, ⬇️ if values are falling, ➖ (or ✅) if stable/normal """ "6. Clinical insights. \n" "\nImportant Rules:\n" "- Format tables in proper
| , | . "
"- Do not include charts, they will be programmatically added."
"""
5. Trend summaries
Generate HTML tables with the following structure and formatting rules:
Columns:
- Test Name
- Latest Value 1, Latest Value 2, Latest Value 3 (use a hyphen "–" if a value is missing)
- Unit
- Reference Range
- Inference (latest value only): ✅ if within normal range, ▲ if above normal (high), ▼ if below normal (low)
- Trend Direction (across last 3 values): ⬆️ if values are rising, ⬇️ if values are falling, ➖ (or ✅) if stable/normal
Formatting requirements:
- The HTML will be shown in a UI (`gr.HTML`) and also rendered to PDF via WeasyPrint.
- The table must ALWAYS fit within 100% of the container width. Do not allow horizontal scrolling, clipping, or overlapping columns.
- Use `table-layout: fixed;` and ` |
|---|