LiamKhoaLe commited on
Commit
a052739
·
1 Parent(s): 883d505

Refactor mongodb.py

Browse files
src/api/routes/chat.py CHANGED
@@ -11,7 +11,7 @@ from src.services.medical_response import generate_medical_response
11
  from src.services.summariser import summarise_title_with_nvidia
12
  from src.utils.logger import get_logger
13
 
14
- from src.data.mongodb import ensure_session
15
 
16
  logger = get_logger("CHAT_ROUTES", __name__)
17
  router = APIRouter()
 
11
  from src.services.summariser import summarise_title_with_nvidia
12
  from src.utils.logger import get_logger
13
 
14
+ from src.data import ensure_session
15
 
16
  logger = get_logger("CHAT_ROUTES", __name__)
17
  router = APIRouter()
src/api/routes/session.py CHANGED
@@ -6,7 +6,7 @@ from datetime import datetime
6
  from src.core.state import MedicalState, get_state
7
  from src.models.chat import SessionRequest
8
  from src.utils.logger import get_logger
9
- from src.data.mongodb import list_patient_sessions, list_session_messages, ensure_session, delete_session, delete_session_messages
10
 
11
  logger = get_logger("SESSION_ROUTES", __name__)
12
  router = APIRouter()
 
6
  from src.core.state import MedicalState, get_state
7
  from src.models.chat import SessionRequest
8
  from src.utils.logger import get_logger
9
+ from src.data import list_patient_sessions, list_session_messages, ensure_session, delete_session, delete_session_messages
10
 
11
  logger = get_logger("SESSION_ROUTES", __name__)
12
  router = APIRouter()
src/api/routes/user.py CHANGED
@@ -5,7 +5,7 @@ from fastapi import APIRouter, Depends, HTTPException
5
  from src.core.state import MedicalState, get_state
6
  from src.models.user import UserProfileRequest, PatientCreateRequest, PatientUpdateRequest, DoctorCreateRequest
7
  from src.utils.logger import get_logger
8
- from src.data.mongodb import create_account, create_doctor, get_doctor_by_name, search_doctors, get_all_doctors
9
 
10
  logger = get_logger("USER_ROUTES", __name__)
11
  router = APIRouter()
@@ -81,7 +81,7 @@ async def get_user_profile(
81
  raise HTTPException(status_code=500, detail=str(e))
82
 
83
  # -------------------- Patient APIs --------------------
84
- from src.data.mongodb import get_patient_by_id, create_patient, update_patient_profile, search_patients
85
 
86
  @router.get("/patients/search")
87
  async def search_patients_route(q: str, limit: int = 20):
 
5
  from src.core.state import MedicalState, get_state
6
  from src.models.user import UserProfileRequest, PatientCreateRequest, PatientUpdateRequest, DoctorCreateRequest
7
  from src.utils.logger import get_logger
8
+ from src.data import create_account, create_doctor, get_doctor_by_name, search_doctors, get_all_doctors
9
 
10
  logger = get_logger("USER_ROUTES", __name__)
11
  router = APIRouter()
 
81
  raise HTTPException(status_code=500, detail=str(e))
82
 
83
  # -------------------- Patient APIs --------------------
84
+ from src.data import get_patient_by_id, create_patient, update_patient_profile, search_patients
85
 
86
  @router.get("/patients/search")
87
  async def search_patients_route(q: str, limit: int = 20):
src/core/memory/history.py CHANGED
@@ -11,7 +11,7 @@ from src.services.summariser import (summarise_qa_with_gemini,
11
  summarise_qa_with_nvidia)
12
  from src.utils.embeddings import EmbeddingClient
13
  from src.utils.logger import get_logger
14
- from src.data.mongodb import save_memory_summary, save_chat_message, ensure_session, get_recent_memory_summaries, search_memory_summaries_semantic
15
 
16
  logger = get_logger("MED_HISTORY")
17
 
 
11
  summarise_qa_with_nvidia)
12
  from src.utils.embeddings import EmbeddingClient
13
  from src.utils.logger import get_logger
14
+ from src.data import save_memory_summary, save_chat_message, ensure_session, get_recent_memory_summaries, search_memory_summaries_semantic
15
 
16
  logger = get_logger("MED_HISTORY")
17
 
src/data/__init__.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # data/__init__.py
2
+ """
3
+ Data layer for MongoDB operations.
4
+ Organized into specialized modules for different data types.
5
+ """
6
+
7
+ from .connection import get_database, get_collection, close_connection
8
+ from .session import *
9
+ from .user import *
10
+ from .message import *
11
+ from .patient import *
12
+ from .medical import *
13
+ from .utils import create_index, backup_collection
14
+
15
+ __all__ = [
16
+ # Connection
17
+ 'get_database',
18
+ 'get_collection',
19
+ 'close_connection',
20
+ # Session functions
21
+ 'create_chat_session',
22
+ 'get_user_sessions',
23
+ 'ensure_session',
24
+ 'list_patient_sessions',
25
+ 'delete_session',
26
+ 'delete_session_messages',
27
+ 'delete_old_sessions',
28
+ # User functions
29
+ 'create_account',
30
+ 'update_account',
31
+ 'get_account_frame',
32
+ 'create_doctor',
33
+ 'get_doctor_by_name',
34
+ 'search_doctors',
35
+ 'get_all_doctors',
36
+ # Message functions
37
+ 'add_message',
38
+ 'get_session_messages',
39
+ 'save_chat_message',
40
+ 'list_session_messages',
41
+ # Patient functions
42
+ 'get_patient_by_id',
43
+ 'create_patient',
44
+ 'update_patient_profile',
45
+ 'search_patients',
46
+ # Medical functions
47
+ 'create_medical_record',
48
+ 'get_user_medical_records',
49
+ 'save_memory_summary',
50
+ 'get_recent_memory_summaries',
51
+ 'search_memory_summaries_semantic',
52
+ # Utility functions
53
+ 'create_index',
54
+ 'backup_collection',
55
+ ]
src/data/connection.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # data/connection.py
2
+ """
3
+ MongoDB connection management and base database operations.
4
+ """
5
+
6
+ import os
7
+ from typing import Any
8
+
9
+ from pymongo import MongoClient
10
+ from pymongo.collection import Collection
11
+ from pymongo.database import Database
12
+
13
+ from src.utils.logger import get_logger
14
+
15
+ logger = get_logger("MONGO")
16
+
17
+ # Global client instance
18
+ _mongo_client: MongoClient | None = None
19
+
20
+ # Collection Names
21
+ ACCOUNTS_COLLECTION = "accounts"
22
+ CHAT_SESSIONS_COLLECTION = "chat_sessions"
23
+ CHAT_MESSAGES_COLLECTION = "chat_messages"
24
+ MEDICAL_RECORDS_COLLECTION = "medical_records"
25
+ MEDICAL_MEMORY_COLLECTION = "medical_memory"
26
+ PATIENTS_COLLECTION = "patients"
27
+
28
+
29
+ def get_database() -> Database:
30
+ """Get database instance with connection management"""
31
+ global _mongo_client
32
+ if _mongo_client is None:
33
+ CONNECTION_STRING = os.getenv("MONGO_USER", "mongodb://127.0.0.1:27017/") # fall back to local host if no user is provided
34
+ try:
35
+ logger.info("Initializing MongoDB connection")
36
+ _mongo_client = MongoClient(CONNECTION_STRING)
37
+ except Exception as e:
38
+ logger.error(f"Failed to connect to MongoDB: {str(e)}")
39
+ # Pass the error down, code that calls this function should handle it
40
+ raise e
41
+ db_name = os.getenv("USER_DB", "medicaldiagnosissystem")
42
+ return _mongo_client[db_name]
43
+
44
+
45
+ def close_connection():
46
+ """Close MongoDB connection"""
47
+ global _mongo_client
48
+ if _mongo_client is not None:
49
+ # Close the connection and reset the client
50
+ _mongo_client.close()
51
+ _mongo_client = None
52
+
53
+
54
+ def get_collection(name: str, /) -> Collection:
55
+ """Get a MongoDB collection by name"""
56
+ db = get_database()
57
+ return db.get_collection(name)
src/data/medical/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # data/medical/__init__.py
2
+ """
3
+ Medical records and memory management operations for MongoDB.
4
+ """
5
+
6
+ from .operations import (
7
+ create_medical_record,
8
+ get_user_medical_records,
9
+ save_memory_summary,
10
+ get_recent_memory_summaries,
11
+ search_memory_summaries_semantic,
12
+ )
13
+
14
+ __all__ = [
15
+ 'create_medical_record',
16
+ 'get_user_medical_records',
17
+ 'save_memory_summary',
18
+ 'get_recent_memory_summaries',
19
+ 'search_memory_summaries_semantic',
20
+ ]
src/data/medical/operations.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # data/medical/operations.py
2
+ """
3
+ Medical records and memory management operations for MongoDB.
4
+ """
5
+
6
+ from datetime import datetime, timezone
7
+ from typing import Any
8
+
9
+ from pymongo import ASCENDING, DESCENDING
10
+
11
+ from ..connection import get_collection, MEDICAL_RECORDS_COLLECTION, MEDICAL_MEMORY_COLLECTION
12
+ from src.utils.logger import get_logger
13
+
14
+ logger = get_logger("MEDICAL_OPS")
15
+
16
+
17
+ def create_medical_record(
18
+ record_data: dict[str, Any],
19
+ /,
20
+ *,
21
+ collection_name: str = MEDICAL_RECORDS_COLLECTION
22
+ ) -> str:
23
+ """Create a new medical record"""
24
+ collection = get_collection(collection_name)
25
+ now = datetime.now(timezone.utc)
26
+ record_data["created_at"] = now
27
+ record_data["updated_at"] = now
28
+ result = collection.insert_one(record_data)
29
+ return str(result.inserted_id)
30
+
31
+
32
+ def get_user_medical_records(
33
+ user_id: str,
34
+ /,
35
+ *,
36
+ collection_name: str = MEDICAL_RECORDS_COLLECTION
37
+ ) -> list[dict[str, Any]]:
38
+ """Get medical records for a specific user"""
39
+ collection = get_collection(collection_name)
40
+ return list(collection.find({"user_id": user_id}).sort("created_at", ASCENDING))
41
+
42
+
43
+ def save_memory_summary(
44
+ *,
45
+ patient_id: str,
46
+ doctor_id: str,
47
+ summary: str,
48
+ embedding: list[float] | None = None,
49
+ created_at: datetime | None = None,
50
+ collection_name: str = MEDICAL_MEMORY_COLLECTION
51
+ ) -> str:
52
+ collection = get_collection(collection_name)
53
+ ts = created_at or datetime.now(timezone.utc)
54
+ doc = {
55
+ "patient_id": patient_id,
56
+ "doctor_id": doctor_id,
57
+ "summary": summary,
58
+ "created_at": ts
59
+ }
60
+ if embedding is not None:
61
+ doc["embedding"] = embedding
62
+ result = collection.insert_one(doc)
63
+ return str(result.inserted_id)
64
+
65
+
66
+ def get_recent_memory_summaries(
67
+ patient_id: str,
68
+ /,
69
+ *,
70
+ limit: int = 20,
71
+ collection_name: str = MEDICAL_MEMORY_COLLECTION
72
+ ) -> list[str]:
73
+ collection = get_collection(collection_name)
74
+ docs = list(collection.find({"patient_id": patient_id}).sort("created_at", DESCENDING).limit(limit))
75
+ return [d.get("summary", "") for d in docs]
76
+
77
+
78
+ def search_memory_summaries_semantic(
79
+ patient_id: str,
80
+ query_embedding: list[float],
81
+ /,
82
+ *,
83
+ limit: int = 5,
84
+ similarity_threshold: float = 0.5, # >= 50% semantic similarity
85
+ collection_name: str = MEDICAL_MEMORY_COLLECTION
86
+ ) -> list[dict[str, Any]]:
87
+ """
88
+ Search memory summaries using semantic similarity with embeddings.
89
+ Returns list of {summary, similarity_score, created_at} sorted by similarity.
90
+ """
91
+ collection = get_collection(collection_name)
92
+
93
+ # Get all summaries with embeddings for this patient
94
+ docs = list(collection.find({
95
+ "patient_id": patient_id,
96
+ "embedding": {"$exists": True}
97
+ }))
98
+
99
+ if not docs:
100
+ return []
101
+
102
+ # Calculate similarities
103
+ import numpy as np
104
+ query_vec = np.array(query_embedding, dtype="float32")
105
+ results = []
106
+
107
+ for doc in docs:
108
+ embedding = doc.get("embedding")
109
+ if not embedding:
110
+ continue
111
+
112
+ # Calculate cosine similarity
113
+ doc_vec = np.array(embedding, dtype="float32")
114
+ dot_product = np.dot(query_vec, doc_vec)
115
+ norm_query = np.linalg.norm(query_vec)
116
+ norm_doc = np.linalg.norm(doc_vec)
117
+
118
+ if norm_query == 0 or norm_doc == 0:
119
+ similarity = 0.0
120
+ else:
121
+ similarity = float(dot_product / (norm_query * norm_doc))
122
+
123
+ if similarity >= similarity_threshold:
124
+ results.append({
125
+ "summary": doc.get("summary", ""),
126
+ "similarity_score": similarity,
127
+ "created_at": doc.get("created_at"),
128
+ "session_id": doc.get("session_id") # if we add this field later
129
+ })
130
+
131
+ # Sort by similarity (highest first) and return top results
132
+ results.sort(key=lambda x: x["similarity_score"], reverse=True)
133
+ return results[:limit]
src/data/message/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # data/message/__init__.py
2
+ """
3
+ Message management operations for MongoDB.
4
+ """
5
+
6
+ from .operations import (
7
+ add_message,
8
+ get_session_messages,
9
+ save_chat_message,
10
+ list_session_messages,
11
+ )
12
+
13
+ __all__ = [
14
+ 'add_message',
15
+ 'get_session_messages',
16
+ 'save_chat_message',
17
+ 'list_session_messages',
18
+ ]
src/data/message/operations.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # data/message/operations.py
2
+ """
3
+ Message management operations for MongoDB.
4
+ """
5
+
6
+ from datetime import datetime, timezone
7
+ from typing import Any
8
+
9
+ from bson import ObjectId
10
+ from pymongo import ASCENDING
11
+
12
+ from ..connection import get_collection, CHAT_SESSIONS_COLLECTION, CHAT_MESSAGES_COLLECTION
13
+ from src.utils.logger import get_logger
14
+
15
+ logger = get_logger("MESSAGE_OPS")
16
+
17
+
18
+ def add_message(
19
+ session_id: str,
20
+ message_data: dict[str, Any],
21
+ /,
22
+ *,
23
+ collection_name: str = CHAT_SESSIONS_COLLECTION
24
+ ) -> str | None:
25
+ """Add a message to a chat session"""
26
+ collection = get_collection(collection_name)
27
+
28
+ # Verify session exists first
29
+ session = collection.find_one({
30
+ "$or": [
31
+ {"_id": session_id},
32
+ {"_id": ObjectId(session_id) if ObjectId.is_valid(session_id) else None}
33
+ ]
34
+ })
35
+ if not session:
36
+ logger.error(f"Failed to add message - session not found: {session_id}")
37
+ raise ValueError(f"Chat session not found: {session_id}")
38
+
39
+ now = datetime.now(timezone.utc)
40
+ message_data["timestamp"] = now
41
+ result = collection.update_one(
42
+ {"_id": session["_id"]},
43
+ {
44
+ "$push": {"messages": message_data},
45
+ "$set": {"updated_at": now}
46
+ }
47
+ )
48
+ return str(session_id) if result.modified_count > 0 else None
49
+
50
+
51
+ def get_session_messages(
52
+ session_id: str,
53
+ /,
54
+ limit: int | None = None,
55
+ *,
56
+ collection_name: str = CHAT_SESSIONS_COLLECTION
57
+ ) -> list[dict[str, Any]]:
58
+ """Get messages from a specific chat session"""
59
+ collection = get_collection(collection_name)
60
+ pipeline = [
61
+ {"$match": {"_id": session_id}},
62
+ {"$unwind": "$messages"},
63
+ {"$sort": {"messages.timestamp": -1}}
64
+ ]
65
+ if limit:
66
+ pipeline.append({"$limit": limit})
67
+ return [doc["messages"] for doc in collection.aggregate(pipeline)]
68
+
69
+
70
+ def save_chat_message(
71
+ *,
72
+ session_id: str,
73
+ patient_id: str,
74
+ doctor_id: str,
75
+ role: str,
76
+ content: str,
77
+ timestamp: datetime | None = None,
78
+ collection_name: str = CHAT_MESSAGES_COLLECTION
79
+ ) -> ObjectId:
80
+ collection = get_collection(collection_name)
81
+ ts = timestamp or datetime.now(timezone.utc)
82
+ doc = {
83
+ "session_id": session_id,
84
+ "patient_id": patient_id,
85
+ "doctor_id": doctor_id,
86
+ "role": role,
87
+ "content": content,
88
+ "timestamp": ts,
89
+ "created_at": ts
90
+ }
91
+ result = collection.insert_one(doc)
92
+ return result.inserted_id
93
+
94
+
95
+ def list_session_messages(
96
+ session_id: str,
97
+ /,
98
+ *,
99
+ patient_id: str | None = None,
100
+ limit: int | None = None,
101
+ collection_name: str = CHAT_MESSAGES_COLLECTION
102
+ ) -> list[dict[str, Any]]:
103
+ collection = get_collection(collection_name)
104
+
105
+ # First verify the session belongs to the patient
106
+ if patient_id:
107
+ session_collection = get_collection(CHAT_SESSIONS_COLLECTION)
108
+ session = session_collection.find_one({
109
+ "session_id": session_id,
110
+ "patient_id": patient_id
111
+ })
112
+ if not session:
113
+ logger.warning(f"Session {session_id} not found for patient {patient_id}")
114
+ return []
115
+
116
+ # Query messages with patient_id filter if provided
117
+ query = {"session_id": session_id}
118
+ if patient_id:
119
+ query["patient_id"] = patient_id
120
+
121
+ cursor = collection.find(query).sort("timestamp", ASCENDING)
122
+ if limit is not None:
123
+ cursor = cursor.limit(limit)
124
+ return list(cursor)
src/data/{mongodb.py → mongodb.py.backup} RENAMED
@@ -2,7 +2,7 @@
2
 
3
  """
4
  Interface for mongodb using pymongo.
5
- Current code is simply a proof of concept and is not ready for implementation.
6
  """
7
 
8
  from datetime import datetime, timedelta, timezone
 
2
 
3
  """
4
  Interface for mongodb using pymongo.
5
+ This file has been refactored.
6
  """
7
 
8
  from datetime import datetime, timedelta, timezone
src/data/patient/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # data/patient/__init__.py
2
+ """
3
+ Patient management operations for MongoDB.
4
+ """
5
+
6
+ from .operations import (
7
+ get_patient_by_id,
8
+ create_patient,
9
+ update_patient_profile,
10
+ search_patients,
11
+ )
12
+
13
+ __all__ = [
14
+ 'get_patient_by_id',
15
+ 'create_patient',
16
+ 'update_patient_profile',
17
+ 'search_patients',
18
+ ]
src/data/patient/operations.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # data/patient/operations.py
2
+ """
3
+ Patient management operations for MongoDB.
4
+ """
5
+
6
+ import re
7
+ from datetime import datetime, timezone
8
+ from typing import Any
9
+
10
+ from pymongo import ASCENDING
11
+
12
+ from ..connection import get_collection, PATIENTS_COLLECTION
13
+ from src.utils.logger import get_logger
14
+
15
+ logger = get_logger("PATIENT_OPS")
16
+
17
+
18
+ def _generate_patient_id() -> str:
19
+ """Generate zero-padded 8-digit ID"""
20
+ import random
21
+ return f"{random.randint(0, 99999999):08d}"
22
+
23
+
24
+ def get_patient_by_id(patient_id: str) -> dict[str, Any] | None:
25
+ collection = get_collection(PATIENTS_COLLECTION)
26
+ return collection.find_one({"patient_id": patient_id})
27
+
28
+
29
+ def create_patient(
30
+ *,
31
+ name: str,
32
+ age: int,
33
+ sex: str,
34
+ address: str | None = None,
35
+ phone: str | None = None,
36
+ email: str | None = None,
37
+ medications: list[str] | None = None,
38
+ past_assessment_summary: str | None = None,
39
+ assigned_doctor_id: str | None = None
40
+ ) -> dict[str, Any]:
41
+ collection = get_collection(PATIENTS_COLLECTION)
42
+ now = datetime.now(timezone.utc)
43
+ # Ensure unique 8-digit id
44
+ for _ in range(10):
45
+ pid = _generate_patient_id()
46
+ if not collection.find_one({"patient_id": pid}):
47
+ break
48
+ else:
49
+ raise RuntimeError("Failed to generate unique patient ID")
50
+ doc = {
51
+ "patient_id": pid,
52
+ "name": name,
53
+ "age": age,
54
+ "sex": sex,
55
+ "address": address,
56
+ "phone": phone,
57
+ "email": email,
58
+ "medications": medications or [],
59
+ "past_assessment_summary": past_assessment_summary or "",
60
+ "assigned_doctor_id": assigned_doctor_id,
61
+ "created_at": now,
62
+ "updated_at": now
63
+ }
64
+ collection.insert_one(doc)
65
+ return doc
66
+
67
+
68
+ def update_patient_profile(patient_id: str, updates: dict[str, Any]) -> int:
69
+ collection = get_collection(PATIENTS_COLLECTION)
70
+ updates["updated_at"] = datetime.now(timezone.utc)
71
+ result = collection.update_one({"patient_id": patient_id}, {"$set": updates})
72
+ return result.modified_count
73
+
74
+
75
+ def search_patients(query: str, limit: int = 10) -> list[dict[str, Any]]:
76
+ """Search patients by name (case-insensitive starts-with/contains) or partial patient_id."""
77
+ collection = get_collection(PATIENTS_COLLECTION)
78
+ if not query:
79
+ return []
80
+
81
+ logger.info(f"Searching patients with query: '{query}', limit: {limit}")
82
+
83
+ # Build a regex for name search and patient_id partial match
84
+ pattern = re.compile(re.escape(query), re.IGNORECASE)
85
+
86
+ try:
87
+ cursor = collection.find({
88
+ "$or": [
89
+ {"name": {"$regex": pattern}},
90
+ {"patient_id": {"$regex": pattern}}
91
+ ]
92
+ }).sort("name", ASCENDING).limit(limit)
93
+ results = []
94
+ for p in cursor:
95
+ p["_id"] = str(p.get("_id")) if p.get("_id") else None
96
+ results.append(p)
97
+ logger.info(f"Found {len(results)} patients matching query")
98
+ return results
99
+ except Exception as e:
100
+ logger.error(f"Error in search_patients: {e}")
101
+ return []
src/data/session/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # data/session/__init__.py
2
+ """
3
+ Session management operations for MongoDB.
4
+ """
5
+
6
+ from .operations import (
7
+ create_chat_session,
8
+ get_user_sessions,
9
+ ensure_session,
10
+ list_patient_sessions,
11
+ delete_session,
12
+ delete_session_messages,
13
+ delete_old_sessions,
14
+ )
15
+
16
+ __all__ = [
17
+ 'create_chat_session',
18
+ 'get_user_sessions',
19
+ 'ensure_session',
20
+ 'list_patient_sessions',
21
+ 'delete_session',
22
+ 'delete_session_messages',
23
+ 'delete_old_sessions',
24
+ ]
src/data/session/operations.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # data/session/operations.py
2
+ """
3
+ Session management operations for MongoDB.
4
+ """
5
+
6
+ from datetime import datetime, timedelta, timezone
7
+ from typing import Any
8
+
9
+ from bson import ObjectId
10
+ from pymongo import ASCENDING, DESCENDING
11
+
12
+ from ..connection import get_collection, CHAT_SESSIONS_COLLECTION, CHAT_MESSAGES_COLLECTION
13
+ from src.utils.logger import get_logger
14
+
15
+ logger = get_logger("SESSION_OPS")
16
+
17
+
18
+ def create_chat_session(
19
+ session_data: dict[str, Any],
20
+ /,
21
+ *,
22
+ collection_name: str = CHAT_SESSIONS_COLLECTION
23
+ ) -> str:
24
+ """Create a new chat session"""
25
+ collection = get_collection(collection_name)
26
+ now = datetime.now(timezone.utc)
27
+ session_data["created_at"] = now
28
+ session_data["updated_at"] = now
29
+ if "_id" not in session_data:
30
+ session_data["_id"] = str(ObjectId())
31
+ result = collection.insert_one(session_data)
32
+ return str(result.inserted_id)
33
+
34
+
35
+ def get_user_sessions(
36
+ user_id: str,
37
+ /,
38
+ limit: int = 20,
39
+ *,
40
+ collection_name: str = CHAT_SESSIONS_COLLECTION
41
+ ) -> list[dict[str, Any]]:
42
+ """Get chat sessions for a specific user"""
43
+ collection = get_collection(collection_name)
44
+ return list(collection.find(
45
+ {"user_id": user_id}
46
+ ).sort("updated_at", DESCENDING).limit(limit))
47
+
48
+
49
+ def ensure_session(
50
+ *,
51
+ session_id: str,
52
+ patient_id: str,
53
+ doctor_id: str,
54
+ title: str,
55
+ last_activity: datetime | None = None,
56
+ collection_name: str = CHAT_SESSIONS_COLLECTION
57
+ ) -> None:
58
+ collection = get_collection(collection_name)
59
+ now = datetime.now(timezone.utc)
60
+ collection.update_one(
61
+ {"session_id": session_id},
62
+ {"$set": {
63
+ "session_id": session_id,
64
+ "patient_id": patient_id,
65
+ "doctor_id": doctor_id,
66
+ "title": title,
67
+ "last_activity": (last_activity or now),
68
+ "updated_at": now
69
+ }, "$setOnInsert": {"created_at": now}},
70
+ upsert=True
71
+ )
72
+
73
+
74
+ def list_patient_sessions(
75
+ patient_id: str,
76
+ /,
77
+ *,
78
+ collection_name: str = CHAT_SESSIONS_COLLECTION
79
+ ) -> list[dict[str, Any]]:
80
+ collection = get_collection(collection_name)
81
+ sessions = list(collection.find({"patient_id": patient_id}).sort("last_activity", DESCENDING))
82
+ # Convert ObjectId to string for JSON serialization
83
+ for session in sessions:
84
+ if "_id" in session:
85
+ session["_id"] = str(session["_id"])
86
+ return sessions
87
+
88
+
89
+ def delete_session(
90
+ session_id: str,
91
+ /,
92
+ *,
93
+ collection_name: str = CHAT_SESSIONS_COLLECTION
94
+ ) -> bool:
95
+ """Delete a chat session from MongoDB"""
96
+ collection = get_collection(collection_name)
97
+ result = collection.delete_one({"session_id": session_id})
98
+ return result.deleted_count > 0
99
+
100
+
101
+ def delete_session_messages(
102
+ session_id: str,
103
+ /,
104
+ *,
105
+ collection_name: str = CHAT_MESSAGES_COLLECTION
106
+ ) -> int:
107
+ """Delete all messages for a session from MongoDB"""
108
+ collection = get_collection(collection_name)
109
+ result = collection.delete_many({"session_id": session_id})
110
+ return result.deleted_count
111
+
112
+
113
+ def delete_old_sessions(
114
+ days: int = 30,
115
+ *,
116
+ collection_name: str = CHAT_SESSIONS_COLLECTION
117
+ ) -> int:
118
+ """Delete chat sessions older than specified days"""
119
+ collection = get_collection(collection_name)
120
+ cutoff = datetime.now(timezone.utc) - timedelta(days=days)
121
+ result = collection.delete_many({
122
+ "updated_at": {"$lt": cutoff}
123
+ })
124
+ if result.deleted_count > 0:
125
+ logger.info(f"Deleted {result.deleted_count} old sessions (>{days} days)")
126
+ return result.deleted_count
src/data/user/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # data/user/__init__.py
2
+ """
3
+ User management operations for MongoDB.
4
+ """
5
+
6
+ from .operations import (
7
+ create_account,
8
+ update_account,
9
+ get_account_frame,
10
+ create_doctor,
11
+ get_doctor_by_name,
12
+ search_doctors,
13
+ get_all_doctors,
14
+ )
15
+
16
+ __all__ = [
17
+ 'create_account',
18
+ 'update_account',
19
+ 'get_account_frame',
20
+ 'create_doctor',
21
+ 'get_doctor_by_name',
22
+ 'search_doctors',
23
+ 'get_all_doctors',
24
+ ]
src/data/user/operations.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # data/user/operations.py
2
+ """
3
+ User management operations for MongoDB.
4
+ """
5
+
6
+ from datetime import datetime, timezone
7
+ from typing import Any
8
+
9
+ import re
10
+ from pandas import DataFrame
11
+ from pymongo import ASCENDING
12
+ from pymongo.errors import DuplicateKeyError
13
+
14
+ from ..connection import get_collection, ACCOUNTS_COLLECTION
15
+ from src.utils.logger import get_logger
16
+
17
+ logger = get_logger("USER_OPS")
18
+
19
+
20
+ def get_account_frame(
21
+ *,
22
+ collection_name: str = ACCOUNTS_COLLECTION
23
+ ) -> DataFrame:
24
+ """Get accounts as a pandas DataFrame"""
25
+ return DataFrame(get_collection(collection_name).find())
26
+
27
+
28
+ def create_account(
29
+ user_data: dict[str, Any],
30
+ /,
31
+ *,
32
+ collection_name: str = ACCOUNTS_COLLECTION
33
+ ) -> str:
34
+ """Create a new user account"""
35
+ collection = get_collection(collection_name)
36
+ now = datetime.now(timezone.utc)
37
+ user_data["created_at"] = now
38
+ user_data["updated_at"] = now
39
+ try:
40
+ result = collection.insert_one(user_data)
41
+ logger.info(f"Created new account: {result.inserted_id}")
42
+ return str(result.inserted_id)
43
+ except DuplicateKeyError as e:
44
+ logger.error(f"Failed to create account - duplicate key: {str(e)}")
45
+ raise DuplicateKeyError(f"Account already exists: {e}") from e
46
+
47
+
48
+ def update_account(
49
+ user_id: str,
50
+ updates: dict[str, Any],
51
+ /,
52
+ *,
53
+ collection_name: str = ACCOUNTS_COLLECTION
54
+ ) -> bool:
55
+ """Update an existing user account"""
56
+ collection = get_collection(collection_name)
57
+ updates["updated_at"] = datetime.now(timezone.utc)
58
+ result = collection.update_one(
59
+ {"_id": user_id},
60
+ {"$set": updates}
61
+ )
62
+ return result.modified_count > 0
63
+
64
+
65
+ def create_doctor(
66
+ *,
67
+ name: str,
68
+ role: str | None = None,
69
+ specialty: str | None = None,
70
+ medical_roles: list[str] | None = None
71
+ ) -> str:
72
+ """Create a new doctor profile"""
73
+ collection = get_collection(ACCOUNTS_COLLECTION)
74
+ now = datetime.now(timezone.utc)
75
+ doctor_doc = {
76
+ "name": name,
77
+ "role": role,
78
+ "specialty": specialty,
79
+ "medical_roles": medical_roles or [],
80
+ "created_at": now,
81
+ "updated_at": now
82
+ }
83
+ try:
84
+ result = collection.insert_one(doctor_doc)
85
+ logger.info(f"Created new doctor: {name} with id {result.inserted_id}")
86
+ return str(result.inserted_id)
87
+ except Exception as e:
88
+ logger.error(f"Error creating doctor: {e}")
89
+ raise e
90
+
91
+
92
+ def get_doctor_by_name(name: str) -> dict[str, Any] | None:
93
+ """Get doctor by name from accounts collection"""
94
+ collection = get_collection(ACCOUNTS_COLLECTION)
95
+ doctor = collection.find_one({
96
+ "name": name,
97
+ "role": {"$in": ["Doctor", "Healthcare Prof", "General Practitioner", "Cardiologist", "Pediatrician", "Neurologist", "Dermatologist"]}
98
+ })
99
+ if doctor:
100
+ doctor["_id"] = str(doctor.get("_id")) if doctor.get("_id") else None
101
+ return doctor
102
+
103
+
104
+ def search_doctors(query: str, limit: int = 10) -> list[dict[str, Any]]:
105
+ """Search doctors by name (case-insensitive contains) from accounts collection"""
106
+ collection = get_collection(ACCOUNTS_COLLECTION)
107
+ if not query:
108
+ return []
109
+
110
+ logger.info(f"Searching doctors with query: '{query}', limit: {limit}")
111
+
112
+ # Build a regex for name search
113
+ pattern = re.compile(re.escape(query), re.IGNORECASE)
114
+
115
+ try:
116
+ cursor = collection.find({
117
+ "name": {"$regex": pattern},
118
+ "role": {"$in": ["Doctor", "Healthcare Prof", "General Practitioner", "Cardiologist", "Pediatrician", "Neurologist", "Dermatologist"]}
119
+ }).sort("name", ASCENDING).limit(limit)
120
+ results = []
121
+ for d in cursor:
122
+ d["_id"] = str(d.get("_id")) if d.get("_id") else None
123
+ results.append(d)
124
+ logger.info(f"Found {len(results)} doctors matching query")
125
+ return results
126
+ except Exception as e:
127
+ logger.error(f"Error in search_doctors: {e}")
128
+ return []
129
+
130
+
131
+ def get_all_doctors(limit: int = 50) -> list[dict[str, Any]]:
132
+ """Get all doctors with optional limit from accounts collection"""
133
+ collection = get_collection(ACCOUNTS_COLLECTION)
134
+ try:
135
+ cursor = collection.find({
136
+ "role": {"$in": ["Doctor", "Healthcare Prof", "General Practitioner", "Cardiologist", "Pediatrician", "Neurologist", "Dermatologist"]}
137
+ }).sort("name", ASCENDING).limit(limit)
138
+ results = []
139
+ for d in cursor:
140
+ d["_id"] = str(d.get("_id")) if d.get("_id") else None
141
+ results.append(d)
142
+ logger.info(f"Retrieved {len(results)} doctors")
143
+ return results
144
+ except Exception as e:
145
+ logger.error(f"Error getting all doctors: {e}")
146
+ return []
src/data/utils.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # data/utils.py
2
+ """
3
+ Utility functions for MongoDB operations.
4
+ """
5
+
6
+ from datetime import datetime, timezone
7
+ from typing import Any
8
+
9
+ from pymongo import ASCENDING
10
+
11
+ from .connection import get_collection, get_database
12
+ from src.utils.logger import get_logger
13
+
14
+ logger = get_logger("MONGO_UTILS")
15
+
16
+
17
+ def create_index(
18
+ collection_name: str,
19
+ field_name: str,
20
+ /,
21
+ unique: bool = False
22
+ ) -> None:
23
+ """Create an index on a collection"""
24
+ collection = get_collection(collection_name)
25
+ collection.create_index([(field_name, ASCENDING)], unique=unique)
26
+
27
+
28
+ def backup_collection(collection_name: str) -> str:
29
+ """Create a backup of a collection"""
30
+ collection = get_collection(collection_name)
31
+ backup_name = f"{collection_name}_backup_{datetime.now(timezone.utc).strftime('%Y%m%d')}"
32
+ db = get_database()
33
+
34
+ # Drop existing backup if it exists
35
+ if backup_name in db.list_collection_names():
36
+ logger.info(f"Removing existing backup: {backup_name}")
37
+ db.drop_collection(backup_name)
38
+
39
+ db.create_collection(backup_name)
40
+ backup = db[backup_name]
41
+
42
+ doc_count = 0
43
+ for doc in collection.find():
44
+ backup.insert_one(doc)
45
+ doc_count += 1
46
+
47
+ logger.info(f"Created backup {backup_name} with {doc_count} documents")
48
+ return backup_name