vikramvasudevan commited on
Commit
0cfb077
·
verified ·
1 Parent(s): ed2dc57

Upload folder using huggingface_hub

Browse files
Files changed (6) hide show
  1. .github/workflows/update_space.yml +28 -28
  2. app.py +58 -58
  3. config.py +8 -8
  4. graph.py +448 -448
  5. models.py +14 -14
  6. pdf_reader.py +16 -16
.github/workflows/update_space.yml CHANGED
@@ -1,28 +1,28 @@
1
- name: Run Python script
2
-
3
- on:
4
- push:
5
- branches:
6
- - main
7
-
8
- jobs:
9
- build:
10
- runs-on: ubuntu-latest
11
-
12
- steps:
13
- - name: Checkout
14
- uses: actions/checkout@v2
15
-
16
- - name: Set up Python
17
- uses: actions/setup-python@v2
18
- with:
19
- python-version: '3.9'
20
-
21
- - name: Install Gradio
22
- run: python -m pip install gradio
23
-
24
- - name: Log in to Hugging Face
25
- run: python -c 'import huggingface_hub; huggingface_hub.login(token="${{ secrets.hf_token }}")'
26
-
27
- - name: Deploy to Spaces
28
- run: gradio deploy
 
1
+ name: Run Python script
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+
8
+ jobs:
9
+ build:
10
+ runs-on: ubuntu-latest
11
+
12
+ steps:
13
+ - name: Checkout
14
+ uses: actions/checkout@v2
15
+
16
+ - name: Set up Python
17
+ uses: actions/setup-python@v2
18
+ with:
19
+ python-version: '3.9'
20
+
21
+ - name: Install Gradio
22
+ run: python -m pip install gradio
23
+
24
+ - name: Log in to Hugging Face
25
+ run: python -c 'import huggingface_hub; huggingface_hub.login(token="${{ secrets.hf_token }}")'
26
+
27
+ - name: Deploy to Spaces
28
+ run: gradio deploy
app.py CHANGED
@@ -1,58 +1,58 @@
1
- import gradio as gr
2
- import uuid
3
- import os
4
- from graph import create_graph, SheamiState, HealthReport
5
- from pdf_reader import read_pdf
6
-
7
-
8
- def process_reports(files):
9
- if not files:
10
- return "Please upload at least one PDF file."
11
-
12
- thread_id = str(uuid.uuid4())
13
- # Create workflow
14
- workflow = create_graph(thread_id=thread_id)
15
-
16
- # Convert uploaded PDFs into HealthReport objects
17
- uploaded_reports = []
18
- for file in files:
19
- file_path = file.name
20
- contents = read_pdf(file_path)
21
- uploaded_reports.append(
22
- HealthReport(
23
- report_file_name=os.path.basename(file_path), report_contents=contents
24
- )
25
- )
26
-
27
- # Run workflow
28
- # Create initial state
29
- state = SheamiState(uploaded_reports=uploaded_reports, thread_id=thread_id)
30
-
31
- config = {"configurable": {"thread_id": thread_id}}
32
- response = workflow.invoke(state, config=config)
33
- return (
34
- f"✅ Processed {len(files)} reports.\n"
35
- "Please download the output file from below within 5 min."
36
- ), response["interpreted_report"]
37
-
38
-
39
- # Build Gradio UI
40
- with gr.Blocks() as demo:
41
- gr.Markdown("## 🩺 Sheami - Smart Healthcare Engine for Artificial Medical Intelligence")
42
-
43
- with gr.Row():
44
- file_input = gr.File(
45
- file_types=[".pdf"],
46
- type="filepath",
47
- file_count="multiple",
48
- label="Upload your Lab Reports (PDF)",
49
- )
50
-
51
- run_btn = gr.Button("Process Reports", variant="primary")
52
- output_box = gr.Textbox(label="Processing Output", lines=2)
53
- pdf_output = gr.File(label="Generated Report")
54
-
55
- run_btn.click(process_reports, inputs=file_input, outputs=[output_box, pdf_output])
56
-
57
- if __name__ == "__main__":
58
- demo.launch()
 
1
+ import gradio as gr
2
+ import uuid
3
+ import os
4
+ from graph import create_graph, SheamiState, HealthReport
5
+ from pdf_reader import read_pdf
6
+
7
+
8
+ def process_reports(files):
9
+ if not files:
10
+ return "Please upload at least one PDF file."
11
+
12
+ thread_id = str(uuid.uuid4())
13
+ # Create workflow
14
+ workflow = create_graph(thread_id=thread_id)
15
+
16
+ # Convert uploaded PDFs into HealthReport objects
17
+ uploaded_reports = []
18
+ for file in files:
19
+ file_path = file.name
20
+ contents = read_pdf(file_path)
21
+ uploaded_reports.append(
22
+ HealthReport(
23
+ report_file_name=os.path.basename(file_path), report_contents=contents
24
+ )
25
+ )
26
+
27
+ # Run workflow
28
+ # Create initial state
29
+ state = SheamiState(uploaded_reports=uploaded_reports, thread_id=thread_id)
30
+
31
+ config = {"configurable": {"thread_id": thread_id}}
32
+ response = workflow.invoke(state, config=config)
33
+ return (
34
+ f"✅ Processed {len(files)} reports.\n"
35
+ "Please download the output file from below within 5 min."
36
+ ), response["interpreted_report"]
37
+
38
+
39
+ # Build Gradio UI
40
+ with gr.Blocks() as demo:
41
+ gr.Markdown("## 🩺 Sheami - Smart Healthcare Engine for Artificial Medical Intelligence")
42
+
43
+ with gr.Row():
44
+ file_input = gr.File(
45
+ file_types=[".pdf"],
46
+ type="filepath",
47
+ file_count="multiple",
48
+ label="Upload your Lab Reports (PDF)",
49
+ )
50
+
51
+ run_btn = gr.Button("Process Reports", variant="primary")
52
+ output_box = gr.Textbox(label="Processing Output", lines=2)
53
+ pdf_output = gr.File(label="Generated Report")
54
+
55
+ run_btn.click(process_reports, inputs=file_input, outputs=[output_box, pdf_output])
56
+
57
+ if __name__ == "__main__":
58
+ demo.launch()
config.py CHANGED
@@ -1,9 +1,9 @@
1
- import os
2
-
3
-
4
- class SheamiConfig:
5
- _output_dir = "./output"
6
- data_dir = "./data"
7
-
8
- def get_output_dir(thread_id:str):
9
  return os.path.join(SheamiConfig._output_dir, thread_id)
 
1
+ import os
2
+
3
+
4
+ class SheamiConfig:
5
+ _output_dir = "./output"
6
+ data_dir = "./data"
7
+
8
+ def get_output_dir(thread_id:str):
9
  return os.path.join(SheamiConfig._output_dir, thread_id)
graph.py CHANGED
@@ -1,448 +1,448 @@
1
- import threading
2
- import time
3
- from numpy import number
4
- import pandas as pd
5
- from langchain_core.prompts import ChatPromptTemplate
6
- from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Image
7
- from reportlab.lib.pagesizes import A4
8
- from reportlab.lib.styles import getSampleStyleSheet
9
- from reportlab.lib.units import inch
10
- import matplotlib.pyplot as plt
11
- from dataclasses import dataclass
12
- from typing import Dict, List, Literal, Optional, TypedDict
13
- import os, json
14
- from pydantic import BaseModel
15
- from langchain_core.messages import HumanMessage, SystemMessage
16
- from langgraph.checkpoint.memory import InMemorySaver
17
- from langgraph.graph.message import StateGraph
18
- from langgraph.graph.state import START, END
19
- from langchain_openai import ChatOpenAI
20
- from dotenv import load_dotenv
21
- from config import SheamiConfig
22
- import logging
23
-
24
- logging.basicConfig()
25
- logger = logging.getLogger(__name__)
26
- logger.setLevel(logging.INFO)
27
-
28
- load_dotenv(override=True)
29
- llm = ChatOpenAI(model=os.getenv("MODEL"), temperature=0.3)
30
-
31
- # -----------------------------
32
- # SCHEMA DEFINITIONS
33
- # -----------------------------
34
-
35
-
36
- from typing import Optional, List
37
- from pydantic import BaseModel, Field
38
-
39
-
40
- class PatientInfo(BaseModel):
41
- name: Optional[str] = Field(None, description="Patient's full name")
42
- age: Optional[int] = Field(None, description="Patient's age in years")
43
- sex: Optional[str] = Field(None, description="Male/Female/Other")
44
- medical_record_number: Optional[str] = None
45
-
46
- class Config:
47
- extra = "forbid" # 🚨 ensures schema matches OpenAI’s strict rules
48
-
49
-
50
- class TestResultReferenceRange(BaseModel):
51
- min: Optional[float] = None
52
- max: Optional[float] = None
53
-
54
-
55
- class LabResult(BaseModel):
56
- test_name: str
57
- result_value: str
58
- test_unit: str
59
- test_reference_range: Optional[TestResultReferenceRange] = None
60
- test_date: Optional[str] = None
61
- inferred_range: Literal["low", "normal", "high"]
62
-
63
- class Config:
64
- extra = "forbid"
65
-
66
-
67
- class StandardizedReport(BaseModel):
68
- patient_info: PatientInfo # 🚨 no longer a raw dict
69
- lab_results: List[LabResult]
70
- diagnosis: List[str]
71
- recommendations: List[str]
72
-
73
- class Config:
74
- extra = "forbid"
75
-
76
-
77
- @dataclass
78
- class HealthReport:
79
- report_file_name: str
80
- report_contents: str
81
-
82
-
83
- class SheamiState(TypedDict):
84
- thread_id: str
85
- uploaded_reports: List[HealthReport]
86
- standardized_reports: List[StandardizedReport]
87
- trends_json: dict
88
- interpreted_report: str
89
-
90
-
91
- import re
92
-
93
-
94
- def safe_filename(name: str) -> str:
95
- # Replace spaces with underscores
96
- name = name.replace(" ", "_")
97
- # Replace any non-alphanumeric / dash / underscore with "_"
98
- name = re.sub(r"[^A-Za-z0-9_\-]", "_", name)
99
- # Collapse multiple underscores
100
- name = re.sub(r"_+", "_", name)
101
- return name.strip("_")
102
-
103
-
104
- import dateutil.parser
105
-
106
-
107
- def parse_any_date(date_str):
108
- if not date_str or pd.isna(date_str):
109
- return pd.NaT
110
- try:
111
- return dateutil.parser.parse(str(date_str), dayfirst=False, fuzzy=True)
112
- except Exception:
113
- return pd.NaT
114
-
115
-
116
- # prompt template
117
- testname_standardizer_prompt = ChatPromptTemplate.from_messages(
118
- [
119
- (
120
- "system",
121
- "You are a medical assistant. Normalize lab test names."
122
- "All outputs must use **title case** (e.g., 'Hemoglobin', 'Blood Glucose')."
123
- "Return ONLY valid JSON where keys are original names and values are standardized names. DO NOT return markdown formatting like backquotes etc.",
124
- ),
125
- (
126
- "human",
127
- """Normalize the following lab test names to their standard medical equivalents.
128
- Test names: {test_names}
129
- """,
130
- ),
131
- ]
132
- )
133
-
134
- # chain = prompt → LLM → string
135
- testname_standardizer_chain = testname_standardizer_prompt | llm
136
-
137
- # -----------------------------
138
- # GRAPH NODES
139
- # -----------------------------
140
-
141
-
142
- def fn_init_node(state: SheamiState):
143
- os.makedirs(SheamiConfig.get_output_dir(state["thread_id"]), exist_ok=True)
144
- state["standardized_reports"] = []
145
- state["trends_json"] = {}
146
- state["interpreted_report"] = ""
147
- return state
148
-
149
-
150
- def fn_standardizer_node(state: SheamiState):
151
- logger.info("%s| Standardizing reports", state["thread_id"])
152
- llm_structured = llm.with_structured_output(StandardizedReport)
153
- for idx, report in enumerate(state["uploaded_reports"]):
154
- logger.info("%s| Standardizing report %s", state["thread_id"], report.report_file_name)
155
-
156
- messages = [
157
- SystemMessage(content="Standardize this medical report into the schema."),
158
- # SystemMessage(
159
- # content="Populate the `inferred_range` field as 'low', 'normal', or 'high' by comparing the result value with the reference range. If both min and max are missing, set 'normal' unless the value is clearly out of usual medical ranges."
160
- # ),
161
- HumanMessage(content=report.report_contents),
162
- ]
163
- result: StandardizedReport = llm_structured.invoke(messages)
164
- state["standardized_reports"].append(result)
165
- # save to disk
166
- with open(
167
- os.path.join(SheamiConfig.get_output_dir(state["thread_id"]), f"report_{idx}.json"), "w"
168
- ) as f:
169
- f.write(result.model_dump_json(indent=2))
170
- logger.info("%s| Standardizing Reports: finished", state["thread_id"])
171
- return state
172
-
173
-
174
- def fn_testname_standardizer_node(state: SheamiState):
175
- logger.info("%s| Standardizing Test Names: started", state["thread_id"])
176
- # collect unique names
177
- unique_names = list(
178
- {
179
- result.test_name
180
- for report in state["standardized_reports"]
181
- for result in report.lab_results
182
- }
183
- )
184
-
185
- # run through LLM
186
- response = testname_standardizer_chain.invoke({"test_names": unique_names})
187
- raw_text = response.content
188
-
189
- try:
190
- normalization_map: Dict[str, str] = json.loads(raw_text)
191
- except Exception as e:
192
- print("Exception in normalization: ", e)
193
- normalization_map = {
194
- name: name for name in unique_names
195
- } # fallback: identity mapping
196
-
197
- # apply mapping back
198
- for report in state["standardized_reports"]:
199
- for result in report.lab_results:
200
- result.test_name = normalization_map.get(result.test_name, result.test_name)
201
- logger.info("%s| Standardizing Test Names: finished", state["thread_id"])
202
- return state
203
-
204
-
205
- def fn_unit_normalizer_node(state: SheamiState):
206
- logger.info("%s| Standardizing Units : started", state["thread_id"])
207
- """
208
- Normalize units for lab test values across all standardized reports.
209
- Example: 'gms/dL', 'gm%', 'G/DL' → 'g/dL'
210
- """
211
- unit_map = {
212
- "g/dl": "g/dL",
213
- "gms/dl": "g/dL",
214
- "gm%": "g/dL",
215
- "g/dl.": "g/dL",
216
- }
217
-
218
- for report in state["standardized_reports"]:
219
- for lr in report.lab_results:
220
- if not lr.test_unit:
221
- continue
222
- normalized = lr.test_unit.lower().replace(" ", "")
223
- lr.test_unit = unit_map.get(
224
- normalized, lr.test_unit
225
- ) # fallback: keep original
226
-
227
- logger.info("%s| Standardizing Units : finished", state["thread_id"])
228
- return state
229
-
230
-
231
- def fn_trends_aggregator_node(state: SheamiState):
232
- logger.info("%s| Aggregating Trends : started", state["thread_id"])
233
- import re
234
-
235
- # group results by test_name
236
- trends = {}
237
- ref_ranges = {}
238
-
239
- def try_parse_float(value: str):
240
- try:
241
- return float(value)
242
- except (ValueError, TypeError):
243
- # return None if not numeric
244
- return None
245
-
246
- for idx, report in enumerate(state["standardized_reports"]):
247
- logger.info("%s| Aggregating Trends for report-%d", state["thread_id"], idx)
248
- for lr in report.lab_results:
249
- numeric_value = try_parse_float(lr.result_value)
250
-
251
- trends.setdefault(lr.test_name, []).append(
252
- {
253
- "date": lr.test_date or "unknown",
254
- "value": (
255
- numeric_value if numeric_value is not None else lr.result_value
256
- ),
257
- "is_numeric": numeric_value is not None,
258
- "unit": lr.test_unit,
259
- }
260
- )
261
-
262
- # Capture reference range if available (assuming same for all entries of a test_name)
263
- if lr.test_reference_range:
264
- ref_ranges[lr.test_name] = {
265
- "min": lr.test_reference_range.min,
266
- "max": lr.test_reference_range.max,
267
- }
268
-
269
- # combine into parameter_trends
270
- state["trends_json"] = {
271
- "parameter_trends": [
272
- {
273
- "test_name": k,
274
- "values": v,
275
- "reference_range": ref_ranges.get(k), # attach thresholds
276
- }
277
- for k, v in trends.items()
278
- ]
279
- }
280
-
281
- with open(os.path.join(SheamiConfig.get_output_dir(state["thread_id"]), "trends.json"), "w") as f:
282
- json.dump(state["trends_json"], f, indent=2)
283
- logger.info("%s| Aggregating Trends : finished", state["thread_id"])
284
- return state
285
-
286
-
287
- def fn_interpreter_node(state: SheamiState):
288
- logger.info("%s| Interpreting Trends : started", state["thread_id"])
289
- # 1. LLM narrative
290
- messages = [
291
- SystemMessage(
292
- content="Interpret the following medical trends and produce a report with patient summary, trend summaries, and clinical insights. "
293
- "Do not include charts, they will be programmatically added."
294
- ),
295
- HumanMessage(content=json.dumps(state["trends_json"], indent=2)),
296
- ]
297
- response = llm.invoke(messages)
298
- interpretation_text = response.content
299
-
300
- # 2. Generate plots for each parameter
301
- plots_dir = os.path.join(SheamiConfig.get_output_dir(state["thread_id"]), "plots")
302
- os.makedirs(plots_dir, exist_ok=True)
303
- plot_files = []
304
-
305
- for param in sorted(
306
- state["trends_json"].get("parameter_trends", []), key=lambda x: x["test_name"]
307
- ):
308
- test_name = param["test_name"]
309
- values = param["values"]
310
-
311
- # plotting + PDF writing logic here
312
- x = [v["date"] for v in values]
313
- # print("original dates for ", test_name, "= ", x)
314
- x = [parse_any_date(d) for d in x]
315
- x = pd.to_datetime(x, errors="coerce")
316
- # print("formatted dates for ", test_name, "= ", x)
317
-
318
- try:
319
- y = [float(v["value"]) for v in values]
320
- except ValueError:
321
- continue # skip non-numeric
322
-
323
- ## sort the data by date
324
- # Zip into a DataFrame for easy sorting
325
- df_plot = pd.DataFrame({"x": x, "y": y})
326
-
327
- # Drop invalid dates if any
328
- df_plot = df_plot.dropna(subset=["x"])
329
-
330
- # Sort by date
331
- df_plot = df_plot.sort_values("x")
332
-
333
- # Extract sorted arrays
334
- x = df_plot["x"].to_numpy()
335
- y = df_plot["y"].to_numpy()
336
-
337
- # print("formatted + sorted dates for", test_name, "=", x)
338
-
339
- plt.figure(figsize=(6, 4))
340
- plt.plot(x, y, marker="o", linestyle="-", label="Observed values")
341
-
342
- # add thresholds if available
343
- ref = param.get("reference_range")
344
- if ref:
345
- ymin, ymax = ref.get("min"), ref.get("max")
346
- if ymin is not None and ymax is not None:
347
- plt.axhspan(
348
- ymin, ymax, color="green", alpha=0.2, label="Reference range"
349
- )
350
- elif ymax is not None:
351
- plt.axhline(
352
- y=ymax, color="red", linestyle="--", label="Upper threshold"
353
- )
354
- elif ymin is not None:
355
- plt.axhline(
356
- y=ymin, color="blue", linestyle="--", label="Lower threshold"
357
- )
358
-
359
- plt.title(f"{test_name} Trend")
360
- plt.xlabel("Date")
361
- plt.ylabel(values[0]["unit"] if values and "unit" in values[0] else "")
362
- plt.grid(True)
363
- plt.xticks(rotation=45)
364
- plt.legend()
365
- plt.tight_layout()
366
-
367
- filename = f"{safe_filename(test_name).replace(' ', '_')}_trend.png"
368
- filepath = os.path.join(plots_dir, filename)
369
- plt.savefig(filepath)
370
- plt.close()
371
- plot_files.append((test_name, filepath))
372
-
373
- # 3. Build PDF
374
- pdf_path = os.path.join(SheamiConfig.get_output_dir(state["thread_id"]), "final_report.pdf")
375
- doc = SimpleDocTemplate(pdf_path, pagesize=A4)
376
- styles = getSampleStyleSheet()
377
- story = []
378
-
379
- # Add title
380
- story.append(Paragraph("<b>Medical Report Interpretation</b>", styles["Title"]))
381
- story.append(Spacer(1, 0.3 * inch))
382
-
383
- # Add interpretation text (LLM output)
384
- for line in interpretation_text.split("\n"):
385
- if line.strip():
386
- story.append(Paragraph(line.strip(), styles["Normal"]))
387
- story.append(Spacer(1, 0.15 * inch))
388
-
389
- # Add charts
390
- story.append(Spacer(1, 0.5 * inch))
391
- story.append(Paragraph("<b>Trends</b>", styles["Heading2"]))
392
- story.append(Spacer(1, 0.2 * inch))
393
-
394
- for test_name, plotfile in plot_files:
395
- story.append(Paragraph(f"<b>{test_name}</b>", styles["Heading3"]))
396
- story.append(Image(plotfile, width=5 * inch, height=3 * inch))
397
- story.append(Spacer(1, 0.3 * inch))
398
-
399
- doc.build(story)
400
-
401
- state["interpreted_report"] = pdf_path
402
- ###### Schedule Cleanup of output dir after 5 min.
403
- schedule_cleanup(file_path=SheamiConfig.get_output_dir(state["thread_id"]))
404
- logger.info("%s| Interpreting Trends : finished", state["thread_id"])
405
- return state
406
-
407
-
408
- def schedule_cleanup(file_path, delay=300): # 300 sec = 5 min
409
- def cleanup():
410
- time.sleep(delay)
411
- if os.path.exists(file_path):
412
- try:
413
- if os.path.isdir(file_path):
414
- import shutil
415
- shutil.rmtree(file_path)
416
- else:
417
- os.remove(file_path)
418
- print(f"Cleaned up: {file_path}")
419
- except Exception as e:
420
- print(f"Cleanup failed for {file_path}: {e}")
421
- threading.Thread(target=cleanup, daemon=True).start()
422
-
423
- # -----------------------------
424
- # GRAPH CREATION
425
- # -----------------------------
426
-
427
-
428
- def create_graph(thread_id : str):
429
- logger.info("%s| Creating Graph : started", thread_id)
430
- memory = InMemorySaver()
431
- workflow = StateGraph(SheamiState)
432
- workflow.add_node("init", fn_init_node)
433
- workflow.add_node("standardizer", fn_standardizer_node)
434
- workflow.add_node("testname_standardizer", fn_testname_standardizer_node)
435
- workflow.add_node("unit_normalizer", fn_unit_normalizer_node)
436
- workflow.add_node("trends", fn_trends_aggregator_node)
437
- workflow.add_node("interpreter", fn_interpreter_node)
438
-
439
- workflow.add_edge(START, "init")
440
- workflow.add_edge("init", "standardizer")
441
- workflow.add_edge("standardizer", "testname_standardizer")
442
- workflow.add_edge("testname_standardizer", "unit_normalizer")
443
- workflow.add_edge("unit_normalizer", "trends")
444
- workflow.add_edge("trends", "interpreter")
445
- workflow.add_edge("interpreter", END)
446
-
447
- logger.info("%s| Creating Graph : finished", thread_id)
448
- return workflow.compile(checkpointer=memory)
 
1
+ import threading
2
+ import time
3
+ from numpy import number
4
+ import pandas as pd
5
+ from langchain_core.prompts import ChatPromptTemplate
6
+ from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Image
7
+ from reportlab.lib.pagesizes import A4
8
+ from reportlab.lib.styles import getSampleStyleSheet
9
+ from reportlab.lib.units import inch
10
+ import matplotlib.pyplot as plt
11
+ from dataclasses import dataclass
12
+ from typing import Dict, List, Literal, Optional, TypedDict
13
+ import os, json
14
+ from pydantic import BaseModel
15
+ from langchain_core.messages import HumanMessage, SystemMessage
16
+ from langgraph.checkpoint.memory import InMemorySaver
17
+ from langgraph.graph.message import StateGraph
18
+ from langgraph.graph.state import START, END
19
+ from langchain_openai import ChatOpenAI
20
+ from dotenv import load_dotenv
21
+ from config import SheamiConfig
22
+ import logging
23
+
24
+ logging.basicConfig()
25
+ logger = logging.getLogger(__name__)
26
+ logger.setLevel(logging.INFO)
27
+
28
+ load_dotenv(override=True)
29
+ llm = ChatOpenAI(model=os.getenv("MODEL"), temperature=0.3)
30
+
31
+ # -----------------------------
32
+ # SCHEMA DEFINITIONS
33
+ # -----------------------------
34
+
35
+
36
+ from typing import Optional, List
37
+ from pydantic import BaseModel, Field
38
+
39
+
40
+ class PatientInfo(BaseModel):
41
+ name: Optional[str] = Field(None, description="Patient's full name")
42
+ age: Optional[int] = Field(None, description="Patient's age in years")
43
+ sex: Optional[str] = Field(None, description="Male/Female/Other")
44
+ medical_record_number: Optional[str] = None
45
+
46
+ class Config:
47
+ extra = "forbid" # 🚨 ensures schema matches OpenAI’s strict rules
48
+
49
+
50
+ class TestResultReferenceRange(BaseModel):
51
+ min: Optional[float] = None
52
+ max: Optional[float] = None
53
+
54
+
55
+ class LabResult(BaseModel):
56
+ test_name: str
57
+ result_value: str
58
+ test_unit: str
59
+ test_reference_range: Optional[TestResultReferenceRange] = None
60
+ test_date: Optional[str] = None
61
+ inferred_range: Literal["low", "normal", "high"]
62
+
63
+ class Config:
64
+ extra = "forbid"
65
+
66
+
67
+ class StandardizedReport(BaseModel):
68
+ patient_info: PatientInfo # 🚨 no longer a raw dict
69
+ lab_results: List[LabResult]
70
+ diagnosis: List[str]
71
+ recommendations: List[str]
72
+
73
+ class Config:
74
+ extra = "forbid"
75
+
76
+
77
+ @dataclass
78
+ class HealthReport:
79
+ report_file_name: str
80
+ report_contents: str
81
+
82
+
83
+ class SheamiState(TypedDict):
84
+ thread_id: str
85
+ uploaded_reports: List[HealthReport]
86
+ standardized_reports: List[StandardizedReport]
87
+ trends_json: dict
88
+ interpreted_report: str
89
+
90
+
91
+ import re
92
+
93
+
94
+ def safe_filename(name: str) -> str:
95
+ # Replace spaces with underscores
96
+ name = name.replace(" ", "_")
97
+ # Replace any non-alphanumeric / dash / underscore with "_"
98
+ name = re.sub(r"[^A-Za-z0-9_\-]", "_", name)
99
+ # Collapse multiple underscores
100
+ name = re.sub(r"_+", "_", name)
101
+ return name.strip("_")
102
+
103
+
104
+ import dateutil.parser
105
+
106
+
107
+ def parse_any_date(date_str):
108
+ if not date_str or pd.isna(date_str):
109
+ return pd.NaT
110
+ try:
111
+ return dateutil.parser.parse(str(date_str), dayfirst=False, fuzzy=True)
112
+ except Exception:
113
+ return pd.NaT
114
+
115
+
116
+ # prompt template
117
+ testname_standardizer_prompt = ChatPromptTemplate.from_messages(
118
+ [
119
+ (
120
+ "system",
121
+ "You are a medical assistant. Normalize lab test names."
122
+ "All outputs must use **title case** (e.g., 'Hemoglobin', 'Blood Glucose')."
123
+ "Return ONLY valid JSON where keys are original names and values are standardized names. DO NOT return markdown formatting like backquotes etc.",
124
+ ),
125
+ (
126
+ "human",
127
+ """Normalize the following lab test names to their standard medical equivalents.
128
+ Test names: {test_names}
129
+ """,
130
+ ),
131
+ ]
132
+ )
133
+
134
+ # chain = prompt → LLM → string
135
+ testname_standardizer_chain = testname_standardizer_prompt | llm
136
+
137
+ # -----------------------------
138
+ # GRAPH NODES
139
+ # -----------------------------
140
+
141
+
142
+ def fn_init_node(state: SheamiState):
143
+ os.makedirs(SheamiConfig.get_output_dir(state["thread_id"]), exist_ok=True)
144
+ state["standardized_reports"] = []
145
+ state["trends_json"] = {}
146
+ state["interpreted_report"] = ""
147
+ return state
148
+
149
+
150
+ def fn_standardizer_node(state: SheamiState):
151
+ logger.info("%s| Standardizing reports", state["thread_id"])
152
+ llm_structured = llm.with_structured_output(StandardizedReport)
153
+ for idx, report in enumerate(state["uploaded_reports"]):
154
+ logger.info("%s| Standardizing report %s", state["thread_id"], report.report_file_name)
155
+
156
+ messages = [
157
+ SystemMessage(content="Standardize this medical report into the schema."),
158
+ # SystemMessage(
159
+ # content="Populate the `inferred_range` field as 'low', 'normal', or 'high' by comparing the result value with the reference range. If both min and max are missing, set 'normal' unless the value is clearly out of usual medical ranges."
160
+ # ),
161
+ HumanMessage(content=report.report_contents),
162
+ ]
163
+ result: StandardizedReport = llm_structured.invoke(messages)
164
+ state["standardized_reports"].append(result)
165
+ # save to disk
166
+ with open(
167
+ os.path.join(SheamiConfig.get_output_dir(state["thread_id"]), f"report_{idx}.json"), "w"
168
+ ) as f:
169
+ f.write(result.model_dump_json(indent=2))
170
+ logger.info("%s| Standardizing Reports: finished", state["thread_id"])
171
+ return state
172
+
173
+
174
+ def fn_testname_standardizer_node(state: SheamiState):
175
+ logger.info("%s| Standardizing Test Names: started", state["thread_id"])
176
+ # collect unique names
177
+ unique_names = list(
178
+ {
179
+ result.test_name
180
+ for report in state["standardized_reports"]
181
+ for result in report.lab_results
182
+ }
183
+ )
184
+
185
+ # run through LLM
186
+ response = testname_standardizer_chain.invoke({"test_names": unique_names})
187
+ raw_text = response.content
188
+
189
+ try:
190
+ normalization_map: Dict[str, str] = json.loads(raw_text)
191
+ except Exception as e:
192
+ print("Exception in normalization: ", e)
193
+ normalization_map = {
194
+ name: name for name in unique_names
195
+ } # fallback: identity mapping
196
+
197
+ # apply mapping back
198
+ for report in state["standardized_reports"]:
199
+ for result in report.lab_results:
200
+ result.test_name = normalization_map.get(result.test_name, result.test_name)
201
+ logger.info("%s| Standardizing Test Names: finished", state["thread_id"])
202
+ return state
203
+
204
+
205
+ def fn_unit_normalizer_node(state: SheamiState):
206
+ logger.info("%s| Standardizing Units : started", state["thread_id"])
207
+ """
208
+ Normalize units for lab test values across all standardized reports.
209
+ Example: 'gms/dL', 'gm%', 'G/DL' → 'g/dL'
210
+ """
211
+ unit_map = {
212
+ "g/dl": "g/dL",
213
+ "gms/dl": "g/dL",
214
+ "gm%": "g/dL",
215
+ "g/dl.": "g/dL",
216
+ }
217
+
218
+ for report in state["standardized_reports"]:
219
+ for lr in report.lab_results:
220
+ if not lr.test_unit:
221
+ continue
222
+ normalized = lr.test_unit.lower().replace(" ", "")
223
+ lr.test_unit = unit_map.get(
224
+ normalized, lr.test_unit
225
+ ) # fallback: keep original
226
+
227
+ logger.info("%s| Standardizing Units : finished", state["thread_id"])
228
+ return state
229
+
230
+
231
+ def fn_trends_aggregator_node(state: SheamiState):
232
+ logger.info("%s| Aggregating Trends : started", state["thread_id"])
233
+ import re
234
+
235
+ # group results by test_name
236
+ trends = {}
237
+ ref_ranges = {}
238
+
239
+ def try_parse_float(value: str):
240
+ try:
241
+ return float(value)
242
+ except (ValueError, TypeError):
243
+ # return None if not numeric
244
+ return None
245
+
246
+ for idx, report in enumerate(state["standardized_reports"]):
247
+ logger.info("%s| Aggregating Trends for report-%d", state["thread_id"], idx)
248
+ for lr in report.lab_results:
249
+ numeric_value = try_parse_float(lr.result_value)
250
+
251
+ trends.setdefault(lr.test_name, []).append(
252
+ {
253
+ "date": lr.test_date or "unknown",
254
+ "value": (
255
+ numeric_value if numeric_value is not None else lr.result_value
256
+ ),
257
+ "is_numeric": numeric_value is not None,
258
+ "unit": lr.test_unit,
259
+ }
260
+ )
261
+
262
+ # Capture reference range if available (assuming same for all entries of a test_name)
263
+ if lr.test_reference_range:
264
+ ref_ranges[lr.test_name] = {
265
+ "min": lr.test_reference_range.min,
266
+ "max": lr.test_reference_range.max,
267
+ }
268
+
269
+ # combine into parameter_trends
270
+ state["trends_json"] = {
271
+ "parameter_trends": [
272
+ {
273
+ "test_name": k,
274
+ "values": v,
275
+ "reference_range": ref_ranges.get(k), # attach thresholds
276
+ }
277
+ for k, v in trends.items()
278
+ ]
279
+ }
280
+
281
+ with open(os.path.join(SheamiConfig.get_output_dir(state["thread_id"]), "trends.json"), "w") as f:
282
+ json.dump(state["trends_json"], f, indent=2)
283
+ logger.info("%s| Aggregating Trends : finished", state["thread_id"])
284
+ return state
285
+
286
+
287
+ def fn_interpreter_node(state: SheamiState):
288
+ logger.info("%s| Interpreting Trends : started", state["thread_id"])
289
+ # 1. LLM narrative
290
+ messages = [
291
+ SystemMessage(
292
+ content="Interpret the following medical trends and produce a report with patient summary, trend summaries, and clinical insights. "
293
+ "Do not include charts, they will be programmatically added."
294
+ ),
295
+ HumanMessage(content=json.dumps(state["trends_json"], indent=2)),
296
+ ]
297
+ response = llm.invoke(messages)
298
+ interpretation_text = response.content
299
+
300
+ # 2. Generate plots for each parameter
301
+ plots_dir = os.path.join(SheamiConfig.get_output_dir(state["thread_id"]), "plots")
302
+ os.makedirs(plots_dir, exist_ok=True)
303
+ plot_files = []
304
+
305
+ for param in sorted(
306
+ state["trends_json"].get("parameter_trends", []), key=lambda x: x["test_name"]
307
+ ):
308
+ test_name = param["test_name"]
309
+ values = param["values"]
310
+
311
+ # plotting + PDF writing logic here
312
+ x = [v["date"] for v in values]
313
+ # print("original dates for ", test_name, "= ", x)
314
+ x = [parse_any_date(d) for d in x]
315
+ x = pd.to_datetime(x, errors="coerce")
316
+ # print("formatted dates for ", test_name, "= ", x)
317
+
318
+ try:
319
+ y = [float(v["value"]) for v in values]
320
+ except ValueError:
321
+ continue # skip non-numeric
322
+
323
+ ## sort the data by date
324
+ # Zip into a DataFrame for easy sorting
325
+ df_plot = pd.DataFrame({"x": x, "y": y})
326
+
327
+ # Drop invalid dates if any
328
+ df_plot = df_plot.dropna(subset=["x"])
329
+
330
+ # Sort by date
331
+ df_plot = df_plot.sort_values("x")
332
+
333
+ # Extract sorted arrays
334
+ x = df_plot["x"].to_numpy()
335
+ y = df_plot["y"].to_numpy()
336
+
337
+ # print("formatted + sorted dates for", test_name, "=", x)
338
+
339
+ plt.figure(figsize=(6, 4))
340
+ plt.plot(x, y, marker="o", linestyle="-", label="Observed values")
341
+
342
+ # add thresholds if available
343
+ ref = param.get("reference_range")
344
+ if ref:
345
+ ymin, ymax = ref.get("min"), ref.get("max")
346
+ if ymin is not None and ymax is not None:
347
+ plt.axhspan(
348
+ ymin, ymax, color="green", alpha=0.2, label="Reference range"
349
+ )
350
+ elif ymax is not None:
351
+ plt.axhline(
352
+ y=ymax, color="red", linestyle="--", label="Upper threshold"
353
+ )
354
+ elif ymin is not None:
355
+ plt.axhline(
356
+ y=ymin, color="blue", linestyle="--", label="Lower threshold"
357
+ )
358
+
359
+ plt.title(f"{test_name} Trend")
360
+ plt.xlabel("Date")
361
+ plt.ylabel(values[0]["unit"] if values and "unit" in values[0] else "")
362
+ plt.grid(True)
363
+ plt.xticks(rotation=45)
364
+ plt.legend()
365
+ plt.tight_layout()
366
+
367
+ filename = f"{safe_filename(test_name).replace(' ', '_')}_trend.png"
368
+ filepath = os.path.join(plots_dir, filename)
369
+ plt.savefig(filepath)
370
+ plt.close()
371
+ plot_files.append((test_name, filepath))
372
+
373
+ # 3. Build PDF
374
+ pdf_path = os.path.join(SheamiConfig.get_output_dir(state["thread_id"]), "final_report.pdf")
375
+ doc = SimpleDocTemplate(pdf_path, pagesize=A4)
376
+ styles = getSampleStyleSheet()
377
+ story = []
378
+
379
+ # Add title
380
+ story.append(Paragraph("<b>Medical Report Interpretation</b>", styles["Title"]))
381
+ story.append(Spacer(1, 0.3 * inch))
382
+
383
+ # Add interpretation text (LLM output)
384
+ for line in interpretation_text.split("\n"):
385
+ if line.strip():
386
+ story.append(Paragraph(line.strip(), styles["Normal"]))
387
+ story.append(Spacer(1, 0.15 * inch))
388
+
389
+ # Add charts
390
+ story.append(Spacer(1, 0.5 * inch))
391
+ story.append(Paragraph("<b>Trends</b>", styles["Heading2"]))
392
+ story.append(Spacer(1, 0.2 * inch))
393
+
394
+ for test_name, plotfile in plot_files:
395
+ story.append(Paragraph(f"<b>{test_name}</b>", styles["Heading3"]))
396
+ story.append(Image(plotfile, width=5 * inch, height=3 * inch))
397
+ story.append(Spacer(1, 0.3 * inch))
398
+
399
+ doc.build(story)
400
+
401
+ state["interpreted_report"] = pdf_path
402
+ ###### Schedule Cleanup of output dir after 5 min.
403
+ schedule_cleanup(file_path=SheamiConfig.get_output_dir(state["thread_id"]))
404
+ logger.info("%s| Interpreting Trends : finished", state["thread_id"])
405
+ return state
406
+
407
+
408
+ def schedule_cleanup(file_path, delay=300): # 300 sec = 5 min
409
+ def cleanup():
410
+ time.sleep(delay)
411
+ if os.path.exists(file_path):
412
+ try:
413
+ if os.path.isdir(file_path):
414
+ import shutil
415
+ shutil.rmtree(file_path)
416
+ else:
417
+ os.remove(file_path)
418
+ print(f"Cleaned up: {file_path}")
419
+ except Exception as e:
420
+ print(f"Cleanup failed for {file_path}: {e}")
421
+ threading.Thread(target=cleanup, daemon=True).start()
422
+
423
+ # -----------------------------
424
+ # GRAPH CREATION
425
+ # -----------------------------
426
+
427
+
428
+ def create_graph(thread_id : str):
429
+ logger.info("%s| Creating Graph : started", thread_id)
430
+ memory = InMemorySaver()
431
+ workflow = StateGraph(SheamiState)
432
+ workflow.add_node("init", fn_init_node)
433
+ workflow.add_node("standardizer", fn_standardizer_node)
434
+ workflow.add_node("testname_standardizer", fn_testname_standardizer_node)
435
+ workflow.add_node("unit_normalizer", fn_unit_normalizer_node)
436
+ workflow.add_node("trends", fn_trends_aggregator_node)
437
+ workflow.add_node("interpreter", fn_interpreter_node)
438
+
439
+ workflow.add_edge(START, "init")
440
+ workflow.add_edge("init", "standardizer")
441
+ workflow.add_edge("standardizer", "testname_standardizer")
442
+ workflow.add_edge("testname_standardizer", "unit_normalizer")
443
+ workflow.add_edge("unit_normalizer", "trends")
444
+ workflow.add_edge("trends", "interpreter")
445
+ workflow.add_edge("interpreter", END)
446
+
447
+ logger.info("%s| Creating Graph : finished", thread_id)
448
+ return workflow.compile(checkpointer=memory)
models.py CHANGED
@@ -1,14 +1,14 @@
1
- from pydantic import BaseModel
2
- from typing import List, Optional
3
-
4
- class SheamiLabResult(BaseModel):
5
- test_name: str
6
- result_value: str
7
- unit: str
8
- reference_range: Optional[str]
9
-
10
- class SheamiStandardizedReport(BaseModel):
11
- patient_info: dict
12
- lab_results: List[SheamiLabResult]
13
- diagnosis: List[str]
14
- recommendations: List[str]
 
1
+ from pydantic import BaseModel
2
+ from typing import List, Optional
3
+
4
+ class SheamiLabResult(BaseModel):
5
+ test_name: str
6
+ result_value: str
7
+ unit: str
8
+ reference_range: Optional[str]
9
+
10
+ class SheamiStandardizedReport(BaseModel):
11
+ patient_info: dict
12
+ lab_results: List[SheamiLabResult]
13
+ diagnosis: List[str]
14
+ recommendations: List[str]
pdf_reader.py CHANGED
@@ -1,17 +1,17 @@
1
- from pypdf import PdfReader
2
-
3
- def read_pdf(file_name:str):
4
- reader = PdfReader(file_name)
5
- # Get the number of pages
6
- number_of_pages = len(reader.pages)
7
- # print(f"Number of pages: {number_of_pages}")
8
-
9
- content = ""
10
- for page_num in range(len(reader.pages)):
11
- page = reader.pages[page_num]
12
- text = page.extract_text()
13
- # print(f"--- Page {page_num + 1} ---")
14
- # print(text)
15
- content += f"--- Page {page_num + 1} ---" + "\n\n" + text
16
-
17
  return content
 
1
+ from pypdf import PdfReader
2
+
3
+ def read_pdf(file_name:str):
4
+ reader = PdfReader(file_name)
5
+ # Get the number of pages
6
+ number_of_pages = len(reader.pages)
7
+ # print(f"Number of pages: {number_of_pages}")
8
+
9
+ content = ""
10
+ for page_num in range(len(reader.pages)):
11
+ page = reader.pages[page_num]
12
+ text = page.extract_text()
13
+ # print(f"--- Page {page_num + 1} ---")
14
+ # print(text)
15
+ content += f"--- Page {page_num + 1} ---" + "\n\n" + text
16
+
17
  return content