# [file name]: core/chat_manager.py import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent)) import asyncio import logging from datetime import datetime from typing import Dict, List, Optional, Any, Callable from langchain_core.runnables import RunnableConfig from langchain_core.messages import BaseMessage from langgraph.types import Command from langgraph.errors import GraphInterrupt from config.settings import settings from models.state_models import MultiCountryLegalState from utils.helpers import dict_to_message_obj logger = logging.getLogger(__name__) class LegalChatManager: """ Chat manager with full human-in-the-loop interrupt support. Handles both synchronous (callback-based) and asynchronous (stored interrupt) modes. """ def __init__(self, graph, checkpointer): self.graph = graph self.checkpointer = checkpointer self.active_sessions = {} self.routing_stats = { "benin": 0, "madagascar": 0, "unclear": 0, "total_queries": 0 } # Track pending interrupts by session self.pending_interrupts = {} async def chat( self, message: str, session_id: str, legal_context: Optional[Dict[str, str]] = None, interrupt_handler: Optional[Callable[[Dict], str]] = None ) -> str: """ Process a chat message with session management and interrupt handling. Args: message: User message to process session_id: Unique session identifier for conversation tracking legal_context: Optional legal context (jurisdiction, user type, etc.) interrupt_handler: Optional callback for handling interrupts synchronously. If provided, interrupts are handled immediately within this call. If None, interrupts are stored for later resolution via subsequent calls. Returns: Assistant's response text Raises: RuntimeError: If system is not initialized """ if not self.graph: raise RuntimeError("System not initialized. Call setup_system() first.") # Initialize or update session self._initialize_session(session_id) # Check if we have a pending interrupt for this session if session_id in self.pending_interrupts: return await self._handle_pending_interrupt(session_id, message) # Prepare input state input_state = self._prepare_input_state(message, session_id, legal_context) config = RunnableConfig( configurable={"thread_id": session_id}, recursion_limit=100 ) try: # Track performance start_time = datetime.now() # 🔥 CRITICAL: Use streaming to detect interrupts in real-time final_response = None async for chunk in self.graph.astream( MultiCountryLegalState(**input_state), config, stream_mode="updates" ): # 🔥 Check if this chunk contains an interrupt if "__interrupt__" in chunk: interrupt_data = chunk["__interrupt__"] logger.info(f"⏸️ Graph interrupted: {interrupt_data}") # Extract interrupt info - handle tuple/Interrupt object format # LangGraph returns interrupts as tuples containing Interrupt objects if isinstance(interrupt_data, (list, tuple)): interrupt_info = interrupt_data[0] else: interrupt_info = interrupt_data # Handle Interrupt object (has .value attribute) if hasattr(interrupt_info, 'value'): interrupt_value = interrupt_info.value elif isinstance(interrupt_info, dict): interrupt_value = interrupt_info.get("value", interrupt_info) else: interrupt_value = {} # Extract message from interrupt value interrupt_message = interrupt_value.get("message", "") if isinstance(interrupt_value, dict) else "" # 🔥 Two modes of operation: # 1. Synchronous: If interrupt_handler provided, handle immediately # 2. Asynchronous: Store interrupt and return, wait for next call if interrupt_handler: # SYNCHRONOUS MODE: Handle interrupt immediately logger.info("📞 Calling synchronous interrupt handler") moderator_response = interrupt_handler(interrupt_value) # Resume immediately with the moderator's response logger.info(f"🔄 Resuming graph with: {moderator_response}") async for resume_chunk in self.graph.astream( Command(resume=moderator_response), config, stream_mode="updates" ): # Continue processing resumed chunks for node_name, node_output in resume_chunk.items(): if node_name != "__interrupt__": logger.debug(f"📦 Resume chunk from {node_name}") # After resume completes, get final state state = await self.graph.aget_state(config) final_response = self._extract_response(state.values) break else: # ASYNCHRONOUS MODE: Store interrupt and return logger.info("💾 Storing interrupt for later resolution") self.pending_interrupts[session_id] = { "type": "human_approval", "config": config, "created_at": datetime.now(), "interrupt_data": interrupt_info } return interrupt_message or self._get_default_approval_prompt() # Process normal chunks (non-interrupt) for node_name, node_output in chunk.items(): if node_name != "__interrupt__": logger.debug(f"📦 Chunk from {node_name}") # If no interrupt occurred, get final state if final_response is None: state = await self.graph.aget_state(config) final_response = self._extract_response(state.values) # Track performance processing_time = (datetime.now() - start_time).total_seconds() self._update_session_stats(session_id, processing_time) self._update_routing_stats(final_response) return final_response except Exception as e: logger.exception(f"Chat error for session {session_id}") self._log_error(session_id, str(e)) return f"Erreur lors du traitement: {str(e)}" async def _handle_pending_interrupt(self, session_id: str, message: str) -> str: """ Handle user response to a pending interrupt using Command(resume=...). This is called when there's a stored interrupt waiting for resolution. Args: session_id: Session with pending interrupt message: User's response (e.g., "approve" or "reject") Returns: Final response after resuming from interrupt """ interrupt_data = self.pending_interrupts.get(session_id) if not interrupt_data: return "Erreur: Aucune interruption en attente." try: logger.info(f"🔥 Resuming graph with moderator decision: {message}") config = interrupt_data["config"] # Use streaming to handle potential nested interrupts final_response = None async for chunk in self.graph.astream( Command(resume=message), config, stream_mode="updates" ): if "__interrupt__" in chunk: # Another interrupt occurred during resume - store it new_interrupt = chunk["__interrupt__"] # Handle tuple/list format if isinstance(new_interrupt, (list, tuple)): interrupt_info = new_interrupt[0] else: interrupt_info = new_interrupt # Extract value from Interrupt object if hasattr(interrupt_info, 'value'): interrupt_value = interrupt_info.value elif isinstance(interrupt_info, dict): interrupt_value = interrupt_info.get("value", interrupt_info) else: interrupt_value = {} # Store the new interrupt self.pending_interrupts[session_id] = { "type": "human_approval", "config": config, "created_at": datetime.now(), "interrupt_data": interrupt_info } interrupt_message = interrupt_value.get("message", "") if isinstance(interrupt_value, dict) else "" return interrupt_message or self._get_default_approval_prompt() # Process normal chunks for node_name, node_output in chunk.items(): if node_name != "__interrupt__": logger.debug(f"📦 Resume chunk from {node_name}") # Get final state after successful resume state = await self.graph.aget_state(config) final_response = self._extract_response(state.values) # Clean up the pending interrupt del self.pending_interrupts[session_id] # Update stats self._update_routing_stats(final_response) logger.info(f"✅ Graph resumed successfully for session {session_id}") return final_response except Exception as e: logger.error(f"Error resuming from interrupt: {str(e)}") # Clean up on error if session_id in self.pending_interrupts: del self.pending_interrupts[session_id] return f"Erreur lors du traitement de la décision: {str(e)}" def _get_default_approval_prompt(self) -> str: """Default approval prompt if interrupt message extraction fails""" return """ 🔒 **APPROBATION HUMAINE REQUISE** Une demande d'assistance juridique nécessite votre approbation. **Veuillez répondre avec:** - "approve [raison]" pour approuver la demande - "reject [raison]" pour rejeter la demande **Exemples:** - "approve Demande légitime de consultation" - "reject Email invalide ou description trop vague" **Votre décision:** """ def get_checkpointer_info(self) -> Dict[str, Any]: """Get information about the current checkpointer type""" checkpointer_type = "unknown" if hasattr(self.checkpointer, '__class__'): class_name = self.checkpointer.__class__.__name__ if 'PostgresSaver' in class_name: checkpointer_type = "postgres" elif 'InMemorySaver' in class_name: checkpointer_type = "memory" return { "type": checkpointer_type, "persistent": checkpointer_type == "postgres", "description": "Persistent storage" if checkpointer_type == "postgres" else "In-memory (volatile)" } async def get_conversation_history(self, session_id: str) -> List[BaseMessage]: """ Get conversation history for a session. Args: session_id: Session identifier Returns: List of message objects from conversation history """ if not self.graph: return [] config = RunnableConfig(configurable={"thread_id": session_id}) try: state = await self.graph.aget_state(config) if not state or not state.values: return [] s = state.values if isinstance(s, MultiCountryLegalState): s = s.model_dump() elif isinstance(s, dict): pass else: s = {} raw_messages = s.get("messages", []) return [dict_to_message_obj(m) for m in raw_messages if isinstance(m, dict)] except Exception as e: logger.exception(f"Error getting conversation history for session {session_id}") return [] def get_session_stats(self, session_id: str) -> Dict[str, Any]: """Get statistics for a specific session""" return self.active_sessions.get(session_id, {}) def get_global_stats(self) -> Dict[str, Any]: """ Get global system statistics. Returns: Dictionary with routing stats, active sessions, and storage info """ stats = { "routing_stats": self.routing_stats, "active_sessions": len(self.active_sessions), "total_queries": self.routing_stats["total_queries"], "pending_interrupts": len(self.pending_interrupts) } # Add checkpointer info stats.update(self.get_checkpointer_info()) return stats def _initialize_session(self, session_id: str): """Initialize or update session tracking""" if session_id not in self.active_sessions: self.active_sessions[session_id] = { "created": datetime.now(), "query_count": 0, "total_processing_time": 0, "average_processing_time": 0, "detected_countries": set(), "last_activity": datetime.now() } session_info = self.active_sessions[session_id] session_info["query_count"] += 1 session_info["last_activity"] = datetime.now() def _prepare_input_state( self, message: str, session_id: str, legal_context: Optional[Dict[str, str]] ) -> Dict[str, Any]: """ Prepare input state for graph processing. Args: message: User message session_id: Session identifier legal_context: Optional legal context Returns: Dictionary with complete input state for graph """ ctx = legal_context or { "jurisdiction": "Unknown", "user_type": "general", "document_type": "legal", "detected_country": "unknown" } if ctx.get("detected_country") is None: ctx["detected_country"] = "unknown" return { "messages": [{"role": "user", "content": message, "meta": {}}], "legal_context": ctx, "session_id": session_id, "router_decision": None, "search_results": None, "route_explanation": None, "last_search_query": None, "detected_articles": [], } def _extract_response(self, result) -> str: """ Extract response text from graph result. Args: result: Graph execution result (state or dict) Returns: Assistant's response text """ if isinstance(result, MultiCountryLegalState): r = result.model_dump() elif isinstance(result, dict): r = result else: r = {} msgs = r.get("messages", []) # Find the last assistant message for m in reversed(msgs): if (m.get("role") or "").lower() in ("assistant", "ai"): return m.get("content", "") return "Désolé, je n'ai pas pu générer de réponse." def _update_session_stats(self, session_id: str, processing_time: float): """Update session statistics with processing time""" if session_id in self.active_sessions: session_info = self.active_sessions[session_id] session_info["total_processing_time"] += processing_time session_info["average_processing_time"] = ( session_info["total_processing_time"] / session_info["query_count"] ) def _update_routing_stats(self, response: str): """Update routing statistics based on response content""" self.routing_stats["total_queries"] += 1 response_lower = response.lower() if any(keyword in response_lower for keyword in ["bénin", "béninois", "béninoise"]): self.routing_stats["benin"] += 1 elif any(keyword in response_lower for keyword in ["madagascar", "malgache", "malagasy"]): self.routing_stats["madagascar"] += 1 else: self.routing_stats["unclear"] += 1 def _log_error(self, session_id: str, error: str): """Log error for monitoring""" logger.error(f"Session {session_id}: {error}") def cleanup_inactive_sessions(self, max_age_hours: int = 24): """ Clean up sessions that have been inactive for too long. Args: max_age_hours: Maximum age in hours before cleanup """ cutoff_time = datetime.now().timestamp() - (max_age_hours * 3600) inactive_sessions = [ session_id for session_id, info in self.active_sessions.items() if info["last_activity"].timestamp() < cutoff_time ] # Also clean up pending interrupts for inactive sessions for session_id in inactive_sessions: if session_id in self.pending_interrupts: del self.pending_interrupts[session_id] del self.active_sessions[session_id] logger.info(f"Cleaned up inactive session: {session_id}") def has_pending_interrupt(self, session_id: str) -> bool: """ Check if there's a pending interrupt for a session. Args: session_id: Session identifier Returns: True if session has pending interrupt, False otherwise """ return session_id in self.pending_interrupts def get_system_info(self) -> Dict[str, Any]: """ Get comprehensive system information. Returns: Dictionary with system status, storage info, and statistics """ return { "system": { "initialized": self.graph is not None, "active_sessions": len(self.active_sessions), "pending_interrupts": len(self.pending_interrupts), "total_queries": self.routing_stats["total_queries"] }, "storage": self.get_checkpointer_info(), "routing": self.routing_stats, "timestamp": datetime.now().isoformat() }