# src/core/memory_manager.py from src.data.connection import ActionFailed from src.data.repositories import account as account_repo from src.data.repositories import information as info_repo from src.data.repositories import medical_memory as memory_repo from src.data.repositories import patient as patient_repo from src.data.repositories import session as session_repo from src.models.account import Account from src.models.patient import Patient from src.models.session import Message, Session from src.services import reranker, summariser from src.services.nvidia import nvidia_chat from src.utils.embeddings import EmbeddingClient from src.utils.logger import logger from src.utils.rotator import APIKeyRotator class MemoryManager: """ A service layer that orchestrates data access and business logic for managing accounts, chat sessions, and long-term medical memory. """ def __init__(self, embedder: EmbeddingClient, max_sessions_per_user: int = 10): self.embedder = embedder self.max_sessions_per_user = max_sessions_per_user # --- Account Management Facade --- def create_account( self, name: str = "Anonymous", role: str = "Other", specialty: str | None = None ) -> str | None: """Creates a new user account.""" try: return account_repo.create_account(name=name, role=role, specialty=specialty) except ActionFailed as e: logger().error(f"Failed to create account in MemoryManager: {e}") return None def get_account(self, user_id: str) -> Account | None: """Retrieves a user account by its ID.""" try: return account_repo.get_account(user_id) except ActionFailed as e: logger().error(f"Failed to get account '{user_id}' in MemoryManager: {e}") return None def get_all_accounts(self, limit: int = 50) -> list[Account]: """Retrieves a list of all accounts.""" try: return account_repo.get_all_accounts(limit=limit) except ActionFailed as e: logger().error(f"Failed to get all accounts in MemoryManager: {e}") return [] def search_accounts(self, query: str, limit: int = 10) -> list[Account]: """Searches for accounts by name.""" try: return account_repo.search_accounts(query, limit=limit) except ActionFailed as e: logger().error(f"Failed to search accounts in MemoryManager: {e}") return [] # --- Patient Management Facade --- def create_patient(self, **kwargs) -> str | None: """Creates a new patient record.""" try: return patient_repo.create_patient(**kwargs) except ActionFailed as e: logger().error(f"Failed to create patient in MemoryManager: {e}") return None def get_patient_by_id(self, patient_id: str) -> Patient | None: """Retrieves a patient by their unique ID.""" try: return patient_repo.get_patient_by_id(patient_id) except ActionFailed as e: logger().error(f"Failed to get patient '{patient_id}' in MemoryManager: {e}") return None def update_patient_profile(self, patient_id: str, updates: dict) -> int: """Updates a patient's profile.""" try: return patient_repo.update_patient_profile(patient_id, updates) except ActionFailed as e: logger().error(f"Failed to update patient '{patient_id}' in MemoryManager: {e}") return 0 def search_patients(self, query: str, limit: int = 10) -> list[Patient]: """Searches for patients by name.""" try: return patient_repo.search_patients(query, limit=limit) except ActionFailed as e: logger().error(f"Failed to search patients in MemoryManager: {e}") return [] # --- Session Management Facade --- def create_session(self, user_id: str, patient_id: str, title: str = "New Chat") -> Session | None: """Creates a new chat session for a user.""" try: return session_repo.create_session(user_id, patient_id, title) except ActionFailed as e: logger().error(f"Failed to create session in MemoryManager: {e}") return None def get_session(self, session_id: str) -> Session | None: """Retrieves a single chat session by its ID.""" try: return session_repo.get_session(session_id) except ActionFailed as e: logger().error(f"Failed to get session '{session_id}' in MemoryManager: {e}") return None def get_user_sessions(self, user_id: str) -> list[Session]: """Retrieves all sessions for a specific user.""" try: return session_repo.get_user_sessions(user_id, limit=self.max_sessions_per_user) except ActionFailed as e: logger().error(f"Failed to get user sessions for '{user_id}': {e}") return [] def update_session_title(self, session_id: str, title: str) -> bool: """Updates the title of a session.""" try: return session_repo.update_session_title(session_id, title) except ActionFailed as e: logger().error(f"Failed to update title for session '{session_id}': {e}") return False def list_patient_sessions(self, patient_id: str) -> list[Session]: """Retrieves all sessions for a specific patient.""" try: return session_repo.list_patient_sessions(patient_id, limit=self.max_sessions_per_user) except ActionFailed as e: logger().error(f"Failed to get sessions for patient '{patient_id}': {e}") return [] def delete_session(self, session_id: str) -> bool: """Deletes a chat session.""" try: return session_repo.delete_session(session_id) except ActionFailed as e: logger().error(f"Failed to delete session '{session_id}' in MemoryManager: {e}") return False def get_session_messages(self, session_id: str, limit: int | None = None) -> list[Message]: """Gets messages from a specific chat session.""" try: return session_repo.get_session_messages(session_id, limit) except ActionFailed as e: logger().error(f"Failed to get messages for session '{session_id}': {e}") return [] # --- Core Business Logic --- async def process_medical_exchange( self, session_id: str, patient_id: str, doctor_id: str, question: str, answer: str, gemini_rotator: APIKeyRotator, nvidia_rotator: APIKeyRotator ) -> str | None: """ Processes a medical Q&A exchange: adds messages to the session, generates a summary, creates an embedding, and saves it to long-term memory. """ try: # 1. Add messages to the current session session_repo.add_message(session_id, question, sent_by_user=True) session_repo.add_message(session_id, answer, sent_by_user=False) # 2. Generate a concise summary of the exchange summary = await self._generate_summary( question=question, answer=answer, gemini_rotator=gemini_rotator, nvidia_rotator=nvidia_rotator ) # 3. Generate an embedding for the summary for semantic search embedding = None if self.embedder: try: embedding = self.embedder.embed([summary])[0] except Exception as e: logger().warning(f"Failed to generate embedding for summary: {e}") # 4. Save the summary and embedding to long-term medical memory memory_repo.create_memory( patient_id=patient_id, doctor_id=doctor_id, session_id=session_id, summary=summary, embedding=embedding ) # 5. Update the session title if this was the first exchange await self._update_session_title_if_first_message( session_id=session_id, question=question, nvidia_rotator=nvidia_rotator ) return summary except ActionFailed as e: logger().error(f"Database error processing medical exchange for session '{session_id}': {e}") return None except Exception as e: logger().error(f"Unexpected error processing medical exchange: {e}") return None async def get_enhanced_context( self, session_id: str, patient_id: str, question: str, nvidia_rotator: APIKeyRotator ) -> str: """ Builds a rich, multi-source context string for a new question, combining short-term memory, long-term semantic memory, information from the knowledge base, and current conversation. """ context_parts = [] # 1. Get recent summaries (Short-Term Memory) try: recent_memories = memory_repo.get_recent_memories(patient_id, limit=3) if recent_memories: # Use NVIDIA to reason about relevance relevant_stm = await self._filter_summaries_for_relevance( question=question, summaries=[mem.summary for mem in recent_memories], nvidia_rotator=nvidia_rotator ) if relevant_stm: context_parts.append("Recent relevant medical context (STM):\n" + "\n".join(relevant_stm)) except ActionFailed as e: logger().warning(f"Could not retrieve recent memories for enhanced context: {e}") # 2. Get semantically similar summaries (Long-Term Memory) if self.embedder and self.embedder.is_available(): try: query_embedding = self.embedder.embed([question])[0] if query_embedding: ltm_results = memory_repo.search_memories_semantic( patient_id=patient_id, query_embedding=query_embedding, limit=2 ) if ltm_results: ltm_summaries = [result.summary for result in ltm_results] context_parts.append("Semantically relevant medical history (LTM):\n" + "\n".join(ltm_summaries)) except (ActionFailed, Exception) as e: logger().warning(f"Failed to perform LTM semantic search: {e}") # 3. Consult knowledge base info = await self._consult_knowledge_base( question=question, nvidia_rotator=nvidia_rotator ) if info: context_parts.append(info) # 4. Get current conversation context try: session = session_repo.get_session(session_id) if session and session.messages: session_context = "\n".join([ f"{'User' if msg.sent_by_user else 'Assistant'}: {msg.content}" for msg in session.messages[-10:] # Get last 10 messages ]) context_parts.append("Current conversation:\n" + session_context) except ActionFailed as e: logger().warning(f"Could not retrieve current session context: {e}") return "\n\n".join(filter(None, context_parts)) # --- Private Helper Methods --- async def _consult_knowledge_base( self, question: str, nvidia_rotator: APIKeyRotator ) -> str: """ Embeds a question, queries the knowledge base for relevant chunks, reranks them, and formats them into a context string. """ if not self.embedder or not self.embedder.is_available(): logger().warning("Embedder not available, skipping knowledge base consultation.") return "" try: # 1. Embed the user's question query_embedding = self.embedder.embed([question])[0] if not query_embedding: logger().warning("Failed to generate query embedding.") return "" # 2. Retrieve initial candidates from MongoDB initial_chunks = info_repo.search_chunks_semantic( query_embedding=query_embedding, limit=10 # Retrieve more candidates for the reranker to process ) if not initial_chunks: logger().info("No relevant chunks found in the knowledge base.") return "" # 3. Rerank the results for semantic relevance reranked_chunks = await reranker.rerank_documents( query=question, documents=initial_chunks, rotator=nvidia_rotator, top_k=3 # Keep the top 3 most relevant results ) if not reranked_chunks: logger().warning("Reranking failed to return any chunks.") return "" # 4. Format the final response context_header = "Consulted Knowledge Base for context:" formatted_chunks = [] for chunk in reranked_chunks: source = chunk.metadata.source content = chunk.content.strip() formatted_chunks.append(f"[Source: {source}]\n{content}") return f"{context_header}\n\n" + "\n\n".join(formatted_chunks) except ActionFailed as e: logger().error(f"A database error occurred while consulting the knowledge base: {e}") except Exception as e: logger().error(f"An unexpected error occurred during knowledge base consultation: {e}") return "" async def _update_session_title_if_first_message( self, session_id: str, question: str, nvidia_rotator: APIKeyRotator ) -> None: """Updates the session title if it contains only the first Q&A pair.""" try: session = self.get_session(session_id) # Check if it's the first user message and first assistant response if session and len(session.messages) == 2: title = await summariser.summarise_title_with_nvidia(text=question, rotator=nvidia_rotator, max_words=5) if not title: title = question[:80] # Fallback to first 80 chars self.update_session_title(session_id=session_id, title=title) except Exception as e: logger().warning(f"Failed to auto-update session title for session '{session_id}': {e}") async def _generate_summary( self, question: str, answer: str, gemini_rotator: APIKeyRotator, nvidia_rotator: APIKeyRotator ) -> str: """Generates a summary of a Q&A exchange, falling back to a basic format if AI fails.""" try: summary = await summariser.summarise_qa_with_gemini( question=question, answer=answer, rotator=gemini_rotator ) if summary: return summary # Fallback to NVIDIA if Gemini fails summary = await summariser.summarise_qa_with_nvidia( question=question, answer=answer, rotator=nvidia_rotator ) if summary: return summary except Exception as e: logger().warning(f"Failed to generate AI summary: {e}") # Fallback for both exceptions and cases where services return None return summariser.summarise_fallback(question=question, answer=answer) async def _filter_summaries_for_relevance( self, question: str, summaries: list[str], nvidia_rotator: APIKeyRotator ) -> list[str]: """Uses an AI model to select only the most relevant summaries for a given question.""" if not summaries: return [] try: sys_prompt = "You are a medical AI assistant. Select only the most relevant recent medical context that directly relates to the new question. Return the selected items verbatim, separated by a newline. If none are relevant, return nothing." user_prompt = f"Question: {question}\n\nSelect relevant items from recent medical context:\n" + "\n".join(summaries) relevant_text = await nvidia_chat(sys_prompt, user_prompt, nvidia_rotator) return relevant_text.strip().split('\n') if relevant_text and relevant_text.strip() else [] except Exception as e: logger().warning(f"Failed to get AI reasoning for STM relevance: {e}") return summaries # Fallback to returning all summaries