File size: 14,258 Bytes
45274c2
 
 
 
d85186d
45274c2
521a291
45274c2
21c55a3
521a291
89a41e2
d85186d
45274c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
632ba5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89a41e2
521a291
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
632ba5e
45274c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
521a291
 
 
 
 
 
 
 
89a41e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45274c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24ed5c4
 
 
 
 
 
45274c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24ed5c4
 
 
 
 
45274c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a3cb67a
45274c2
 
 
 
 
 
 
 
 
24ed5c4
 
 
45274c2
 
 
 
 
 
 
d85186d
45274c2
 
d85186d
 
 
 
 
 
 
 
 
45274c2
 
 
24ed5c4
d85186d
 
 
 
 
 
24ed5c4
 
45274c2
 
 
 
 
 
 
 
 
 
 
d85186d
45274c2
 
 
d85186d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5c9fd8
 
45274c2
 
 
 
 
 
 
 
 
 
 
24ed5c4
45274c2
 
24ed5c4
45274c2
 
 
 
 
 
 
 
 
07e18ce
 
45274c2
24ed5c4
 
 
 
 
 
 
45274c2
24ed5c4
 
 
 
 
 
45274c2
 
07e18ce
 
24ed5c4
45274c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
# src/core/memory_manager.py

from src.data.connection import ActionFailed
from src.data.repositories import account as account_repo
from src.data.repositories import information as info_repo
from src.data.repositories import medical_memory as memory_repo
from src.data.repositories import patient as patient_repo
from src.data.repositories import session as session_repo
from src.models.account import Account
from src.models.patient import Patient
from src.models.session import Message, Session
from src.services import reranker, summariser
from src.services.nvidia import nvidia_chat
from src.utils.embeddings import EmbeddingClient
from src.utils.logger import logger
from src.utils.rotator import APIKeyRotator


class MemoryManager:
	"""
	A service layer that orchestrates data access and business logic for managing
	accounts, chat sessions, and long-term medical memory.
	"""
	def __init__(self, embedder: EmbeddingClient, max_sessions_per_user: int = 10):
		self.embedder = embedder
		self.max_sessions_per_user = max_sessions_per_user

	# --- Account Management Facade ---

	def create_account(
		self,
		name: str = "Anonymous",
		role: str = "Other",
		specialty: str | None = None
	) -> str | None:
		"""Creates a new user account."""
		try:
			return account_repo.create_account(name=name, role=role, specialty=specialty)
		except ActionFailed as e:
			logger().error(f"Failed to create account in MemoryManager: {e}")
			return None

	def get_account(self, user_id: str) -> Account | None:
		"""Retrieves a user account by its ID."""
		try:
			return account_repo.get_account(user_id)
		except ActionFailed as e:
			logger().error(f"Failed to get account '{user_id}' in MemoryManager: {e}")
			return None

	def get_all_accounts(self, limit: int = 50) -> list[Account]:
		"""Retrieves a list of all accounts."""
		try:
			return account_repo.get_all_accounts(limit=limit)
		except ActionFailed as e:
			logger().error(f"Failed to get all accounts in MemoryManager: {e}")
			return []

	def search_accounts(self, query: str, limit: int = 10) -> list[Account]:
		"""Searches for accounts by name."""
		try:
			return account_repo.search_accounts(query, limit=limit)
		except ActionFailed as e:
			logger().error(f"Failed to search accounts in MemoryManager: {e}")
			return []

	# --- Patient Management Facade ---

	def create_patient(self, **kwargs) -> str | None:
		"""Creates a new patient record."""
		try:
			return patient_repo.create_patient(**kwargs)
		except ActionFailed as e:
			logger().error(f"Failed to create patient in MemoryManager: {e}")
			return None

	def get_patient_by_id(self, patient_id: str) -> Patient | None:
		"""Retrieves a patient by their unique ID."""
		try:
			return patient_repo.get_patient_by_id(patient_id)
		except ActionFailed as e:
			logger().error(f"Failed to get patient '{patient_id}' in MemoryManager: {e}")
			return None

	def update_patient_profile(self, patient_id: str, updates: dict) -> int:
		"""Updates a patient's profile."""
		try:
			return patient_repo.update_patient_profile(patient_id, updates)
		except ActionFailed as e:
			logger().error(f"Failed to update patient '{patient_id}' in MemoryManager: {e}")
			return 0

	def search_patients(self, query: str, limit: int = 10) -> list[Patient]:
		"""Searches for patients by name."""
		try:
			return patient_repo.search_patients(query, limit=limit)
		except ActionFailed as e:
			logger().error(f"Failed to search patients in MemoryManager: {e}")
			return []

	# --- Session Management Facade ---

	def create_session(self, user_id: str, patient_id: str, title: str = "New Chat") -> Session | None:
		"""Creates a new chat session for a user."""
		try:
			return session_repo.create_session(user_id, patient_id, title)
		except ActionFailed as e:
			logger().error(f"Failed to create session in MemoryManager: {e}")
			return None

	def get_session(self, session_id: str) -> Session | None:
		"""Retrieves a single chat session by its ID."""
		try:
			return session_repo.get_session(session_id)
		except ActionFailed as e:
			logger().error(f"Failed to get session '{session_id}' in MemoryManager: {e}")
			return None

	def get_user_sessions(self, user_id: str) -> list[Session]:
		"""Retrieves all sessions for a specific user."""
		try:
			return session_repo.get_user_sessions(user_id, limit=self.max_sessions_per_user)
		except ActionFailed as e:
			logger().error(f"Failed to get user sessions for '{user_id}': {e}")
			return []

	def update_session_title(self, session_id: str, title: str) -> bool:
		"""Updates the title of a session."""
		try:
			return session_repo.update_session_title(session_id, title)
		except ActionFailed as e:
			logger().error(f"Failed to update title for session '{session_id}': {e}")
			return False

	def list_patient_sessions(self, patient_id: str) -> list[Session]:
		"""Retrieves all sessions for a specific patient."""
		try:
			return session_repo.list_patient_sessions(patient_id, limit=self.max_sessions_per_user)
		except ActionFailed as e:
			logger().error(f"Failed to get sessions for patient '{patient_id}': {e}")
			return []

	def delete_session(self, session_id: str) -> bool:
		"""Deletes a chat session."""
		try:
			return session_repo.delete_session(session_id)
		except ActionFailed as e:
			logger().error(f"Failed to delete session '{session_id}' in MemoryManager: {e}")
			return False

	def get_session_messages(self, session_id: str, limit: int | None = None) -> list[Message]:
		"""Gets messages from a specific chat session."""
		try:
			return session_repo.get_session_messages(session_id, limit)
		except ActionFailed as e:
			logger().error(f"Failed to get messages for session '{session_id}': {e}")
			return []

	# --- Core Business Logic ---

	async def process_medical_exchange(
		self,
		session_id: str,
		patient_id: str,
		doctor_id: str,
		question: str,
		answer: str,
		gemini_rotator: APIKeyRotator,
		nvidia_rotator: APIKeyRotator
	) -> str | None:
		"""
		Processes a medical Q&A exchange: adds messages to the session, generates
		a summary, creates an embedding, and saves it to long-term memory.
		"""
		try:
			# 1. Add messages to the current session
			session_repo.add_message(session_id, question, sent_by_user=True)
			session_repo.add_message(session_id, answer, sent_by_user=False)

			# 2. Generate a concise summary of the exchange
			summary = await self._generate_summary(
				question=question,
				answer=answer,
				gemini_rotator=gemini_rotator,
				nvidia_rotator=nvidia_rotator
			)

			# 3. Generate an embedding for the summary for semantic search
			embedding = None
			if self.embedder:
				try:
					embedding = self.embedder.embed([summary])[0]
				except Exception as e:
					logger().warning(f"Failed to generate embedding for summary: {e}")

			# 4. Save the summary and embedding to long-term medical memory
			memory_repo.create_memory(
				patient_id=patient_id,
				doctor_id=doctor_id,
				session_id=session_id,
				summary=summary,
				embedding=embedding
			)

			# 5. Update the session title if this was the first exchange
			await self._update_session_title_if_first_message(
				session_id=session_id,
				question=question,
				nvidia_rotator=nvidia_rotator
			)

			return summary
		except ActionFailed as e:
			logger().error(f"Database error processing medical exchange for session '{session_id}': {e}")
			return None
		except Exception as e:
			logger().error(f"Unexpected error processing medical exchange: {e}")
			return None

	async def get_enhanced_context(
		self,
		session_id: str,
		patient_id: str,
		question: str,
		nvidia_rotator: APIKeyRotator
	) -> str:
		"""
		Builds a rich, multi-source context string for a new question, combining
		short-term memory, long-term semantic memory, information from the knowledge base, and current conversation.
		"""
		context_parts = []

		# 1. Get recent summaries (Short-Term Memory)
		try:
			recent_memories = memory_repo.get_recent_memories(patient_id, limit=3)
			if recent_memories:
				# Use NVIDIA to reason about relevance
				relevant_stm = await self._filter_summaries_for_relevance(
					question=question,
					summaries=[mem.summary for mem in recent_memories],
					nvidia_rotator=nvidia_rotator
				)
				if relevant_stm:
					context_parts.append("Recent relevant medical context (STM):\n" + "\n".join(relevant_stm))
		except ActionFailed as e:
			logger().warning(f"Could not retrieve recent memories for enhanced context: {e}")

		# 2. Get semantically similar summaries (Long-Term Memory)
		if self.embedder and self.embedder.is_available():
			try:
				query_embedding = self.embedder.embed([question])[0]
				if query_embedding:
					ltm_results = memory_repo.search_memories_semantic(
						patient_id=patient_id,
						query_embedding=query_embedding,
						limit=2
					)
					if ltm_results:
						ltm_summaries = [result.summary for result in ltm_results]
						context_parts.append("Semantically relevant medical history (LTM):\n" + "\n".join(ltm_summaries))
			except (ActionFailed, Exception) as e:
				logger().warning(f"Failed to perform LTM semantic search: {e}")

		# 3. Consult knowledge base
		info = await self._consult_knowledge_base(
			question=question,
			nvidia_rotator=nvidia_rotator
		)
		if info:
			context_parts.append(info)

		# 4. Get current conversation context
		try:
			session = session_repo.get_session(session_id)
			if session and session.messages:
				session_context = "\n".join([
					f"{'User' if msg.sent_by_user else 'Assistant'}: {msg.content}"
					for msg in session.messages[-10:] # Get last 10 messages
				])
				context_parts.append("Current conversation:\n" + session_context)
		except ActionFailed as e:
			logger().warning(f"Could not retrieve current session context: {e}")

		return "\n\n".join(filter(None, context_parts))

	# --- Private Helper Methods ---

	async def _consult_knowledge_base(
		self,
		question: str,
		nvidia_rotator: APIKeyRotator
	) -> str:
		"""
		Embeds a question, queries the knowledge base for relevant chunks,
		reranks them, and formats them into a context string.
		"""
		if not self.embedder or not self.embedder.is_available():
			logger().warning("Embedder not available, skipping knowledge base consultation.")
			return ""

		try:
			# 1. Embed the user's question
			query_embedding = self.embedder.embed([question])[0]
			if not query_embedding:
				logger().warning("Failed to generate query embedding.")
				return ""

			# 2. Retrieve initial candidates from MongoDB
			initial_chunks = info_repo.search_chunks_semantic(
				query_embedding=query_embedding,
				limit=10 # Retrieve more candidates for the reranker to process
			)
			if not initial_chunks:
				logger().info("No relevant chunks found in the knowledge base.")
				return ""

			# 3. Rerank the results for semantic relevance
			reranked_chunks = await reranker.rerank_documents(
				query=question,
				documents=initial_chunks,
				rotator=nvidia_rotator,
				top_k=3 # Keep the top 3 most relevant results
			)
			if not reranked_chunks:
				logger().warning("Reranking failed to return any chunks.")
				return ""

			# 4. Format the final response
			context_header = "Consulted Knowledge Base for context:"
			formatted_chunks = []
			for chunk in reranked_chunks:
				source = chunk.metadata.source
				content = chunk.content.strip()
				formatted_chunks.append(f"[Source: {source}]\n{content}")

			return f"{context_header}\n\n" + "\n\n".join(formatted_chunks)

		except ActionFailed as e:
			logger().error(f"A database error occurred while consulting the knowledge base: {e}")
		except Exception as e:
			logger().error(f"An unexpected error occurred during knowledge base consultation: {e}")

		return ""

	async def _update_session_title_if_first_message(
		self,
		session_id: str,
		question: str,
		nvidia_rotator: APIKeyRotator
	) -> None:
		"""Updates the session title if it contains only the first Q&A pair."""
		try:
			session = self.get_session(session_id)
			# Check if it's the first user message and first assistant response
			if session and len(session.messages) == 2:
				title = await summariser.summarise_title_with_nvidia(text=question, rotator=nvidia_rotator, max_words=5)
				if not title:
					title = question[:80] # Fallback to first 80 chars
				self.update_session_title(session_id=session_id, title=title)
		except Exception as e:
			logger().warning(f"Failed to auto-update session title for session '{session_id}': {e}")

	async def _generate_summary(
		self,
		question: str,
		answer: str,
		gemini_rotator: APIKeyRotator,
		nvidia_rotator: APIKeyRotator
	) -> str:
		"""Generates a summary of a Q&A exchange, falling back to a basic format if AI fails."""
		try:
			summary = await summariser.summarise_qa_with_gemini(
				question=question,
				answer=answer,
				rotator=gemini_rotator
			)
			if summary: return summary

			# Fallback to NVIDIA if Gemini fails
			summary = await summariser.summarise_qa_with_nvidia(
				question=question,
				answer=answer,
				rotator=nvidia_rotator
			)
			if summary: return summary
		except Exception as e:
			logger().warning(f"Failed to generate AI summary: {e}")

		# Fallback for both exceptions and cases where services return None
		return summariser.summarise_fallback(question=question, answer=answer)

	async def _filter_summaries_for_relevance(
		self,
		question: str,
		summaries: list[str],
		nvidia_rotator: APIKeyRotator
	) -> list[str]:
		"""Uses an AI model to select only the most relevant summaries for a given question."""
		if not summaries:
			return []
		try:
			sys_prompt = "You are a medical AI assistant. Select only the most relevant recent medical context that directly relates to the new question. Return the selected items verbatim, separated by a newline. If none are relevant, return nothing."
			user_prompt = f"Question: {question}\n\nSelect relevant items from recent medical context:\n" + "\n".join(summaries)

			relevant_text = await nvidia_chat(sys_prompt, user_prompt, nvidia_rotator)
			return relevant_text.strip().split('\n') if relevant_text and relevant_text.strip() else []
		except Exception as e:
			logger().warning(f"Failed to get AI reasoning for STM relevance: {e}")
			return summaries # Fallback to returning all summaries