LiamKhoaLe commited on
Commit
4159427
·
1 Parent(s): 03e916b

Update the app connection to MongoDB user storage + chat history LTM/STM. Distinguish patient/doctor roles on chat session

Browse files
src/api/routes/chat.py CHANGED
@@ -11,6 +11,8 @@ from src.services.medical_response import generate_medical_response
11
  from src.services.summariser import summarise_title_with_nvidia
12
  from src.utils.logger import get_logger
13
 
 
 
14
  logger = get_logger("CHAT_ROUTES", __name__)
15
  router = APIRouter()
16
 
@@ -19,14 +21,16 @@ async def chat_endpoint(
19
  request: ChatRequest,
20
  state: MedicalState = Depends(get_state)
21
  ):
22
- """Handle chat messages and generate medical responses"""
 
 
23
  start_time = time.time()
24
 
25
  try:
26
  logger.info(f"Chat request from user {request.user_id} in session {request.session_id}")
27
  logger.info(f"Message: {request.message[:100]}...") # Log first 100 chars of message
28
 
29
- # Get or create user profile
30
  user_profile = state.memory_system.get_user(request.user_id)
31
  if not user_profile:
32
  state.memory_system.create_user(request.user_id, request.user_role or "Anonymous")
@@ -37,7 +41,7 @@ async def chat_endpoint(
37
  else:
38
  logger.warning("Failued to retrieve user")
39
 
40
- # Get or create session
41
  session = state.memory_system.get_session(request.session_id)
42
  if not session:
43
  session_id = state.memory_system.create_session(request.user_id, request.title or "New Chat")
@@ -45,11 +49,15 @@ async def chat_endpoint(
45
  session = state.memory_system.get_session(session_id)
46
  logger.info(f"Created new session: {session_id}")
47
 
48
- # Get medical context from memory
 
 
 
49
  medical_context = state.history_manager.get_conversation_context(
50
  request.user_id,
51
  request.session_id,
52
- request.message
 
53
  )
54
 
55
  # Generate response using Gemini AI
@@ -70,7 +78,10 @@ async def chat_endpoint(
70
  request.message,
71
  response,
72
  state.gemini_rotator,
73
- state.nvidia_rotator
 
 
 
74
  )
75
  except Exception as e:
76
  logger.warning(f"Failed to process medical exchange: {e}")
 
11
  from src.services.summariser import summarise_title_with_nvidia
12
  from src.utils.logger import get_logger
13
 
14
+ from src.data.mongodb import ensure_session
15
+
16
  logger = get_logger("CHAT_ROUTES", __name__)
17
  router = APIRouter()
18
 
 
21
  request: ChatRequest,
22
  state: MedicalState = Depends(get_state)
23
  ):
24
+ """
25
+ Process a chat message, generate response, and persist short-term cache + long-term Mongo.
26
+ """
27
  start_time = time.time()
28
 
29
  try:
30
  logger.info(f"Chat request from user {request.user_id} in session {request.session_id}")
31
  logger.info(f"Message: {request.message[:100]}...") # Log first 100 chars of message
32
 
33
+ # Get or create user profile (doctor as current user profile)
34
  user_profile = state.memory_system.get_user(request.user_id)
35
  if not user_profile:
36
  state.memory_system.create_user(request.user_id, request.user_role or "Anonymous")
 
41
  else:
42
  logger.warning("Failued to retrieve user")
43
 
44
+ # Get or create session (cache)
45
  session = state.memory_system.get_session(request.session_id)
46
  if not session:
47
  session_id = state.memory_system.create_session(request.user_id, request.title or "New Chat")
 
49
  session = state.memory_system.get_session(session_id)
50
  logger.info(f"Created new session: {session_id}")
51
 
52
+ # Ensure session exists in Mongo with patient/doctor context
53
+ ensure_session(session_id=request.session_id, patient_id=request.patient_id, doctor_id=request.doctor_id, title=request.title or "New Chat", last_activity=datetime.now(timezone.utc))
54
+
55
+ # Get medical context from memory (short-term) + Mongo long-term
56
  medical_context = state.history_manager.get_conversation_context(
57
  request.user_id,
58
  request.session_id,
59
+ request.message,
60
+ patient_id=request.patient_id
61
  )
62
 
63
  # Generate response using Gemini AI
 
78
  request.message,
79
  response,
80
  state.gemini_rotator,
81
+ state.nvidia_rotator,
82
+ patient_id=request.patient_id,
83
+ doctor_id=request.doctor_id,
84
+ session_title=request.title or "New Chat"
85
  )
86
  except Exception as e:
87
  logger.warning(f"Failed to process medical exchange: {e}")
src/api/routes/session.py CHANGED
@@ -6,6 +6,7 @@ from datetime import datetime
6
  from src.core.state import MedicalState, get_state
7
  from src.models.chat import SessionRequest
8
  from src.utils.logger import get_logger
 
9
 
10
  logger = get_logger("SESSION_ROUTES", __name__)
11
  router = APIRouter()
@@ -15,9 +16,11 @@ async def create_chat_session(
15
  request: SessionRequest,
16
  state: MedicalState = Depends(get_state)
17
  ):
18
- """Create a new chat session"""
19
  try:
20
  session_id = state.memory_system.create_session(request.user_id, request.title or "New Chat")
 
 
21
  return {"session_id": session_id, "message": "Session created successfully"}
22
  except Exception as e:
23
  logger.error(f"Error creating session: {e}")
@@ -28,7 +31,7 @@ async def get_chat_session(
28
  session_id: str,
29
  state: MedicalState = Depends(get_state)
30
  ):
31
- """Get chat session details and messages"""
32
  try:
33
  session = state.memory_system.get_session(session_id)
34
  if not session:
@@ -52,6 +55,30 @@ async def get_chat_session(
52
  logger.error(f"Error getting session: {e}")
53
  raise HTTPException(status_code=500, detail=str(e))
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  @router.delete("/sessions/{session_id}")
56
  async def delete_chat_session(
57
  session_id: str,
 
6
  from src.core.state import MedicalState, get_state
7
  from src.models.chat import SessionRequest
8
  from src.utils.logger import get_logger
9
+ from src.data.mongodb import list_patient_sessions, list_session_messages, ensure_session
10
 
11
  logger = get_logger("SESSION_ROUTES", __name__)
12
  router = APIRouter()
 
16
  request: SessionRequest,
17
  state: MedicalState = Depends(get_state)
18
  ):
19
+ """Create a new chat session (cache + Mongo)"""
20
  try:
21
  session_id = state.memory_system.create_session(request.user_id, request.title or "New Chat")
22
+ # Also ensure in Mongo with patient/doctor
23
+ ensure_session(session_id=session_id, patient_id=request.patient_id, doctor_id=request.doctor_id, title=request.title or "New Chat")
24
  return {"session_id": session_id, "message": "Session created successfully"}
25
  except Exception as e:
26
  logger.error(f"Error creating session: {e}")
 
31
  session_id: str,
32
  state: MedicalState = Depends(get_state)
33
  ):
34
+ """Get session from cache (for quick preview)"""
35
  try:
36
  session = state.memory_system.get_session(session_id)
37
  if not session:
 
55
  logger.error(f"Error getting session: {e}")
56
  raise HTTPException(status_code=500, detail=str(e))
57
 
58
+ @router.get("/patients/{patient_id}/sessions")
59
+ async def list_sessions_for_patient(patient_id: str):
60
+ """List sessions for a patient from Mongo"""
61
+ try:
62
+ return {"sessions": list_patient_sessions(patient_id)}
63
+ except Exception as e:
64
+ logger.error(f"Error listing sessions: {e}")
65
+ raise HTTPException(status_code=500, detail=str(e))
66
+
67
+ @router.get("/sessions/{session_id}/messages")
68
+ async def list_messages_for_session(session_id: str, limit: int | None = None):
69
+ """List messages for a session from Mongo"""
70
+ try:
71
+ msgs = list_session_messages(session_id, limit=limit)
72
+ # ensure JSON-friendly timestamps
73
+ for m in msgs:
74
+ if isinstance(m.get("timestamp"), datetime):
75
+ m["timestamp"] = m["timestamp"].isoformat()
76
+ m["_id"] = str(m["_id"]) if "_id" in m else None
77
+ return {"messages": msgs}
78
+ except Exception as e:
79
+ logger.error(f"Error listing messages: {e}")
80
+ raise HTTPException(status_code=500, detail=str(e))
81
+
82
  @router.delete("/sessions/{session_id}")
83
  async def delete_chat_session(
84
  session_id: str,
src/core/memory/history.py CHANGED
@@ -4,14 +4,16 @@ import json
4
  from typing import Any
5
 
6
  import numpy as np
 
7
 
8
  from src.services.nvidia import nvidia_chat
9
  from src.services.summariser import (summarise_qa_with_gemini,
10
  summarise_qa_with_nvidia)
11
  from src.utils.embeddings import EmbeddingClient
12
  from src.utils.logger import get_logger
 
13
 
14
- logger = get_logger("RAG", __name__)
15
 
16
  def _safe_json(s: str) -> Any:
17
  try:
@@ -91,9 +93,9 @@ class MedicalHistoryManager:
91
  self.memory = memory
92
  self.embedder = embedder
93
 
94
- async def process_medical_exchange(self, user_id: str, session_id: str, question: str, answer: str, gemini_rotator, nvidia_rotator=None) -> str:
95
  """
96
- Process a medical Q&A exchange and store it in memory
97
  """
98
  try:
99
  # Check if we have valid API keys
@@ -117,13 +119,21 @@ class MedicalHistoryManager:
117
  logger.warning(f"Failed to create AI summary: {e}")
118
  summary = f"q: {question}\na: {answer}"
119
 
120
- # Store in memory
121
- self.memory.add(user_id, summary)
 
122
 
123
- # Add to session history
124
  self.memory.add_message_to_session(session_id, "user", question)
125
  self.memory.add_message_to_session(session_id, "assistant", answer)
126
 
 
 
 
 
 
 
 
127
  # Update session title if it's the first message
128
  session = self.memory.get_session(session_id)
129
  if session and len(session.messages) == 2: # Just user + assistant
@@ -137,18 +147,50 @@ class MedicalHistoryManager:
137
  logger.error(f"Error processing medical exchange: {e}")
138
  # Fallback: store without summary
139
  summary = f"q: {question}\na: {answer}"
140
- self.memory.add(user_id, summary)
 
141
  self.memory.add_message_to_session(session_id, "user", question)
142
  self.memory.add_message_to_session(session_id, "assistant", answer)
143
  return summary
144
 
145
- def get_conversation_context(self, user_id: str, session_id: str, question: str) -> str:
146
  """
147
- Get relevant conversation context for a new question
148
  """
149
- return self.memory.get_medical_context(user_id, session_id, question)
 
 
150
 
151
- def get_user_medical_history(self, user_id: str, limit: int = 10) -> list[str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  """
153
  Get user's medical history (QA summaries)
154
  """
 
4
  from typing import Any
5
 
6
  import numpy as np
7
+ from datetime import datetime, timezone
8
 
9
  from src.services.nvidia import nvidia_chat
10
  from src.services.summariser import (summarise_qa_with_gemini,
11
  summarise_qa_with_nvidia)
12
  from src.utils.embeddings import EmbeddingClient
13
  from src.utils.logger import get_logger
14
+ from src.data.mongodb import save_memory_summary, save_chat_message, ensure_session, get_recent_memory_summaries
15
 
16
+ logger = get_logger("MED_HISTORY")
17
 
18
  def _safe_json(s: str) -> Any:
19
  try:
 
93
  self.memory = memory
94
  self.embedder = embedder
95
 
96
+ async def process_medical_exchange(self, user_id: str, session_id: str, question: str, answer: str, gemini_rotator, nvidia_rotator=None, *, patient_id: str | None = None, doctor_id: str | None = None, session_title: str | None = None) -> str:
97
  """
98
+ Process a medical Q&A exchange and store it in memory and MongoDB
99
  """
100
  try:
101
  # Check if we have valid API keys
 
119
  logger.warning(f"Failed to create AI summary: {e}")
120
  summary = f"q: {question}\na: {answer}"
121
 
122
+ # Short-term cache under patient_id when available
123
+ cache_key = patient_id or user_id
124
+ self.memory.add(cache_key, summary)
125
 
126
+ # Add to session history in cache
127
  self.memory.add_message_to_session(session_id, "user", question)
128
  self.memory.add_message_to_session(session_id, "assistant", answer)
129
 
130
+ # Persist to MongoDB with patient/doctor context
131
+ if patient_id and doctor_id:
132
+ ensure_session(session_id=session_id, patient_id=patient_id, doctor_id=doctor_id, title=session_title or "New Chat", last_activity=datetime.now(timezone.utc))
133
+ save_chat_message(session_id=session_id, patient_id=patient_id, doctor_id=doctor_id, role="user", content=question)
134
+ save_chat_message(session_id=session_id, patient_id=patient_id, doctor_id=doctor_id, role="assistant", content=answer)
135
+ save_memory_summary(patient_id=patient_id, doctor_id=doctor_id, summary=summary)
136
+
137
  # Update session title if it's the first message
138
  session = self.memory.get_session(session_id)
139
  if session and len(session.messages) == 2: # Just user + assistant
 
147
  logger.error(f"Error processing medical exchange: {e}")
148
  # Fallback: store without summary
149
  summary = f"q: {question}\na: {answer}"
150
+ cache_key = patient_id or user_id
151
+ self.memory.add(cache_key, summary)
152
  self.memory.add_message_to_session(session_id, "user", question)
153
  self.memory.add_message_to_session(session_id, "assistant", answer)
154
  return summary
155
 
156
+ def get_conversation_context(self, user_id: str, session_id: str, question: str, *, patient_id: str | None = None) -> str:
157
  """
158
+ Get relevant conversation context combining short-term cache (3) and long-term Mongo (20)
159
  """
160
+ # Short-term summaries
161
+ cache_key = patient_id or user_id
162
+ recent_qa = self.memory.recent(cache_key, 3)
163
 
164
+ # Long-term summaries from Mongo (exclude ones already likely in cache by time order)
165
+ long_term = []
166
+ if patient_id:
167
+ try:
168
+ long_term = get_recent_memory_summaries(patient_id, limit=20)
169
+ except Exception as e:
170
+ logger.warning(f"Failed to fetch long-term memory: {e}")
171
+
172
+ # Get current session messages for context
173
+ session = self.memory.get_session(session_id)
174
+ session_context = ""
175
+ if session:
176
+ recent_messages = session.get_messages(10)
177
+ session_context = "\n".join([f"{msg['role']}: {msg['content']}" for msg in recent_messages])
178
+
179
+ # Combine context
180
+ context_parts = []
181
+ combined = []
182
+ if long_term:
183
+ combined.extend(long_term[::-1]) # oldest to newest within limit
184
+ if recent_qa:
185
+ combined.extend(recent_qa[::-1])
186
+ if combined:
187
+ context_parts.append("Recent medical context:\n" + "\n".join(combined[-20:]))
188
+ if session_context:
189
+ context_parts.append("Current conversation:\n" + session_context)
190
+
191
+ return "\n\n".join(context_parts) if context_parts else ""
192
+
193
+ def get_user_medical_history(self, user_id: str, limit: int = 20) -> list[str]:
194
  """
195
  Get user's medical history (QA summaries)
196
  """
src/core/state.py CHANGED
@@ -19,7 +19,8 @@ class MedicalState:
19
 
20
  def initialize(self):
21
  """Initialize all core components"""
22
- self.memory_system = MemoryLRU(capacity=50, max_sessions_per_user=20)
 
23
  self.embedding_client = create_embedding_client("all-MiniLM-L6-v2", dimension=384)
24
  self.history_manager = MedicalHistoryManager(self.memory_system, self.embedding_client)
25
  self.gemini_rotator = APIKeyRotator("GEMINI_API_", max_slots=5)
@@ -31,6 +32,7 @@ class MedicalState:
31
  cls._instance = MedicalState()
32
  return cls._instance
33
 
 
34
  def get_state() -> MedicalState:
35
  """FastAPI dependency for getting application state"""
36
  return MedicalState.get_instance()
 
19
 
20
  def initialize(self):
21
  """Initialize all core components"""
22
+ # Keep only 3 short-term summaries/messages in cache
23
+ self.memory_system = MemoryLRU(capacity=3, max_sessions_per_user=20)
24
  self.embedding_client = create_embedding_client("all-MiniLM-L6-v2", dimension=384)
25
  self.history_manager = MedicalHistoryManager(self.memory_system, self.embedding_client)
26
  self.gemini_rotator = APIKeyRotator("GEMINI_API_", max_slots=5)
 
32
  cls._instance = MedicalState()
33
  return cls._instance
34
 
35
+
36
  def get_state() -> MedicalState:
37
  """FastAPI dependency for getting application state"""
38
  return MedicalState.get_instance()
src/data/medical_kb.py CHANGED
@@ -1,6 +1,7 @@
1
  # data/medical_kb.py
2
- # Medical Knowledge Base for the Medical AI Assistant
3
 
 
 
4
  MEDICAL_KB = {
5
  "symptoms": {
6
  "fever": "Fever is a temporary increase in body temperature, often due to illness. Normal body temperature is around 98.6°F (37°C).",
 
1
  # data/medical_kb.py
 
2
 
3
+ # TODO: This should be replaced with a more robust knowledge base system that can be updated by the user.
4
+ # Medical Knowledge Base for the Medical AI Assistant
5
  MEDICAL_KB = {
6
  "symptoms": {
7
  "fever": "Fever is a temporary increase in body temperature, often due to illness. Normal body temperature is around 98.6°F (37°C).",
src/data/mongodb.py CHANGED
@@ -16,6 +16,7 @@ from pymongo.database import Database
16
  from pymongo.errors import DuplicateKeyError
17
 
18
  from src.utils.logger import get_logger
 
19
 
20
  logger = get_logger("MONGO")
21
 
@@ -32,8 +33,7 @@ def get_database() -> Database:
32
  """Get database instance with connection management"""
33
  global _mongo_client
34
  if _mongo_client is None:
35
- # TODO This needs to use an environment variable when deployed
36
- CONNECTION_STRING = "mongodb://127.0.0.1:27017/"
37
  try:
38
  logger.info("Initializing MongoDB connection")
39
  _mongo_client = MongoClient(CONNECTION_STRING)
@@ -41,13 +41,14 @@ def get_database() -> Database:
41
  logger.error(f"Failed to connect to MongoDB: {str(e)}")
42
  # Pass the error down, code that calls this function should handle it
43
  raise e
44
- return _mongo_client['medicaldiagnosissystem']
 
45
 
46
  def close_connection():
47
  """Close MongoDB connection"""
48
  global _mongo_client
49
  if _mongo_client is not None:
50
- logger.info("Closing MongoDB connection")
51
  _mongo_client.close()
52
  _mongo_client = None
53
 
@@ -251,3 +252,114 @@ def backup_collection(collection_name: str) -> str:
251
 
252
  logger.info(f"Created backup {backup_name} with {doc_count} documents")
253
  return backup_name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  from pymongo.errors import DuplicateKeyError
17
 
18
  from src.utils.logger import get_logger
19
+ import os
20
 
21
  logger = get_logger("MONGO")
22
 
 
33
  """Get database instance with connection management"""
34
  global _mongo_client
35
  if _mongo_client is None:
36
+ CONNECTION_STRING = os.getenv("MONGO_URI", "mongodb://127.0.0.1:27017/")
 
37
  try:
38
  logger.info("Initializing MongoDB connection")
39
  _mongo_client = MongoClient(CONNECTION_STRING)
 
41
  logger.error(f"Failed to connect to MongoDB: {str(e)}")
42
  # Pass the error down, code that calls this function should handle it
43
  raise e
44
+ db_name = os.getenv("MONGO_DB", "medicaldiagnosissystem")
45
+ return _mongo_client[db_name]
46
 
47
  def close_connection():
48
  """Close MongoDB connection"""
49
  global _mongo_client
50
  if _mongo_client is not None:
51
+ # Close the connection and reset the client
52
  _mongo_client.close()
53
  _mongo_client = None
54
 
 
252
 
253
  logger.info(f"Created backup {backup_name} with {doc_count} documents")
254
  return backup_name
255
+
256
+ # New: Chat and Medical Memory Persistence Helpers
257
+
258
+ CHAT_MESSAGES_COLLECTION = "chat_messages"
259
+ MEDICAL_MEMORY_COLLECTION = "medical_memory"
260
+ PATIENTS_COLLECTION = "patients"
261
+
262
+
263
+ def ensure_session(
264
+ *,
265
+ session_id: str,
266
+ patient_id: str,
267
+ doctor_id: str,
268
+ title: str,
269
+ last_activity: datetime | None = None,
270
+ collection_name: str = CHAT_SESSIONS_COLLECTION
271
+ ) -> None:
272
+ collection = get_collection(collection_name)
273
+ now = datetime.now(timezone.utc)
274
+ collection.update_one(
275
+ {"session_id": session_id},
276
+ {"$set": {
277
+ "session_id": session_id,
278
+ "patient_id": patient_id,
279
+ "doctor_id": doctor_id,
280
+ "title": title,
281
+ "last_activity": (last_activity or now),
282
+ "updated_at": now
283
+ }, "$setOnInsert": {"created_at": now}},
284
+ upsert=True
285
+ )
286
+
287
+
288
+ def save_chat_message(
289
+ *,
290
+ session_id: str,
291
+ patient_id: str,
292
+ doctor_id: str,
293
+ role: str,
294
+ content: str,
295
+ timestamp: datetime | None = None,
296
+ collection_name: str = CHAT_MESSAGES_COLLECTION
297
+ ) -> ObjectId:
298
+ collection = get_collection(collection_name)
299
+ ts = timestamp or datetime.now(timezone.utc)
300
+ doc = {
301
+ "session_id": session_id,
302
+ "patient_id": patient_id,
303
+ "doctor_id": doctor_id,
304
+ "role": role,
305
+ "content": content,
306
+ "timestamp": ts,
307
+ "created_at": ts
308
+ }
309
+ result = collection.insert_one(doc)
310
+ return result.inserted_id
311
+
312
+
313
+ def list_session_messages(
314
+ session_id: str,
315
+ /,
316
+ *,
317
+ limit: int | None = None,
318
+ collection_name: str = CHAT_MESSAGES_COLLECTION
319
+ ) -> list[dict[str, Any]]:
320
+ collection = get_collection(collection_name)
321
+ cursor = collection.find({"session_id": session_id}).sort("timestamp", ASCENDING)
322
+ if limit is not None:
323
+ cursor = cursor.limit(limit)
324
+ return list(cursor)
325
+
326
+
327
+ def save_memory_summary(
328
+ *,
329
+ patient_id: str,
330
+ doctor_id: str,
331
+ summary: str,
332
+ created_at: datetime | None = None,
333
+ collection_name: str = MEDICAL_MEMORY_COLLECTION
334
+ ) -> ObjectId:
335
+ collection = get_collection(collection_name)
336
+ ts = created_at or datetime.now(timezone.utc)
337
+ result = collection.insert_one({
338
+ "patient_id": patient_id,
339
+ "doctor_id": doctor_id,
340
+ "summary": summary,
341
+ "created_at": ts
342
+ })
343
+ return result.inserted_id
344
+
345
+
346
+ def get_recent_memory_summaries(
347
+ patient_id: str,
348
+ /,
349
+ *,
350
+ limit: int = 20,
351
+ collection_name: str = MEDICAL_MEMORY_COLLECTION
352
+ ) -> list[str]:
353
+ collection = get_collection(collection_name)
354
+ docs = list(collection.find({"patient_id": patient_id}).sort("created_at", DESCENDING).limit(limit))
355
+ return [d.get("summary", "") for d in docs]
356
+
357
+
358
+ def list_patient_sessions(
359
+ patient_id: str,
360
+ /,
361
+ *,
362
+ collection_name: str = CHAT_SESSIONS_COLLECTION
363
+ ) -> list[dict[str, Any]]:
364
+ collection = get_collection(collection_name)
365
+ return list(collection.find({"patient_id": patient_id}).sort("last_activity", DESCENDING))
src/models/chat.py CHANGED
@@ -4,6 +4,8 @@ from pydantic import BaseModel
4
 
5
  class ChatRequest(BaseModel):
6
  user_id: str
 
 
7
  session_id: str
8
  message: str
9
  user_role: str | None = "Medical Professional"
@@ -18,6 +20,8 @@ class ChatResponse(BaseModel):
18
 
19
  class SessionRequest(BaseModel):
20
  user_id: str
 
 
21
  title: str | None = "New Chat"
22
 
23
  class SummariseRequest(BaseModel):
 
4
 
5
  class ChatRequest(BaseModel):
6
  user_id: str
7
+ patient_id: str
8
+ doctor_id: str
9
  session_id: str
10
  message: str
11
  user_role: str | None = "Medical Professional"
 
20
 
21
  class SessionRequest(BaseModel):
22
  user_id: str
23
+ patient_id: str
24
+ doctor_id: str
25
  title: str | None = "New Chat"
26
 
27
  class SummariseRequest(BaseModel):
static/css/styles.css CHANGED
@@ -858,3 +858,48 @@ body {
858
  color: var(--primary-color);
859
  margin-right: var(--spacing-sm);
860
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
858
  color: var(--primary-color);
859
  margin-right: var(--spacing-sm);
860
  }
861
+
862
+ .patient-section {
863
+ padding: var(--spacing-lg);
864
+ border-bottom: 1px solid var(--border-color);
865
+ }
866
+
867
+ .patient-header {
868
+ font-weight: 600;
869
+ margin-bottom: var(--spacing-sm);
870
+ color: var(--text-primary);
871
+ }
872
+
873
+ .patient-input-group {
874
+ display: flex;
875
+ gap: var(--spacing-sm);
876
+ align-items: center;
877
+ }
878
+
879
+ .patient-input {
880
+ flex: 1;
881
+ padding: 8px 10px;
882
+ border: 1px solid var(--border-color);
883
+ border-radius: 6px;
884
+ background: var(--bg-primary);
885
+ color: var(--text-primary);
886
+ }
887
+
888
+ .patient-load-btn {
889
+ padding: 8px 10px;
890
+ background: var(--primary-color);
891
+ color: #fff;
892
+ border: none;
893
+ border-radius: 6px;
894
+ cursor: pointer;
895
+ }
896
+
897
+ .patient-load-btn:hover {
898
+ background: var(--primary-hover);
899
+ }
900
+
901
+ .patient-status {
902
+ margin-top: var(--spacing-sm);
903
+ font-size: 0.8rem;
904
+ color: var(--text-secondary);
905
+ }
static/index.html CHANGED
@@ -26,7 +26,7 @@
26
  </div>
27
  <div class="user-info">
28
  <div class="user-name" id="userName">Anonymous</div>
29
- <div class="user-status">Medical Professional</div>
30
  </div>
31
  <button class="user-menu-btn" id="userMenuBtn">
32
  <i class="fas fa-ellipsis-v"></i>
@@ -34,6 +34,18 @@
34
  </div>
35
  </div>
36
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  <div class="chat-sessions" id="chatSessions">
38
  <!-- Chat sessions will be populated here -->
39
  </div>
 
26
  </div>
27
  <div class="user-info">
28
  <div class="user-name" id="userName">Anonymous</div>
29
+ <div class="user-status" id="userStatus">Medical Professional</div>
30
  </div>
31
  <button class="user-menu-btn" id="userMenuBtn">
32
  <i class="fas fa-ellipsis-v"></i>
 
34
  </div>
35
  </div>
36
 
37
+ <!-- Patient Login Section -->
38
+ <div class="patient-section">
39
+ <div class="patient-header">Patient</div>
40
+ <div class="patient-input-group">
41
+ <input type="text" id="patientIdInput" class="patient-input" placeholder="Enter 8-digit Patient ID" maxlength="8" inputmode="numeric" pattern="\\d{8}">
42
+ <button class="patient-load-btn" id="loadPatientBtn" title="Load Patient">
43
+ <i class="fas fa-user-injured"></i>
44
+ </button>
45
+ </div>
46
+ <div class="patient-status" id="patientStatus">No patient selected</div>
47
+ </div>
48
+
49
  <div class="chat-sessions" id="chatSessions">
50
  <!-- Chat sessions will be populated here -->
51
  </div>
static/js/app.js CHANGED
@@ -2,7 +2,8 @@
2
 
3
  class MedicalChatbotApp {
4
  constructor() {
5
- this.currentUser = null;
 
6
  this.currentSession = null;
7
  this.memory = new Map(); // In-memory storage for demo
8
  this.isLoading = false;
@@ -31,6 +32,20 @@ class MedicalChatbotApp {
31
  this.startNewChat();
32
  });
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  // Send button and input
35
  document.getElementById('sendBtn').addEventListener('click', () => {
36
  this.sendMessage();
@@ -267,11 +282,48 @@ I'm here to help you with medical questions, diagnosis assistance, and healthcar
267
  How can I assist you today?`;
268
  }
269
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  async sendMessage() {
271
  const input = document.getElementById('chatInput');
272
  const message = input.value.trim();
273
 
274
  if (!message || this.isLoading) return;
 
 
 
 
 
 
275
 
276
  // Clear input
277
  input.value = '';
@@ -322,6 +374,8 @@ How can I assist you today?`;
322
  },
323
  body: JSON.stringify({
324
  user_id: this.currentUser.id,
 
 
325
  session_id: this.currentSession?.id || 'default',
326
  message: message,
327
  user_role: this.currentUser.role,
@@ -344,7 +398,8 @@ How can I assist you today?`;
344
  message: error.message,
345
  stack: error.stack,
346
  user: this.currentUser,
347
- session: this.currentSession
 
348
  });
349
 
350
  // Only return mock response if it's a network error, not a server error
 
2
 
3
  class MedicalChatbotApp {
4
  constructor() {
5
+ this.currentUser = null; // doctor
6
+ this.currentPatientId = null;
7
  this.currentSession = null;
8
  this.memory = new Map(); // In-memory storage for demo
9
  this.isLoading = false;
 
32
  this.startNewChat();
33
  });
34
 
35
+ // Patient load button
36
+ const loadBtn = document.getElementById('loadPatientBtn');
37
+ if (loadBtn) {
38
+ loadBtn.addEventListener('click', () => this.loadPatient());
39
+ }
40
+ const patientInput = document.getElementById('patientIdInput');
41
+ if (patientInput) {
42
+ patientInput.addEventListener('keydown', (e) => {
43
+ if (e.key === 'Enter') {
44
+ this.loadPatient();
45
+ }
46
+ });
47
+ }
48
+
49
  // Send button and input
50
  document.getElementById('sendBtn').addEventListener('click', () => {
51
  this.sendMessage();
 
282
  How can I assist you today?`;
283
  }
284
 
285
+ async loadPatient() {
286
+ const input = document.getElementById('patientIdInput');
287
+ const status = document.getElementById('patientStatus');
288
+ const id = (input?.value || '').trim();
289
+ if (!/^\d{8}$/.test(id)) {
290
+ status.textContent = 'Invalid patient ID. Use 8 digits.';
291
+ status.style.color = 'var(--warning-color)';
292
+ return;
293
+ }
294
+ // For now we accept ID and load sessions from backend
295
+ this.currentPatientId = id;
296
+ status.textContent = `Patient: ${id}`;
297
+ status.style.color = 'var(--text-secondary)';
298
+ await this.fetchAndRenderPatientSessions();
299
+ }
300
+
301
+ async fetchAndRenderPatientSessions() {
302
+ if (!this.currentPatientId) return;
303
+ try {
304
+ const resp = await fetch(`/sessions/patients/${this.currentPatientId}/sessions`.replace('/sessions/', '/sessions/'));
305
+ if (resp.ok) {
306
+ const data = await resp.json();
307
+ // Map to sidebar session cards if needed. For now, rely on local sessions until full backend sync is added.
308
+ // Future: hydrate local UI from data.sessions
309
+ }
310
+ } catch (e) {
311
+ console.error('Failed to load patient sessions', e);
312
+ }
313
+ this.loadChatSessions();
314
+ }
315
+
316
  async sendMessage() {
317
  const input = document.getElementById('chatInput');
318
  const message = input.value.trim();
319
 
320
  if (!message || this.isLoading) return;
321
+ if (!this.currentPatientId) {
322
+ const status = document.getElementById('patientStatus');
323
+ status.textContent = 'Select a patient before chatting.';
324
+ status.style.color = 'var(--warning-color)';
325
+ return;
326
+ }
327
 
328
  // Clear input
329
  input.value = '';
 
374
  },
375
  body: JSON.stringify({
376
  user_id: this.currentUser.id,
377
+ patient_id: this.currentPatientId,
378
+ doctor_id: this.currentUser.id,
379
  session_id: this.currentSession?.id || 'default',
380
  message: message,
381
  user_role: this.currentUser.role,
 
398
  message: error.message,
399
  stack: error.stack,
400
  user: this.currentUser,
401
+ session: this.currentSession,
402
+ patientId: this.currentPatientId
403
  });
404
 
405
  // Only return mock response if it's a network error, not a server error