vikramvasudevan commited on
Commit
35d82ad
·
verified ·
1 Parent(s): d76bf6d

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. graph.py +6 -1
  2. home.py +61 -45
  3. modules/db.py +69 -2
  4. modules/models.py +2 -0
graph.py CHANGED
@@ -519,9 +519,14 @@ async def fn_interpreter_node(state: SheamiState):
519
  get_db().update_run_stats(run_id=state["run_id"], status="completed")
520
 
521
  ## add parsed reports
522
- get_db().add_report_v2(
523
  patient_id=state["patient_id"], reports=state["standardized_reports"]
524
  )
 
 
 
 
 
525
 
526
  return state
527
 
 
519
  get_db().update_run_stats(run_id=state["run_id"], status="completed")
520
 
521
  ## add parsed reports
522
+ report_id_list = get_db().add_report_v2(
523
  patient_id=state["patient_id"], reports=state["standardized_reports"]
524
  )
525
+ state["report_id_list"] = report_id_list
526
+
527
+ logger.info("report_id_list = %s", report_id_list)
528
+ for report_id in report_id_list.split(","):
529
+ get_db().aggregate_trends_from_report(state["patient_id"], report_id)
530
 
531
  return state
532
 
home.py CHANGED
@@ -53,10 +53,45 @@ def flatten_reports_v2(reports: List[Dict[str, Any]]) -> pd.DataFrame:
53
  lab_results = parsed_v2.get("lab_results", [])
54
 
55
  if not lab_results:
56
- rows.append({
57
- "report_id": rid,
58
- "uploaded_at": uploaded_at,
59
- "file_name": file_name,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  "test_name": "",
61
  "result_value": "",
62
  "test_unit": "",
@@ -65,39 +100,10 @@ def flatten_reports_v2(reports: List[Dict[str, Any]]) -> pd.DataFrame:
65
  "ref_raw": "",
66
  "test_date": "",
67
  "inferred_range": "",
68
- })
69
- else:
70
- for t in lab_results:
71
- ref = t.get("test_reference_range", {})
72
- rows.append({
73
- "report_id": rid,
74
- "uploaded_at": uploaded_at,
75
- "file_name": file_name,
76
- "test_name": t.get("test_name", ""),
77
- "result_value": t.get("result_value", ""),
78
- "test_unit": t.get("test_unit", ""),
79
- "ref_min": ref.get("min","-"),
80
- "ref_max": ref.get("max","-"),
81
- "ref_raw": ref.get("raw", "-"),
82
- "test_date": t.get("test_date", ""),
83
- "inferred_range": t.get("inferred_range", ""),
84
- })
85
- if not rows:
86
- rows = [{
87
- "report_id": "",
88
- "uploaded_at": "",
89
- "file_name": "",
90
- "test_name": "",
91
- "result_value": "",
92
- "test_unit": "",
93
- "ref_min": "",
94
- "ref_max": "",
95
- "ref_raw": "",
96
- "test_date": "",
97
- "inferred_range": "",
98
- }]
99
  df = pd.DataFrame(rows)
100
- df = df.fillna("-") # 👈 normalize missing values
101
  return df
102
 
103
 
@@ -370,9 +376,13 @@ with gr.Blocks(
370
  user_email_state = gr.State("")
371
  patient_id_state = gr.State("")
372
  # show them if you want
373
- shown_email = gr.Textbox(label="User Email", interactive=False)
374
- shown_patient = gr.Textbox(label="Patient ID", interactive=False)
375
- get_gradio_block(container=upload_reports_modal,user_email_state=user_email_state, patient_id_state=patient_id_state)
 
 
 
 
376
 
377
  with gr.Sidebar() as sheami_sidebar: # Sidebar
378
  # gr.Markdown("### Sidebar")
@@ -514,18 +524,24 @@ with gr.Blocks(
514
  # open modal and set states
515
  def show_upload_reports_modal(user_email, patient_id):
516
  return [
517
- user_email,
518
- patient_id,
519
- gr.update(value=user_email),
520
- gr.update(value=patient_id),
521
  gr.update(visible=True),
522
  ]
523
 
524
  upload_reports_btn.click(close_side_bar, outputs=[sheami_sidebar]).then(
525
  show_upload_reports_modal,
526
  inputs=[email_in, patient_list],
527
- outputs=[user_email_state, patient_id_state, shown_email, shown_patient, upload_reports_modal],
528
- )
 
 
 
 
 
 
529
 
530
  if __name__ == "__main__":
531
  demo.launch()
 
53
  lab_results = parsed_v2.get("lab_results", [])
54
 
55
  if not lab_results:
56
+ rows.append(
57
+ {
58
+ "report_id": rid,
59
+ "uploaded_at": uploaded_at,
60
+ "file_name": file_name,
61
+ "test_name": "",
62
+ "result_value": "",
63
+ "test_unit": "",
64
+ "ref_min": "",
65
+ "ref_max": "",
66
+ "ref_raw": "",
67
+ "test_date": "",
68
+ "inferred_range": "",
69
+ }
70
+ )
71
+ else:
72
+ for t in lab_results:
73
+ ref = t.get("test_reference_range", {})
74
+ rows.append(
75
+ {
76
+ "report_id": rid,
77
+ "uploaded_at": uploaded_at,
78
+ "file_name": file_name,
79
+ "test_name": t.get("test_name", ""),
80
+ "result_value": t.get("result_value", ""),
81
+ "test_unit": t.get("test_unit", ""),
82
+ "ref_min": ref.get("min", "-"),
83
+ "ref_max": ref.get("max", "-"),
84
+ "ref_raw": ref.get("raw", "-"),
85
+ "test_date": t.get("test_date", ""),
86
+ "inferred_range": t.get("inferred_range", ""),
87
+ }
88
+ )
89
+ if not rows:
90
+ rows = [
91
+ {
92
+ "report_id": "",
93
+ "uploaded_at": "",
94
+ "file_name": "",
95
  "test_name": "",
96
  "result_value": "",
97
  "test_unit": "",
 
100
  "ref_raw": "",
101
  "test_date": "",
102
  "inferred_range": "",
103
+ }
104
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  df = pd.DataFrame(rows)
106
+ df = df.fillna("-") # 👈 normalize missing values
107
  return df
108
 
109
 
 
376
  user_email_state = gr.State("")
377
  patient_id_state = gr.State("")
378
  # show them if you want
379
+ shown_email = gr.Textbox(label="User Email", interactive=False, visible=False)
380
+ shown_patient = gr.Textbox(label="Patient ID", interactive=False, visible=False)
381
+ get_gradio_block(
382
+ container=upload_reports_modal,
383
+ user_email_state=user_email_state,
384
+ patient_id_state=patient_id_state,
385
+ )
386
 
387
  with gr.Sidebar() as sheami_sidebar: # Sidebar
388
  # gr.Markdown("### Sidebar")
 
524
  # open modal and set states
525
  def show_upload_reports_modal(user_email, patient_id):
526
  return [
527
+ user_email,
528
+ patient_id,
529
+ gr.update(value=user_email),
530
+ gr.update(value=patient_id),
531
  gr.update(visible=True),
532
  ]
533
 
534
  upload_reports_btn.click(close_side_bar, outputs=[sheami_sidebar]).then(
535
  show_upload_reports_modal,
536
  inputs=[email_in, patient_list],
537
+ outputs=[
538
+ user_email_state,
539
+ patient_id_state,
540
+ shown_email,
541
+ shown_patient,
542
+ upload_reports_modal,
543
+ ],
544
+ )
545
 
546
  if __name__ == "__main__":
547
  demo.launch()
modules/db.py CHANGED
@@ -7,6 +7,7 @@ from dotenv import load_dotenv
7
 
8
  from modules.models import StandardizedReport
9
 
 
10
  class SheamiDB:
11
  def __init__(self, uri: str, db_name: str = "sheami"):
12
  """Initialize connection to MongoDB Atlas (or local Mongo)."""
@@ -70,7 +71,7 @@ class SheamiDB:
70
  # REPORT FUNCTIONS
71
  # ---------------------------
72
  def add_report_v2(self, patient_id: str, reports: list[StandardizedReport]) -> str:
73
- inserted_ids = []
74
  for parsed_data in reports:
75
  report = {
76
  "patient_id": ObjectId(patient_id),
@@ -80,7 +81,7 @@ class SheamiDB:
80
  }
81
  result = self.reports.insert_one(report)
82
  inserted_ids.append(result.inserted_id)
83
- return str(inserted_ids)
84
 
85
  def add_report(self, patient_id: str, file_name: str, parsed_data: any) -> str:
86
  report = {
@@ -305,3 +306,69 @@ class SheamiDB:
305
  )
306
 
307
  return result.modified_count
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  from modules.models import StandardizedReport
9
 
10
+
11
  class SheamiDB:
12
  def __init__(self, uri: str, db_name: str = "sheami"):
13
  """Initialize connection to MongoDB Atlas (or local Mongo)."""
 
71
  # REPORT FUNCTIONS
72
  # ---------------------------
73
  def add_report_v2(self, patient_id: str, reports: list[StandardizedReport]) -> str:
74
+ inserted_ids: list[ObjectId] = []
75
  for parsed_data in reports:
76
  report = {
77
  "patient_id": ObjectId(patient_id),
 
81
  }
82
  result = self.reports.insert_one(report)
83
  inserted_ids.append(result.inserted_id)
84
+ return ",".join([str(inserted_id) for inserted_id in inserted_ids])
85
 
86
  def add_report(self, patient_id: str, file_name: str, parsed_data: any) -> str:
87
  report = {
 
306
  )
307
 
308
  return result.modified_count
309
+
310
+ def aggregate_trends_from_report(self, patient_id: str, report_id: str):
311
+ """
312
+ Incrementally update patient trends based on a new report's tests.
313
+ - Fetches the report
314
+ - For each test, appends (date, value, unit) to trends[patient_id, test_name]
315
+ - Ensures no duplicate points for same report/test combo
316
+ """
317
+ report = self.reports.find_one({"_id": ObjectId(report_id)})
318
+ if not report:
319
+ raise ValueError(f"Report {report_id} not found")
320
+
321
+ # print("report = ",report)
322
+ tests = report.get("parsed_data_v2", {"lab_results": []}).get("lab_results", [])
323
+ if not tests:
324
+ return 0
325
+
326
+ updated = 0
327
+ for test in tests:
328
+ test_name = test.get("test_name")
329
+ if not test_name:
330
+ continue
331
+
332
+ value = test.get("result_value")
333
+ unit = test.get("test_unit")
334
+ test_date = (
335
+ test.get("test_date")
336
+ or report.get("uploaded_at")
337
+ or datetime.now(timezone.utc)
338
+ )
339
+
340
+ # Normalize date
341
+ if isinstance(test_date, (int, float)):
342
+ # handle timestamp
343
+ test_date = datetime.fromtimestamp(test_date, tz=timezone.utc)
344
+ elif isinstance(test_date, str):
345
+ try:
346
+ test_date = datetime.fromisoformat(test_date)
347
+ except Exception:
348
+ test_date = datetime.now(timezone.utc)
349
+
350
+ point = {
351
+ "date": test_date,
352
+ "value": value,
353
+ "unit": unit,
354
+ "report_id": ObjectId(report_id),
355
+ }
356
+
357
+ # Upsert trend doc
358
+ result = self.trends.update_one(
359
+ {"patient_id": ObjectId(patient_id), "test_name": test_name},
360
+ {
361
+ "$setOnInsert": {
362
+ "patient_id": ObjectId(patient_id),
363
+ "test_name": test_name,
364
+ "created_at": datetime.now(timezone.utc),
365
+ },
366
+ "$push": {"trend_data": point},
367
+ "$set": {"last_updated": datetime.now(timezone.utc)},
368
+ },
369
+ upsert=True,
370
+ )
371
+ updated += result.modified_count
372
+
373
+ # print("updated/inserted", updated, "trends")
374
+ return updated
modules/models.py CHANGED
@@ -3,6 +3,7 @@ from datetime import datetime
3
  from pydantic import BaseModel, Field
4
  from typing import List, Literal, Optional, TypedDict, Union
5
 
 
6
  class PatientInfo(BaseModel):
7
  name: Optional[str] = Field(None, description="Patient's full name")
8
  age: Optional[int] = Field(None, description="Patient's age in years")
@@ -84,6 +85,7 @@ class SheamiMilestone:
84
  class SheamiState(TypedDict):
85
  user_email: str
86
  patient_id: str
 
87
  run_id: str
88
  messages: list[str]
89
  thread_id: str
 
3
  from pydantic import BaseModel, Field
4
  from typing import List, Literal, Optional, TypedDict, Union
5
 
6
+
7
  class PatientInfo(BaseModel):
8
  name: Optional[str] = Field(None, description="Patient's full name")
9
  age: Optional[int] = Field(None, description="Patient's age in years")
 
85
  class SheamiState(TypedDict):
86
  user_email: str
87
  patient_id: str
88
+ report_id_list: str
89
  run_id: str
90
  messages: list[str]
91
  thread_id: str