sheami / db.py
vikramvasudevan's picture
Upload folder using huggingface_hub
7911979 verified
import os
from pymongo import MongoClient
from datetime import datetime
from bson import ObjectId
from dotenv import load_dotenv
class SheamiDB:
def __init__(self, uri: str, db_name: str = "sheami"):
"""Initialize connection to MongoDB Atlas (or local Mongo)."""
self.client = MongoClient(uri)
self.db = self.client[db_name]
# Collections
self.users = self.db["users"]
self.patients = self.db["patients"]
self.reports = self.db["reports"]
self.trends = self.db["trends"]
self.final_reports = self.db["final_reports"]
# ---------------------------
# USER FUNCTIONS
# ---------------------------
def add_user(self, email: str, name: str) -> str:
user = {
"email": email,
"name": name,
"created_at": datetime.utcnow()
}
result = self.users.insert_one(user)
return str(result.inserted_id)
def get_user(self, user_id: str) -> dict:
return self.users.find_one({"_id": ObjectId(user_id)})
# ---------------------------
# PATIENT FUNCTIONS
# ---------------------------
def add_patient(self, user_id: str, name: str, dob: str, gender: str) -> str:
patient = {
"user_id": ObjectId(user_id),
"name": name,
"dob": dob,
"gender": gender,
"created_at": datetime.utcnow()
}
result = self.patients.insert_one(patient)
return str(result.inserted_id)
def get_patients_by_user(self, user_id: str) -> list:
return list(self.patients.find({"user_id": ObjectId(user_id)}))
# ---------------------------
# REPORT FUNCTIONS
# ---------------------------
def add_report(self, patient_id: str, file_name: str, parsed_data: dict) -> str:
report = {
"patient_id": ObjectId(patient_id),
"uploaded_at": datetime.utcnow(),
"file_name": file_name,
"parsed_data": parsed_data
}
result = self.reports.insert_one(report)
return str(result.inserted_id)
def get_reports_by_patient(self, patient_id: str) -> list:
return list(self.reports.find({"patient_id": ObjectId(patient_id)}))
# ---------------------------
# TREND FUNCTIONS
# ---------------------------
def add_or_update_trend(self, patient_id: str, test_name: str, trend_data: list):
"""Insert new trend or update existing one."""
self.trends.update_one(
{"patient_id": ObjectId(patient_id), "test_name": test_name},
{"$set": {"trend_data": trend_data, "last_updated": datetime.utcnow()}},
upsert=True
)
def get_trends_by_patient(self, patient_id: str) -> list:
return list(self.trends.find({"patient_id": ObjectId(patient_id)}))
# ---------------------------
# FINAL REPORT FUNCTIONS
# ---------------------------
def add_final_report(self, patient_id: str, summary: str, recommendations: list, trend_snapshots: list) -> str:
final_report = {
"patient_id": ObjectId(patient_id),
"generated_at": datetime.utcnow(),
"summary": summary,
"recommendations": recommendations,
"trend_snapshots": trend_snapshots
}
result = self.final_reports.insert_one(final_report)
return str(result.inserted_id)
def get_final_reports_by_patient(self, patient_id: str) -> list:
return list(self.final_reports.find({"patient_id": ObjectId(patient_id)}))
# ---------------------------
# Example usage
# ---------------------------
if __name__ == "__main__":
load_dotenv(override=True)
db = SheamiDB(os.getenv("DB_URI"))
# Add user
user_id = db.add_user("doctor1@sheami.com", "Dr. Smith")
# Add patient
patient_id = db.add_patient(user_id, "John Doe", "1980-05-20", "male")
# Add report
parsed_data = {
"tests": [
{"name": "Hemoglobin", "value": 13.5, "unit": "g/dL", "reference_range": "13.0-17.0"},
{"name": "Cholesterol", "value": 210, "unit": "mg/dL", "reference_range": "<200"}
]
}
report_id = db.add_report(patient_id, "bloodwork_july.pdf", parsed_data)
# Add trend
db.add_or_update_trend(patient_id, "Hemoglobin", [
{"date": "2025-05-01", "value": 13.2},
{"date": "2025-07-01", "value": 13.5},
{"date": "2025-08-19", "value": 13.8}
])
# Add final report
final_report_id = db.add_final_report(
patient_id,
"Hemoglobin stable, cholesterol slightly high.",
["Maintain healthy diet", "Check cholesterol in 3 months"],
[
{"test_name": "Hemoglobin", "latest_value": 13.8, "direction": "stable"},
{"test_name": "Cholesterol", "latest_value": 210, "direction": "increasing"}
]
)
print("User ID:", user_id)
print("Patient ID:", patient_id)
print("Report ID:", report_id)
print("Final Report ID:", final_report_id)