Upload folder using huggingface_hub
Browse files- .github/workflows/update_space.yml +28 -28
- app.py +58 -58
- config.py +8 -8
- graph.py +448 -448
- models.py +14 -14
- pdf_reader.py +16 -16
.github/workflows/update_space.yml
CHANGED
|
@@ -1,28 +1,28 @@
|
|
| 1 |
-
name: Run Python script
|
| 2 |
-
|
| 3 |
-
on:
|
| 4 |
-
push:
|
| 5 |
-
branches:
|
| 6 |
-
- main
|
| 7 |
-
|
| 8 |
-
jobs:
|
| 9 |
-
build:
|
| 10 |
-
runs-on: ubuntu-latest
|
| 11 |
-
|
| 12 |
-
steps:
|
| 13 |
-
- name: Checkout
|
| 14 |
-
uses: actions/checkout@v2
|
| 15 |
-
|
| 16 |
-
- name: Set up Python
|
| 17 |
-
uses: actions/setup-python@v2
|
| 18 |
-
with:
|
| 19 |
-
python-version: '3.9'
|
| 20 |
-
|
| 21 |
-
- name: Install Gradio
|
| 22 |
-
run: python -m pip install gradio
|
| 23 |
-
|
| 24 |
-
- name: Log in to Hugging Face
|
| 25 |
-
run: python -c 'import huggingface_hub; huggingface_hub.login(token="${{ secrets.hf_token }}")'
|
| 26 |
-
|
| 27 |
-
- name: Deploy to Spaces
|
| 28 |
-
run: gradio deploy
|
|
|
|
| 1 |
+
name: Run Python script
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
branches:
|
| 6 |
+
- main
|
| 7 |
+
|
| 8 |
+
jobs:
|
| 9 |
+
build:
|
| 10 |
+
runs-on: ubuntu-latest
|
| 11 |
+
|
| 12 |
+
steps:
|
| 13 |
+
- name: Checkout
|
| 14 |
+
uses: actions/checkout@v2
|
| 15 |
+
|
| 16 |
+
- name: Set up Python
|
| 17 |
+
uses: actions/setup-python@v2
|
| 18 |
+
with:
|
| 19 |
+
python-version: '3.9'
|
| 20 |
+
|
| 21 |
+
- name: Install Gradio
|
| 22 |
+
run: python -m pip install gradio
|
| 23 |
+
|
| 24 |
+
- name: Log in to Hugging Face
|
| 25 |
+
run: python -c 'import huggingface_hub; huggingface_hub.login(token="${{ secrets.hf_token }}")'
|
| 26 |
+
|
| 27 |
+
- name: Deploy to Spaces
|
| 28 |
+
run: gradio deploy
|
app.py
CHANGED
|
@@ -1,58 +1,58 @@
|
|
| 1 |
-
import gradio as gr
|
| 2 |
-
import uuid
|
| 3 |
-
import os
|
| 4 |
-
from graph import create_graph, SheamiState, HealthReport
|
| 5 |
-
from pdf_reader import read_pdf
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
def process_reports(files):
|
| 9 |
-
if not files:
|
| 10 |
-
return "Please upload at least one PDF file."
|
| 11 |
-
|
| 12 |
-
thread_id = str(uuid.uuid4())
|
| 13 |
-
# Create workflow
|
| 14 |
-
workflow = create_graph(thread_id=thread_id)
|
| 15 |
-
|
| 16 |
-
# Convert uploaded PDFs into HealthReport objects
|
| 17 |
-
uploaded_reports = []
|
| 18 |
-
for file in files:
|
| 19 |
-
file_path = file.name
|
| 20 |
-
contents = read_pdf(file_path)
|
| 21 |
-
uploaded_reports.append(
|
| 22 |
-
HealthReport(
|
| 23 |
-
report_file_name=os.path.basename(file_path), report_contents=contents
|
| 24 |
-
)
|
| 25 |
-
)
|
| 26 |
-
|
| 27 |
-
# Run workflow
|
| 28 |
-
# Create initial state
|
| 29 |
-
state = SheamiState(uploaded_reports=uploaded_reports, thread_id=thread_id)
|
| 30 |
-
|
| 31 |
-
config = {"configurable": {"thread_id": thread_id}}
|
| 32 |
-
response = workflow.invoke(state, config=config)
|
| 33 |
-
return (
|
| 34 |
-
f"✅ Processed {len(files)} reports.\n"
|
| 35 |
-
"Please download the output file from below within 5 min."
|
| 36 |
-
), response["interpreted_report"]
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
# Build Gradio UI
|
| 40 |
-
with gr.Blocks() as demo:
|
| 41 |
-
gr.Markdown("## 🩺 Sheami - Smart Healthcare Engine for Artificial Medical Intelligence")
|
| 42 |
-
|
| 43 |
-
with gr.Row():
|
| 44 |
-
file_input = gr.File(
|
| 45 |
-
file_types=[".pdf"],
|
| 46 |
-
type="filepath",
|
| 47 |
-
file_count="multiple",
|
| 48 |
-
label="Upload your Lab Reports (PDF)",
|
| 49 |
-
)
|
| 50 |
-
|
| 51 |
-
run_btn = gr.Button("Process Reports", variant="primary")
|
| 52 |
-
output_box = gr.Textbox(label="Processing Output", lines=2)
|
| 53 |
-
pdf_output = gr.File(label="Generated Report")
|
| 54 |
-
|
| 55 |
-
run_btn.click(process_reports, inputs=file_input, outputs=[output_box, pdf_output])
|
| 56 |
-
|
| 57 |
-
if __name__ == "__main__":
|
| 58 |
-
demo.launch()
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import uuid
|
| 3 |
+
import os
|
| 4 |
+
from graph import create_graph, SheamiState, HealthReport
|
| 5 |
+
from pdf_reader import read_pdf
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def process_reports(files):
|
| 9 |
+
if not files:
|
| 10 |
+
return "Please upload at least one PDF file."
|
| 11 |
+
|
| 12 |
+
thread_id = str(uuid.uuid4())
|
| 13 |
+
# Create workflow
|
| 14 |
+
workflow = create_graph(thread_id=thread_id)
|
| 15 |
+
|
| 16 |
+
# Convert uploaded PDFs into HealthReport objects
|
| 17 |
+
uploaded_reports = []
|
| 18 |
+
for file in files:
|
| 19 |
+
file_path = file.name
|
| 20 |
+
contents = read_pdf(file_path)
|
| 21 |
+
uploaded_reports.append(
|
| 22 |
+
HealthReport(
|
| 23 |
+
report_file_name=os.path.basename(file_path), report_contents=contents
|
| 24 |
+
)
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
# Run workflow
|
| 28 |
+
# Create initial state
|
| 29 |
+
state = SheamiState(uploaded_reports=uploaded_reports, thread_id=thread_id)
|
| 30 |
+
|
| 31 |
+
config = {"configurable": {"thread_id": thread_id}}
|
| 32 |
+
response = workflow.invoke(state, config=config)
|
| 33 |
+
return (
|
| 34 |
+
f"✅ Processed {len(files)} reports.\n"
|
| 35 |
+
"Please download the output file from below within 5 min."
|
| 36 |
+
), response["interpreted_report"]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# Build Gradio UI
|
| 40 |
+
with gr.Blocks() as demo:
|
| 41 |
+
gr.Markdown("## 🩺 Sheami - Smart Healthcare Engine for Artificial Medical Intelligence")
|
| 42 |
+
|
| 43 |
+
with gr.Row():
|
| 44 |
+
file_input = gr.File(
|
| 45 |
+
file_types=[".pdf"],
|
| 46 |
+
type="filepath",
|
| 47 |
+
file_count="multiple",
|
| 48 |
+
label="Upload your Lab Reports (PDF)",
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
run_btn = gr.Button("Process Reports", variant="primary")
|
| 52 |
+
output_box = gr.Textbox(label="Processing Output", lines=2)
|
| 53 |
+
pdf_output = gr.File(label="Generated Report")
|
| 54 |
+
|
| 55 |
+
run_btn.click(process_reports, inputs=file_input, outputs=[output_box, pdf_output])
|
| 56 |
+
|
| 57 |
+
if __name__ == "__main__":
|
| 58 |
+
demo.launch()
|
config.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
-
import os
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
class SheamiConfig:
|
| 5 |
-
_output_dir = "./output"
|
| 6 |
-
data_dir = "./data"
|
| 7 |
-
|
| 8 |
-
def get_output_dir(thread_id:str):
|
| 9 |
return os.path.join(SheamiConfig._output_dir, thread_id)
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class SheamiConfig:
|
| 5 |
+
_output_dir = "./output"
|
| 6 |
+
data_dir = "./data"
|
| 7 |
+
|
| 8 |
+
def get_output_dir(thread_id:str):
|
| 9 |
return os.path.join(SheamiConfig._output_dir, thread_id)
|
graph.py
CHANGED
|
@@ -1,448 +1,448 @@
|
|
| 1 |
-
import threading
|
| 2 |
-
import time
|
| 3 |
-
from numpy import number
|
| 4 |
-
import pandas as pd
|
| 5 |
-
from langchain_core.prompts import ChatPromptTemplate
|
| 6 |
-
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Image
|
| 7 |
-
from reportlab.lib.pagesizes import A4
|
| 8 |
-
from reportlab.lib.styles import getSampleStyleSheet
|
| 9 |
-
from reportlab.lib.units import inch
|
| 10 |
-
import matplotlib.pyplot as plt
|
| 11 |
-
from dataclasses import dataclass
|
| 12 |
-
from typing import Dict, List, Literal, Optional, TypedDict
|
| 13 |
-
import os, json
|
| 14 |
-
from pydantic import BaseModel
|
| 15 |
-
from langchain_core.messages import HumanMessage, SystemMessage
|
| 16 |
-
from langgraph.checkpoint.memory import InMemorySaver
|
| 17 |
-
from langgraph.graph.message import StateGraph
|
| 18 |
-
from langgraph.graph.state import START, END
|
| 19 |
-
from langchain_openai import ChatOpenAI
|
| 20 |
-
from dotenv import load_dotenv
|
| 21 |
-
from config import SheamiConfig
|
| 22 |
-
import logging
|
| 23 |
-
|
| 24 |
-
logging.basicConfig()
|
| 25 |
-
logger = logging.getLogger(__name__)
|
| 26 |
-
logger.setLevel(logging.INFO)
|
| 27 |
-
|
| 28 |
-
load_dotenv(override=True)
|
| 29 |
-
llm = ChatOpenAI(model=os.getenv("MODEL"), temperature=0.3)
|
| 30 |
-
|
| 31 |
-
# -----------------------------
|
| 32 |
-
# SCHEMA DEFINITIONS
|
| 33 |
-
# -----------------------------
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
from typing import Optional, List
|
| 37 |
-
from pydantic import BaseModel, Field
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
class PatientInfo(BaseModel):
|
| 41 |
-
name: Optional[str] = Field(None, description="Patient's full name")
|
| 42 |
-
age: Optional[int] = Field(None, description="Patient's age in years")
|
| 43 |
-
sex: Optional[str] = Field(None, description="Male/Female/Other")
|
| 44 |
-
medical_record_number: Optional[str] = None
|
| 45 |
-
|
| 46 |
-
class Config:
|
| 47 |
-
extra = "forbid" # 🚨 ensures schema matches OpenAI’s strict rules
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
class TestResultReferenceRange(BaseModel):
|
| 51 |
-
min: Optional[float] = None
|
| 52 |
-
max: Optional[float] = None
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
class LabResult(BaseModel):
|
| 56 |
-
test_name: str
|
| 57 |
-
result_value: str
|
| 58 |
-
test_unit: str
|
| 59 |
-
test_reference_range: Optional[TestResultReferenceRange] = None
|
| 60 |
-
test_date: Optional[str] = None
|
| 61 |
-
inferred_range: Literal["low", "normal", "high"]
|
| 62 |
-
|
| 63 |
-
class Config:
|
| 64 |
-
extra = "forbid"
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
class StandardizedReport(BaseModel):
|
| 68 |
-
patient_info: PatientInfo # 🚨 no longer a raw dict
|
| 69 |
-
lab_results: List[LabResult]
|
| 70 |
-
diagnosis: List[str]
|
| 71 |
-
recommendations: List[str]
|
| 72 |
-
|
| 73 |
-
class Config:
|
| 74 |
-
extra = "forbid"
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
@dataclass
|
| 78 |
-
class HealthReport:
|
| 79 |
-
report_file_name: str
|
| 80 |
-
report_contents: str
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
class SheamiState(TypedDict):
|
| 84 |
-
thread_id: str
|
| 85 |
-
uploaded_reports: List[HealthReport]
|
| 86 |
-
standardized_reports: List[StandardizedReport]
|
| 87 |
-
trends_json: dict
|
| 88 |
-
interpreted_report: str
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
import re
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
def safe_filename(name: str) -> str:
|
| 95 |
-
# Replace spaces with underscores
|
| 96 |
-
name = name.replace(" ", "_")
|
| 97 |
-
# Replace any non-alphanumeric / dash / underscore with "_"
|
| 98 |
-
name = re.sub(r"[^A-Za-z0-9_\-]", "_", name)
|
| 99 |
-
# Collapse multiple underscores
|
| 100 |
-
name = re.sub(r"_+", "_", name)
|
| 101 |
-
return name.strip("_")
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
import dateutil.parser
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
def parse_any_date(date_str):
|
| 108 |
-
if not date_str or pd.isna(date_str):
|
| 109 |
-
return pd.NaT
|
| 110 |
-
try:
|
| 111 |
-
return dateutil.parser.parse(str(date_str), dayfirst=False, fuzzy=True)
|
| 112 |
-
except Exception:
|
| 113 |
-
return pd.NaT
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
# prompt template
|
| 117 |
-
testname_standardizer_prompt = ChatPromptTemplate.from_messages(
|
| 118 |
-
[
|
| 119 |
-
(
|
| 120 |
-
"system",
|
| 121 |
-
"You are a medical assistant. Normalize lab test names."
|
| 122 |
-
"All outputs must use **title case** (e.g., 'Hemoglobin', 'Blood Glucose')."
|
| 123 |
-
"Return ONLY valid JSON where keys are original names and values are standardized names. DO NOT return markdown formatting like backquotes etc.",
|
| 124 |
-
),
|
| 125 |
-
(
|
| 126 |
-
"human",
|
| 127 |
-
"""Normalize the following lab test names to their standard medical equivalents.
|
| 128 |
-
Test names: {test_names}
|
| 129 |
-
""",
|
| 130 |
-
),
|
| 131 |
-
]
|
| 132 |
-
)
|
| 133 |
-
|
| 134 |
-
# chain = prompt → LLM → string
|
| 135 |
-
testname_standardizer_chain = testname_standardizer_prompt | llm
|
| 136 |
-
|
| 137 |
-
# -----------------------------
|
| 138 |
-
# GRAPH NODES
|
| 139 |
-
# -----------------------------
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
def fn_init_node(state: SheamiState):
|
| 143 |
-
os.makedirs(SheamiConfig.get_output_dir(state["thread_id"]), exist_ok=True)
|
| 144 |
-
state["standardized_reports"] = []
|
| 145 |
-
state["trends_json"] = {}
|
| 146 |
-
state["interpreted_report"] = ""
|
| 147 |
-
return state
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
def fn_standardizer_node(state: SheamiState):
|
| 151 |
-
logger.info("%s| Standardizing reports", state["thread_id"])
|
| 152 |
-
llm_structured = llm.with_structured_output(StandardizedReport)
|
| 153 |
-
for idx, report in enumerate(state["uploaded_reports"]):
|
| 154 |
-
logger.info("%s| Standardizing report %s", state["thread_id"], report.report_file_name)
|
| 155 |
-
|
| 156 |
-
messages = [
|
| 157 |
-
SystemMessage(content="Standardize this medical report into the schema."),
|
| 158 |
-
# SystemMessage(
|
| 159 |
-
# content="Populate the `inferred_range` field as 'low', 'normal', or 'high' by comparing the result value with the reference range. If both min and max are missing, set 'normal' unless the value is clearly out of usual medical ranges."
|
| 160 |
-
# ),
|
| 161 |
-
HumanMessage(content=report.report_contents),
|
| 162 |
-
]
|
| 163 |
-
result: StandardizedReport = llm_structured.invoke(messages)
|
| 164 |
-
state["standardized_reports"].append(result)
|
| 165 |
-
# save to disk
|
| 166 |
-
with open(
|
| 167 |
-
os.path.join(SheamiConfig.get_output_dir(state["thread_id"]), f"report_{idx}.json"), "w"
|
| 168 |
-
) as f:
|
| 169 |
-
f.write(result.model_dump_json(indent=2))
|
| 170 |
-
logger.info("%s| Standardizing Reports: finished", state["thread_id"])
|
| 171 |
-
return state
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
def fn_testname_standardizer_node(state: SheamiState):
|
| 175 |
-
logger.info("%s| Standardizing Test Names: started", state["thread_id"])
|
| 176 |
-
# collect unique names
|
| 177 |
-
unique_names = list(
|
| 178 |
-
{
|
| 179 |
-
result.test_name
|
| 180 |
-
for report in state["standardized_reports"]
|
| 181 |
-
for result in report.lab_results
|
| 182 |
-
}
|
| 183 |
-
)
|
| 184 |
-
|
| 185 |
-
# run through LLM
|
| 186 |
-
response = testname_standardizer_chain.invoke({"test_names": unique_names})
|
| 187 |
-
raw_text = response.content
|
| 188 |
-
|
| 189 |
-
try:
|
| 190 |
-
normalization_map: Dict[str, str] = json.loads(raw_text)
|
| 191 |
-
except Exception as e:
|
| 192 |
-
print("Exception in normalization: ", e)
|
| 193 |
-
normalization_map = {
|
| 194 |
-
name: name for name in unique_names
|
| 195 |
-
} # fallback: identity mapping
|
| 196 |
-
|
| 197 |
-
# apply mapping back
|
| 198 |
-
for report in state["standardized_reports"]:
|
| 199 |
-
for result in report.lab_results:
|
| 200 |
-
result.test_name = normalization_map.get(result.test_name, result.test_name)
|
| 201 |
-
logger.info("%s| Standardizing Test Names: finished", state["thread_id"])
|
| 202 |
-
return state
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
def fn_unit_normalizer_node(state: SheamiState):
|
| 206 |
-
logger.info("%s| Standardizing Units : started", state["thread_id"])
|
| 207 |
-
"""
|
| 208 |
-
Normalize units for lab test values across all standardized reports.
|
| 209 |
-
Example: 'gms/dL', 'gm%', 'G/DL' → 'g/dL'
|
| 210 |
-
"""
|
| 211 |
-
unit_map = {
|
| 212 |
-
"g/dl": "g/dL",
|
| 213 |
-
"gms/dl": "g/dL",
|
| 214 |
-
"gm%": "g/dL",
|
| 215 |
-
"g/dl.": "g/dL",
|
| 216 |
-
}
|
| 217 |
-
|
| 218 |
-
for report in state["standardized_reports"]:
|
| 219 |
-
for lr in report.lab_results:
|
| 220 |
-
if not lr.test_unit:
|
| 221 |
-
continue
|
| 222 |
-
normalized = lr.test_unit.lower().replace(" ", "")
|
| 223 |
-
lr.test_unit = unit_map.get(
|
| 224 |
-
normalized, lr.test_unit
|
| 225 |
-
) # fallback: keep original
|
| 226 |
-
|
| 227 |
-
logger.info("%s| Standardizing Units : finished", state["thread_id"])
|
| 228 |
-
return state
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
def fn_trends_aggregator_node(state: SheamiState):
|
| 232 |
-
logger.info("%s| Aggregating Trends : started", state["thread_id"])
|
| 233 |
-
import re
|
| 234 |
-
|
| 235 |
-
# group results by test_name
|
| 236 |
-
trends = {}
|
| 237 |
-
ref_ranges = {}
|
| 238 |
-
|
| 239 |
-
def try_parse_float(value: str):
|
| 240 |
-
try:
|
| 241 |
-
return float(value)
|
| 242 |
-
except (ValueError, TypeError):
|
| 243 |
-
# return None if not numeric
|
| 244 |
-
return None
|
| 245 |
-
|
| 246 |
-
for idx, report in enumerate(state["standardized_reports"]):
|
| 247 |
-
logger.info("%s| Aggregating Trends for report-%d", state["thread_id"], idx)
|
| 248 |
-
for lr in report.lab_results:
|
| 249 |
-
numeric_value = try_parse_float(lr.result_value)
|
| 250 |
-
|
| 251 |
-
trends.setdefault(lr.test_name, []).append(
|
| 252 |
-
{
|
| 253 |
-
"date": lr.test_date or "unknown",
|
| 254 |
-
"value": (
|
| 255 |
-
numeric_value if numeric_value is not None else lr.result_value
|
| 256 |
-
),
|
| 257 |
-
"is_numeric": numeric_value is not None,
|
| 258 |
-
"unit": lr.test_unit,
|
| 259 |
-
}
|
| 260 |
-
)
|
| 261 |
-
|
| 262 |
-
# Capture reference range if available (assuming same for all entries of a test_name)
|
| 263 |
-
if lr.test_reference_range:
|
| 264 |
-
ref_ranges[lr.test_name] = {
|
| 265 |
-
"min": lr.test_reference_range.min,
|
| 266 |
-
"max": lr.test_reference_range.max,
|
| 267 |
-
}
|
| 268 |
-
|
| 269 |
-
# combine into parameter_trends
|
| 270 |
-
state["trends_json"] = {
|
| 271 |
-
"parameter_trends": [
|
| 272 |
-
{
|
| 273 |
-
"test_name": k,
|
| 274 |
-
"values": v,
|
| 275 |
-
"reference_range": ref_ranges.get(k), # attach thresholds
|
| 276 |
-
}
|
| 277 |
-
for k, v in trends.items()
|
| 278 |
-
]
|
| 279 |
-
}
|
| 280 |
-
|
| 281 |
-
with open(os.path.join(SheamiConfig.get_output_dir(state["thread_id"]), "trends.json"), "w") as f:
|
| 282 |
-
json.dump(state["trends_json"], f, indent=2)
|
| 283 |
-
logger.info("%s| Aggregating Trends : finished", state["thread_id"])
|
| 284 |
-
return state
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
def fn_interpreter_node(state: SheamiState):
|
| 288 |
-
logger.info("%s| Interpreting Trends : started", state["thread_id"])
|
| 289 |
-
# 1. LLM narrative
|
| 290 |
-
messages = [
|
| 291 |
-
SystemMessage(
|
| 292 |
-
content="Interpret the following medical trends and produce a report with patient summary, trend summaries, and clinical insights. "
|
| 293 |
-
"Do not include charts, they will be programmatically added."
|
| 294 |
-
),
|
| 295 |
-
HumanMessage(content=json.dumps(state["trends_json"], indent=2)),
|
| 296 |
-
]
|
| 297 |
-
response = llm.invoke(messages)
|
| 298 |
-
interpretation_text = response.content
|
| 299 |
-
|
| 300 |
-
# 2. Generate plots for each parameter
|
| 301 |
-
plots_dir = os.path.join(SheamiConfig.get_output_dir(state["thread_id"]), "plots")
|
| 302 |
-
os.makedirs(plots_dir, exist_ok=True)
|
| 303 |
-
plot_files = []
|
| 304 |
-
|
| 305 |
-
for param in sorted(
|
| 306 |
-
state["trends_json"].get("parameter_trends", []), key=lambda x: x["test_name"]
|
| 307 |
-
):
|
| 308 |
-
test_name = param["test_name"]
|
| 309 |
-
values = param["values"]
|
| 310 |
-
|
| 311 |
-
# plotting + PDF writing logic here
|
| 312 |
-
x = [v["date"] for v in values]
|
| 313 |
-
# print("original dates for ", test_name, "= ", x)
|
| 314 |
-
x = [parse_any_date(d) for d in x]
|
| 315 |
-
x = pd.to_datetime(x, errors="coerce")
|
| 316 |
-
# print("formatted dates for ", test_name, "= ", x)
|
| 317 |
-
|
| 318 |
-
try:
|
| 319 |
-
y = [float(v["value"]) for v in values]
|
| 320 |
-
except ValueError:
|
| 321 |
-
continue # skip non-numeric
|
| 322 |
-
|
| 323 |
-
## sort the data by date
|
| 324 |
-
# Zip into a DataFrame for easy sorting
|
| 325 |
-
df_plot = pd.DataFrame({"x": x, "y": y})
|
| 326 |
-
|
| 327 |
-
# Drop invalid dates if any
|
| 328 |
-
df_plot = df_plot.dropna(subset=["x"])
|
| 329 |
-
|
| 330 |
-
# Sort by date
|
| 331 |
-
df_plot = df_plot.sort_values("x")
|
| 332 |
-
|
| 333 |
-
# Extract sorted arrays
|
| 334 |
-
x = df_plot["x"].to_numpy()
|
| 335 |
-
y = df_plot["y"].to_numpy()
|
| 336 |
-
|
| 337 |
-
# print("formatted + sorted dates for", test_name, "=", x)
|
| 338 |
-
|
| 339 |
-
plt.figure(figsize=(6, 4))
|
| 340 |
-
plt.plot(x, y, marker="o", linestyle="-", label="Observed values")
|
| 341 |
-
|
| 342 |
-
# add thresholds if available
|
| 343 |
-
ref = param.get("reference_range")
|
| 344 |
-
if ref:
|
| 345 |
-
ymin, ymax = ref.get("min"), ref.get("max")
|
| 346 |
-
if ymin is not None and ymax is not None:
|
| 347 |
-
plt.axhspan(
|
| 348 |
-
ymin, ymax, color="green", alpha=0.2, label="Reference range"
|
| 349 |
-
)
|
| 350 |
-
elif ymax is not None:
|
| 351 |
-
plt.axhline(
|
| 352 |
-
y=ymax, color="red", linestyle="--", label="Upper threshold"
|
| 353 |
-
)
|
| 354 |
-
elif ymin is not None:
|
| 355 |
-
plt.axhline(
|
| 356 |
-
y=ymin, color="blue", linestyle="--", label="Lower threshold"
|
| 357 |
-
)
|
| 358 |
-
|
| 359 |
-
plt.title(f"{test_name} Trend")
|
| 360 |
-
plt.xlabel("Date")
|
| 361 |
-
plt.ylabel(values[0]["unit"] if values and "unit" in values[0] else "")
|
| 362 |
-
plt.grid(True)
|
| 363 |
-
plt.xticks(rotation=45)
|
| 364 |
-
plt.legend()
|
| 365 |
-
plt.tight_layout()
|
| 366 |
-
|
| 367 |
-
filename = f"{safe_filename(test_name).replace(' ', '_')}_trend.png"
|
| 368 |
-
filepath = os.path.join(plots_dir, filename)
|
| 369 |
-
plt.savefig(filepath)
|
| 370 |
-
plt.close()
|
| 371 |
-
plot_files.append((test_name, filepath))
|
| 372 |
-
|
| 373 |
-
# 3. Build PDF
|
| 374 |
-
pdf_path = os.path.join(SheamiConfig.get_output_dir(state["thread_id"]), "final_report.pdf")
|
| 375 |
-
doc = SimpleDocTemplate(pdf_path, pagesize=A4)
|
| 376 |
-
styles = getSampleStyleSheet()
|
| 377 |
-
story = []
|
| 378 |
-
|
| 379 |
-
# Add title
|
| 380 |
-
story.append(Paragraph("<b>Medical Report Interpretation</b>", styles["Title"]))
|
| 381 |
-
story.append(Spacer(1, 0.3 * inch))
|
| 382 |
-
|
| 383 |
-
# Add interpretation text (LLM output)
|
| 384 |
-
for line in interpretation_text.split("\n"):
|
| 385 |
-
if line.strip():
|
| 386 |
-
story.append(Paragraph(line.strip(), styles["Normal"]))
|
| 387 |
-
story.append(Spacer(1, 0.15 * inch))
|
| 388 |
-
|
| 389 |
-
# Add charts
|
| 390 |
-
story.append(Spacer(1, 0.5 * inch))
|
| 391 |
-
story.append(Paragraph("<b>Trends</b>", styles["Heading2"]))
|
| 392 |
-
story.append(Spacer(1, 0.2 * inch))
|
| 393 |
-
|
| 394 |
-
for test_name, plotfile in plot_files:
|
| 395 |
-
story.append(Paragraph(f"<b>{test_name}</b>", styles["Heading3"]))
|
| 396 |
-
story.append(Image(plotfile, width=5 * inch, height=3 * inch))
|
| 397 |
-
story.append(Spacer(1, 0.3 * inch))
|
| 398 |
-
|
| 399 |
-
doc.build(story)
|
| 400 |
-
|
| 401 |
-
state["interpreted_report"] = pdf_path
|
| 402 |
-
###### Schedule Cleanup of output dir after 5 min.
|
| 403 |
-
schedule_cleanup(file_path=SheamiConfig.get_output_dir(state["thread_id"]))
|
| 404 |
-
logger.info("%s| Interpreting Trends : finished", state["thread_id"])
|
| 405 |
-
return state
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
def schedule_cleanup(file_path, delay=300): # 300 sec = 5 min
|
| 409 |
-
def cleanup():
|
| 410 |
-
time.sleep(delay)
|
| 411 |
-
if os.path.exists(file_path):
|
| 412 |
-
try:
|
| 413 |
-
if os.path.isdir(file_path):
|
| 414 |
-
import shutil
|
| 415 |
-
shutil.rmtree(file_path)
|
| 416 |
-
else:
|
| 417 |
-
os.remove(file_path)
|
| 418 |
-
print(f"Cleaned up: {file_path}")
|
| 419 |
-
except Exception as e:
|
| 420 |
-
print(f"Cleanup failed for {file_path}: {e}")
|
| 421 |
-
threading.Thread(target=cleanup, daemon=True).start()
|
| 422 |
-
|
| 423 |
-
# -----------------------------
|
| 424 |
-
# GRAPH CREATION
|
| 425 |
-
# -----------------------------
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
def create_graph(thread_id : str):
|
| 429 |
-
logger.info("%s| Creating Graph : started", thread_id)
|
| 430 |
-
memory = InMemorySaver()
|
| 431 |
-
workflow = StateGraph(SheamiState)
|
| 432 |
-
workflow.add_node("init", fn_init_node)
|
| 433 |
-
workflow.add_node("standardizer", fn_standardizer_node)
|
| 434 |
-
workflow.add_node("testname_standardizer", fn_testname_standardizer_node)
|
| 435 |
-
workflow.add_node("unit_normalizer", fn_unit_normalizer_node)
|
| 436 |
-
workflow.add_node("trends", fn_trends_aggregator_node)
|
| 437 |
-
workflow.add_node("interpreter", fn_interpreter_node)
|
| 438 |
-
|
| 439 |
-
workflow.add_edge(START, "init")
|
| 440 |
-
workflow.add_edge("init", "standardizer")
|
| 441 |
-
workflow.add_edge("standardizer", "testname_standardizer")
|
| 442 |
-
workflow.add_edge("testname_standardizer", "unit_normalizer")
|
| 443 |
-
workflow.add_edge("unit_normalizer", "trends")
|
| 444 |
-
workflow.add_edge("trends", "interpreter")
|
| 445 |
-
workflow.add_edge("interpreter", END)
|
| 446 |
-
|
| 447 |
-
logger.info("%s| Creating Graph : finished", thread_id)
|
| 448 |
-
return workflow.compile(checkpointer=memory)
|
|
|
|
| 1 |
+
import threading
|
| 2 |
+
import time
|
| 3 |
+
from numpy import number
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 6 |
+
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Image
|
| 7 |
+
from reportlab.lib.pagesizes import A4
|
| 8 |
+
from reportlab.lib.styles import getSampleStyleSheet
|
| 9 |
+
from reportlab.lib.units import inch
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
from typing import Dict, List, Literal, Optional, TypedDict
|
| 13 |
+
import os, json
|
| 14 |
+
from pydantic import BaseModel
|
| 15 |
+
from langchain_core.messages import HumanMessage, SystemMessage
|
| 16 |
+
from langgraph.checkpoint.memory import InMemorySaver
|
| 17 |
+
from langgraph.graph.message import StateGraph
|
| 18 |
+
from langgraph.graph.state import START, END
|
| 19 |
+
from langchain_openai import ChatOpenAI
|
| 20 |
+
from dotenv import load_dotenv
|
| 21 |
+
from config import SheamiConfig
|
| 22 |
+
import logging
|
| 23 |
+
|
| 24 |
+
logging.basicConfig()
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
logger.setLevel(logging.INFO)
|
| 27 |
+
|
| 28 |
+
load_dotenv(override=True)
|
| 29 |
+
llm = ChatOpenAI(model=os.getenv("MODEL"), temperature=0.3)
|
| 30 |
+
|
| 31 |
+
# -----------------------------
|
| 32 |
+
# SCHEMA DEFINITIONS
|
| 33 |
+
# -----------------------------
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
from typing import Optional, List
|
| 37 |
+
from pydantic import BaseModel, Field
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class PatientInfo(BaseModel):
|
| 41 |
+
name: Optional[str] = Field(None, description="Patient's full name")
|
| 42 |
+
age: Optional[int] = Field(None, description="Patient's age in years")
|
| 43 |
+
sex: Optional[str] = Field(None, description="Male/Female/Other")
|
| 44 |
+
medical_record_number: Optional[str] = None
|
| 45 |
+
|
| 46 |
+
class Config:
|
| 47 |
+
extra = "forbid" # 🚨 ensures schema matches OpenAI’s strict rules
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class TestResultReferenceRange(BaseModel):
|
| 51 |
+
min: Optional[float] = None
|
| 52 |
+
max: Optional[float] = None
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class LabResult(BaseModel):
|
| 56 |
+
test_name: str
|
| 57 |
+
result_value: str
|
| 58 |
+
test_unit: str
|
| 59 |
+
test_reference_range: Optional[TestResultReferenceRange] = None
|
| 60 |
+
test_date: Optional[str] = None
|
| 61 |
+
inferred_range: Literal["low", "normal", "high"]
|
| 62 |
+
|
| 63 |
+
class Config:
|
| 64 |
+
extra = "forbid"
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class StandardizedReport(BaseModel):
|
| 68 |
+
patient_info: PatientInfo # 🚨 no longer a raw dict
|
| 69 |
+
lab_results: List[LabResult]
|
| 70 |
+
diagnosis: List[str]
|
| 71 |
+
recommendations: List[str]
|
| 72 |
+
|
| 73 |
+
class Config:
|
| 74 |
+
extra = "forbid"
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@dataclass
|
| 78 |
+
class HealthReport:
|
| 79 |
+
report_file_name: str
|
| 80 |
+
report_contents: str
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class SheamiState(TypedDict):
|
| 84 |
+
thread_id: str
|
| 85 |
+
uploaded_reports: List[HealthReport]
|
| 86 |
+
standardized_reports: List[StandardizedReport]
|
| 87 |
+
trends_json: dict
|
| 88 |
+
interpreted_report: str
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
import re
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def safe_filename(name: str) -> str:
|
| 95 |
+
# Replace spaces with underscores
|
| 96 |
+
name = name.replace(" ", "_")
|
| 97 |
+
# Replace any non-alphanumeric / dash / underscore with "_"
|
| 98 |
+
name = re.sub(r"[^A-Za-z0-9_\-]", "_", name)
|
| 99 |
+
# Collapse multiple underscores
|
| 100 |
+
name = re.sub(r"_+", "_", name)
|
| 101 |
+
return name.strip("_")
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
import dateutil.parser
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def parse_any_date(date_str):
|
| 108 |
+
if not date_str or pd.isna(date_str):
|
| 109 |
+
return pd.NaT
|
| 110 |
+
try:
|
| 111 |
+
return dateutil.parser.parse(str(date_str), dayfirst=False, fuzzy=True)
|
| 112 |
+
except Exception:
|
| 113 |
+
return pd.NaT
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# prompt template
|
| 117 |
+
testname_standardizer_prompt = ChatPromptTemplate.from_messages(
|
| 118 |
+
[
|
| 119 |
+
(
|
| 120 |
+
"system",
|
| 121 |
+
"You are a medical assistant. Normalize lab test names."
|
| 122 |
+
"All outputs must use **title case** (e.g., 'Hemoglobin', 'Blood Glucose')."
|
| 123 |
+
"Return ONLY valid JSON where keys are original names and values are standardized names. DO NOT return markdown formatting like backquotes etc.",
|
| 124 |
+
),
|
| 125 |
+
(
|
| 126 |
+
"human",
|
| 127 |
+
"""Normalize the following lab test names to their standard medical equivalents.
|
| 128 |
+
Test names: {test_names}
|
| 129 |
+
""",
|
| 130 |
+
),
|
| 131 |
+
]
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
# chain = prompt → LLM → string
|
| 135 |
+
testname_standardizer_chain = testname_standardizer_prompt | llm
|
| 136 |
+
|
| 137 |
+
# -----------------------------
|
| 138 |
+
# GRAPH NODES
|
| 139 |
+
# -----------------------------
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def fn_init_node(state: SheamiState):
|
| 143 |
+
os.makedirs(SheamiConfig.get_output_dir(state["thread_id"]), exist_ok=True)
|
| 144 |
+
state["standardized_reports"] = []
|
| 145 |
+
state["trends_json"] = {}
|
| 146 |
+
state["interpreted_report"] = ""
|
| 147 |
+
return state
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def fn_standardizer_node(state: SheamiState):
|
| 151 |
+
logger.info("%s| Standardizing reports", state["thread_id"])
|
| 152 |
+
llm_structured = llm.with_structured_output(StandardizedReport)
|
| 153 |
+
for idx, report in enumerate(state["uploaded_reports"]):
|
| 154 |
+
logger.info("%s| Standardizing report %s", state["thread_id"], report.report_file_name)
|
| 155 |
+
|
| 156 |
+
messages = [
|
| 157 |
+
SystemMessage(content="Standardize this medical report into the schema."),
|
| 158 |
+
# SystemMessage(
|
| 159 |
+
# content="Populate the `inferred_range` field as 'low', 'normal', or 'high' by comparing the result value with the reference range. If both min and max are missing, set 'normal' unless the value is clearly out of usual medical ranges."
|
| 160 |
+
# ),
|
| 161 |
+
HumanMessage(content=report.report_contents),
|
| 162 |
+
]
|
| 163 |
+
result: StandardizedReport = llm_structured.invoke(messages)
|
| 164 |
+
state["standardized_reports"].append(result)
|
| 165 |
+
# save to disk
|
| 166 |
+
with open(
|
| 167 |
+
os.path.join(SheamiConfig.get_output_dir(state["thread_id"]), f"report_{idx}.json"), "w"
|
| 168 |
+
) as f:
|
| 169 |
+
f.write(result.model_dump_json(indent=2))
|
| 170 |
+
logger.info("%s| Standardizing Reports: finished", state["thread_id"])
|
| 171 |
+
return state
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def fn_testname_standardizer_node(state: SheamiState):
|
| 175 |
+
logger.info("%s| Standardizing Test Names: started", state["thread_id"])
|
| 176 |
+
# collect unique names
|
| 177 |
+
unique_names = list(
|
| 178 |
+
{
|
| 179 |
+
result.test_name
|
| 180 |
+
for report in state["standardized_reports"]
|
| 181 |
+
for result in report.lab_results
|
| 182 |
+
}
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
# run through LLM
|
| 186 |
+
response = testname_standardizer_chain.invoke({"test_names": unique_names})
|
| 187 |
+
raw_text = response.content
|
| 188 |
+
|
| 189 |
+
try:
|
| 190 |
+
normalization_map: Dict[str, str] = json.loads(raw_text)
|
| 191 |
+
except Exception as e:
|
| 192 |
+
print("Exception in normalization: ", e)
|
| 193 |
+
normalization_map = {
|
| 194 |
+
name: name for name in unique_names
|
| 195 |
+
} # fallback: identity mapping
|
| 196 |
+
|
| 197 |
+
# apply mapping back
|
| 198 |
+
for report in state["standardized_reports"]:
|
| 199 |
+
for result in report.lab_results:
|
| 200 |
+
result.test_name = normalization_map.get(result.test_name, result.test_name)
|
| 201 |
+
logger.info("%s| Standardizing Test Names: finished", state["thread_id"])
|
| 202 |
+
return state
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def fn_unit_normalizer_node(state: SheamiState):
|
| 206 |
+
logger.info("%s| Standardizing Units : started", state["thread_id"])
|
| 207 |
+
"""
|
| 208 |
+
Normalize units for lab test values across all standardized reports.
|
| 209 |
+
Example: 'gms/dL', 'gm%', 'G/DL' → 'g/dL'
|
| 210 |
+
"""
|
| 211 |
+
unit_map = {
|
| 212 |
+
"g/dl": "g/dL",
|
| 213 |
+
"gms/dl": "g/dL",
|
| 214 |
+
"gm%": "g/dL",
|
| 215 |
+
"g/dl.": "g/dL",
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
for report in state["standardized_reports"]:
|
| 219 |
+
for lr in report.lab_results:
|
| 220 |
+
if not lr.test_unit:
|
| 221 |
+
continue
|
| 222 |
+
normalized = lr.test_unit.lower().replace(" ", "")
|
| 223 |
+
lr.test_unit = unit_map.get(
|
| 224 |
+
normalized, lr.test_unit
|
| 225 |
+
) # fallback: keep original
|
| 226 |
+
|
| 227 |
+
logger.info("%s| Standardizing Units : finished", state["thread_id"])
|
| 228 |
+
return state
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def fn_trends_aggregator_node(state: SheamiState):
|
| 232 |
+
logger.info("%s| Aggregating Trends : started", state["thread_id"])
|
| 233 |
+
import re
|
| 234 |
+
|
| 235 |
+
# group results by test_name
|
| 236 |
+
trends = {}
|
| 237 |
+
ref_ranges = {}
|
| 238 |
+
|
| 239 |
+
def try_parse_float(value: str):
|
| 240 |
+
try:
|
| 241 |
+
return float(value)
|
| 242 |
+
except (ValueError, TypeError):
|
| 243 |
+
# return None if not numeric
|
| 244 |
+
return None
|
| 245 |
+
|
| 246 |
+
for idx, report in enumerate(state["standardized_reports"]):
|
| 247 |
+
logger.info("%s| Aggregating Trends for report-%d", state["thread_id"], idx)
|
| 248 |
+
for lr in report.lab_results:
|
| 249 |
+
numeric_value = try_parse_float(lr.result_value)
|
| 250 |
+
|
| 251 |
+
trends.setdefault(lr.test_name, []).append(
|
| 252 |
+
{
|
| 253 |
+
"date": lr.test_date or "unknown",
|
| 254 |
+
"value": (
|
| 255 |
+
numeric_value if numeric_value is not None else lr.result_value
|
| 256 |
+
),
|
| 257 |
+
"is_numeric": numeric_value is not None,
|
| 258 |
+
"unit": lr.test_unit,
|
| 259 |
+
}
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
# Capture reference range if available (assuming same for all entries of a test_name)
|
| 263 |
+
if lr.test_reference_range:
|
| 264 |
+
ref_ranges[lr.test_name] = {
|
| 265 |
+
"min": lr.test_reference_range.min,
|
| 266 |
+
"max": lr.test_reference_range.max,
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
# combine into parameter_trends
|
| 270 |
+
state["trends_json"] = {
|
| 271 |
+
"parameter_trends": [
|
| 272 |
+
{
|
| 273 |
+
"test_name": k,
|
| 274 |
+
"values": v,
|
| 275 |
+
"reference_range": ref_ranges.get(k), # attach thresholds
|
| 276 |
+
}
|
| 277 |
+
for k, v in trends.items()
|
| 278 |
+
]
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
with open(os.path.join(SheamiConfig.get_output_dir(state["thread_id"]), "trends.json"), "w") as f:
|
| 282 |
+
json.dump(state["trends_json"], f, indent=2)
|
| 283 |
+
logger.info("%s| Aggregating Trends : finished", state["thread_id"])
|
| 284 |
+
return state
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def fn_interpreter_node(state: SheamiState):
|
| 288 |
+
logger.info("%s| Interpreting Trends : started", state["thread_id"])
|
| 289 |
+
# 1. LLM narrative
|
| 290 |
+
messages = [
|
| 291 |
+
SystemMessage(
|
| 292 |
+
content="Interpret the following medical trends and produce a report with patient summary, trend summaries, and clinical insights. "
|
| 293 |
+
"Do not include charts, they will be programmatically added."
|
| 294 |
+
),
|
| 295 |
+
HumanMessage(content=json.dumps(state["trends_json"], indent=2)),
|
| 296 |
+
]
|
| 297 |
+
response = llm.invoke(messages)
|
| 298 |
+
interpretation_text = response.content
|
| 299 |
+
|
| 300 |
+
# 2. Generate plots for each parameter
|
| 301 |
+
plots_dir = os.path.join(SheamiConfig.get_output_dir(state["thread_id"]), "plots")
|
| 302 |
+
os.makedirs(plots_dir, exist_ok=True)
|
| 303 |
+
plot_files = []
|
| 304 |
+
|
| 305 |
+
for param in sorted(
|
| 306 |
+
state["trends_json"].get("parameter_trends", []), key=lambda x: x["test_name"]
|
| 307 |
+
):
|
| 308 |
+
test_name = param["test_name"]
|
| 309 |
+
values = param["values"]
|
| 310 |
+
|
| 311 |
+
# plotting + PDF writing logic here
|
| 312 |
+
x = [v["date"] for v in values]
|
| 313 |
+
# print("original dates for ", test_name, "= ", x)
|
| 314 |
+
x = [parse_any_date(d) for d in x]
|
| 315 |
+
x = pd.to_datetime(x, errors="coerce")
|
| 316 |
+
# print("formatted dates for ", test_name, "= ", x)
|
| 317 |
+
|
| 318 |
+
try:
|
| 319 |
+
y = [float(v["value"]) for v in values]
|
| 320 |
+
except ValueError:
|
| 321 |
+
continue # skip non-numeric
|
| 322 |
+
|
| 323 |
+
## sort the data by date
|
| 324 |
+
# Zip into a DataFrame for easy sorting
|
| 325 |
+
df_plot = pd.DataFrame({"x": x, "y": y})
|
| 326 |
+
|
| 327 |
+
# Drop invalid dates if any
|
| 328 |
+
df_plot = df_plot.dropna(subset=["x"])
|
| 329 |
+
|
| 330 |
+
# Sort by date
|
| 331 |
+
df_plot = df_plot.sort_values("x")
|
| 332 |
+
|
| 333 |
+
# Extract sorted arrays
|
| 334 |
+
x = df_plot["x"].to_numpy()
|
| 335 |
+
y = df_plot["y"].to_numpy()
|
| 336 |
+
|
| 337 |
+
# print("formatted + sorted dates for", test_name, "=", x)
|
| 338 |
+
|
| 339 |
+
plt.figure(figsize=(6, 4))
|
| 340 |
+
plt.plot(x, y, marker="o", linestyle="-", label="Observed values")
|
| 341 |
+
|
| 342 |
+
# add thresholds if available
|
| 343 |
+
ref = param.get("reference_range")
|
| 344 |
+
if ref:
|
| 345 |
+
ymin, ymax = ref.get("min"), ref.get("max")
|
| 346 |
+
if ymin is not None and ymax is not None:
|
| 347 |
+
plt.axhspan(
|
| 348 |
+
ymin, ymax, color="green", alpha=0.2, label="Reference range"
|
| 349 |
+
)
|
| 350 |
+
elif ymax is not None:
|
| 351 |
+
plt.axhline(
|
| 352 |
+
y=ymax, color="red", linestyle="--", label="Upper threshold"
|
| 353 |
+
)
|
| 354 |
+
elif ymin is not None:
|
| 355 |
+
plt.axhline(
|
| 356 |
+
y=ymin, color="blue", linestyle="--", label="Lower threshold"
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
plt.title(f"{test_name} Trend")
|
| 360 |
+
plt.xlabel("Date")
|
| 361 |
+
plt.ylabel(values[0]["unit"] if values and "unit" in values[0] else "")
|
| 362 |
+
plt.grid(True)
|
| 363 |
+
plt.xticks(rotation=45)
|
| 364 |
+
plt.legend()
|
| 365 |
+
plt.tight_layout()
|
| 366 |
+
|
| 367 |
+
filename = f"{safe_filename(test_name).replace(' ', '_')}_trend.png"
|
| 368 |
+
filepath = os.path.join(plots_dir, filename)
|
| 369 |
+
plt.savefig(filepath)
|
| 370 |
+
plt.close()
|
| 371 |
+
plot_files.append((test_name, filepath))
|
| 372 |
+
|
| 373 |
+
# 3. Build PDF
|
| 374 |
+
pdf_path = os.path.join(SheamiConfig.get_output_dir(state["thread_id"]), "final_report.pdf")
|
| 375 |
+
doc = SimpleDocTemplate(pdf_path, pagesize=A4)
|
| 376 |
+
styles = getSampleStyleSheet()
|
| 377 |
+
story = []
|
| 378 |
+
|
| 379 |
+
# Add title
|
| 380 |
+
story.append(Paragraph("<b>Medical Report Interpretation</b>", styles["Title"]))
|
| 381 |
+
story.append(Spacer(1, 0.3 * inch))
|
| 382 |
+
|
| 383 |
+
# Add interpretation text (LLM output)
|
| 384 |
+
for line in interpretation_text.split("\n"):
|
| 385 |
+
if line.strip():
|
| 386 |
+
story.append(Paragraph(line.strip(), styles["Normal"]))
|
| 387 |
+
story.append(Spacer(1, 0.15 * inch))
|
| 388 |
+
|
| 389 |
+
# Add charts
|
| 390 |
+
story.append(Spacer(1, 0.5 * inch))
|
| 391 |
+
story.append(Paragraph("<b>Trends</b>", styles["Heading2"]))
|
| 392 |
+
story.append(Spacer(1, 0.2 * inch))
|
| 393 |
+
|
| 394 |
+
for test_name, plotfile in plot_files:
|
| 395 |
+
story.append(Paragraph(f"<b>{test_name}</b>", styles["Heading3"]))
|
| 396 |
+
story.append(Image(plotfile, width=5 * inch, height=3 * inch))
|
| 397 |
+
story.append(Spacer(1, 0.3 * inch))
|
| 398 |
+
|
| 399 |
+
doc.build(story)
|
| 400 |
+
|
| 401 |
+
state["interpreted_report"] = pdf_path
|
| 402 |
+
###### Schedule Cleanup of output dir after 5 min.
|
| 403 |
+
schedule_cleanup(file_path=SheamiConfig.get_output_dir(state["thread_id"]))
|
| 404 |
+
logger.info("%s| Interpreting Trends : finished", state["thread_id"])
|
| 405 |
+
return state
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
def schedule_cleanup(file_path, delay=300): # 300 sec = 5 min
|
| 409 |
+
def cleanup():
|
| 410 |
+
time.sleep(delay)
|
| 411 |
+
if os.path.exists(file_path):
|
| 412 |
+
try:
|
| 413 |
+
if os.path.isdir(file_path):
|
| 414 |
+
import shutil
|
| 415 |
+
shutil.rmtree(file_path)
|
| 416 |
+
else:
|
| 417 |
+
os.remove(file_path)
|
| 418 |
+
print(f"Cleaned up: {file_path}")
|
| 419 |
+
except Exception as e:
|
| 420 |
+
print(f"Cleanup failed for {file_path}: {e}")
|
| 421 |
+
threading.Thread(target=cleanup, daemon=True).start()
|
| 422 |
+
|
| 423 |
+
# -----------------------------
|
| 424 |
+
# GRAPH CREATION
|
| 425 |
+
# -----------------------------
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
def create_graph(thread_id : str):
|
| 429 |
+
logger.info("%s| Creating Graph : started", thread_id)
|
| 430 |
+
memory = InMemorySaver()
|
| 431 |
+
workflow = StateGraph(SheamiState)
|
| 432 |
+
workflow.add_node("init", fn_init_node)
|
| 433 |
+
workflow.add_node("standardizer", fn_standardizer_node)
|
| 434 |
+
workflow.add_node("testname_standardizer", fn_testname_standardizer_node)
|
| 435 |
+
workflow.add_node("unit_normalizer", fn_unit_normalizer_node)
|
| 436 |
+
workflow.add_node("trends", fn_trends_aggregator_node)
|
| 437 |
+
workflow.add_node("interpreter", fn_interpreter_node)
|
| 438 |
+
|
| 439 |
+
workflow.add_edge(START, "init")
|
| 440 |
+
workflow.add_edge("init", "standardizer")
|
| 441 |
+
workflow.add_edge("standardizer", "testname_standardizer")
|
| 442 |
+
workflow.add_edge("testname_standardizer", "unit_normalizer")
|
| 443 |
+
workflow.add_edge("unit_normalizer", "trends")
|
| 444 |
+
workflow.add_edge("trends", "interpreter")
|
| 445 |
+
workflow.add_edge("interpreter", END)
|
| 446 |
+
|
| 447 |
+
logger.info("%s| Creating Graph : finished", thread_id)
|
| 448 |
+
return workflow.compile(checkpointer=memory)
|
models.py
CHANGED
|
@@ -1,14 +1,14 @@
|
|
| 1 |
-
from pydantic import BaseModel
|
| 2 |
-
from typing import List, Optional
|
| 3 |
-
|
| 4 |
-
class SheamiLabResult(BaseModel):
|
| 5 |
-
test_name: str
|
| 6 |
-
result_value: str
|
| 7 |
-
unit: str
|
| 8 |
-
reference_range: Optional[str]
|
| 9 |
-
|
| 10 |
-
class SheamiStandardizedReport(BaseModel):
|
| 11 |
-
patient_info: dict
|
| 12 |
-
lab_results: List[SheamiLabResult]
|
| 13 |
-
diagnosis: List[str]
|
| 14 |
-
recommendations: List[str]
|
|
|
|
| 1 |
+
from pydantic import BaseModel
|
| 2 |
+
from typing import List, Optional
|
| 3 |
+
|
| 4 |
+
class SheamiLabResult(BaseModel):
|
| 5 |
+
test_name: str
|
| 6 |
+
result_value: str
|
| 7 |
+
unit: str
|
| 8 |
+
reference_range: Optional[str]
|
| 9 |
+
|
| 10 |
+
class SheamiStandardizedReport(BaseModel):
|
| 11 |
+
patient_info: dict
|
| 12 |
+
lab_results: List[SheamiLabResult]
|
| 13 |
+
diagnosis: List[str]
|
| 14 |
+
recommendations: List[str]
|
pdf_reader.py
CHANGED
|
@@ -1,17 +1,17 @@
|
|
| 1 |
-
from pypdf import PdfReader
|
| 2 |
-
|
| 3 |
-
def read_pdf(file_name:str):
|
| 4 |
-
reader = PdfReader(file_name)
|
| 5 |
-
# Get the number of pages
|
| 6 |
-
number_of_pages = len(reader.pages)
|
| 7 |
-
# print(f"Number of pages: {number_of_pages}")
|
| 8 |
-
|
| 9 |
-
content = ""
|
| 10 |
-
for page_num in range(len(reader.pages)):
|
| 11 |
-
page = reader.pages[page_num]
|
| 12 |
-
text = page.extract_text()
|
| 13 |
-
# print(f"--- Page {page_num + 1} ---")
|
| 14 |
-
# print(text)
|
| 15 |
-
content += f"--- Page {page_num + 1} ---" + "\n\n" + text
|
| 16 |
-
|
| 17 |
return content
|
|
|
|
| 1 |
+
from pypdf import PdfReader
|
| 2 |
+
|
| 3 |
+
def read_pdf(file_name:str):
|
| 4 |
+
reader = PdfReader(file_name)
|
| 5 |
+
# Get the number of pages
|
| 6 |
+
number_of_pages = len(reader.pages)
|
| 7 |
+
# print(f"Number of pages: {number_of_pages}")
|
| 8 |
+
|
| 9 |
+
content = ""
|
| 10 |
+
for page_num in range(len(reader.pages)):
|
| 11 |
+
page = reader.pages[page_num]
|
| 12 |
+
text = page.extract_text()
|
| 13 |
+
# print(f"--- Page {page_num + 1} ---")
|
| 14 |
+
# print(text)
|
| 15 |
+
content += f"--- Page {page_num + 1} ---" + "\n\n" + text
|
| 16 |
+
|
| 17 |
return content
|