Upload folder using huggingface_hub
Browse files- graph.py +6 -1
- home.py +61 -45
- modules/db.py +69 -2
- modules/models.py +2 -0
graph.py
CHANGED
|
@@ -519,9 +519,14 @@ async def fn_interpreter_node(state: SheamiState):
|
|
| 519 |
get_db().update_run_stats(run_id=state["run_id"], status="completed")
|
| 520 |
|
| 521 |
## add parsed reports
|
| 522 |
-
get_db().add_report_v2(
|
| 523 |
patient_id=state["patient_id"], reports=state["standardized_reports"]
|
| 524 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 525 |
|
| 526 |
return state
|
| 527 |
|
|
|
|
| 519 |
get_db().update_run_stats(run_id=state["run_id"], status="completed")
|
| 520 |
|
| 521 |
## add parsed reports
|
| 522 |
+
report_id_list = get_db().add_report_v2(
|
| 523 |
patient_id=state["patient_id"], reports=state["standardized_reports"]
|
| 524 |
)
|
| 525 |
+
state["report_id_list"] = report_id_list
|
| 526 |
+
|
| 527 |
+
logger.info("report_id_list = %s", report_id_list)
|
| 528 |
+
for report_id in report_id_list.split(","):
|
| 529 |
+
get_db().aggregate_trends_from_report(state["patient_id"], report_id)
|
| 530 |
|
| 531 |
return state
|
| 532 |
|
home.py
CHANGED
|
@@ -53,10 +53,45 @@ def flatten_reports_v2(reports: List[Dict[str, Any]]) -> pd.DataFrame:
|
|
| 53 |
lab_results = parsed_v2.get("lab_results", [])
|
| 54 |
|
| 55 |
if not lab_results:
|
| 56 |
-
rows.append(
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
"test_name": "",
|
| 61 |
"result_value": "",
|
| 62 |
"test_unit": "",
|
|
@@ -65,39 +100,10 @@ def flatten_reports_v2(reports: List[Dict[str, Any]]) -> pd.DataFrame:
|
|
| 65 |
"ref_raw": "",
|
| 66 |
"test_date": "",
|
| 67 |
"inferred_range": "",
|
| 68 |
-
}
|
| 69 |
-
|
| 70 |
-
for t in lab_results:
|
| 71 |
-
ref = t.get("test_reference_range", {})
|
| 72 |
-
rows.append({
|
| 73 |
-
"report_id": rid,
|
| 74 |
-
"uploaded_at": uploaded_at,
|
| 75 |
-
"file_name": file_name,
|
| 76 |
-
"test_name": t.get("test_name", ""),
|
| 77 |
-
"result_value": t.get("result_value", ""),
|
| 78 |
-
"test_unit": t.get("test_unit", ""),
|
| 79 |
-
"ref_min": ref.get("min","-"),
|
| 80 |
-
"ref_max": ref.get("max","-"),
|
| 81 |
-
"ref_raw": ref.get("raw", "-"),
|
| 82 |
-
"test_date": t.get("test_date", ""),
|
| 83 |
-
"inferred_range": t.get("inferred_range", ""),
|
| 84 |
-
})
|
| 85 |
-
if not rows:
|
| 86 |
-
rows = [{
|
| 87 |
-
"report_id": "",
|
| 88 |
-
"uploaded_at": "",
|
| 89 |
-
"file_name": "",
|
| 90 |
-
"test_name": "",
|
| 91 |
-
"result_value": "",
|
| 92 |
-
"test_unit": "",
|
| 93 |
-
"ref_min": "",
|
| 94 |
-
"ref_max": "",
|
| 95 |
-
"ref_raw": "",
|
| 96 |
-
"test_date": "",
|
| 97 |
-
"inferred_range": "",
|
| 98 |
-
}]
|
| 99 |
df = pd.DataFrame(rows)
|
| 100 |
-
df = df.fillna("-")
|
| 101 |
return df
|
| 102 |
|
| 103 |
|
|
@@ -370,9 +376,13 @@ with gr.Blocks(
|
|
| 370 |
user_email_state = gr.State("")
|
| 371 |
patient_id_state = gr.State("")
|
| 372 |
# show them if you want
|
| 373 |
-
shown_email = gr.Textbox(label="User Email", interactive=False)
|
| 374 |
-
shown_patient = gr.Textbox(label="Patient ID", interactive=False)
|
| 375 |
-
get_gradio_block(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
|
| 377 |
with gr.Sidebar() as sheami_sidebar: # Sidebar
|
| 378 |
# gr.Markdown("### Sidebar")
|
|
@@ -514,18 +524,24 @@ with gr.Blocks(
|
|
| 514 |
# open modal and set states
|
| 515 |
def show_upload_reports_modal(user_email, patient_id):
|
| 516 |
return [
|
| 517 |
-
user_email,
|
| 518 |
-
patient_id,
|
| 519 |
-
gr.update(value=user_email),
|
| 520 |
-
gr.update(value=patient_id),
|
| 521 |
gr.update(visible=True),
|
| 522 |
]
|
| 523 |
|
| 524 |
upload_reports_btn.click(close_side_bar, outputs=[sheami_sidebar]).then(
|
| 525 |
show_upload_reports_modal,
|
| 526 |
inputs=[email_in, patient_list],
|
| 527 |
-
outputs=[
|
| 528 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 529 |
|
| 530 |
if __name__ == "__main__":
|
| 531 |
demo.launch()
|
|
|
|
| 53 |
lab_results = parsed_v2.get("lab_results", [])
|
| 54 |
|
| 55 |
if not lab_results:
|
| 56 |
+
rows.append(
|
| 57 |
+
{
|
| 58 |
+
"report_id": rid,
|
| 59 |
+
"uploaded_at": uploaded_at,
|
| 60 |
+
"file_name": file_name,
|
| 61 |
+
"test_name": "",
|
| 62 |
+
"result_value": "",
|
| 63 |
+
"test_unit": "",
|
| 64 |
+
"ref_min": "",
|
| 65 |
+
"ref_max": "",
|
| 66 |
+
"ref_raw": "",
|
| 67 |
+
"test_date": "",
|
| 68 |
+
"inferred_range": "",
|
| 69 |
+
}
|
| 70 |
+
)
|
| 71 |
+
else:
|
| 72 |
+
for t in lab_results:
|
| 73 |
+
ref = t.get("test_reference_range", {})
|
| 74 |
+
rows.append(
|
| 75 |
+
{
|
| 76 |
+
"report_id": rid,
|
| 77 |
+
"uploaded_at": uploaded_at,
|
| 78 |
+
"file_name": file_name,
|
| 79 |
+
"test_name": t.get("test_name", ""),
|
| 80 |
+
"result_value": t.get("result_value", ""),
|
| 81 |
+
"test_unit": t.get("test_unit", ""),
|
| 82 |
+
"ref_min": ref.get("min", "-"),
|
| 83 |
+
"ref_max": ref.get("max", "-"),
|
| 84 |
+
"ref_raw": ref.get("raw", "-"),
|
| 85 |
+
"test_date": t.get("test_date", ""),
|
| 86 |
+
"inferred_range": t.get("inferred_range", ""),
|
| 87 |
+
}
|
| 88 |
+
)
|
| 89 |
+
if not rows:
|
| 90 |
+
rows = [
|
| 91 |
+
{
|
| 92 |
+
"report_id": "",
|
| 93 |
+
"uploaded_at": "",
|
| 94 |
+
"file_name": "",
|
| 95 |
"test_name": "",
|
| 96 |
"result_value": "",
|
| 97 |
"test_unit": "",
|
|
|
|
| 100 |
"ref_raw": "",
|
| 101 |
"test_date": "",
|
| 102 |
"inferred_range": "",
|
| 103 |
+
}
|
| 104 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
df = pd.DataFrame(rows)
|
| 106 |
+
df = df.fillna("-") # 👈 normalize missing values
|
| 107 |
return df
|
| 108 |
|
| 109 |
|
|
|
|
| 376 |
user_email_state = gr.State("")
|
| 377 |
patient_id_state = gr.State("")
|
| 378 |
# show them if you want
|
| 379 |
+
shown_email = gr.Textbox(label="User Email", interactive=False, visible=False)
|
| 380 |
+
shown_patient = gr.Textbox(label="Patient ID", interactive=False, visible=False)
|
| 381 |
+
get_gradio_block(
|
| 382 |
+
container=upload_reports_modal,
|
| 383 |
+
user_email_state=user_email_state,
|
| 384 |
+
patient_id_state=patient_id_state,
|
| 385 |
+
)
|
| 386 |
|
| 387 |
with gr.Sidebar() as sheami_sidebar: # Sidebar
|
| 388 |
# gr.Markdown("### Sidebar")
|
|
|
|
| 524 |
# open modal and set states
|
| 525 |
def show_upload_reports_modal(user_email, patient_id):
|
| 526 |
return [
|
| 527 |
+
user_email,
|
| 528 |
+
patient_id,
|
| 529 |
+
gr.update(value=user_email),
|
| 530 |
+
gr.update(value=patient_id),
|
| 531 |
gr.update(visible=True),
|
| 532 |
]
|
| 533 |
|
| 534 |
upload_reports_btn.click(close_side_bar, outputs=[sheami_sidebar]).then(
|
| 535 |
show_upload_reports_modal,
|
| 536 |
inputs=[email_in, patient_list],
|
| 537 |
+
outputs=[
|
| 538 |
+
user_email_state,
|
| 539 |
+
patient_id_state,
|
| 540 |
+
shown_email,
|
| 541 |
+
shown_patient,
|
| 542 |
+
upload_reports_modal,
|
| 543 |
+
],
|
| 544 |
+
)
|
| 545 |
|
| 546 |
if __name__ == "__main__":
|
| 547 |
demo.launch()
|
modules/db.py
CHANGED
|
@@ -7,6 +7,7 @@ from dotenv import load_dotenv
|
|
| 7 |
|
| 8 |
from modules.models import StandardizedReport
|
| 9 |
|
|
|
|
| 10 |
class SheamiDB:
|
| 11 |
def __init__(self, uri: str, db_name: str = "sheami"):
|
| 12 |
"""Initialize connection to MongoDB Atlas (or local Mongo)."""
|
|
@@ -70,7 +71,7 @@ class SheamiDB:
|
|
| 70 |
# REPORT FUNCTIONS
|
| 71 |
# ---------------------------
|
| 72 |
def add_report_v2(self, patient_id: str, reports: list[StandardizedReport]) -> str:
|
| 73 |
-
inserted_ids = []
|
| 74 |
for parsed_data in reports:
|
| 75 |
report = {
|
| 76 |
"patient_id": ObjectId(patient_id),
|
|
@@ -80,7 +81,7 @@ class SheamiDB:
|
|
| 80 |
}
|
| 81 |
result = self.reports.insert_one(report)
|
| 82 |
inserted_ids.append(result.inserted_id)
|
| 83 |
-
return str(inserted_ids)
|
| 84 |
|
| 85 |
def add_report(self, patient_id: str, file_name: str, parsed_data: any) -> str:
|
| 86 |
report = {
|
|
@@ -305,3 +306,69 @@ class SheamiDB:
|
|
| 305 |
)
|
| 306 |
|
| 307 |
return result.modified_count
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
from modules.models import StandardizedReport
|
| 9 |
|
| 10 |
+
|
| 11 |
class SheamiDB:
|
| 12 |
def __init__(self, uri: str, db_name: str = "sheami"):
|
| 13 |
"""Initialize connection to MongoDB Atlas (or local Mongo)."""
|
|
|
|
| 71 |
# REPORT FUNCTIONS
|
| 72 |
# ---------------------------
|
| 73 |
def add_report_v2(self, patient_id: str, reports: list[StandardizedReport]) -> str:
|
| 74 |
+
inserted_ids: list[ObjectId] = []
|
| 75 |
for parsed_data in reports:
|
| 76 |
report = {
|
| 77 |
"patient_id": ObjectId(patient_id),
|
|
|
|
| 81 |
}
|
| 82 |
result = self.reports.insert_one(report)
|
| 83 |
inserted_ids.append(result.inserted_id)
|
| 84 |
+
return ",".join([str(inserted_id) for inserted_id in inserted_ids])
|
| 85 |
|
| 86 |
def add_report(self, patient_id: str, file_name: str, parsed_data: any) -> str:
|
| 87 |
report = {
|
|
|
|
| 306 |
)
|
| 307 |
|
| 308 |
return result.modified_count
|
| 309 |
+
|
| 310 |
+
def aggregate_trends_from_report(self, patient_id: str, report_id: str):
|
| 311 |
+
"""
|
| 312 |
+
Incrementally update patient trends based on a new report's tests.
|
| 313 |
+
- Fetches the report
|
| 314 |
+
- For each test, appends (date, value, unit) to trends[patient_id, test_name]
|
| 315 |
+
- Ensures no duplicate points for same report/test combo
|
| 316 |
+
"""
|
| 317 |
+
report = self.reports.find_one({"_id": ObjectId(report_id)})
|
| 318 |
+
if not report:
|
| 319 |
+
raise ValueError(f"Report {report_id} not found")
|
| 320 |
+
|
| 321 |
+
# print("report = ",report)
|
| 322 |
+
tests = report.get("parsed_data_v2", {"lab_results": []}).get("lab_results", [])
|
| 323 |
+
if not tests:
|
| 324 |
+
return 0
|
| 325 |
+
|
| 326 |
+
updated = 0
|
| 327 |
+
for test in tests:
|
| 328 |
+
test_name = test.get("test_name")
|
| 329 |
+
if not test_name:
|
| 330 |
+
continue
|
| 331 |
+
|
| 332 |
+
value = test.get("result_value")
|
| 333 |
+
unit = test.get("test_unit")
|
| 334 |
+
test_date = (
|
| 335 |
+
test.get("test_date")
|
| 336 |
+
or report.get("uploaded_at")
|
| 337 |
+
or datetime.now(timezone.utc)
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
# Normalize date
|
| 341 |
+
if isinstance(test_date, (int, float)):
|
| 342 |
+
# handle timestamp
|
| 343 |
+
test_date = datetime.fromtimestamp(test_date, tz=timezone.utc)
|
| 344 |
+
elif isinstance(test_date, str):
|
| 345 |
+
try:
|
| 346 |
+
test_date = datetime.fromisoformat(test_date)
|
| 347 |
+
except Exception:
|
| 348 |
+
test_date = datetime.now(timezone.utc)
|
| 349 |
+
|
| 350 |
+
point = {
|
| 351 |
+
"date": test_date,
|
| 352 |
+
"value": value,
|
| 353 |
+
"unit": unit,
|
| 354 |
+
"report_id": ObjectId(report_id),
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
# Upsert trend doc
|
| 358 |
+
result = self.trends.update_one(
|
| 359 |
+
{"patient_id": ObjectId(patient_id), "test_name": test_name},
|
| 360 |
+
{
|
| 361 |
+
"$setOnInsert": {
|
| 362 |
+
"patient_id": ObjectId(patient_id),
|
| 363 |
+
"test_name": test_name,
|
| 364 |
+
"created_at": datetime.now(timezone.utc),
|
| 365 |
+
},
|
| 366 |
+
"$push": {"trend_data": point},
|
| 367 |
+
"$set": {"last_updated": datetime.now(timezone.utc)},
|
| 368 |
+
},
|
| 369 |
+
upsert=True,
|
| 370 |
+
)
|
| 371 |
+
updated += result.modified_count
|
| 372 |
+
|
| 373 |
+
# print("updated/inserted", updated, "trends")
|
| 374 |
+
return updated
|
modules/models.py
CHANGED
|
@@ -3,6 +3,7 @@ from datetime import datetime
|
|
| 3 |
from pydantic import BaseModel, Field
|
| 4 |
from typing import List, Literal, Optional, TypedDict, Union
|
| 5 |
|
|
|
|
| 6 |
class PatientInfo(BaseModel):
|
| 7 |
name: Optional[str] = Field(None, description="Patient's full name")
|
| 8 |
age: Optional[int] = Field(None, description="Patient's age in years")
|
|
@@ -84,6 +85,7 @@ class SheamiMilestone:
|
|
| 84 |
class SheamiState(TypedDict):
|
| 85 |
user_email: str
|
| 86 |
patient_id: str
|
|
|
|
| 87 |
run_id: str
|
| 88 |
messages: list[str]
|
| 89 |
thread_id: str
|
|
|
|
| 3 |
from pydantic import BaseModel, Field
|
| 4 |
from typing import List, Literal, Optional, TypedDict, Union
|
| 5 |
|
| 6 |
+
|
| 7 |
class PatientInfo(BaseModel):
|
| 8 |
name: Optional[str] = Field(None, description="Patient's full name")
|
| 9 |
age: Optional[int] = Field(None, description="Patient's age in years")
|
|
|
|
| 85 |
class SheamiState(TypedDict):
|
| 86 |
user_email: str
|
| 87 |
patient_id: str
|
| 88 |
+
report_id_list: str
|
| 89 |
run_id: str
|
| 90 |
messages: list[str]
|
| 91 |
thread_id: str
|