import os from pymongo import MongoClient from datetime import datetime from bson import ObjectId from dotenv import load_dotenv class SheamiDB: def __init__(self, uri: str, db_name: str = "sheami"): """Initialize connection to MongoDB Atlas (or local Mongo).""" self.client = MongoClient(uri) self.db = self.client[db_name] # Collections self.users = self.db["users"] self.patients = self.db["patients"] self.reports = self.db["reports"] self.trends = self.db["trends"] self.final_reports = self.db["final_reports"] # --------------------------- # USER FUNCTIONS # --------------------------- def add_user(self, email: str, name: str) -> str: user = { "email": email, "name": name, "created_at": datetime.utcnow() } result = self.users.insert_one(user) return str(result.inserted_id) def get_user(self, user_id: str) -> dict: return self.users.find_one({"_id": ObjectId(user_id)}) # --------------------------- # PATIENT FUNCTIONS # --------------------------- def add_patient(self, user_id: str, name: str, dob: str, gender: str) -> str: patient = { "user_id": ObjectId(user_id), "name": name, "dob": dob, "gender": gender, "created_at": datetime.utcnow() } result = self.patients.insert_one(patient) return str(result.inserted_id) def get_patients_by_user(self, user_id: str) -> list: return list(self.patients.find({"user_id": ObjectId(user_id)})) # --------------------------- # REPORT FUNCTIONS # --------------------------- def add_report(self, patient_id: str, file_name: str, parsed_data: dict) -> str: report = { "patient_id": ObjectId(patient_id), "uploaded_at": datetime.utcnow(), "file_name": file_name, "parsed_data": parsed_data } result = self.reports.insert_one(report) return str(result.inserted_id) def get_reports_by_patient(self, patient_id: str) -> list: return list(self.reports.find({"patient_id": ObjectId(patient_id)})) # --------------------------- # TREND FUNCTIONS # --------------------------- def add_or_update_trend(self, patient_id: str, test_name: str, trend_data: list): """Insert new trend or update existing one.""" self.trends.update_one( {"patient_id": ObjectId(patient_id), "test_name": test_name}, {"$set": {"trend_data": trend_data, "last_updated": datetime.utcnow()}}, upsert=True ) def get_trends_by_patient(self, patient_id: str) -> list: return list(self.trends.find({"patient_id": ObjectId(patient_id)})) # --------------------------- # FINAL REPORT FUNCTIONS # --------------------------- def add_final_report(self, patient_id: str, summary: str, recommendations: list, trend_snapshots: list) -> str: final_report = { "patient_id": ObjectId(patient_id), "generated_at": datetime.utcnow(), "summary": summary, "recommendations": recommendations, "trend_snapshots": trend_snapshots } result = self.final_reports.insert_one(final_report) return str(result.inserted_id) def get_final_reports_by_patient(self, patient_id: str) -> list: return list(self.final_reports.find({"patient_id": ObjectId(patient_id)})) # --------------------------- # Example usage # --------------------------- if __name__ == "__main__": load_dotenv(override=True) db = SheamiDB(os.getenv("DB_URI")) # Add user user_id = db.add_user("doctor1@sheami.com", "Dr. Smith") # Add patient patient_id = db.add_patient(user_id, "John Doe", "1980-05-20", "male") # Add report parsed_data = { "tests": [ {"name": "Hemoglobin", "value": 13.5, "unit": "g/dL", "reference_range": "13.0-17.0"}, {"name": "Cholesterol", "value": 210, "unit": "mg/dL", "reference_range": "<200"} ] } report_id = db.add_report(patient_id, "bloodwork_july.pdf", parsed_data) # Add trend db.add_or_update_trend(patient_id, "Hemoglobin", [ {"date": "2025-05-01", "value": 13.2}, {"date": "2025-07-01", "value": 13.5}, {"date": "2025-08-19", "value": 13.8} ]) # Add final report final_report_id = db.add_final_report( patient_id, "Hemoglobin stable, cholesterol slightly high.", ["Maintain healthy diet", "Check cholesterol in 3 months"], [ {"test_name": "Hemoglobin", "latest_value": 13.8, "direction": "stable"}, {"test_name": "Cholesterol", "latest_value": 210, "direction": "increasing"} ] ) print("User ID:", user_id) print("Patient ID:", patient_id) print("Report ID:", report_id) print("Final Report ID:", final_report_id)