sheami / modules /db.py
vikramvasudevan's picture
Upload folder using huggingface_hub
db5222a verified
from datetime import timezone, date
import os
import random
from typing import Any
from datetime import datetime, timezone
from bson import ObjectId
from dotenv import load_dotenv
from modules.models import StandardizedReport
from motor.motor_asyncio import AsyncIOMotorGridFSBucket
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorGridFSBucket
import os
from dotenv import load_dotenv
import pandas as pd
class SheamiDB:
def __init__(self, uri: str = None, db_name: str = "sheami"):
if not uri:
load_dotenv(override=True)
uri = os.getenv("MONGODB_URI")
# Use Motor's AsyncIOMotorClient instead of PyMongo's AsyncMongoClient
self.client = AsyncIOMotorClient(uri)
self.db = self.client[db_name]
# Collections
self.users = self.db["users"]
self.patients = self.db["patients"]
self.reports = self.db["reports"]
self.trends = self.db["trends"]
self.final_reports = self.db["final_reports"]
self.run_stats = self.db["run_stats"]
self.vitals = self.db["vitals"]
# Motor's GridFSBucket requires a MotorDatabase
self.fs = AsyncIOMotorGridFSBucket(self.db)
# ---------------------------
# USER FUNCTIONS
# ---------------------------
async def add_user(self, email: str, name: str) -> str:
user = {"email": email, "name": name, "created_at": datetime.now(timezone.utc)}
result = await self.users.insert_one(user)
return str(result.inserted_id)
async def get_user(self, user_id: str) -> dict:
user = await self.users.find_one({"_id": ObjectId(user_id)})
return user
def convert_dob(self, dob: int | float | datetime):
if isinstance(dob, datetime):
transformed_dob = dob # PyMongo will store it as ISODate automatically
elif isinstance(dob, (int, float)):
transformed_dob = datetime.fromtimestamp(dob)
else:
transformed_dob = dob
return transformed_dob
# ---------------------------
# PATIENT FUNCTIONS
# ---------------------------
async def add_patient(
self, user_id: str, name: str, dob: int | float | datetime, gender: str
) -> str:
# Ensure DOB is stored as MongoDB Date
transformed_dob = self.convert_dob(dob)
patient = {
"user_id": ObjectId(user_id),
"name": name,
"dob": transformed_dob,
"gender": gender,
"created_at": datetime.now(timezone.utc),
}
result = await self.patients.insert_one(patient)
return str(result.inserted_id)
async def get_patient_by_id(
self, patient_id: str, fields: list[str] = [], serializable: bool = False
) -> Any | None:
patient = await self.patients.find_one({"_id": ObjectId(patient_id)})
if fields:
patient = {key: patient[key] for key in fields if key in patient}
if serializable:
patient = self.convert_to_serializable_data(data=patient)
return patient
async def get_patients_by_user(self, user_id: str) -> list:
cursor = self.patients.find({"user_id": ObjectId(user_id)}).sort("name")
# all the locked records are example records
example_cursor = self.patients.find(
{"user_id": ObjectId("68a40aa2208e5689f7342a6e"), "locked": True}
).sort("name")
matching_patients = await cursor.to_list(
length=None
) # length=None returns all documents
example_patients = await example_cursor.to_list(length=None)
sorted_list = sorted(
example_patients + matching_patients, key=lambda x: x["name"]
)
return sorted_list
# ---------------------------
# REPORT FUNCTIONS
# ---------------------------
async def add_report_v2(
self, patient_id: str, reports: list[StandardizedReport], run_id: str
) -> str:
inserted_ids: list[ObjectId] = []
for parsed_data in reports:
report = {
"patient_id": ObjectId(patient_id),
"uploaded_at": datetime.now(timezone.utc),
"file_name": parsed_data.original_report_file_name,
"parsed_data_v2": parsed_data.model_dump(),
"run_id": ObjectId(run_id),
}
result = await self.reports.insert_one(report)
inserted_ids.append(result.inserted_id)
return ",".join([str(inserted_id) for inserted_id in inserted_ids])
async def add_report(
self, patient_id: str, file_name: str, parsed_data: any
) -> str:
report = {
"patient_id": ObjectId(patient_id),
"uploaded_at": datetime.now(timezone.utc),
"file_name": file_name,
"parsed_data": parsed_data,
}
result = await self.reports.insert_one(report)
return str(result.inserted_id)
async def get_reports_by_patient(self, patient_id: str) -> list:
reports_cursor = self.reports.find({"patient_id": ObjectId(patient_id)}).sort(
"_id", -1
)
reports = await reports_cursor.to_list(
length=None
) # Fetch all sorted documents as a list
return reports
# ---------------------------
# TREND FUNCTIONS
# ---------------------------
async def add_or_update_trend(
self, patient_id: str, test_name: str, trend_data: list
):
"""Insert new trend or update existing one."""
await self.trends.update_one(
{"patient_id": ObjectId(patient_id), "test_name": test_name},
{
"$set": {
"trend_data": trend_data,
"last_updated": datetime.now(timezone.utc),
}
},
upsert=True,
)
async def get_trends_by_patient(
self, patient_id: str, fields: list[str] = None, serializable=False
) -> list:
cursor = self.trends.find({"patient_id": ObjectId(patient_id)}).sort(
"test_name"
)
trends = await cursor.to_list(length=None)
if fields:
trends = [
{field: trend[field] for field in fields if field in trend}
for trend in trends
]
if serializable:
trends = self.convert_to_serializable_data(data=trends)
return trends
# ---------------------------
# FINAL REPORT FUNCTIONS
# ---------------------------
async def add_final_report(
self,
patient_id: str,
summary: str,
recommendations: list,
trend_snapshots: list,
) -> str:
final_report = {
"patient_id": ObjectId(patient_id),
"generated_at": datetime.now(timezone.utc),
"summary": summary,
"recommendations": recommendations,
"trend_snapshots": trend_snapshots,
}
result = await self.final_reports.insert_one(final_report)
return str(result.inserted_id)
async def get_final_reports_by_patient(self, patient_id: str) -> list:
cursor = self.final_reports.find({"patient_id": ObjectId(patient_id)}).sort(
"_id", -1
)
final_reports = await cursor.to_list(length=None)
return final_reports
# ---------------------------
# FETCH FULL USER DATA
# ---------------------------
async def get_user_by_email(self, email: str) -> dict:
"""Fetch user by email."""
user = await self.users.find_one({"email": email})
return user
async def get_user_full_data(self, user_id: str) -> dict:
"""
Fetch user + all patients, reports, trends, final reports
for populating UI (tabbed layout).
"""
user = await self.get_user(user_id)
if not user:
return {}
# Get patients for user
patients = await self.get_patients_by_user(user_id)
full_patients = []
for patient in patients:
pid = str(patient["_id"])
# Fetch related collections
patient_reports = await self.get_reports_by_patient(pid)
patient_trends = await self.get_trends_by_patient(pid)
patient_final_reports = await self.get_final_reports_by_patient(pid)
full_patients.append(
{
"patient": patient,
"reports": patient_reports,
"trends": patient_trends,
"final_reports": patient_final_reports,
}
)
return {"user": user, "patients": full_patients}
async def update_patient(self, patient_id, fields: dict):
fields["dob"] = self.convert_dob(fields["dob"])
result = await self.patients.update_one(
{"_id": ObjectId(patient_id)}, {"$set": fields}
)
return result.modified_count > 0
async def delete_patient(self, patient_id: str):
try:
yield "⌛Deleting patient PDFs ... "
deleted_count = await self.delete_pdfs_by_patient_id(patient_id)
yield f"✅Deleted {deleted_count} patient PDFs ... "
yield "⌛Deleting patient reports ... "
result = await self.reports.delete_one({"patient_id": ObjectId(patient_id)})
yield f"✅Deleted {result.deleted_count} patient reports ... "
yield "⌛Deleting patient trends ... "
result = await self.trends.delete_one({"patient_id": ObjectId(patient_id)})
yield f"✅Deleted {result.deleted_count} patient trends ... "
yield "⌛Deleting patient final reports ... "
result = await self.final_reports.delete_one(
{"patient_id": ObjectId(patient_id)}
)
yield f"✅Deleted {result.deleted_count} patient final reports ... "
yield "⌛Deleting patient run stats ... "
result = await self.run_stats.delete_one(
{"patient_id": ObjectId(patient_id)}
)
yield f"✅Deleted {result.deleted_count} patient run stats ... "
yield "⌛Deleting patient ... "
result = await self.patients.delete_one({"_id": ObjectId(patient_id)})
yield f"✅Deleted {result.deleted_count} patient ... "
except Exception as e:
print(f"Error deleting patient: {e}")
yield f"❌ Error deleting patient {e}"
return
async def start_run(
self,
user_email: str,
patient_id: str,
source_file_names: list[str],
source_file_contents: list[str],
action: str = "upload test reports",
):
print("Getting details for user:", user_email, " and patient:", patient_id)
user = await self.get_user_by_email(user_email)
if not user:
raise Exception(f"User {user_email} not found!")
user_id = user.get("_id")
if not user_id:
raise Exception(f"User {user_email} not found!")
if not patient_id:
raise Exception("Patient not found!")
if not source_file_names:
raise Exception("No source files to process!")
stat_entry = {
"user_id": ObjectId(user_id),
"patient_id": ObjectId(patient_id),
"source_file_names": source_file_names,
"source_file_contents": source_file_contents,
"action": action,
"status": "inprogress",
"steps_completed": 0,
"steps_total": 5, # max_steps
"milestones": [
{
"milestone": "Initializing",
"start_timestamp": datetime.now(timezone.utc),
"end_timestamp": datetime.now(timezone.utc),
"status": "completed",
}
],
"created_at": datetime.now(timezone.utc),
}
result = await self.run_stats.insert_one(stat_entry)
return str(result.inserted_id)
async def update_run_stats(self, run_id: str, **kwargs):
"""
Update top-level fields in a run's stats document.
Example usage:
update_run_stats(run_id, steps_completed=2, steps_total=5)
"""
update_fields = {}
for key, value in kwargs.items():
if key in [
"steps_completed",
"steps_total",
"action",
"source_file_names",
"source_file_contents",
"status",
"message",
]:
update_fields[key] = value
if not update_fields:
raise ValueError("No valid run-level fields to update")
result = await self.run_stats.update_one(
{"_id": ObjectId(run_id)}, {"$set": update_fields}
)
return result.modified_count
async def add_or_update_milestone(
self, run_id: str, milestone: str, status: str = None, end: bool = False
):
"""
Add a new milestone or update an existing one.
- If milestone doesn't exist, pushes a new entry with start_timestamp.
- If exists, updates status or end_timestamp.
- If milestone completes (end=True), also increments steps_completed at run level.
"""
run = await self.run_stats.find_one({"_id": ObjectId(run_id)})
if not run:
raise Exception(f"Run {run_id} not found")
milestones = run.get("milestones", [])
existing = next((m for m in milestones if m["milestone"] == milestone), None)
if not existing:
# add new milestone
new_milestone = {
"milestone": milestone,
"start_timestamp": datetime.now(timezone.utc),
"end_timestamp": None,
"status": status or "inprogress",
}
result = await self.run_stats.update_one(
{"_id": ObjectId(run_id)}, {"$push": {"milestones": new_milestone}}
)
else:
# update existing milestone
updates = {}
if status:
updates["milestones.$.status"] = status
if end:
updates["milestones.$.end_timestamp"] = datetime.now(timezone.utc)
if not updates:
raise ValueError("Nothing to update in milestone")
update_ops = {"$set": updates}
if end:
# also increment steps_completed at run level
update_ops["$inc"] = {"steps_completed": 1}
result = await self.run_stats.update_one(
{"_id": ObjectId(run_id), "milestones.milestone": milestone},
update_ops,
)
return result.modified_count
async def aggregate_trends_from_report(self, patient_id: str, report_id: str):
"""
Incrementally update patient trends based on a new report's tests.
- Fetches the report
- For each test, appends (date, value, unit) to trends[patient_id, test_name]
- Ensures no duplicate points for same report/test combo
"""
report = await self.reports.find_one({"_id": ObjectId(report_id)})
if not report:
raise ValueError(f"Report {report_id} not found")
# print("report = ",report)
tests = report.get("parsed_data_v2", {"lab_results": []}).get("lab_results", [])
if not tests:
return 0
updated = 0
async def add_or_update_trend_data_point(test):
test_name = test.get("test_name")
value = test.get("result_value")
unit = test.get("test_unit")
test_date = test.get("test_date") or datetime.now(timezone.utc)
test_reference_range = test.get("test_reference_range")
inferred_range = test.get("inferred_range")
# Normalize test_date (keep your existing normalization here)...
if isinstance(test_date, (int, float)):
test_date = datetime.fromtimestamp(test_date, tz=timezone.utc)
elif isinstance(test_date, str):
try:
test_date = datetime.fromisoformat(test_date)
except Exception:
test_date = datetime.now(timezone.utc)
point = {
"date": test_date,
"value": value,
"unit": unit,
"report_id": ObjectId(report_id),
}
# Step 1: Check if trend_data with same date exists
existing_doc = await self.trends.find_one(
{
"patient_id": ObjectId(patient_id),
"test_name": test_name,
"trend_data.date": test_date,
},
projection={"trend_data.$": 1}, # Project only matched array element
)
if existing_doc:
# Step 2: Update the existing trend_data array element with new data
result = await self.trends.update_one(
{
"patient_id": ObjectId(patient_id),
"test_name": test_name,
"trend_data.date": test_date,
},
{
"$set": {
"trend_data.$.value": value,
"trend_data.$.unit": unit,
"trend_data.$.report_id": ObjectId(report_id),
"last_updated": datetime.now(timezone.utc),
"test_reference_range": test_reference_range,
"inferred_range": inferred_range,
},
"$setOnInsert": {
"patient_id": ObjectId(patient_id),
"test_name": test_name,
"created_at": datetime.now(timezone.utc),
},
},
)
else:
# Step 3: Insert new point as it does not exist yet
result = await self.trends.update_one(
{"patient_id": ObjectId(patient_id), "test_name": test_name},
{
"$setOnInsert": {
"patient_id": ObjectId(patient_id),
"test_name": test_name,
"created_at": datetime.now(timezone.utc),
},
"$push": {"trend_data": point},
"$set": {
"last_updated": datetime.now(timezone.utc),
"test_reference_range": test_reference_range,
"inferred_range": inferred_range,
"test_reference_range": test_reference_range,
"inferred_range": inferred_range,
},
},
upsert=True,
)
return result
for test in tests:
test_name = test.get("test_name")
if not test_name:
sub_results = test.get("sub_results", [])
if not sub_results:
continue
for sub_result in sub_results:
test_name = sub_result.get("test_name")
db_output = await add_or_update_trend_data_point(sub_result)
updated += db_output.modified_count
continue
else:
db_output = await add_or_update_trend_data_point(test)
updated += db_output.modified_count
# print("updated/inserted", updated, "trends")
return updated
async def aggregate_trends_snapshot(self, patient_id: str):
# fetch trends for this patient
trend_docs = await self.get_trends_by_patient(patient_id)
# extract "snapshots"
snapshots = []
for t in trend_docs:
td = t.get("trend_data", [])
if not td:
continue
# last point = most recent measurement
last_point = max(td, key=lambda x: x.get("date"))
snapshots.append(
{
"test_name": t.get("test_name", ""),
"latest_value": last_point.get("value"),
"latest_date": last_point.get("date"),
"unit": last_point.get("unit", ""),
"reference_range": last_point.get("reference_range", ""),
}
)
return snapshots
async def upload_bytes_to_fs(self, data: bytes, filename: str, patient_id):
# Open an upload stream
upload_stream = self.fs.open_upload_stream(
filename,
metadata={
"patient_id": patient_id,
"uploaded_at": datetime.now(timezone.utc),
},
)
# Write data to stream
await upload_stream.write(data)
await upload_stream.close()
# The file ID is in upload_stream._id
return upload_stream._id
# ---------------------------
# FINAL REPORT FUNCTIONS
# ---------------------------
async def add_final_report_v2(
self,
patient_id: str,
summary: str,
pdf_bytes: bytes = None,
file_name: str = None,
) -> str:
"""
Insert a final report.
If pdf_bytes is provided, stores it in GridFS and saves file_id in metadata.
"""
pdf_file_id = None
if pdf_bytes:
pdf_file_id = await self.upload_bytes_to_fs(
data=pdf_bytes, filename=file_name, patient_id=ObjectId(patient_id)
)
final_report = {
"patient_id": ObjectId(patient_id),
"generated_at": datetime.now(timezone.utc),
"summary": summary,
"trend_snapshots": await self.aggregate_trends_snapshot(
patient_id=patient_id
),
"pdf_file_id": pdf_file_id, # Reference to GridFS file
}
result = await self.final_reports.insert_one(final_report)
return str(result.inserted_id)
async def get_final_report_pdf(self, report_id: str) -> bytes | None:
"""
Fetch the PDF bytes for a given final_report (if stored).
"""
doc = await self.final_reports.find_one({"_id": ObjectId(report_id)})
if not doc or not doc.get("pdf_file_id"):
return None
grid_out = await self.fs.open_download_stream(doc["pdf_file_id"])
# Read the file data in chunks
file_data = b""
while True:
chunk = await grid_out.read(1024) # Read 1024 bytes at a time
if not chunk:
break
file_data += chunk
# await grid_out.close()
return file_data
async def get_final_report_html(self, report_id: str) -> bytes | None:
"""
Fetch the HTML for a given final_report (if stored).
"""
doc = await self.final_reports.find_one({"_id": ObjectId(report_id)})
if not doc or not doc.get("summary"):
return "<html></html>" # empty tag
return doc.get("summary")
async def get_run_stats_by_patient(self, patient_id: str) -> list:
cursor = self.run_stats.find({"patient_id": ObjectId(patient_id)}).sort(
"created_at", -1
)
run_stats = await cursor.to_list(length=None)
return run_stats
async def get_run_stats_by_id(self, id: str):
run_stat = await self.run_stats.find_one({"_id": ObjectId(id)})
return run_stat
async def delete_pdfs_by_patient_id(self, patient_id: str) -> int:
# Find all files with the specified patient_id in metadata
cursor = self.db.fs.files.find(
{"metadata.patient_id": ObjectId(patient_id)}, projection={"_id": 1}
)
deleted_count = 0
async for file_doc in cursor:
file_id = file_doc["_id"]
await self.fs.delete(file_id)
deleted_count += 1
return deleted_count
def convert_to_serializable_data(self, data):
"""
Recursively converts MongoDB-specific types to JSON serializable formats.
- ObjectId to string
- datetime to ISO 8601 string
Handles dict, list, and basic types.
"""
if isinstance(data, dict):
return {k: self.convert_to_serializable_data(v) for k, v in data.items()}
elif isinstance(data, list):
return [self.convert_to_serializable_data(i) for i in data]
elif isinstance(data, ObjectId):
return str(data)
elif isinstance(data, datetime):
return data.isoformat()
else:
return data
def normalize_to_date(self, reading_date):
"""
Convert a datetime or date to a datetime at 00:00:00
"""
if isinstance(reading_date, date) and not isinstance(reading_date, datetime):
# convert date to datetime at midnight
return datetime(reading_date.year, reading_date.month, reading_date.day)
elif isinstance(reading_date, datetime):
# strip time part
return datetime(reading_date.year, reading_date.month, reading_date.day)
else:
raise ValueError("reading_date must be a date or datetime")
async def save_readings_to_db(
self, patient_id: str, reading_date, new_readings: list, created_by: str
):
"""
Async save/merge readings for a patient/date using Motor.
"""
reading_date = self.normalize_to_date(reading_date)
# Find existing document
doc = await self.vitals.find_one(
{
"patient_id": ObjectId(patient_id),
"date": reading_date,
}
)
if not doc:
# Insert new
await self.vitals.insert_one(
{
"patient_id": ObjectId(patient_id),
"date": reading_date,
"readings": new_readings,
"created_by": created_by,
}
)
else:
# Merge: update existing readings by 'name', append new ones
existing = doc.get("readings", [])
merged = existing.copy()
existing_names = {r["name"]: r for r in existing}
for r in new_readings:
if r["name"] in existing_names:
# Update existing entry
existing_names[r["name"]].update(r)
else:
# Append new entry
merged.append(r)
# For entries that were updated in-place, ensure they are in merged
merged_names = {r["name"] for r in merged}
for name, r in existing_names.items():
if name not in merged_names:
merged.append(r)
await self.vitals.update_one(
{"_id": doc["_id"]},
{"$set": {"readings": merged, "created_by": created_by}},
)
async def get_vitals_by_patient(
self, patient_id: str, fields: list[str] = None, serializable=False
) -> list:
cursor = self.vitals.find({"patient_id": ObjectId(patient_id)}).sort("date", -1)
vitals = await cursor.to_list(length=None)
if fields:
vitals = [
{field: vital[field] for field in fields if field in vital}
for vital in vitals
]
if serializable:
vitals = self.convert_to_serializable_data(data=vitals)
# print("vitals = ",vitals)
return vitals
async def get_latest_vitals_by_patient(self, patient_id: str) -> dict:
# sort by date descending and get the first record
cursor = self.vitals.find({"patient_id": ObjectId(patient_id)}).sort("date", -1)
vitals = await cursor.to_list(length=None)
if len(vitals) > 0:
vitals = vitals[0]
else:
vitals = {}
return vitals
async def get_vitals_trends_by_patient(self, patient_id):
docs = await self.vitals.aggregate([
{"$match": {"patient_id": ObjectId(patient_id)}},
{"$unwind": "$readings"},
{"$project": {
"date": 1,
"name": "$readings.name",
"value": "$readings.value",
"unit": "$readings.unit"
}},
{"$sort": {"date": 1}}
]).to_list(length=None)
df = pd.DataFrame(docs)
trend_docs = []
if "name" in df:
for vital_name, group in df.groupby("name"):
if vital_name.lower() in ["bp", "blood pressure"]:
# Split systolic/diastolic
systolic, diastolic = [], []
for _, row in group.iterrows():
try:
sys, dia = row["value"].split("/")
systolic.append({"date": row["date"], "value": int(sys)})
diastolic.append({"date": row["date"], "value": int(dia)})
except Exception:
continue
# Append as two separate test series
trend_docs.append({
"test_name": "BP - Systolic",
"trend_data": systolic,
"unit": "mmHg",
"test_reference_range": {}
})
trend_docs.append({
"test_name": "BP - Diastolic",
"trend_data": diastolic,
"unit": "mmHg",
"test_reference_range": {}
})
else:
# Other vitals
trend_docs.append({
"test_name": vital_name,
"trend_data": group[["date", "value"]].to_dict("records"),
"unit": group["unit"].iloc[0],
"test_reference_range": {}
})
# 🔑 Sort by test_name so BP series appear together
trend_docs = sorted(trend_docs, key=lambda x: x["test_name"])
return trend_docs