from datetime import timezone, date import os import random from typing import Any from datetime import datetime, timezone from bson import ObjectId from dotenv import load_dotenv from modules.models import StandardizedReport from motor.motor_asyncio import AsyncIOMotorGridFSBucket from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorGridFSBucket import os from dotenv import load_dotenv import pandas as pd class SheamiDB: def __init__(self, uri: str = None, db_name: str = "sheami"): if not uri: load_dotenv(override=True) uri = os.getenv("MONGODB_URI") # Use Motor's AsyncIOMotorClient instead of PyMongo's AsyncMongoClient self.client = AsyncIOMotorClient(uri) self.db = self.client[db_name] # Collections self.users = self.db["users"] self.patients = self.db["patients"] self.reports = self.db["reports"] self.trends = self.db["trends"] self.final_reports = self.db["final_reports"] self.run_stats = self.db["run_stats"] self.vitals = self.db["vitals"] # Motor's GridFSBucket requires a MotorDatabase self.fs = AsyncIOMotorGridFSBucket(self.db) # --------------------------- # USER FUNCTIONS # --------------------------- async def add_user(self, email: str, name: str) -> str: user = {"email": email, "name": name, "created_at": datetime.now(timezone.utc)} result = await self.users.insert_one(user) return str(result.inserted_id) async def get_user(self, user_id: str) -> dict: user = await self.users.find_one({"_id": ObjectId(user_id)}) return user def convert_dob(self, dob: int | float | datetime): if isinstance(dob, datetime): transformed_dob = dob # PyMongo will store it as ISODate automatically elif isinstance(dob, (int, float)): transformed_dob = datetime.fromtimestamp(dob) else: transformed_dob = dob return transformed_dob # --------------------------- # PATIENT FUNCTIONS # --------------------------- async def add_patient( self, user_id: str, name: str, dob: int | float | datetime, gender: str ) -> str: # Ensure DOB is stored as MongoDB Date transformed_dob = self.convert_dob(dob) patient = { "user_id": ObjectId(user_id), "name": name, "dob": transformed_dob, "gender": gender, "created_at": datetime.now(timezone.utc), } result = await self.patients.insert_one(patient) return str(result.inserted_id) async def get_patient_by_id( self, patient_id: str, fields: list[str] = [], serializable: bool = False ) -> Any | None: patient = await self.patients.find_one({"_id": ObjectId(patient_id)}) if fields: patient = {key: patient[key] for key in fields if key in patient} if serializable: patient = self.convert_to_serializable_data(data=patient) return patient async def get_patients_by_user(self, user_id: str) -> list: cursor = self.patients.find({"user_id": ObjectId(user_id)}).sort("name") # all the locked records are example records example_cursor = self.patients.find( {"user_id": ObjectId("68a40aa2208e5689f7342a6e"), "locked": True} ).sort("name") matching_patients = await cursor.to_list( length=None ) # length=None returns all documents example_patients = await example_cursor.to_list(length=None) sorted_list = sorted( example_patients + matching_patients, key=lambda x: x["name"] ) return sorted_list # --------------------------- # REPORT FUNCTIONS # --------------------------- async def add_report_v2( self, patient_id: str, reports: list[StandardizedReport], run_id: str ) -> str: inserted_ids: list[ObjectId] = [] for parsed_data in reports: report = { "patient_id": ObjectId(patient_id), "uploaded_at": datetime.now(timezone.utc), "file_name": parsed_data.original_report_file_name, "parsed_data_v2": parsed_data.model_dump(), "run_id": ObjectId(run_id), } result = await self.reports.insert_one(report) inserted_ids.append(result.inserted_id) return ",".join([str(inserted_id) for inserted_id in inserted_ids]) async def add_report( self, patient_id: str, file_name: str, parsed_data: any ) -> str: report = { "patient_id": ObjectId(patient_id), "uploaded_at": datetime.now(timezone.utc), "file_name": file_name, "parsed_data": parsed_data, } result = await self.reports.insert_one(report) return str(result.inserted_id) async def get_reports_by_patient(self, patient_id: str) -> list: reports_cursor = self.reports.find({"patient_id": ObjectId(patient_id)}).sort( "_id", -1 ) reports = await reports_cursor.to_list( length=None ) # Fetch all sorted documents as a list return reports # --------------------------- # TREND FUNCTIONS # --------------------------- async def add_or_update_trend( self, patient_id: str, test_name: str, trend_data: list ): """Insert new trend or update existing one.""" await self.trends.update_one( {"patient_id": ObjectId(patient_id), "test_name": test_name}, { "$set": { "trend_data": trend_data, "last_updated": datetime.now(timezone.utc), } }, upsert=True, ) async def get_trends_by_patient( self, patient_id: str, fields: list[str] = None, serializable=False ) -> list: cursor = self.trends.find({"patient_id": ObjectId(patient_id)}).sort( "test_name" ) trends = await cursor.to_list(length=None) if fields: trends = [ {field: trend[field] for field in fields if field in trend} for trend in trends ] if serializable: trends = self.convert_to_serializable_data(data=trends) return trends # --------------------------- # FINAL REPORT FUNCTIONS # --------------------------- async def add_final_report( self, patient_id: str, summary: str, recommendations: list, trend_snapshots: list, ) -> str: final_report = { "patient_id": ObjectId(patient_id), "generated_at": datetime.now(timezone.utc), "summary": summary, "recommendations": recommendations, "trend_snapshots": trend_snapshots, } result = await self.final_reports.insert_one(final_report) return str(result.inserted_id) async def get_final_reports_by_patient(self, patient_id: str) -> list: cursor = self.final_reports.find({"patient_id": ObjectId(patient_id)}).sort( "_id", -1 ) final_reports = await cursor.to_list(length=None) return final_reports # --------------------------- # FETCH FULL USER DATA # --------------------------- async def get_user_by_email(self, email: str) -> dict: """Fetch user by email.""" user = await self.users.find_one({"email": email}) return user async def get_user_full_data(self, user_id: str) -> dict: """ Fetch user + all patients, reports, trends, final reports for populating UI (tabbed layout). """ user = await self.get_user(user_id) if not user: return {} # Get patients for user patients = await self.get_patients_by_user(user_id) full_patients = [] for patient in patients: pid = str(patient["_id"]) # Fetch related collections patient_reports = await self.get_reports_by_patient(pid) patient_trends = await self.get_trends_by_patient(pid) patient_final_reports = await self.get_final_reports_by_patient(pid) full_patients.append( { "patient": patient, "reports": patient_reports, "trends": patient_trends, "final_reports": patient_final_reports, } ) return {"user": user, "patients": full_patients} async def update_patient(self, patient_id, fields: dict): fields["dob"] = self.convert_dob(fields["dob"]) result = await self.patients.update_one( {"_id": ObjectId(patient_id)}, {"$set": fields} ) return result.modified_count > 0 async def delete_patient(self, patient_id: str): try: yield "⌛Deleting patient PDFs ... " deleted_count = await self.delete_pdfs_by_patient_id(patient_id) yield f"✅Deleted {deleted_count} patient PDFs ... " yield "⌛Deleting patient reports ... " result = await self.reports.delete_one({"patient_id": ObjectId(patient_id)}) yield f"✅Deleted {result.deleted_count} patient reports ... " yield "⌛Deleting patient trends ... " result = await self.trends.delete_one({"patient_id": ObjectId(patient_id)}) yield f"✅Deleted {result.deleted_count} patient trends ... " yield "⌛Deleting patient final reports ... " result = await self.final_reports.delete_one( {"patient_id": ObjectId(patient_id)} ) yield f"✅Deleted {result.deleted_count} patient final reports ... " yield "⌛Deleting patient run stats ... " result = await self.run_stats.delete_one( {"patient_id": ObjectId(patient_id)} ) yield f"✅Deleted {result.deleted_count} patient run stats ... " yield "⌛Deleting patient ... " result = await self.patients.delete_one({"_id": ObjectId(patient_id)}) yield f"✅Deleted {result.deleted_count} patient ... " except Exception as e: print(f"Error deleting patient: {e}") yield f"❌ Error deleting patient {e}" return async def start_run( self, user_email: str, patient_id: str, source_file_names: list[str], source_file_contents: list[str], action: str = "upload test reports", ): print("Getting details for user:", user_email, " and patient:", patient_id) user = await self.get_user_by_email(user_email) if not user: raise Exception(f"User {user_email} not found!") user_id = user.get("_id") if not user_id: raise Exception(f"User {user_email} not found!") if not patient_id: raise Exception("Patient not found!") if not source_file_names: raise Exception("No source files to process!") stat_entry = { "user_id": ObjectId(user_id), "patient_id": ObjectId(patient_id), "source_file_names": source_file_names, "source_file_contents": source_file_contents, "action": action, "status": "inprogress", "steps_completed": 0, "steps_total": 5, # max_steps "milestones": [ { "milestone": "Initializing", "start_timestamp": datetime.now(timezone.utc), "end_timestamp": datetime.now(timezone.utc), "status": "completed", } ], "created_at": datetime.now(timezone.utc), } result = await self.run_stats.insert_one(stat_entry) return str(result.inserted_id) async def update_run_stats(self, run_id: str, **kwargs): """ Update top-level fields in a run's stats document. Example usage: update_run_stats(run_id, steps_completed=2, steps_total=5) """ update_fields = {} for key, value in kwargs.items(): if key in [ "steps_completed", "steps_total", "action", "source_file_names", "source_file_contents", "status", "message", ]: update_fields[key] = value if not update_fields: raise ValueError("No valid run-level fields to update") result = await self.run_stats.update_one( {"_id": ObjectId(run_id)}, {"$set": update_fields} ) return result.modified_count async def add_or_update_milestone( self, run_id: str, milestone: str, status: str = None, end: bool = False ): """ Add a new milestone or update an existing one. - If milestone doesn't exist, pushes a new entry with start_timestamp. - If exists, updates status or end_timestamp. - If milestone completes (end=True), also increments steps_completed at run level. """ run = await self.run_stats.find_one({"_id": ObjectId(run_id)}) if not run: raise Exception(f"Run {run_id} not found") milestones = run.get("milestones", []) existing = next((m for m in milestones if m["milestone"] == milestone), None) if not existing: # add new milestone new_milestone = { "milestone": milestone, "start_timestamp": datetime.now(timezone.utc), "end_timestamp": None, "status": status or "inprogress", } result = await self.run_stats.update_one( {"_id": ObjectId(run_id)}, {"$push": {"milestones": new_milestone}} ) else: # update existing milestone updates = {} if status: updates["milestones.$.status"] = status if end: updates["milestones.$.end_timestamp"] = datetime.now(timezone.utc) if not updates: raise ValueError("Nothing to update in milestone") update_ops = {"$set": updates} if end: # also increment steps_completed at run level update_ops["$inc"] = {"steps_completed": 1} result = await self.run_stats.update_one( {"_id": ObjectId(run_id), "milestones.milestone": milestone}, update_ops, ) return result.modified_count async def aggregate_trends_from_report(self, patient_id: str, report_id: str): """ Incrementally update patient trends based on a new report's tests. - Fetches the report - For each test, appends (date, value, unit) to trends[patient_id, test_name] - Ensures no duplicate points for same report/test combo """ report = await self.reports.find_one({"_id": ObjectId(report_id)}) if not report: raise ValueError(f"Report {report_id} not found") # print("report = ",report) tests = report.get("parsed_data_v2", {"lab_results": []}).get("lab_results", []) if not tests: return 0 updated = 0 async def add_or_update_trend_data_point(test): test_name = test.get("test_name") value = test.get("result_value") unit = test.get("test_unit") test_date = test.get("test_date") or datetime.now(timezone.utc) test_reference_range = test.get("test_reference_range") inferred_range = test.get("inferred_range") # Normalize test_date (keep your existing normalization here)... if isinstance(test_date, (int, float)): test_date = datetime.fromtimestamp(test_date, tz=timezone.utc) elif isinstance(test_date, str): try: test_date = datetime.fromisoformat(test_date) except Exception: test_date = datetime.now(timezone.utc) point = { "date": test_date, "value": value, "unit": unit, "report_id": ObjectId(report_id), } # Step 1: Check if trend_data with same date exists existing_doc = await self.trends.find_one( { "patient_id": ObjectId(patient_id), "test_name": test_name, "trend_data.date": test_date, }, projection={"trend_data.$": 1}, # Project only matched array element ) if existing_doc: # Step 2: Update the existing trend_data array element with new data result = await self.trends.update_one( { "patient_id": ObjectId(patient_id), "test_name": test_name, "trend_data.date": test_date, }, { "$set": { "trend_data.$.value": value, "trend_data.$.unit": unit, "trend_data.$.report_id": ObjectId(report_id), "last_updated": datetime.now(timezone.utc), "test_reference_range": test_reference_range, "inferred_range": inferred_range, }, "$setOnInsert": { "patient_id": ObjectId(patient_id), "test_name": test_name, "created_at": datetime.now(timezone.utc), }, }, ) else: # Step 3: Insert new point as it does not exist yet result = await self.trends.update_one( {"patient_id": ObjectId(patient_id), "test_name": test_name}, { "$setOnInsert": { "patient_id": ObjectId(patient_id), "test_name": test_name, "created_at": datetime.now(timezone.utc), }, "$push": {"trend_data": point}, "$set": { "last_updated": datetime.now(timezone.utc), "test_reference_range": test_reference_range, "inferred_range": inferred_range, "test_reference_range": test_reference_range, "inferred_range": inferred_range, }, }, upsert=True, ) return result for test in tests: test_name = test.get("test_name") if not test_name: sub_results = test.get("sub_results", []) if not sub_results: continue for sub_result in sub_results: test_name = sub_result.get("test_name") db_output = await add_or_update_trend_data_point(sub_result) updated += db_output.modified_count continue else: db_output = await add_or_update_trend_data_point(test) updated += db_output.modified_count # print("updated/inserted", updated, "trends") return updated async def aggregate_trends_snapshot(self, patient_id: str): # fetch trends for this patient trend_docs = await self.get_trends_by_patient(patient_id) # extract "snapshots" snapshots = [] for t in trend_docs: td = t.get("trend_data", []) if not td: continue # last point = most recent measurement last_point = max(td, key=lambda x: x.get("date")) snapshots.append( { "test_name": t.get("test_name", ""), "latest_value": last_point.get("value"), "latest_date": last_point.get("date"), "unit": last_point.get("unit", ""), "reference_range": last_point.get("reference_range", ""), } ) return snapshots async def upload_bytes_to_fs(self, data: bytes, filename: str, patient_id): # Open an upload stream upload_stream = self.fs.open_upload_stream( filename, metadata={ "patient_id": patient_id, "uploaded_at": datetime.now(timezone.utc), }, ) # Write data to stream await upload_stream.write(data) await upload_stream.close() # The file ID is in upload_stream._id return upload_stream._id # --------------------------- # FINAL REPORT FUNCTIONS # --------------------------- async def add_final_report_v2( self, patient_id: str, summary: str, pdf_bytes: bytes = None, file_name: str = None, ) -> str: """ Insert a final report. If pdf_bytes is provided, stores it in GridFS and saves file_id in metadata. """ pdf_file_id = None if pdf_bytes: pdf_file_id = await self.upload_bytes_to_fs( data=pdf_bytes, filename=file_name, patient_id=ObjectId(patient_id) ) final_report = { "patient_id": ObjectId(patient_id), "generated_at": datetime.now(timezone.utc), "summary": summary, "trend_snapshots": await self.aggregate_trends_snapshot( patient_id=patient_id ), "pdf_file_id": pdf_file_id, # Reference to GridFS file } result = await self.final_reports.insert_one(final_report) return str(result.inserted_id) async def get_final_report_pdf(self, report_id: str) -> bytes | None: """ Fetch the PDF bytes for a given final_report (if stored). """ doc = await self.final_reports.find_one({"_id": ObjectId(report_id)}) if not doc or not doc.get("pdf_file_id"): return None grid_out = await self.fs.open_download_stream(doc["pdf_file_id"]) # Read the file data in chunks file_data = b"" while True: chunk = await grid_out.read(1024) # Read 1024 bytes at a time if not chunk: break file_data += chunk # await grid_out.close() return file_data async def get_final_report_html(self, report_id: str) -> bytes | None: """ Fetch the HTML for a given final_report (if stored). """ doc = await self.final_reports.find_one({"_id": ObjectId(report_id)}) if not doc or not doc.get("summary"): return "" # empty tag return doc.get("summary") async def get_run_stats_by_patient(self, patient_id: str) -> list: cursor = self.run_stats.find({"patient_id": ObjectId(patient_id)}).sort( "created_at", -1 ) run_stats = await cursor.to_list(length=None) return run_stats async def get_run_stats_by_id(self, id: str): run_stat = await self.run_stats.find_one({"_id": ObjectId(id)}) return run_stat async def delete_pdfs_by_patient_id(self, patient_id: str) -> int: # Find all files with the specified patient_id in metadata cursor = self.db.fs.files.find( {"metadata.patient_id": ObjectId(patient_id)}, projection={"_id": 1} ) deleted_count = 0 async for file_doc in cursor: file_id = file_doc["_id"] await self.fs.delete(file_id) deleted_count += 1 return deleted_count def convert_to_serializable_data(self, data): """ Recursively converts MongoDB-specific types to JSON serializable formats. - ObjectId to string - datetime to ISO 8601 string Handles dict, list, and basic types. """ if isinstance(data, dict): return {k: self.convert_to_serializable_data(v) for k, v in data.items()} elif isinstance(data, list): return [self.convert_to_serializable_data(i) for i in data] elif isinstance(data, ObjectId): return str(data) elif isinstance(data, datetime): return data.isoformat() else: return data def normalize_to_date(self, reading_date): """ Convert a datetime or date to a datetime at 00:00:00 """ if isinstance(reading_date, date) and not isinstance(reading_date, datetime): # convert date to datetime at midnight return datetime(reading_date.year, reading_date.month, reading_date.day) elif isinstance(reading_date, datetime): # strip time part return datetime(reading_date.year, reading_date.month, reading_date.day) else: raise ValueError("reading_date must be a date or datetime") async def save_readings_to_db( self, patient_id: str, reading_date, new_readings: list, created_by: str ): """ Async save/merge readings for a patient/date using Motor. """ reading_date = self.normalize_to_date(reading_date) # Find existing document doc = await self.vitals.find_one( { "patient_id": ObjectId(patient_id), "date": reading_date, } ) if not doc: # Insert new await self.vitals.insert_one( { "patient_id": ObjectId(patient_id), "date": reading_date, "readings": new_readings, "created_by": created_by, } ) else: # Merge: update existing readings by 'name', append new ones existing = doc.get("readings", []) merged = existing.copy() existing_names = {r["name"]: r for r in existing} for r in new_readings: if r["name"] in existing_names: # Update existing entry existing_names[r["name"]].update(r) else: # Append new entry merged.append(r) # For entries that were updated in-place, ensure they are in merged merged_names = {r["name"] for r in merged} for name, r in existing_names.items(): if name not in merged_names: merged.append(r) await self.vitals.update_one( {"_id": doc["_id"]}, {"$set": {"readings": merged, "created_by": created_by}}, ) async def get_vitals_by_patient( self, patient_id: str, fields: list[str] = None, serializable=False ) -> list: cursor = self.vitals.find({"patient_id": ObjectId(patient_id)}).sort("date", -1) vitals = await cursor.to_list(length=None) if fields: vitals = [ {field: vital[field] for field in fields if field in vital} for vital in vitals ] if serializable: vitals = self.convert_to_serializable_data(data=vitals) # print("vitals = ",vitals) return vitals async def get_latest_vitals_by_patient(self, patient_id: str) -> dict: # sort by date descending and get the first record cursor = self.vitals.find({"patient_id": ObjectId(patient_id)}).sort("date", -1) vitals = await cursor.to_list(length=None) if len(vitals) > 0: vitals = vitals[0] else: vitals = {} return vitals async def get_vitals_trends_by_patient(self, patient_id): docs = await self.vitals.aggregate([ {"$match": {"patient_id": ObjectId(patient_id)}}, {"$unwind": "$readings"}, {"$project": { "date": 1, "name": "$readings.name", "value": "$readings.value", "unit": "$readings.unit" }}, {"$sort": {"date": 1}} ]).to_list(length=None) df = pd.DataFrame(docs) trend_docs = [] if "name" in df: for vital_name, group in df.groupby("name"): if vital_name.lower() in ["bp", "blood pressure"]: # Split systolic/diastolic systolic, diastolic = [], [] for _, row in group.iterrows(): try: sys, dia = row["value"].split("/") systolic.append({"date": row["date"], "value": int(sys)}) diastolic.append({"date": row["date"], "value": int(dia)}) except Exception: continue # Append as two separate test series trend_docs.append({ "test_name": "BP - Systolic", "trend_data": systolic, "unit": "mmHg", "test_reference_range": {} }) trend_docs.append({ "test_name": "BP - Diastolic", "trend_data": diastolic, "unit": "mmHg", "test_reference_range": {} }) else: # Other vitals trend_docs.append({ "test_name": vital_name, "trend_data": group[["date", "value"]].to_dict("records"), "unit": group["unit"].iloc[0], "test_reference_range": {} }) # 🔑 Sort by test_name so BP series appear together trend_docs = sorted(trend_docs, key=lambda x: x["test_name"]) return trend_docs