Spaces:
Sleeping
Sleeping
| # [file name]: core/chat_manager.py | |
| import sys | |
| from pathlib import Path | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| import asyncio | |
| import logging | |
| from datetime import datetime | |
| from typing import Dict, List, Optional, Any, Callable | |
| from langchain_core.runnables import RunnableConfig | |
| from langchain_core.messages import BaseMessage | |
| from langgraph.types import Command | |
| from langgraph.errors import GraphInterrupt | |
| from config.settings import settings | |
| from models.state_models import MultiCountryLegalState | |
| from utils.helpers import dict_to_message_obj | |
| logger = logging.getLogger(__name__) | |
| class LegalChatManager: | |
| """ | |
| Chat manager with full human-in-the-loop interrupt support. | |
| Handles both synchronous (callback-based) and asynchronous (stored interrupt) modes. | |
| """ | |
| def __init__(self, graph, checkpointer): | |
| self.graph = graph | |
| self.checkpointer = checkpointer | |
| self.active_sessions = {} | |
| self.routing_stats = { | |
| "benin": 0, | |
| "madagascar": 0, | |
| "unclear": 0, | |
| "total_queries": 0 | |
| } | |
| # Track pending interrupts by session | |
| self.pending_interrupts = {} | |
| async def chat( | |
| self, | |
| message: str, | |
| session_id: str, | |
| legal_context: Optional[Dict[str, str]] = None, | |
| interrupt_handler: Optional[Callable[[Dict], str]] = None | |
| ) -> str: | |
| """ | |
| Process a chat message with session management and interrupt handling. | |
| Args: | |
| message: User message to process | |
| session_id: Unique session identifier for conversation tracking | |
| legal_context: Optional legal context (jurisdiction, user type, etc.) | |
| interrupt_handler: Optional callback for handling interrupts synchronously. | |
| If provided, interrupts are handled immediately within this call. | |
| If None, interrupts are stored for later resolution via subsequent calls. | |
| Returns: | |
| Assistant's response text | |
| Raises: | |
| RuntimeError: If system is not initialized | |
| """ | |
| if not self.graph: | |
| raise RuntimeError("System not initialized. Call setup_system() first.") | |
| # Initialize or update session | |
| self._initialize_session(session_id) | |
| # Check if we have a pending interrupt for this session | |
| if session_id in self.pending_interrupts: | |
| return await self._handle_pending_interrupt(session_id, message) | |
| # Prepare input state | |
| input_state = self._prepare_input_state(message, session_id, legal_context) | |
| config = RunnableConfig( | |
| configurable={"thread_id": session_id}, | |
| recursion_limit=100 | |
| ) | |
| try: | |
| # Track performance | |
| start_time = datetime.now() | |
| # 🔥 CRITICAL: Use streaming to detect interrupts in real-time | |
| final_response = None | |
| async for chunk in self.graph.astream( | |
| MultiCountryLegalState(**input_state), | |
| config, | |
| stream_mode="updates" | |
| ): | |
| # 🔥 Check if this chunk contains an interrupt | |
| if "__interrupt__" in chunk: | |
| interrupt_data = chunk["__interrupt__"] | |
| logger.info(f"⏸️ Graph interrupted: {interrupt_data}") | |
| # Extract interrupt info - handle tuple/Interrupt object format | |
| # LangGraph returns interrupts as tuples containing Interrupt objects | |
| if isinstance(interrupt_data, (list, tuple)): | |
| interrupt_info = interrupt_data[0] | |
| else: | |
| interrupt_info = interrupt_data | |
| # Handle Interrupt object (has .value attribute) | |
| if hasattr(interrupt_info, 'value'): | |
| interrupt_value = interrupt_info.value | |
| elif isinstance(interrupt_info, dict): | |
| interrupt_value = interrupt_info.get("value", interrupt_info) | |
| else: | |
| interrupt_value = {} | |
| # Extract message from interrupt value | |
| interrupt_message = interrupt_value.get("message", "") if isinstance(interrupt_value, dict) else "" | |
| # 🔥 Two modes of operation: | |
| # 1. Synchronous: If interrupt_handler provided, handle immediately | |
| # 2. Asynchronous: Store interrupt and return, wait for next call | |
| if interrupt_handler: | |
| # SYNCHRONOUS MODE: Handle interrupt immediately | |
| logger.info("📞 Calling synchronous interrupt handler") | |
| moderator_response = interrupt_handler(interrupt_value) | |
| # Resume immediately with the moderator's response | |
| logger.info(f"🔄 Resuming graph with: {moderator_response}") | |
| async for resume_chunk in self.graph.astream( | |
| Command(resume=moderator_response), | |
| config, | |
| stream_mode="updates" | |
| ): | |
| # Continue processing resumed chunks | |
| for node_name, node_output in resume_chunk.items(): | |
| if node_name != "__interrupt__": | |
| logger.debug(f"📦 Resume chunk from {node_name}") | |
| # After resume completes, get final state | |
| state = await self.graph.aget_state(config) | |
| final_response = self._extract_response(state.values) | |
| break | |
| else: | |
| # ASYNCHRONOUS MODE: Store interrupt and return | |
| logger.info("💾 Storing interrupt for later resolution") | |
| self.pending_interrupts[session_id] = { | |
| "type": "human_approval", | |
| "config": config, | |
| "created_at": datetime.now(), | |
| "interrupt_data": interrupt_info | |
| } | |
| return interrupt_message or self._get_default_approval_prompt() | |
| # Process normal chunks (non-interrupt) | |
| for node_name, node_output in chunk.items(): | |
| if node_name != "__interrupt__": | |
| logger.debug(f"📦 Chunk from {node_name}") | |
| # If no interrupt occurred, get final state | |
| if final_response is None: | |
| state = await self.graph.aget_state(config) | |
| final_response = self._extract_response(state.values) | |
| # Track performance | |
| processing_time = (datetime.now() - start_time).total_seconds() | |
| self._update_session_stats(session_id, processing_time) | |
| self._update_routing_stats(final_response) | |
| return final_response | |
| except Exception as e: | |
| logger.exception(f"Chat error for session {session_id}") | |
| self._log_error(session_id, str(e)) | |
| return f"Erreur lors du traitement: {str(e)}" | |
| async def _handle_pending_interrupt(self, session_id: str, message: str) -> str: | |
| """ | |
| Handle user response to a pending interrupt using Command(resume=...). | |
| This is called when there's a stored interrupt waiting for resolution. | |
| Args: | |
| session_id: Session with pending interrupt | |
| message: User's response (e.g., "approve" or "reject") | |
| Returns: | |
| Final response after resuming from interrupt | |
| """ | |
| interrupt_data = self.pending_interrupts.get(session_id) | |
| if not interrupt_data: | |
| return "Erreur: Aucune interruption en attente." | |
| try: | |
| logger.info(f"🔥 Resuming graph with moderator decision: {message}") | |
| config = interrupt_data["config"] | |
| # Use streaming to handle potential nested interrupts | |
| final_response = None | |
| async for chunk in self.graph.astream( | |
| Command(resume=message), | |
| config, | |
| stream_mode="updates" | |
| ): | |
| if "__interrupt__" in chunk: | |
| # Another interrupt occurred during resume - store it | |
| new_interrupt = chunk["__interrupt__"] | |
| # Handle tuple/list format | |
| if isinstance(new_interrupt, (list, tuple)): | |
| interrupt_info = new_interrupt[0] | |
| else: | |
| interrupt_info = new_interrupt | |
| # Extract value from Interrupt object | |
| if hasattr(interrupt_info, 'value'): | |
| interrupt_value = interrupt_info.value | |
| elif isinstance(interrupt_info, dict): | |
| interrupt_value = interrupt_info.get("value", interrupt_info) | |
| else: | |
| interrupt_value = {} | |
| # Store the new interrupt | |
| self.pending_interrupts[session_id] = { | |
| "type": "human_approval", | |
| "config": config, | |
| "created_at": datetime.now(), | |
| "interrupt_data": interrupt_info | |
| } | |
| interrupt_message = interrupt_value.get("message", "") if isinstance(interrupt_value, dict) else "" | |
| return interrupt_message or self._get_default_approval_prompt() | |
| # Process normal chunks | |
| for node_name, node_output in chunk.items(): | |
| if node_name != "__interrupt__": | |
| logger.debug(f"📦 Resume chunk from {node_name}") | |
| # Get final state after successful resume | |
| state = await self.graph.aget_state(config) | |
| final_response = self._extract_response(state.values) | |
| # Clean up the pending interrupt | |
| del self.pending_interrupts[session_id] | |
| # Update stats | |
| self._update_routing_stats(final_response) | |
| logger.info(f"✅ Graph resumed successfully for session {session_id}") | |
| return final_response | |
| except Exception as e: | |
| logger.error(f"Error resuming from interrupt: {str(e)}") | |
| # Clean up on error | |
| if session_id in self.pending_interrupts: | |
| del self.pending_interrupts[session_id] | |
| return f"Erreur lors du traitement de la décision: {str(e)}" | |
| def _get_default_approval_prompt(self) -> str: | |
| """Default approval prompt if interrupt message extraction fails""" | |
| return """ | |
| 🔒 **APPROBATION HUMAINE REQUISE** | |
| Une demande d'assistance juridique nécessite votre approbation. | |
| **Veuillez répondre avec:** | |
| - "approve [raison]" pour approuver la demande | |
| - "reject [raison]" pour rejeter la demande | |
| **Exemples:** | |
| - "approve Demande légitime de consultation" | |
| - "reject Email invalide ou description trop vague" | |
| **Votre décision:** | |
| """ | |
| def get_checkpointer_info(self) -> Dict[str, Any]: | |
| """Get information about the current checkpointer type""" | |
| checkpointer_type = "unknown" | |
| if hasattr(self.checkpointer, '__class__'): | |
| class_name = self.checkpointer.__class__.__name__ | |
| if 'PostgresSaver' in class_name: | |
| checkpointer_type = "postgres" | |
| elif 'InMemorySaver' in class_name: | |
| checkpointer_type = "memory" | |
| return { | |
| "type": checkpointer_type, | |
| "persistent": checkpointer_type == "postgres", | |
| "description": "Persistent storage" if checkpointer_type == "postgres" else "In-memory (volatile)" | |
| } | |
| async def get_conversation_history(self, session_id: str) -> List[BaseMessage]: | |
| """ | |
| Get conversation history for a session. | |
| Args: | |
| session_id: Session identifier | |
| Returns: | |
| List of message objects from conversation history | |
| """ | |
| if not self.graph: | |
| return [] | |
| config = RunnableConfig(configurable={"thread_id": session_id}) | |
| try: | |
| state = await self.graph.aget_state(config) | |
| if not state or not state.values: | |
| return [] | |
| s = state.values | |
| if isinstance(s, MultiCountryLegalState): | |
| s = s.model_dump() | |
| elif isinstance(s, dict): | |
| pass | |
| else: | |
| s = {} | |
| raw_messages = s.get("messages", []) | |
| return [dict_to_message_obj(m) for m in raw_messages if isinstance(m, dict)] | |
| except Exception as e: | |
| logger.exception(f"Error getting conversation history for session {session_id}") | |
| return [] | |
| def get_session_stats(self, session_id: str) -> Dict[str, Any]: | |
| """Get statistics for a specific session""" | |
| return self.active_sessions.get(session_id, {}) | |
| def get_global_stats(self) -> Dict[str, Any]: | |
| """ | |
| Get global system statistics. | |
| Returns: | |
| Dictionary with routing stats, active sessions, and storage info | |
| """ | |
| stats = { | |
| "routing_stats": self.routing_stats, | |
| "active_sessions": len(self.active_sessions), | |
| "total_queries": self.routing_stats["total_queries"], | |
| "pending_interrupts": len(self.pending_interrupts) | |
| } | |
| # Add checkpointer info | |
| stats.update(self.get_checkpointer_info()) | |
| return stats | |
| def _initialize_session(self, session_id: str): | |
| """Initialize or update session tracking""" | |
| if session_id not in self.active_sessions: | |
| self.active_sessions[session_id] = { | |
| "created": datetime.now(), | |
| "query_count": 0, | |
| "total_processing_time": 0, | |
| "average_processing_time": 0, | |
| "detected_countries": set(), | |
| "last_activity": datetime.now() | |
| } | |
| session_info = self.active_sessions[session_id] | |
| session_info["query_count"] += 1 | |
| session_info["last_activity"] = datetime.now() | |
| def _prepare_input_state( | |
| self, | |
| message: str, | |
| session_id: str, | |
| legal_context: Optional[Dict[str, str]] | |
| ) -> Dict[str, Any]: | |
| """ | |
| Prepare input state for graph processing. | |
| Args: | |
| message: User message | |
| session_id: Session identifier | |
| legal_context: Optional legal context | |
| Returns: | |
| Dictionary with complete input state for graph | |
| """ | |
| ctx = legal_context or { | |
| "jurisdiction": "Unknown", | |
| "user_type": "general", | |
| "document_type": "legal", | |
| "detected_country": "unknown" | |
| } | |
| if ctx.get("detected_country") is None: | |
| ctx["detected_country"] = "unknown" | |
| return { | |
| "messages": [{"role": "user", "content": message, "meta": {}}], | |
| "legal_context": ctx, | |
| "session_id": session_id, | |
| "router_decision": None, | |
| "search_results": None, | |
| "route_explanation": None, | |
| "last_search_query": None, | |
| "detected_articles": [], | |
| } | |
| def _extract_response(self, result) -> str: | |
| """ | |
| Extract response text from graph result. | |
| Args: | |
| result: Graph execution result (state or dict) | |
| Returns: | |
| Assistant's response text | |
| """ | |
| if isinstance(result, MultiCountryLegalState): | |
| r = result.model_dump() | |
| elif isinstance(result, dict): | |
| r = result | |
| else: | |
| r = {} | |
| msgs = r.get("messages", []) | |
| # Find the last assistant message | |
| for m in reversed(msgs): | |
| if (m.get("role") or "").lower() in ("assistant", "ai"): | |
| return m.get("content", "") | |
| return "Désolé, je n'ai pas pu générer de réponse." | |
| def _update_session_stats(self, session_id: str, processing_time: float): | |
| """Update session statistics with processing time""" | |
| if session_id in self.active_sessions: | |
| session_info = self.active_sessions[session_id] | |
| session_info["total_processing_time"] += processing_time | |
| session_info["average_processing_time"] = ( | |
| session_info["total_processing_time"] / session_info["query_count"] | |
| ) | |
| def _update_routing_stats(self, response: str): | |
| """Update routing statistics based on response content""" | |
| self.routing_stats["total_queries"] += 1 | |
| response_lower = response.lower() | |
| if any(keyword in response_lower for keyword in ["bénin", "béninois", "béninoise"]): | |
| self.routing_stats["benin"] += 1 | |
| elif any(keyword in response_lower for keyword in ["madagascar", "malgache", "malagasy"]): | |
| self.routing_stats["madagascar"] += 1 | |
| else: | |
| self.routing_stats["unclear"] += 1 | |
| def _log_error(self, session_id: str, error: str): | |
| """Log error for monitoring""" | |
| logger.error(f"Session {session_id}: {error}") | |
| def cleanup_inactive_sessions(self, max_age_hours: int = 24): | |
| """ | |
| Clean up sessions that have been inactive for too long. | |
| Args: | |
| max_age_hours: Maximum age in hours before cleanup | |
| """ | |
| cutoff_time = datetime.now().timestamp() - (max_age_hours * 3600) | |
| inactive_sessions = [ | |
| session_id for session_id, info in self.active_sessions.items() | |
| if info["last_activity"].timestamp() < cutoff_time | |
| ] | |
| # Also clean up pending interrupts for inactive sessions | |
| for session_id in inactive_sessions: | |
| if session_id in self.pending_interrupts: | |
| del self.pending_interrupts[session_id] | |
| del self.active_sessions[session_id] | |
| logger.info(f"Cleaned up inactive session: {session_id}") | |
| def has_pending_interrupt(self, session_id: str) -> bool: | |
| """ | |
| Check if there's a pending interrupt for a session. | |
| Args: | |
| session_id: Session identifier | |
| Returns: | |
| True if session has pending interrupt, False otherwise | |
| """ | |
| return session_id in self.pending_interrupts | |
| def get_system_info(self) -> Dict[str, Any]: | |
| """ | |
| Get comprehensive system information. | |
| Returns: | |
| Dictionary with system status, storage info, and statistics | |
| """ | |
| return { | |
| "system": { | |
| "initialized": self.graph is not None, | |
| "active_sessions": len(self.active_sessions), | |
| "pending_interrupts": len(self.pending_interrupts), | |
| "total_queries": self.routing_stats["total_queries"] | |
| }, | |
| "storage": self.get_checkpointer_info(), | |
| "routing": self.routing_stats, | |
| "timestamp": datetime.now().isoformat() | |
| } |