MultiCountryRAG / core /chat_manager.py
SAAHMATHWORKS
ready for hugging face space
f37bf1d
# [file name]: core/chat_manager.py
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
import asyncio
import logging
from datetime import datetime
from typing import Dict, List, Optional, Any, Callable
from langchain_core.runnables import RunnableConfig
from langchain_core.messages import BaseMessage
from langgraph.types import Command
from langgraph.errors import GraphInterrupt
from config.settings import settings
from models.state_models import MultiCountryLegalState
from utils.helpers import dict_to_message_obj
logger = logging.getLogger(__name__)
class LegalChatManager:
"""
Chat manager with full human-in-the-loop interrupt support.
Handles both synchronous (callback-based) and asynchronous (stored interrupt) modes.
"""
def __init__(self, graph, checkpointer):
self.graph = graph
self.checkpointer = checkpointer
self.active_sessions = {}
self.routing_stats = {
"benin": 0,
"madagascar": 0,
"unclear": 0,
"total_queries": 0
}
# Track pending interrupts by session
self.pending_interrupts = {}
async def chat(
self,
message: str,
session_id: str,
legal_context: Optional[Dict[str, str]] = None,
interrupt_handler: Optional[Callable[[Dict], str]] = None
) -> str:
"""
Process a chat message with session management and interrupt handling.
Args:
message: User message to process
session_id: Unique session identifier for conversation tracking
legal_context: Optional legal context (jurisdiction, user type, etc.)
interrupt_handler: Optional callback for handling interrupts synchronously.
If provided, interrupts are handled immediately within this call.
If None, interrupts are stored for later resolution via subsequent calls.
Returns:
Assistant's response text
Raises:
RuntimeError: If system is not initialized
"""
if not self.graph:
raise RuntimeError("System not initialized. Call setup_system() first.")
# Initialize or update session
self._initialize_session(session_id)
# Check if we have a pending interrupt for this session
if session_id in self.pending_interrupts:
return await self._handle_pending_interrupt(session_id, message)
# Prepare input state
input_state = self._prepare_input_state(message, session_id, legal_context)
config = RunnableConfig(
configurable={"thread_id": session_id},
recursion_limit=100
)
try:
# Track performance
start_time = datetime.now()
# 🔥 CRITICAL: Use streaming to detect interrupts in real-time
final_response = None
async for chunk in self.graph.astream(
MultiCountryLegalState(**input_state),
config,
stream_mode="updates"
):
# 🔥 Check if this chunk contains an interrupt
if "__interrupt__" in chunk:
interrupt_data = chunk["__interrupt__"]
logger.info(f"⏸️ Graph interrupted: {interrupt_data}")
# Extract interrupt info - handle tuple/Interrupt object format
# LangGraph returns interrupts as tuples containing Interrupt objects
if isinstance(interrupt_data, (list, tuple)):
interrupt_info = interrupt_data[0]
else:
interrupt_info = interrupt_data
# Handle Interrupt object (has .value attribute)
if hasattr(interrupt_info, 'value'):
interrupt_value = interrupt_info.value
elif isinstance(interrupt_info, dict):
interrupt_value = interrupt_info.get("value", interrupt_info)
else:
interrupt_value = {}
# Extract message from interrupt value
interrupt_message = interrupt_value.get("message", "") if isinstance(interrupt_value, dict) else ""
# 🔥 Two modes of operation:
# 1. Synchronous: If interrupt_handler provided, handle immediately
# 2. Asynchronous: Store interrupt and return, wait for next call
if interrupt_handler:
# SYNCHRONOUS MODE: Handle interrupt immediately
logger.info("📞 Calling synchronous interrupt handler")
moderator_response = interrupt_handler(interrupt_value)
# Resume immediately with the moderator's response
logger.info(f"🔄 Resuming graph with: {moderator_response}")
async for resume_chunk in self.graph.astream(
Command(resume=moderator_response),
config,
stream_mode="updates"
):
# Continue processing resumed chunks
for node_name, node_output in resume_chunk.items():
if node_name != "__interrupt__":
logger.debug(f"📦 Resume chunk from {node_name}")
# After resume completes, get final state
state = await self.graph.aget_state(config)
final_response = self._extract_response(state.values)
break
else:
# ASYNCHRONOUS MODE: Store interrupt and return
logger.info("💾 Storing interrupt for later resolution")
self.pending_interrupts[session_id] = {
"type": "human_approval",
"config": config,
"created_at": datetime.now(),
"interrupt_data": interrupt_info
}
return interrupt_message or self._get_default_approval_prompt()
# Process normal chunks (non-interrupt)
for node_name, node_output in chunk.items():
if node_name != "__interrupt__":
logger.debug(f"📦 Chunk from {node_name}")
# If no interrupt occurred, get final state
if final_response is None:
state = await self.graph.aget_state(config)
final_response = self._extract_response(state.values)
# Track performance
processing_time = (datetime.now() - start_time).total_seconds()
self._update_session_stats(session_id, processing_time)
self._update_routing_stats(final_response)
return final_response
except Exception as e:
logger.exception(f"Chat error for session {session_id}")
self._log_error(session_id, str(e))
return f"Erreur lors du traitement: {str(e)}"
async def _handle_pending_interrupt(self, session_id: str, message: str) -> str:
"""
Handle user response to a pending interrupt using Command(resume=...).
This is called when there's a stored interrupt waiting for resolution.
Args:
session_id: Session with pending interrupt
message: User's response (e.g., "approve" or "reject")
Returns:
Final response after resuming from interrupt
"""
interrupt_data = self.pending_interrupts.get(session_id)
if not interrupt_data:
return "Erreur: Aucune interruption en attente."
try:
logger.info(f"🔥 Resuming graph with moderator decision: {message}")
config = interrupt_data["config"]
# Use streaming to handle potential nested interrupts
final_response = None
async for chunk in self.graph.astream(
Command(resume=message),
config,
stream_mode="updates"
):
if "__interrupt__" in chunk:
# Another interrupt occurred during resume - store it
new_interrupt = chunk["__interrupt__"]
# Handle tuple/list format
if isinstance(new_interrupt, (list, tuple)):
interrupt_info = new_interrupt[0]
else:
interrupt_info = new_interrupt
# Extract value from Interrupt object
if hasattr(interrupt_info, 'value'):
interrupt_value = interrupt_info.value
elif isinstance(interrupt_info, dict):
interrupt_value = interrupt_info.get("value", interrupt_info)
else:
interrupt_value = {}
# Store the new interrupt
self.pending_interrupts[session_id] = {
"type": "human_approval",
"config": config,
"created_at": datetime.now(),
"interrupt_data": interrupt_info
}
interrupt_message = interrupt_value.get("message", "") if isinstance(interrupt_value, dict) else ""
return interrupt_message or self._get_default_approval_prompt()
# Process normal chunks
for node_name, node_output in chunk.items():
if node_name != "__interrupt__":
logger.debug(f"📦 Resume chunk from {node_name}")
# Get final state after successful resume
state = await self.graph.aget_state(config)
final_response = self._extract_response(state.values)
# Clean up the pending interrupt
del self.pending_interrupts[session_id]
# Update stats
self._update_routing_stats(final_response)
logger.info(f"✅ Graph resumed successfully for session {session_id}")
return final_response
except Exception as e:
logger.error(f"Error resuming from interrupt: {str(e)}")
# Clean up on error
if session_id in self.pending_interrupts:
del self.pending_interrupts[session_id]
return f"Erreur lors du traitement de la décision: {str(e)}"
def _get_default_approval_prompt(self) -> str:
"""Default approval prompt if interrupt message extraction fails"""
return """
🔒 **APPROBATION HUMAINE REQUISE**
Une demande d'assistance juridique nécessite votre approbation.
**Veuillez répondre avec:**
- "approve [raison]" pour approuver la demande
- "reject [raison]" pour rejeter la demande
**Exemples:**
- "approve Demande légitime de consultation"
- "reject Email invalide ou description trop vague"
**Votre décision:**
"""
def get_checkpointer_info(self) -> Dict[str, Any]:
"""Get information about the current checkpointer type"""
checkpointer_type = "unknown"
if hasattr(self.checkpointer, '__class__'):
class_name = self.checkpointer.__class__.__name__
if 'PostgresSaver' in class_name:
checkpointer_type = "postgres"
elif 'InMemorySaver' in class_name:
checkpointer_type = "memory"
return {
"type": checkpointer_type,
"persistent": checkpointer_type == "postgres",
"description": "Persistent storage" if checkpointer_type == "postgres" else "In-memory (volatile)"
}
async def get_conversation_history(self, session_id: str) -> List[BaseMessage]:
"""
Get conversation history for a session.
Args:
session_id: Session identifier
Returns:
List of message objects from conversation history
"""
if not self.graph:
return []
config = RunnableConfig(configurable={"thread_id": session_id})
try:
state = await self.graph.aget_state(config)
if not state or not state.values:
return []
s = state.values
if isinstance(s, MultiCountryLegalState):
s = s.model_dump()
elif isinstance(s, dict):
pass
else:
s = {}
raw_messages = s.get("messages", [])
return [dict_to_message_obj(m) for m in raw_messages if isinstance(m, dict)]
except Exception as e:
logger.exception(f"Error getting conversation history for session {session_id}")
return []
def get_session_stats(self, session_id: str) -> Dict[str, Any]:
"""Get statistics for a specific session"""
return self.active_sessions.get(session_id, {})
def get_global_stats(self) -> Dict[str, Any]:
"""
Get global system statistics.
Returns:
Dictionary with routing stats, active sessions, and storage info
"""
stats = {
"routing_stats": self.routing_stats,
"active_sessions": len(self.active_sessions),
"total_queries": self.routing_stats["total_queries"],
"pending_interrupts": len(self.pending_interrupts)
}
# Add checkpointer info
stats.update(self.get_checkpointer_info())
return stats
def _initialize_session(self, session_id: str):
"""Initialize or update session tracking"""
if session_id not in self.active_sessions:
self.active_sessions[session_id] = {
"created": datetime.now(),
"query_count": 0,
"total_processing_time": 0,
"average_processing_time": 0,
"detected_countries": set(),
"last_activity": datetime.now()
}
session_info = self.active_sessions[session_id]
session_info["query_count"] += 1
session_info["last_activity"] = datetime.now()
def _prepare_input_state(
self,
message: str,
session_id: str,
legal_context: Optional[Dict[str, str]]
) -> Dict[str, Any]:
"""
Prepare input state for graph processing.
Args:
message: User message
session_id: Session identifier
legal_context: Optional legal context
Returns:
Dictionary with complete input state for graph
"""
ctx = legal_context or {
"jurisdiction": "Unknown",
"user_type": "general",
"document_type": "legal",
"detected_country": "unknown"
}
if ctx.get("detected_country") is None:
ctx["detected_country"] = "unknown"
return {
"messages": [{"role": "user", "content": message, "meta": {}}],
"legal_context": ctx,
"session_id": session_id,
"router_decision": None,
"search_results": None,
"route_explanation": None,
"last_search_query": None,
"detected_articles": [],
}
def _extract_response(self, result) -> str:
"""
Extract response text from graph result.
Args:
result: Graph execution result (state or dict)
Returns:
Assistant's response text
"""
if isinstance(result, MultiCountryLegalState):
r = result.model_dump()
elif isinstance(result, dict):
r = result
else:
r = {}
msgs = r.get("messages", [])
# Find the last assistant message
for m in reversed(msgs):
if (m.get("role") or "").lower() in ("assistant", "ai"):
return m.get("content", "")
return "Désolé, je n'ai pas pu générer de réponse."
def _update_session_stats(self, session_id: str, processing_time: float):
"""Update session statistics with processing time"""
if session_id in self.active_sessions:
session_info = self.active_sessions[session_id]
session_info["total_processing_time"] += processing_time
session_info["average_processing_time"] = (
session_info["total_processing_time"] / session_info["query_count"]
)
def _update_routing_stats(self, response: str):
"""Update routing statistics based on response content"""
self.routing_stats["total_queries"] += 1
response_lower = response.lower()
if any(keyword in response_lower for keyword in ["bénin", "béninois", "béninoise"]):
self.routing_stats["benin"] += 1
elif any(keyword in response_lower for keyword in ["madagascar", "malgache", "malagasy"]):
self.routing_stats["madagascar"] += 1
else:
self.routing_stats["unclear"] += 1
def _log_error(self, session_id: str, error: str):
"""Log error for monitoring"""
logger.error(f"Session {session_id}: {error}")
def cleanup_inactive_sessions(self, max_age_hours: int = 24):
"""
Clean up sessions that have been inactive for too long.
Args:
max_age_hours: Maximum age in hours before cleanup
"""
cutoff_time = datetime.now().timestamp() - (max_age_hours * 3600)
inactive_sessions = [
session_id for session_id, info in self.active_sessions.items()
if info["last_activity"].timestamp() < cutoff_time
]
# Also clean up pending interrupts for inactive sessions
for session_id in inactive_sessions:
if session_id in self.pending_interrupts:
del self.pending_interrupts[session_id]
del self.active_sessions[session_id]
logger.info(f"Cleaned up inactive session: {session_id}")
def has_pending_interrupt(self, session_id: str) -> bool:
"""
Check if there's a pending interrupt for a session.
Args:
session_id: Session identifier
Returns:
True if session has pending interrupt, False otherwise
"""
return session_id in self.pending_interrupts
def get_system_info(self) -> Dict[str, Any]:
"""
Get comprehensive system information.
Returns:
Dictionary with system status, storage info, and statistics
"""
return {
"system": {
"initialized": self.graph is not None,
"active_sessions": len(self.active_sessions),
"pending_interrupts": len(self.pending_interrupts),
"total_queries": self.routing_stats["total_queries"]
},
"storage": self.get_checkpointer_info(),
"routing": self.routing_stats,
"timestamp": datetime.now().isoformat()
}