Spaces:
Sleeping
Sleeping
SAAHMATHWORKS
commited on
Commit
·
fbdfc24
1
Parent(s):
44d7656
Initial deployment: Legal Assistant application
Browse files- .gitignore +10 -0
- Dockerfile +30 -0
- api/main.py +310 -0
- api/requirements.txt +4 -0
- app.py +7 -0
- config/__init__.py +22 -0
- config/constants.py +107 -0
- config/settings.py +55 -0
- core/assistance/__init__.py +5 -0
- core/assistance/email_service.py +40 -0
- core/assistance/workflow_nodes.py +186 -0
- core/chat_manager.py +289 -0
- core/conversation_repair.py +88 -0
- core/email_tool.py +187 -0
- core/graph_builder.py +266 -0
- core/human_approval_node.py +222 -0
- core/nodes/__init__.py +14 -0
- core/nodes/base_node.py +79 -0
- core/nodes/helper_nodes.py +147 -0
- core/nodes/response_nodes.py +219 -0
- core/nodes/retrieval_nodes.py +83 -0
- core/nodes/routing_nodes.py +193 -0
- core/prompts/__init__.py +3 -0
- core/prompts/prompt_templates.py +94 -0
- core/retriever.py +386 -0
- core/router.py +238 -0
- core/routing/__init__.py +3 -0
- core/routing/routing_logic.py +158 -0
- core/system_initializer.py +103 -0
- database/__init__py +0 -0
- database/mongodb_client.py +153 -0
- database/postgres_checkpointer.py +97 -0
- generate_graph.py +66 -0
- interfaces/__init__.py +0 -0
- interfaces/monitoring.py +109 -0
- interfaces/web_interface.py +96 -0
- main.py +629 -0
- models/__init__py +0 -0
- models/state_models.py +112 -0
- requirements.txt +13 -0
- utils/__init__.py +0 -0
- utils/helpers.py +68 -0
- utils/logger.py +74 -0
.gitignore
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
*.log
|
| 4 |
+
*.tmp
|
| 5 |
+
.cache/
|
| 6 |
+
.DS_Store
|
| 7 |
+
*.pkl
|
| 8 |
+
*.bin
|
| 9 |
+
*.ipynb
|
| 10 |
+
.ipynb_checkpoints/
|
Dockerfile
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Install system dependencies
|
| 6 |
+
RUN apt-get update && apt-get install -y \
|
| 7 |
+
gcc \
|
| 8 |
+
curl \
|
| 9 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 10 |
+
|
| 11 |
+
# Copy requirements and install Python dependencies
|
| 12 |
+
COPY requirements.txt .
|
| 13 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 14 |
+
|
| 15 |
+
# Copy application code
|
| 16 |
+
COPY . .
|
| 17 |
+
|
| 18 |
+
# Create non-root user
|
| 19 |
+
RUN useradd -m -u 1000 user
|
| 20 |
+
USER user
|
| 21 |
+
|
| 22 |
+
# Expose port (Hugging Face uses 7860)
|
| 23 |
+
EXPOSE 7860
|
| 24 |
+
|
| 25 |
+
# Health check
|
| 26 |
+
HEALTHCHECK --interval=30s --timeout=30s --start-period=5s --retries=3 \
|
| 27 |
+
CMD curl -f http://localhost:7860/health || exit 1
|
| 28 |
+
|
| 29 |
+
# Start command
|
| 30 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
api/main.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# api/main.py
|
| 2 |
+
from typing import Optional
|
| 3 |
+
from contextlib import asynccontextmanager
|
| 4 |
+
from fastapi import FastAPI, Query, HTTPException
|
| 5 |
+
from fastapi.responses import StreamingResponse, HTMLResponse
|
| 6 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 7 |
+
from langchain_core.messages import AIMessageChunk
|
| 8 |
+
import json
|
| 9 |
+
from uuid import uuid4
|
| 10 |
+
import logging
|
| 11 |
+
import os
|
| 12 |
+
import asyncio
|
| 13 |
+
|
| 14 |
+
# Import your existing system
|
| 15 |
+
from core.system_initializer import setup_system
|
| 16 |
+
from models.state_models import MultiCountryLegalState
|
| 17 |
+
|
| 18 |
+
# Setup logging
|
| 19 |
+
logging.basicConfig(level=logging.INFO)
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
# Global variables
|
| 23 |
+
chat_manager = None
|
| 24 |
+
graph = None
|
| 25 |
+
system_initialized = False
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
async def initialize_system():
|
| 29 |
+
global chat_manager, graph, system_initialized
|
| 30 |
+
try:
|
| 31 |
+
# Check for required environment variables based on YOUR settings
|
| 32 |
+
required_vars = ['OPENAI_API_KEY', 'MONGO_URI', 'NEON_DB_URL', 'NEON_END_POINT']
|
| 33 |
+
missing_vars = [var for var in required_vars if not os.getenv(var)]
|
| 34 |
+
|
| 35 |
+
if missing_vars:
|
| 36 |
+
logger.warning(f"⚠️ Missing environment variables: {missing_vars}")
|
| 37 |
+
logger.warning("System will start but may not function properly")
|
| 38 |
+
|
| 39 |
+
system = await setup_system()
|
| 40 |
+
chat_manager = system["chat_manager"]
|
| 41 |
+
graph = system["graph"]
|
| 42 |
+
system_initialized = True
|
| 43 |
+
logger.info("✅ Legal assistant system initialized for Hugging Face")
|
| 44 |
+
except Exception as e:
|
| 45 |
+
logger.error(f"❌ Failed to initialize system: {e}")
|
| 46 |
+
system_initialized = False
|
| 47 |
+
|
| 48 |
+
@asynccontextmanager
|
| 49 |
+
async def lifespan(app: FastAPI):
|
| 50 |
+
"""Modern lifespan event handler"""
|
| 51 |
+
# Startup logic
|
| 52 |
+
logger.info("🚀 Starting Legal Assistant API...")
|
| 53 |
+
|
| 54 |
+
# Initialize system in background
|
| 55 |
+
initialization_task = asyncio.create_task(initialize_system())
|
| 56 |
+
|
| 57 |
+
yield # App runs here
|
| 58 |
+
|
| 59 |
+
# Shutdown logic
|
| 60 |
+
logger.info("🛑 Shutting down Legal Assistant API...")
|
| 61 |
+
initialization_task.cancel()
|
| 62 |
+
try:
|
| 63 |
+
await initialization_task
|
| 64 |
+
except asyncio.CancelledError:
|
| 65 |
+
pass
|
| 66 |
+
|
| 67 |
+
app = FastAPI(
|
| 68 |
+
title="Legal Assistant API",
|
| 69 |
+
version="1.0.0",
|
| 70 |
+
description="Multi-country legal RAG system for Benin and Madagascar",
|
| 71 |
+
docs_url="/docs",
|
| 72 |
+
redoc_url="/redoc",
|
| 73 |
+
lifespan=lifespan
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# Add CORS middleware
|
| 77 |
+
app.add_middleware(
|
| 78 |
+
CORSMiddleware,
|
| 79 |
+
allow_origins=["*"],
|
| 80 |
+
allow_credentials=True,
|
| 81 |
+
allow_methods=["*"],
|
| 82 |
+
allow_headers=["*"],
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
@app.get("/", response_class=HTMLResponse)
|
| 86 |
+
async def read_root():
|
| 87 |
+
"""Simple homepage for better UX"""
|
| 88 |
+
return """
|
| 89 |
+
<html>
|
| 90 |
+
<head>
|
| 91 |
+
<title>Legal Assistant API</title>
|
| 92 |
+
<style>
|
| 93 |
+
body { font-family: Arial, sans-serif; margin: 40px; }
|
| 94 |
+
.container { max-width: 800px; margin: 0 auto; }
|
| 95 |
+
.card { border: 1px solid #ddd; padding: 20px; margin: 10px 0; border-radius: 8px; }
|
| 96 |
+
.status-ready { color: green; }
|
| 97 |
+
.status-starting { color: orange; }
|
| 98 |
+
.status-error { color: red; }
|
| 99 |
+
</style>
|
| 100 |
+
</head>
|
| 101 |
+
<body>
|
| 102 |
+
<div class="container">
|
| 103 |
+
<h1>🧑⚖️ Legal Assistant API</h1>
|
| 104 |
+
<p>Multi-country legal RAG system for Benin and Madagascar</p>
|
| 105 |
+
|
| 106 |
+
<div class="card">
|
| 107 |
+
<h3>📚 Available Endpoints</h3>
|
| 108 |
+
<ul>
|
| 109 |
+
<li><a href="/docs">API Documentation</a></li>
|
| 110 |
+
<li><a href="/health">Health Check</a></li>
|
| 111 |
+
<li><strong>GET /chat</strong> - Streaming chat</li>
|
| 112 |
+
<li><strong>GET /sessions/{id}/history</strong> - Conversation history</li>
|
| 113 |
+
</ul>
|
| 114 |
+
</div>
|
| 115 |
+
|
| 116 |
+
<div class="card">
|
| 117 |
+
<h3>🔧 System Status</h3>
|
| 118 |
+
<div id="status">
|
| 119 |
+
<p>Loading system status...</p>
|
| 120 |
+
</div>
|
| 121 |
+
</div>
|
| 122 |
+
|
| 123 |
+
<script>
|
| 124 |
+
async function updateStatus() {
|
| 125 |
+
try {
|
| 126 |
+
const response = await fetch('/health');
|
| 127 |
+
const data = await response.json();
|
| 128 |
+
|
| 129 |
+
const statusEl = document.getElementById('status');
|
| 130 |
+
let statusClass = 'status-starting';
|
| 131 |
+
let statusText = '🔄 Starting...';
|
| 132 |
+
|
| 133 |
+
if (data.system_initialized) {
|
| 134 |
+
statusClass = 'status-ready';
|
| 135 |
+
statusText = '✅ System Ready';
|
| 136 |
+
} else if (data.status === 'error') {
|
| 137 |
+
statusClass = 'status-error';
|
| 138 |
+
statusText = '❌ System Error';
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
statusEl.innerHTML = `
|
| 142 |
+
<p class="${statusClass}"><strong>${statusText}</strong></p>
|
| 143 |
+
<p><strong>MongoDB:</strong> ${data.mongodb_connected ? '✅ Connected' : '❌ Disconnected'}</p>
|
| 144 |
+
<p><strong>Countries:</strong> ${data.available_countries?.join(', ') || 'Loading...'}</p>
|
| 145 |
+
<p><strong>OpenAI:</strong> ${data.openai_configured ? '✅ Configured' : '❌ Not Configured'}</p>
|
| 146 |
+
`;
|
| 147 |
+
} catch (error) {
|
| 148 |
+
document.getElementById('status').innerHTML =
|
| 149 |
+
'<p class="status-error">❌ Failed to load system status</p>';
|
| 150 |
+
}
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
updateStatus();
|
| 154 |
+
setInterval(updateStatus, 5000);
|
| 155 |
+
</script>
|
| 156 |
+
</div>
|
| 157 |
+
</body>
|
| 158 |
+
</html>
|
| 159 |
+
"""
|
| 160 |
+
|
| 161 |
+
@app.get("/health")
|
| 162 |
+
async def health_check():
|
| 163 |
+
"""Enhanced health check with your specific environment variables"""
|
| 164 |
+
return {
|
| 165 |
+
"status": "healthy" if system_initialized else "starting",
|
| 166 |
+
"system_initialized": system_initialized,
|
| 167 |
+
"service": "Legal Assistant API",
|
| 168 |
+
"available_countries": ["benin", "madagascar"] if system_initialized else [],
|
| 169 |
+
"mongodb_connected": system_initialized and bool(os.getenv("MONGO_URI")),
|
| 170 |
+
"openai_configured": bool(os.getenv("OPENAI_API_KEY")),
|
| 171 |
+
"neon_postgres_configured": bool(os.getenv("NEON_END_POINT")),
|
| 172 |
+
"missing_variables": [var for var in ['OPENAI_API_KEY', 'MONGO_URI', 'NEON_DB_URL', 'NEON_END_POINT'] if not os.getenv(var)],
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
def serialize_ai_message_chunk(chunk):
|
| 176 |
+
"""Serialize AI message chunks for streaming"""
|
| 177 |
+
if isinstance(chunk, AIMessageChunk):
|
| 178 |
+
return chunk.content
|
| 179 |
+
else:
|
| 180 |
+
raise TypeError(
|
| 181 |
+
f"Object of type {type(chunk).__name__} is not correctly formatted for serialisation"
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
async def generate_legal_chat_responses(message: str, session_id: Optional[str] = None):
|
| 185 |
+
"""Generate streaming responses for legal chat"""
|
| 186 |
+
if not system_initialized:
|
| 187 |
+
yield f"data: {json.dumps({'type': 'error', 'message': 'System is still starting up. Please try again in a moment.'})}\n\n"
|
| 188 |
+
yield f"data: {json.dumps({'type': 'end'})}\n\n"
|
| 189 |
+
return
|
| 190 |
+
|
| 191 |
+
is_new_conversation = session_id is None
|
| 192 |
+
|
| 193 |
+
if is_new_conversation:
|
| 194 |
+
session_id = f"api_{uuid4()}"
|
| 195 |
+
logger.info(f"🆕 New conversation session: {session_id}")
|
| 196 |
+
yield f"data: {json.dumps({'type': 'session', 'session_id': session_id})}\n\n"
|
| 197 |
+
else:
|
| 198 |
+
logger.info(f"🔄 Continuing session: {session_id}")
|
| 199 |
+
|
| 200 |
+
try:
|
| 201 |
+
input_state = {
|
| 202 |
+
"messages": [{"role": "user", "content": message, "meta": {}}],
|
| 203 |
+
"legal_context": {
|
| 204 |
+
"jurisdiction": "Unknown",
|
| 205 |
+
"user_type": "general",
|
| 206 |
+
"document_type": "legal",
|
| 207 |
+
"detected_country": "unknown"
|
| 208 |
+
},
|
| 209 |
+
"session_id": session_id,
|
| 210 |
+
"router_decision": None,
|
| 211 |
+
"search_results": None,
|
| 212 |
+
"route_explanation": None,
|
| 213 |
+
"last_search_query": None,
|
| 214 |
+
"detected_articles": [],
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
config = {
|
| 218 |
+
"configurable": {
|
| 219 |
+
"thread_id": session_id
|
| 220 |
+
}
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
events = graph.astream_events(
|
| 224 |
+
MultiCountryLegalState(**input_state),
|
| 225 |
+
version="v2",
|
| 226 |
+
config=config
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
current_content = ""
|
| 230 |
+
current_node = ""
|
| 231 |
+
|
| 232 |
+
async for event in events:
|
| 233 |
+
event_type = event["event"]
|
| 234 |
+
node_name = event.get("name", "")
|
| 235 |
+
|
| 236 |
+
if node_name != current_node:
|
| 237 |
+
current_node = node_name
|
| 238 |
+
yield f"data: {json.dumps({'type': 'node_transition', 'node': node_name})}\n\n"
|
| 239 |
+
|
| 240 |
+
if event_type == "on_chat_model_stream":
|
| 241 |
+
chunk_content = serialize_ai_message_chunk(event["data"]["chunk"])
|
| 242 |
+
current_content += chunk_content
|
| 243 |
+
yield f"data: {json.dumps({'type': 'content', 'content': chunk_content})}\n\n"
|
| 244 |
+
|
| 245 |
+
elif event_type == "on_chat_model_end":
|
| 246 |
+
yield f"data: {json.dumps({'type': 'content_end'})}\n\n"
|
| 247 |
+
|
| 248 |
+
elif event_type == "on_chain_start" and "retrieval" in node_name:
|
| 249 |
+
country = node_name.replace("_retrieval", "")
|
| 250 |
+
yield f"data: {json.dumps({'type': 'search_start', 'country': country})}\n\n"
|
| 251 |
+
|
| 252 |
+
elif event_type == "on_chain_end" and "retrieval" in node_name:
|
| 253 |
+
country = node_name.replace("_retrieval", "")
|
| 254 |
+
yield f"data: {json.dumps({'type': 'search_end', 'country': country})}\n\n"
|
| 255 |
+
|
| 256 |
+
elif event_type == "on_tool_end":
|
| 257 |
+
tool_name = event["name"]
|
| 258 |
+
yield f"data: {json.dumps({'type': 'tool_complete', 'tool': tool_name})}\n\n"
|
| 259 |
+
|
| 260 |
+
elif event_type == "on_graph_end":
|
| 261 |
+
yield f"data: {json.dumps({'type': 'graph_end'})}\n\n"
|
| 262 |
+
|
| 263 |
+
except Exception as e:
|
| 264 |
+
logger.error(f"Error in streaming: {e}")
|
| 265 |
+
yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n"
|
| 266 |
+
|
| 267 |
+
yield f"data: {json.dumps({'type': 'end'})}\n\n"
|
| 268 |
+
|
| 269 |
+
@app.get("/chat")
|
| 270 |
+
async def chat_stream(
|
| 271 |
+
message: str = Query(..., description="User message"),
|
| 272 |
+
session_id: Optional[str] = Query(None, description="Existing session ID")
|
| 273 |
+
):
|
| 274 |
+
"""Streaming chat endpoint with initialization check"""
|
| 275 |
+
if not system_initialized:
|
| 276 |
+
raise HTTPException(
|
| 277 |
+
status_code=503,
|
| 278 |
+
detail="System is still starting up. Please try again in a moment."
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
return StreamingResponse(
|
| 282 |
+
generate_legal_chat_responses(message, session_id),
|
| 283 |
+
media_type="text/event-stream",
|
| 284 |
+
headers={
|
| 285 |
+
"Cache-Control": "no-cache",
|
| 286 |
+
"Connection": "keep-alive",
|
| 287 |
+
}
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
@app.get("/sessions/{session_id}/history")
|
| 291 |
+
async def get_conversation_history(session_id: str):
|
| 292 |
+
"""Get conversation history for a session"""
|
| 293 |
+
if not chat_manager:
|
| 294 |
+
return {"error": "System not initialized"}
|
| 295 |
+
|
| 296 |
+
try:
|
| 297 |
+
history = await chat_manager.get_conversation_history(session_id)
|
| 298 |
+
return {
|
| 299 |
+
"session_id": session_id,
|
| 300 |
+
"history": [
|
| 301 |
+
{
|
| 302 |
+
"role": msg.role if hasattr(msg, 'role') else msg.get('role', 'unknown'),
|
| 303 |
+
"content": msg.content if hasattr(msg, 'content') else msg.get('content', ''),
|
| 304 |
+
"timestamp": getattr(msg, 'timestamp', None)
|
| 305 |
+
}
|
| 306 |
+
for msg in history
|
| 307 |
+
]
|
| 308 |
+
}
|
| 309 |
+
except Exception as e:
|
| 310 |
+
return {"error": str(e)}
|
api/requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi==0.104.1
|
| 2 |
+
uvicorn==0.24.0
|
| 3 |
+
python-multipart==0.0.6
|
| 4 |
+
pydantic==2.5.0
|
app.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app.py - Main entry point for Hugging Face Spaces
|
| 2 |
+
from api.main import app
|
| 3 |
+
|
| 4 |
+
# Hugging Face Spaces will automatically use this 'app' variable
|
| 5 |
+
if __name__ == "__main__":
|
| 6 |
+
import uvicorn
|
| 7 |
+
uvicorn.run(app, host="0.0.0.0", port=7860) # HF uses port 7860
|
config/__init__.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# config/__init__.py
|
| 2 |
+
from .settings import settings
|
| 3 |
+
from .constants import (
|
| 4 |
+
COUNTRY_PATTERNS,
|
| 5 |
+
ARTICLE_PATTERNS,
|
| 6 |
+
CATEGORY_KEYWORDS,
|
| 7 |
+
DOCUMENT_TYPE_KEYWORDS,
|
| 8 |
+
DOCUMENT_TYPE_DESCRIPTIONS,
|
| 9 |
+
LEGAL_CONTEXTS,
|
| 10 |
+
USER_TYPE_CONTEXTS
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
__all__ = [
|
| 14 |
+
'settings',
|
| 15 |
+
'COUNTRY_PATTERNS',
|
| 16 |
+
'ARTICLE_PATTERNS',
|
| 17 |
+
'CATEGORY_KEYWORDS',
|
| 18 |
+
'DOCUMENT_TYPE_KEYWORDS',
|
| 19 |
+
'DOCUMENT_TYPE_DESCRIPTIONS',
|
| 20 |
+
'LEGAL_CONTEXTS',
|
| 21 |
+
'USER_TYPE_CONTEXTS'
|
| 22 |
+
]
|
config/constants.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List
|
| 2 |
+
|
| 3 |
+
# Country patterns for routing
|
| 4 |
+
COUNTRY_PATTERNS = {
|
| 5 |
+
"benin": [
|
| 6 |
+
r"\bbénin\b", r"\bbeninois\b", r"\bbéninoise\b", r"\bbenin\b",
|
| 7 |
+
r"\bdahomey\b", r"\bporto-novo\b", r"\bcotonou\b",
|
| 8 |
+
r"\bdroit béninois\b", r"\bloi béninoise\b"
|
| 9 |
+
],
|
| 10 |
+
"madagascar": [
|
| 11 |
+
r"\bmadagascar\b", r"\bmalgache\b", r"\bmalagasy\b",
|
| 12 |
+
r"\bantananarivo\b", r"\bmadagasikara\b",
|
| 13 |
+
r"\bdroit malgache\b", r"\bloi malgache\b"
|
| 14 |
+
]
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
# Article detection patterns
|
| 18 |
+
ARTICLE_PATTERNS = [
|
| 19 |
+
r"article[s]?\s+(\d+(?:\s+(?:et|à|\-)\s+\d+)*)",
|
| 20 |
+
r"art\.?\s*(\d+(?:\s+(?:et|à|\-)\s+\d+)*)",
|
| 21 |
+
r"articles?\s+(\d+)\s*à\s*(\d+)",
|
| 22 |
+
r"art\.?\s*(\d+)\s*au\s*(\d+)",
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
# Legal domain categories
|
| 26 |
+
CATEGORY_KEYWORDS = {
|
| 27 |
+
"mariage": "Code des personnes et de la famille",
|
| 28 |
+
"divorce": "Code des personnes et de la famille",
|
| 29 |
+
"héritage": "Code des personnes et de la famille",
|
| 30 |
+
"succession": "Code des personnes et de la famille",
|
| 31 |
+
"adoption": "Code des personnes et de la famille",
|
| 32 |
+
"enfant": "Code des personnes et de la famille",
|
| 33 |
+
"pension": "Code des personnes et de la famille",
|
| 34 |
+
|
| 35 |
+
"infraction": "Droit pénal",
|
| 36 |
+
"délit": "Droit pénal",
|
| 37 |
+
"crime": "Droit pénal",
|
| 38 |
+
"peine": "Droit pénal",
|
| 39 |
+
"prison": "Droit pénal",
|
| 40 |
+
|
| 41 |
+
"entreprise": "Droit commercial",
|
| 42 |
+
"commerce": "Droit commercial",
|
| 43 |
+
"contrat": "Droit commercial",
|
| 44 |
+
"société": "Droit commercial",
|
| 45 |
+
|
| 46 |
+
"administration": "Droit administratif",
|
| 47 |
+
"fonctionnaire": "Droit administratif",
|
| 48 |
+
"service public": "Droit administratif"
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
# Document type detection keywords
|
| 52 |
+
DOCUMENT_TYPE_KEYWORDS = {
|
| 53 |
+
"case_study": [
|
| 54 |
+
"jurisprudence", "arrêt", "décision", "tribunal", "cours", "jugement",
|
| 55 |
+
"affaire", "procès", "litige", "contentieux", "précédent", "cas",
|
| 56 |
+
"cour d'appel", "cour suprême", "conseil d'état"
|
| 57 |
+
],
|
| 58 |
+
"articles": [
|
| 59 |
+
"article", "loi", "code", "décret", "texte", "disposition",
|
| 60 |
+
"règlement", "ordonnance", "prescription", "norme"
|
| 61 |
+
]
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
# Document type descriptions
|
| 65 |
+
DOCUMENT_TYPE_DESCRIPTIONS = {
|
| 66 |
+
"articles": "Textes législatifs et réglementaires (lois, codes, décrets)",
|
| 67 |
+
"case_study": "Jurisprudence et décisions de justice (arrêts, jugements)"
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
# Legal context templates
|
| 71 |
+
LEGAL_CONTEXTS = {
|
| 72 |
+
"benin": {
|
| 73 |
+
"jurisdiction": "Bénin",
|
| 74 |
+
"user_type": "citizen",
|
| 75 |
+
"document_type": "Code des personnes et de la famille",
|
| 76 |
+
"language": "français",
|
| 77 |
+
"legal_system": "civil_law"
|
| 78 |
+
},
|
| 79 |
+
"madagascar": {
|
| 80 |
+
"jurisdiction": "Madagascar",
|
| 81 |
+
"user_type": "citizen",
|
| 82 |
+
"document_type": "legal",
|
| 83 |
+
"language": "français",
|
| 84 |
+
"legal_system": "mixed_civil_customary"
|
| 85 |
+
}
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
# User type contexts
|
| 89 |
+
USER_TYPE_CONTEXTS = {
|
| 90 |
+
"citizen": {
|
| 91 |
+
"expertise_level": "basic",
|
| 92 |
+
"response_style": "accessible",
|
| 93 |
+
"include_procedures": True
|
| 94 |
+
},
|
| 95 |
+
"lawyer": {
|
| 96 |
+
"expertise_level": "advanced",
|
| 97 |
+
"response_style": "technical",
|
| 98 |
+
"include_precedents": True
|
| 99 |
+
},
|
| 100 |
+
"student": {
|
| 101 |
+
"expertise_level": "intermediate",
|
| 102 |
+
"response_style": "educational",
|
| 103 |
+
"include_examples": True
|
| 104 |
+
}
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
# LAW_KEYWORDS a été supprimé comme demandé - le filtre "titre" n'est plus utilisé
|
config/settings.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dotenv import load_dotenv
|
| 3 |
+
|
| 4 |
+
# Change to:
|
| 5 |
+
try:
|
| 6 |
+
load_dotenv("../.env", override=True)
|
| 7 |
+
except:
|
| 8 |
+
pass # Ignore if .env file doesn't exist (like on Hugging Face)
|
| 9 |
+
|
| 10 |
+
class Settings:
|
| 11 |
+
# API Keys
|
| 12 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
| 13 |
+
MONGO_URI = os.environ.get("MONGO_URI")
|
| 14 |
+
NEON_DB_URL = os.environ.get("NEON_DB_URL")
|
| 15 |
+
NEON_END_POINT = os.getenv("NEON_END_POINT")
|
| 16 |
+
|
| 17 |
+
# Database
|
| 18 |
+
DATABASE_URL = NEON_END_POINT
|
| 19 |
+
|
| 20 |
+
# Model Configurations
|
| 21 |
+
EMBEDDING_MODEL = "text-embedding-ada-002"
|
| 22 |
+
CHAT_MODEL = "gpt-4o-mini"
|
| 23 |
+
CHAT_MODEL_2 = "gpt-3.5-turbo"
|
| 24 |
+
CHAT_TEMPERATURE = 0.1
|
| 25 |
+
CHAT_MAX_TOKENS = 2000
|
| 26 |
+
|
| 27 |
+
# Vector Search
|
| 28 |
+
VECTOR_INDEX_NAME = "vector_index"
|
| 29 |
+
TEXT_KEY = "contenu"
|
| 30 |
+
EMBEDDING_KEY = "vecteur_embedding"
|
| 31 |
+
|
| 32 |
+
# Collections
|
| 33 |
+
BENIN_COLLECTION = "legal_documents"
|
| 34 |
+
MADAGASCAR_COLLECTION = "legal_documents_madagascar"
|
| 35 |
+
DATABASE_NAME = "legal_db"
|
| 36 |
+
|
| 37 |
+
# Search Parameters
|
| 38 |
+
MAX_SEARCH_RESULTS = 10
|
| 39 |
+
MAX_CONVERSATION_HISTORY = 8
|
| 40 |
+
|
| 41 |
+
def validate(self):
|
| 42 |
+
missing = []
|
| 43 |
+
if not self.OPENAI_API_KEY:
|
| 44 |
+
missing.append("OPENAI_API_KEY")
|
| 45 |
+
if not self.MONGO_URI:
|
| 46 |
+
missing.append("MONGO_URI")
|
| 47 |
+
if not self.NEON_DB_URL:
|
| 48 |
+
missing.append("NEON_DB_URL")
|
| 49 |
+
if not self.NEON_END_POINT:
|
| 50 |
+
missing.append("NEON_END_POINT")
|
| 51 |
+
|
| 52 |
+
if missing:
|
| 53 |
+
raise ValueError(f"Missing environment variables: {', '.join(missing)}")
|
| 54 |
+
|
| 55 |
+
settings = Settings()
|
core/assistance/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# [file name]: core/assistance/__init__.py
|
| 2 |
+
from .workflow_nodes import AssistanceWorkflowNodes
|
| 3 |
+
from .email_service import AssistanceEmailService
|
| 4 |
+
|
| 5 |
+
__all__ = ["AssistanceWorkflowNodes", "AssistanceEmailService"]
|
core/assistance/email_service.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# [file name]: core/assistance/email_service.py
|
| 2 |
+
"""
|
| 3 |
+
Wrapper for email functionality - provides a consistent interface
|
| 4 |
+
"""
|
| 5 |
+
import re
|
| 6 |
+
import logging
|
| 7 |
+
from typing import Optional, Dict
|
| 8 |
+
from core.email_tool import LegalAssistanceEmailer
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class AssistanceEmailService:
|
| 14 |
+
"""Service wrapper for email operations"""
|
| 15 |
+
|
| 16 |
+
def __init__(self):
|
| 17 |
+
self.emailer = LegalAssistanceEmailer()
|
| 18 |
+
|
| 19 |
+
def extract_email_from_text(self, text: str) -> Optional[str]:
|
| 20 |
+
"""Extract email from text"""
|
| 21 |
+
return self.emailer.extract_email_from_text(text)
|
| 22 |
+
|
| 23 |
+
def validate_email(self, email: str) -> bool:
|
| 24 |
+
"""Validate email format"""
|
| 25 |
+
return self.emailer.validate_email(email)
|
| 26 |
+
|
| 27 |
+
def send_assistance_request(
|
| 28 |
+
self,
|
| 29 |
+
user_email: str,
|
| 30 |
+
user_query: str,
|
| 31 |
+
assistance_description: str,
|
| 32 |
+
country: str
|
| 33 |
+
) -> Dict:
|
| 34 |
+
"""Send assistance request emails"""
|
| 35 |
+
return self.emailer.send_assistance_request(
|
| 36 |
+
user_email=user_email,
|
| 37 |
+
user_query=user_query,
|
| 38 |
+
assistance_description=assistance_description,
|
| 39 |
+
country=country
|
| 40 |
+
)
|
core/assistance/workflow_nodes.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# core/assistance/workflow_nodes.py
|
| 2 |
+
import logging
|
| 3 |
+
import re
|
| 4 |
+
from typing import Dict, Any, List
|
| 5 |
+
from langchain_core.runnables import RunnableConfig
|
| 6 |
+
from models.state_models import MultiCountryLegalState
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
class AssistanceWorkflowNodes:
|
| 11 |
+
def __init__(self):
|
| 12 |
+
self.email_pattern = re.compile(r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$')
|
| 13 |
+
|
| 14 |
+
async def collect_assistance_info_node(self, state: MultiCountryLegalState, config: RunnableConfig) -> Dict[str, Any]:
|
| 15 |
+
"""Collect assistance information (email, description)"""
|
| 16 |
+
s = state.model_dump()
|
| 17 |
+
assistance_step = s.get("assistance_step", "collecting_email")
|
| 18 |
+
user_input = s.get("messages", [{}])[-1].get("content", "") if s.get("messages") else ""
|
| 19 |
+
|
| 20 |
+
logger.info(f"📝 Collecting assistance info - step: {assistance_step}")
|
| 21 |
+
logger.debug(f"User input: {user_input}")
|
| 22 |
+
# 🔥 NEW: Check for cancellation commands
|
| 23 |
+
cancellation_keywords = ["annuler", "cancel", "stop", "arrêter", "je ne veux plus", "plus besoin", "abandonner"]
|
| 24 |
+
if any(keyword in user_input for keyword in cancellation_keywords):
|
| 25 |
+
logger.info("🚫 User requested cancellation of assistance workflow")
|
| 26 |
+
return {
|
| 27 |
+
"assistance_step": "cancelled",
|
| 28 |
+
"assistance_requested": False,
|
| 29 |
+
"user_email": None,
|
| 30 |
+
"assistance_description": None,
|
| 31 |
+
"messages": [{
|
| 32 |
+
"role": "assistant",
|
| 33 |
+
"content": "✅ Votre demande d'assistance a été annulée. Comment puis-je vous aider autrement ?",
|
| 34 |
+
"meta": {"assistance_cancelled": True}
|
| 35 |
+
}]
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
if assistance_step == "collecting_email":
|
| 39 |
+
if not user_input:
|
| 40 |
+
logger.info(f"ℹ️ Waiting for email input")
|
| 41 |
+
return {
|
| 42 |
+
"assistance_step": "collecting_email",
|
| 43 |
+
"messages": [] # Response node will generate the message
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
if self.email_pattern.match(user_input):
|
| 47 |
+
logger.info(f"📧 Email collected: {user_input}")
|
| 48 |
+
return {
|
| 49 |
+
"assistance_step": "collecting_description",
|
| 50 |
+
"user_email": user_input,
|
| 51 |
+
"assistance_requested": True,
|
| 52 |
+
"messages": [] # Response node will generate the message
|
| 53 |
+
}
|
| 54 |
+
else:
|
| 55 |
+
logger.warning(f"Invalid email: {user_input}")
|
| 56 |
+
return {
|
| 57 |
+
"assistance_step": "collecting_email",
|
| 58 |
+
"messages": [{
|
| 59 |
+
"role": "assistant",
|
| 60 |
+
"content": """⚠️ L'adresse email fournie semble invalide. Veuillez fournir une adresse email valide.
|
| 61 |
+
|
| 62 |
+
📧 **Veuillez me fournir votre adresse email :**""",
|
| 63 |
+
"meta": {"assistance_step": "collecting_email"}
|
| 64 |
+
}]
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
elif assistance_step == "collecting_description":
|
| 68 |
+
if not user_input or len(user_input.strip()) < 10:
|
| 69 |
+
logger.info(f"ℹ️ Waiting for description input")
|
| 70 |
+
return {
|
| 71 |
+
"assistance_step": "collecting_description",
|
| 72 |
+
"messages": [] # Response node will generate the message
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
# Detect country from the description
|
| 76 |
+
detected_country = MultiCountryLegalState.detect_country(user_input)
|
| 77 |
+
|
| 78 |
+
logger.info(f"📝 Description collected: {user_input[:50]}...")
|
| 79 |
+
logger.info(f"🌍 Detected country: {detected_country}")
|
| 80 |
+
|
| 81 |
+
# Return the update - move to confirmation step
|
| 82 |
+
return {
|
| 83 |
+
"assistance_description": user_input,
|
| 84 |
+
"assistance_step": "confirming_send",
|
| 85 |
+
"country": detected_country,
|
| 86 |
+
"legal_context": {
|
| 87 |
+
**state.legal_context,
|
| 88 |
+
"detected_country": detected_country
|
| 89 |
+
},
|
| 90 |
+
"messages": [] # Response node will generate the confirmation message
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
return {}
|
| 94 |
+
|
| 95 |
+
async def confirm_assistance_send_node(self, state: MultiCountryLegalState, config: RunnableConfig) -> Dict[str, Any]:
|
| 96 |
+
"""Confirm assistance request before sending to legal team"""
|
| 97 |
+
s = state.model_dump()
|
| 98 |
+
user_input = s.get("messages", [{}])[-1].get("content", "").lower().strip() if s.get("messages") else ""
|
| 99 |
+
|
| 100 |
+
logger.info(f"✅ Confirmation node - user input: {user_input[:30]}...")
|
| 101 |
+
|
| 102 |
+
if user_input in ["oui", "yes", "ok", "confirmer"]:
|
| 103 |
+
logger.info("✅ User confirmed assistance request")
|
| 104 |
+
return {
|
| 105 |
+
"assistance_step": "confirmed",
|
| 106 |
+
"messages": [] # Let response node or approval node handle the message
|
| 107 |
+
}
|
| 108 |
+
elif user_input in ["non", "no", "cancel", "annuler"]:
|
| 109 |
+
logger.info("❌ User cancelled assistance request")
|
| 110 |
+
return {
|
| 111 |
+
"assistance_step": "cancelled",
|
| 112 |
+
"assistance_requested": False,
|
| 113 |
+
"messages": [{
|
| 114 |
+
"role": "assistant",
|
| 115 |
+
"content": """❌ Votre demande a été annulée.
|
| 116 |
+
|
| 117 |
+
Si vous changez d'avis, vous pouvez relancer une demande en disant "Je veux parler à un avocat".""",
|
| 118 |
+
"meta": {"assistance_step": "cancelled"}
|
| 119 |
+
}]
|
| 120 |
+
}
|
| 121 |
+
else:
|
| 122 |
+
logger.info("ℹ️ Awaiting valid confirmation")
|
| 123 |
+
return {
|
| 124 |
+
"assistance_step": "confirming_send",
|
| 125 |
+
"messages": [{
|
| 126 |
+
"role": "assistant",
|
| 127 |
+
"content": f"""⚠️ Veuillez confirmer avec "oui" ou "non".
|
| 128 |
+
|
| 129 |
+
📋 **RÉCAPITULATIF DE VOTRE DEMANDE :**
|
| 130 |
+
|
| 131 |
+
📧 **Email :** {s.get("user_email")}
|
| 132 |
+
📝 **Description :** {s.get("assistance_description")}
|
| 133 |
+
|
| 134 |
+
✅ **Confirmez-vous l'envoi de cette demande à notre équipe juridique ?**
|
| 135 |
+
|
| 136 |
+
Répondez par :
|
| 137 |
+
- **"oui"** pour confirmer et envoyer
|
| 138 |
+
- **"non"** pour annuler et modifier""",
|
| 139 |
+
"meta": {"assistance_step": "confirming_send"}
|
| 140 |
+
}]
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
def route_assistance(self, state: MultiCountryLegalState) -> str:
|
| 144 |
+
"""Route assistance workflow based on current state"""
|
| 145 |
+
s = state.model_dump()
|
| 146 |
+
assistance_step = s.get("assistance_step", "collecting_email")
|
| 147 |
+
|
| 148 |
+
logger.info(f"📋 Assistance step: {assistance_step}")
|
| 149 |
+
logger.info(f" - Has email: {s.get('user_email') is not None} ({s.get('user_email')})")
|
| 150 |
+
logger.info(f" - Has description: {s.get('assistance_description') is not None} ({s.get('assistance_description')})")
|
| 151 |
+
|
| 152 |
+
if assistance_step == "collecting_email" and not s.get("user_email"):
|
| 153 |
+
logger.info("→ Routing to: need_email (waiting for email)")
|
| 154 |
+
return "need_email"
|
| 155 |
+
elif assistance_step == "collecting_description" and not s.get("assistance_description"):
|
| 156 |
+
logger.info("→ Routing to: need_description (waiting for description)")
|
| 157 |
+
return "need_description"
|
| 158 |
+
elif assistance_step == "confirming_send" and s.get("user_email") and s.get("assistance_description"):
|
| 159 |
+
logger.info("→ Routing to: ready_to_confirm (awaiting user confirmation)")
|
| 160 |
+
return "ready_to_confirm"
|
| 161 |
+
elif assistance_step == "cancelled":
|
| 162 |
+
logger.info("→ Routing to: cancelled")
|
| 163 |
+
return "cancelled"
|
| 164 |
+
|
| 165 |
+
logger.info("→ Routing to: need_email (default)")
|
| 166 |
+
return "need_email"
|
| 167 |
+
|
| 168 |
+
def route_after_confirmation(self, state: MultiCountryLegalState) -> str:
|
| 169 |
+
"""Route after confirmation step"""
|
| 170 |
+
s = state.model_dump()
|
| 171 |
+
assistance_step = s.get("assistance_step")
|
| 172 |
+
last_message = s.get("messages", [{}])[-1] if s.get("messages") else {}
|
| 173 |
+
user_input = last_message.get("content", "").lower().strip() if last_message.get("role") == "user" else ""
|
| 174 |
+
|
| 175 |
+
logger.info(f"📋 Confirmation step: {assistance_step}")
|
| 176 |
+
logger.info(f" - Last user message: '{user_input}'")
|
| 177 |
+
|
| 178 |
+
if assistance_step == "confirmed":
|
| 179 |
+
logger.info("→ Routing to: confirmed (human approval)")
|
| 180 |
+
return "confirmed"
|
| 181 |
+
elif assistance_step == "cancelled":
|
| 182 |
+
logger.info("→ Routing to: cancelled")
|
| 183 |
+
return "cancelled"
|
| 184 |
+
else:
|
| 185 |
+
logger.info("→ Routing to: needs_correction (need clarification)")
|
| 186 |
+
return "needs_correction"
|
core/chat_manager.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# [file name]: core/chat_manager.py
|
| 2 |
+
import asyncio
|
| 3 |
+
import logging
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
from typing import Dict, List, Optional
|
| 6 |
+
from langchain_core.runnables import RunnableConfig
|
| 7 |
+
from langchain_core.messages import BaseMessage
|
| 8 |
+
from langgraph.types import Command
|
| 9 |
+
|
| 10 |
+
from config.settings import settings
|
| 11 |
+
from models.state_models import MultiCountryLegalState
|
| 12 |
+
from utils.helpers import dict_to_message_obj
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
class LegalChatManager:
|
| 17 |
+
def __init__(self, graph, checkpointer):
|
| 18 |
+
self.graph = graph
|
| 19 |
+
self.checkpointer = checkpointer
|
| 20 |
+
self.active_sessions = {}
|
| 21 |
+
self.routing_stats = {
|
| 22 |
+
"benin": 0,
|
| 23 |
+
"madagascar": 0,
|
| 24 |
+
"unclear": 0,
|
| 25 |
+
"total_queries": 0
|
| 26 |
+
}
|
| 27 |
+
# Track pending interrupts by session
|
| 28 |
+
self.pending_interrupts = {}
|
| 29 |
+
|
| 30 |
+
async def chat(self, message: str, session_id: str,
|
| 31 |
+
legal_context: Optional[Dict[str, str]] = None) -> str:
|
| 32 |
+
"""Process a chat message with session management and interrupt handling"""
|
| 33 |
+
if not self.graph:
|
| 34 |
+
raise RuntimeError("System not initialized. Call setup_system() first.")
|
| 35 |
+
|
| 36 |
+
# Initialize or update session
|
| 37 |
+
self._initialize_session(session_id)
|
| 38 |
+
|
| 39 |
+
# Check if we have a pending interrupt for this session
|
| 40 |
+
if session_id in self.pending_interrupts:
|
| 41 |
+
return await self._handle_pending_interrupt(session_id, message)
|
| 42 |
+
|
| 43 |
+
# Prepare input state
|
| 44 |
+
input_state = self._prepare_input_state(message, session_id, legal_context)
|
| 45 |
+
config = RunnableConfig(
|
| 46 |
+
configurable={"thread_id": session_id},
|
| 47 |
+
recursion_limit=100
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
try:
|
| 51 |
+
# Track performance
|
| 52 |
+
start_time = datetime.now()
|
| 53 |
+
|
| 54 |
+
# Process through graph
|
| 55 |
+
result = await self.graph.ainvoke(MultiCountryLegalState(**input_state), config)
|
| 56 |
+
|
| 57 |
+
# Check for interrupt
|
| 58 |
+
state_snapshot = await self.graph.aget_state(config)
|
| 59 |
+
if state_snapshot and state_snapshot.next:
|
| 60 |
+
# Graph is paused at an interrupt
|
| 61 |
+
logger.info(f"⏸️ Graph interrupted at: {state_snapshot.next}")
|
| 62 |
+
self.pending_interrupts[session_id] = {
|
| 63 |
+
"type": "human_approval",
|
| 64 |
+
"config": config,
|
| 65 |
+
"created_at": datetime.now(),
|
| 66 |
+
"paused_at": state_snapshot.next
|
| 67 |
+
}
|
| 68 |
+
return self._get_approval_prompt_message(result)
|
| 69 |
+
|
| 70 |
+
# Track performance
|
| 71 |
+
processing_time = (datetime.now() - start_time).total_seconds()
|
| 72 |
+
self._update_session_stats(session_id, processing_time)
|
| 73 |
+
|
| 74 |
+
# Extract and return response
|
| 75 |
+
response = self._extract_response(result)
|
| 76 |
+
self._update_routing_stats(response)
|
| 77 |
+
|
| 78 |
+
return response
|
| 79 |
+
|
| 80 |
+
except Exception as e:
|
| 81 |
+
logger.exception(f"Chat error for session {session_id}")
|
| 82 |
+
self._log_error(session_id, str(e))
|
| 83 |
+
return f"Erreur lors du traitement: {str(e)}"
|
| 84 |
+
|
| 85 |
+
async def _handle_pending_interrupt(self, session_id: str, message: str) -> str:
|
| 86 |
+
"""Handle user response to a pending interrupt using Command(resume=...)"""
|
| 87 |
+
interrupt_data = self.pending_interrupts.get(session_id)
|
| 88 |
+
if not interrupt_data:
|
| 89 |
+
return "Erreur: Aucune interruption en attente."
|
| 90 |
+
|
| 91 |
+
try:
|
| 92 |
+
logger.info(f"📥 Resuming graph with moderator decision: {message}")
|
| 93 |
+
|
| 94 |
+
config = interrupt_data["config"]
|
| 95 |
+
|
| 96 |
+
# CRITICAL FIX: Use Command(resume=...) to properly resume from interrupt
|
| 97 |
+
# This sends the user's message back to the interrupt() call
|
| 98 |
+
result = await self.graph.ainvoke(
|
| 99 |
+
Command(resume=message),
|
| 100 |
+
config
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# Clean up the pending interrupt
|
| 104 |
+
del self.pending_interrupts[session_id]
|
| 105 |
+
|
| 106 |
+
# Extract and return final response
|
| 107 |
+
response = self._extract_response(result)
|
| 108 |
+
self._update_routing_stats(response)
|
| 109 |
+
|
| 110 |
+
logger.info(f"✅ Graph resumed successfully for session {session_id}")
|
| 111 |
+
return response
|
| 112 |
+
|
| 113 |
+
except Exception as e:
|
| 114 |
+
logger.error(f"Error resuming from interrupt: {str(e)}")
|
| 115 |
+
# Clean up on error
|
| 116 |
+
if session_id in self.pending_interrupts:
|
| 117 |
+
del self.pending_interrupts[session_id]
|
| 118 |
+
return f"Erreur lors du traitement de la décision: {str(e)}"
|
| 119 |
+
|
| 120 |
+
def _get_approval_prompt_message(self, state) -> str:
|
| 121 |
+
"""Generate message asking for human approval"""
|
| 122 |
+
# Extract metadata from state
|
| 123 |
+
if isinstance(state, MultiCountryLegalState):
|
| 124 |
+
state_dict = state.model_dump()
|
| 125 |
+
elif isinstance(state, dict):
|
| 126 |
+
state_dict = state
|
| 127 |
+
else:
|
| 128 |
+
state_dict = {}
|
| 129 |
+
|
| 130 |
+
user_email = state_dict.get("user_email", "Non spécifié")
|
| 131 |
+
country = state_dict.get("legal_context", {}).get("detected_country", "Non spécifié")
|
| 132 |
+
description = state_dict.get("assistance_description", "Non spécifié")
|
| 133 |
+
|
| 134 |
+
return f"""
|
| 135 |
+
🔒 **APPROBATION HUMAINE REQUISE**
|
| 136 |
+
|
| 137 |
+
📧 **Utilisateur**: {user_email}
|
| 138 |
+
🌍 **Pays**: {country}
|
| 139 |
+
📝 **Description**: {description}
|
| 140 |
+
|
| 141 |
+
**Veuillez répondre avec:**
|
| 142 |
+
- "approve [raison]" pour approuver la demande
|
| 143 |
+
- "reject [raison]" pour rejeter la demande
|
| 144 |
+
|
| 145 |
+
**Exemples:**
|
| 146 |
+
- "approve Demande légitime de consultation"
|
| 147 |
+
- "reject Email invalide ou description trop vague"
|
| 148 |
+
|
| 149 |
+
**Votre décision:**
|
| 150 |
+
"""
|
| 151 |
+
|
| 152 |
+
# === EXISTING METHODS (unchanged) ===
|
| 153 |
+
|
| 154 |
+
async def get_conversation_history(self, session_id: str) -> List[BaseMessage]:
|
| 155 |
+
"""Get conversation history for a session"""
|
| 156 |
+
if not self.graph:
|
| 157 |
+
return []
|
| 158 |
+
|
| 159 |
+
config = RunnableConfig(configurable={"thread_id": session_id})
|
| 160 |
+
|
| 161 |
+
try:
|
| 162 |
+
state = await self.graph.aget_state(config)
|
| 163 |
+
if not state or not state.values:
|
| 164 |
+
return []
|
| 165 |
+
|
| 166 |
+
s = state.values
|
| 167 |
+
if isinstance(s, MultiCountryLegalState):
|
| 168 |
+
s = s.model_dump()
|
| 169 |
+
elif isinstance(s, dict):
|
| 170 |
+
pass
|
| 171 |
+
else:
|
| 172 |
+
s = {}
|
| 173 |
+
|
| 174 |
+
raw_messages = s.get("messages", [])
|
| 175 |
+
return [dict_to_message_obj(m) for m in raw_messages if isinstance(m, dict)]
|
| 176 |
+
|
| 177 |
+
except Exception as e:
|
| 178 |
+
logger.exception(f"Error getting conversation history for session {session_id}")
|
| 179 |
+
return []
|
| 180 |
+
|
| 181 |
+
def get_session_stats(self, session_id: str) -> Dict:
|
| 182 |
+
"""Get statistics for a specific session"""
|
| 183 |
+
return self.active_sessions.get(session_id, {})
|
| 184 |
+
|
| 185 |
+
def get_global_stats(self) -> Dict:
|
| 186 |
+
"""Get global system statistics"""
|
| 187 |
+
return {
|
| 188 |
+
"routing_stats": self.routing_stats,
|
| 189 |
+
"active_sessions": len(self.active_sessions),
|
| 190 |
+
"total_queries": self.routing_stats["total_queries"],
|
| 191 |
+
"pending_interrupts": len(self.pending_interrupts)
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
def _initialize_session(self, session_id: str):
|
| 195 |
+
"""Initialize or update session tracking"""
|
| 196 |
+
if session_id not in self.active_sessions:
|
| 197 |
+
self.active_sessions[session_id] = {
|
| 198 |
+
"created": datetime.now(),
|
| 199 |
+
"query_count": 0,
|
| 200 |
+
"total_processing_time": 0,
|
| 201 |
+
"average_processing_time": 0,
|
| 202 |
+
"detected_countries": set(),
|
| 203 |
+
"last_activity": datetime.now()
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
session_info = self.active_sessions[session_id]
|
| 207 |
+
session_info["query_count"] += 1
|
| 208 |
+
session_info["last_activity"] = datetime.now()
|
| 209 |
+
|
| 210 |
+
def _prepare_input_state(self, message: str, session_id: str,
|
| 211 |
+
legal_context: Optional[Dict[str, str]]) -> Dict:
|
| 212 |
+
"""Prepare input state for graph processing"""
|
| 213 |
+
ctx = legal_context or {
|
| 214 |
+
"jurisdiction": "Unknown",
|
| 215 |
+
"user_type": "general",
|
| 216 |
+
"document_type": "legal",
|
| 217 |
+
"detected_country": "unknown"
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
if ctx.get("detected_country") is None:
|
| 221 |
+
ctx["detected_country"] = "unknown"
|
| 222 |
+
|
| 223 |
+
return {
|
| 224 |
+
"messages": [{"role": "user", "content": message, "meta": {}}],
|
| 225 |
+
"legal_context": ctx,
|
| 226 |
+
"session_id": session_id,
|
| 227 |
+
"router_decision": None,
|
| 228 |
+
"search_results": None,
|
| 229 |
+
"route_explanation": None,
|
| 230 |
+
"last_search_query": None,
|
| 231 |
+
"detected_articles": [],
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
def _extract_response(self, result) -> str:
|
| 235 |
+
"""Extract response text from graph result"""
|
| 236 |
+
if isinstance(result, MultiCountryLegalState):
|
| 237 |
+
r = result.model_dump()
|
| 238 |
+
elif isinstance(result, dict):
|
| 239 |
+
r = result
|
| 240 |
+
else:
|
| 241 |
+
r = {}
|
| 242 |
+
|
| 243 |
+
msgs = r.get("messages", [])
|
| 244 |
+
for m in reversed(msgs):
|
| 245 |
+
if (m.get("role") or "").lower() in ("assistant", "ai"):
|
| 246 |
+
return m.get("content", "")
|
| 247 |
+
|
| 248 |
+
return "Désolé, je n'ai pas pu générer de réponse."
|
| 249 |
+
|
| 250 |
+
def _update_session_stats(self, session_id: str, processing_time: float):
|
| 251 |
+
"""Update session statistics with processing time"""
|
| 252 |
+
if session_id in self.active_sessions:
|
| 253 |
+
session_info = self.active_sessions[session_id]
|
| 254 |
+
session_info["total_processing_time"] += processing_time
|
| 255 |
+
session_info["average_processing_time"] = (
|
| 256 |
+
session_info["total_processing_time"] / session_info["query_count"]
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
def _update_routing_stats(self, response: str):
|
| 260 |
+
"""Update routing statistics based on response content"""
|
| 261 |
+
self.routing_stats["total_queries"] += 1
|
| 262 |
+
|
| 263 |
+
response_lower = response.lower()
|
| 264 |
+
if any(keyword in response_lower for keyword in ["bénin", "béninois", "béninoise"]):
|
| 265 |
+
self.routing_stats["benin"] += 1
|
| 266 |
+
elif any(keyword in response_lower for keyword in ["madagascar", "malgache", "malagasy"]):
|
| 267 |
+
self.routing_stats["madagascar"] += 1
|
| 268 |
+
else:
|
| 269 |
+
self.routing_stats["unclear"] += 1
|
| 270 |
+
|
| 271 |
+
def _log_error(self, session_id: str, error: str):
|
| 272 |
+
"""Log error for monitoring"""
|
| 273 |
+
logger.error(f"Session {session_id}: {error}")
|
| 274 |
+
|
| 275 |
+
def cleanup_inactive_sessions(self, max_age_hours: int = 24):
|
| 276 |
+
"""Clean up sessions that have been inactive for too long"""
|
| 277 |
+
cutoff_time = datetime.now().timestamp() - (max_age_hours * 3600)
|
| 278 |
+
|
| 279 |
+
inactive_sessions = [
|
| 280 |
+
session_id for session_id, info in self.active_sessions.items()
|
| 281 |
+
if info["last_activity"].timestamp() < cutoff_time
|
| 282 |
+
]
|
| 283 |
+
|
| 284 |
+
# Also clean up pending interrupts for inactive sessions
|
| 285 |
+
for session_id in inactive_sessions:
|
| 286 |
+
if session_id in self.pending_interrupts:
|
| 287 |
+
del self.pending_interrupts[session_id]
|
| 288 |
+
del self.active_sessions[session_id]
|
| 289 |
+
logger.info(f"Cleaned up inactive session: {session_id}")
|
core/conversation_repair.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# [file name]: core/conversation_repair.py
|
| 2 |
+
import logging
|
| 3 |
+
from typing import Dict, List, Optional, Any
|
| 4 |
+
import re
|
| 5 |
+
|
| 6 |
+
logger = logging.getLogger(__name__)
|
| 7 |
+
|
| 8 |
+
class ConversationRepair:
|
| 9 |
+
def __init__(self):
|
| 10 |
+
self.meta_keywords = [
|
| 11 |
+
"pas compris", "mal compris", "reformuler", "autrement",
|
| 12 |
+
"différemment", "répéter", "redire", "expliquer autrement",
|
| 13 |
+
"plus simple", "plus clair", "clarifier", "précisez",
|
| 14 |
+
"explique mieux", "développe", "approfondis", "que veux-tu dire",
|
| 15 |
+
"c'est-à-dire", "concrètement", "en pratique", "recommence",
|
| 16 |
+
"ce n'est pas ça", "tu n'as pas compris", "erreur", "faux"
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
def detect_repair_intent(self, query: str, conversation_history: List[Dict]) -> bool:
|
| 20 |
+
"""Simple detection - just check if this is a repair request"""
|
| 21 |
+
query_lower = query.lower()
|
| 22 |
+
return any(keyword in query_lower for keyword in self.meta_keywords)
|
| 23 |
+
|
| 24 |
+
async def generate_repair_response(self, query: str, conversation_history: List[Dict], llm) -> str:
|
| 25 |
+
"""Unified LLM-powered repair handling"""
|
| 26 |
+
try:
|
| 27 |
+
# Build conversation context
|
| 28 |
+
context = self._build_conversation_context(conversation_history)
|
| 29 |
+
|
| 30 |
+
repair_prompt = self._build_repair_prompt(query, context)
|
| 31 |
+
|
| 32 |
+
# Use LLM for intelligent repair response
|
| 33 |
+
from langchain_core.messages import HumanMessage
|
| 34 |
+
response = await llm.ainvoke([HumanMessage(content=repair_prompt)])
|
| 35 |
+
|
| 36 |
+
return response.content if hasattr(response, 'content') else str(response)
|
| 37 |
+
|
| 38 |
+
except Exception as e:
|
| 39 |
+
logger.error(f"LLM repair generation failed: {e}")
|
| 40 |
+
return self._generate_fallback_response()
|
| 41 |
+
|
| 42 |
+
def _build_conversation_context(self, conversation_history: List[Dict]) -> str:
|
| 43 |
+
"""Build conversation context for LLM"""
|
| 44 |
+
if not conversation_history:
|
| 45 |
+
return "Aucun contexte de conversation"
|
| 46 |
+
|
| 47 |
+
# Get relevant conversation history
|
| 48 |
+
relevant_messages = conversation_history[-6:] # Last 6 messages
|
| 49 |
+
|
| 50 |
+
context_lines = []
|
| 51 |
+
for msg in relevant_messages:
|
| 52 |
+
role = "Utilisateur" if msg.get("role") == "user" else "Assistant"
|
| 53 |
+
content = msg.get("content", "")
|
| 54 |
+
context_lines.append(f"{role}: {content}")
|
| 55 |
+
|
| 56 |
+
return "\n".join(context_lines)
|
| 57 |
+
|
| 58 |
+
def _build_repair_prompt(self, current_query: str, conversation_context: str) -> str:
|
| 59 |
+
"""Build intelligent repair prompt"""
|
| 60 |
+
return f"""
|
| 61 |
+
Vous êtes un assistant juridique expert. L'utilisateur exprime un problème de compréhension ou demande une clarification.
|
| 62 |
+
|
| 63 |
+
**CONTEXTE DE LA CONVERSATION:**
|
| 64 |
+
{conversation_context}
|
| 65 |
+
|
| 66 |
+
**REQUÊTE ACTUELLE DE L'UTILISATEUR:**
|
| 67 |
+
"{current_query}"
|
| 68 |
+
|
| 69 |
+
**ANALYSE REQUISE:**
|
| 70 |
+
1. Identifiez le type de problème : incompréhension, besoin de clarification, reformulation, correction d'erreur
|
| 71 |
+
2. Analysez quel aspect précis pose problème dans la conversation
|
| 72 |
+
3. Adaptez votre réponse au contexte juridique si pertinent
|
| 73 |
+
|
| 74 |
+
**INSTRUCTIONS POUR LA RÉPONSE:**
|
| 75 |
+
- Accusez réception du problème de compréhension
|
| 76 |
+
- Fournissez une clarification adaptée et utile
|
| 77 |
+
- Si c'est juridique, simplifiez sans perdre la précision légale
|
| 78 |
+
- Utilisez des exemples concrets si pertinent
|
| 79 |
+
- Proposez des pistes pour avancer
|
| 80 |
+
- Gardez un ton professionnel et empathique
|
| 81 |
+
- Maximum 5-7 phrases
|
| 82 |
+
|
| 83 |
+
**RÉPONSE:**
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
def _generate_fallback_response(self) -> str:
|
| 87 |
+
"""Fallback if LLM fails"""
|
| 88 |
+
return "Je m'excuse pour ce malentendu. Pouvez-vous reformuler votre demande ou préciser ce qui n'était pas clair ?"
|
core/email_tool.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# acfai_project/core/email_tool.py
|
| 2 |
+
import os
|
| 3 |
+
import smtplib
|
| 4 |
+
import logging
|
| 5 |
+
from email.mime.text import MIMEText # Correction: MIMEText au lieu de MimeText
|
| 6 |
+
from email.mime.multipart import MIMEMultipart # Correction: MIMEMultipart au lieu de MimeMultipart
|
| 7 |
+
from typing import Dict, Optional
|
| 8 |
+
import re
|
| 9 |
+
import datetime # Ajout pour la date
|
| 10 |
+
|
| 11 |
+
from config.settings import settings
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
class LegalAssistanceEmailer:
|
| 16 |
+
def __init__(self):
|
| 17 |
+
self.email_address = os.getenv("EMAIL_ADDRESS")
|
| 18 |
+
self.email_password = os.getenv("EMAIL_APP_PASSWORD")
|
| 19 |
+
self.lawyer_email = os.getenv("LAWYER_EMAIL", "fitahiana@acfai.org")
|
| 20 |
+
self.smtp_server = "smtp.gmail.com"
|
| 21 |
+
self.smtp_port = 587
|
| 22 |
+
|
| 23 |
+
def is_assistance_request(self, query: str) -> bool:
|
| 24 |
+
"""Détecte si l'utilisateur demande une assistance humaine"""
|
| 25 |
+
assistance_keywords = [
|
| 26 |
+
"parler à un avocat", "avocat humain", "assistance humaine",
|
| 27 |
+
"contactez-moi", "rappelez-moi", "assistance téléphonique",
|
| 28 |
+
"besoin d'un avocat", "consultation juridique", "avocat réel",
|
| 29 |
+
"aide humaine", "contact humain", "échange avec un avocat",
|
| 30 |
+
"assisté", "assisté par", "être assisté"
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
query_lower = query.lower()
|
| 34 |
+
return any(keyword in query_lower for keyword in assistance_keywords)
|
| 35 |
+
|
| 36 |
+
def extract_email_from_text(self, text: str) -> Optional[str]:
|
| 37 |
+
"""Extrait un email d'un texte"""
|
| 38 |
+
email_pattern = r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b'
|
| 39 |
+
matches = re.findall(email_pattern, text)
|
| 40 |
+
return matches[0] if matches else None
|
| 41 |
+
|
| 42 |
+
def validate_email(self, email: str) -> bool:
|
| 43 |
+
"""Valide le format d'un email"""
|
| 44 |
+
pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
|
| 45 |
+
return re.match(pattern, email) is not None
|
| 46 |
+
|
| 47 |
+
def send_assistance_request(self, user_email: str, user_query: str,
|
| 48 |
+
assistance_description: str, country: str) -> Dict[str, any]:
|
| 49 |
+
"""Envoie les emails de confirmation à l'utilisateur et à l'avocat"""
|
| 50 |
+
try:
|
| 51 |
+
# Validation des emails
|
| 52 |
+
if not self.validate_email(user_email):
|
| 53 |
+
return {"success": False, "error": "Format d'email utilisateur invalide"}
|
| 54 |
+
|
| 55 |
+
if not self.validate_email(self.lawyer_email):
|
| 56 |
+
return {"success": False, "error": "Format d'email avocat invalide"}
|
| 57 |
+
|
| 58 |
+
# Connexion SMTP
|
| 59 |
+
server = smtplib.SMTP(self.smtp_server, self.smtp_port)
|
| 60 |
+
server.starttls()
|
| 61 |
+
server.login(self.email_address, self.email_password)
|
| 62 |
+
|
| 63 |
+
# Email à l'utilisateur
|
| 64 |
+
user_email_sent = self._send_user_confirmation(server, user_email, user_query, country)
|
| 65 |
+
|
| 66 |
+
# Email à l'avocat
|
| 67 |
+
lawyer_email_sent = self._send_lawyer_notification(server, user_email, user_query,
|
| 68 |
+
assistance_description, country)
|
| 69 |
+
|
| 70 |
+
server.quit()
|
| 71 |
+
|
| 72 |
+
if user_email_sent and lawyer_email_sent:
|
| 73 |
+
logger.info(f"✅ Emails envoyés avec succès pour {user_email}")
|
| 74 |
+
return {
|
| 75 |
+
"success": True,
|
| 76 |
+
"message": "Demande d'assistance envoyée avec succès",
|
| 77 |
+
"user_email": user_email,
|
| 78 |
+
"lawyer_email": self.lawyer_email
|
| 79 |
+
}
|
| 80 |
+
else:
|
| 81 |
+
return {"success": False, "error": "Échec de l'envoi des emails"}
|
| 82 |
+
|
| 83 |
+
except Exception as e:
|
| 84 |
+
logger.error(f"❌ Erreur d'envoi d'email: {e}")
|
| 85 |
+
return {"success": False, "error": f"Erreur SMTP: {str(e)}"}
|
| 86 |
+
|
| 87 |
+
def _send_user_confirmation(self, server, user_email: str, user_query: str, country: str) -> bool:
|
| 88 |
+
"""Envoie l'email de confirmation à l'utilisateur"""
|
| 89 |
+
try:
|
| 90 |
+
message = MIMEMultipart() # Correction: MIMEMultipart
|
| 91 |
+
message["From"] = self.email_address
|
| 92 |
+
message["To"] = user_email
|
| 93 |
+
message["Subject"] = "📧 Confirmation de votre demande d'assistance juridique"
|
| 94 |
+
|
| 95 |
+
body = f"""
|
| 96 |
+
<html>
|
| 97 |
+
<body>
|
| 98 |
+
<h2 style="color: #2E86AB;">Confirmation de votre demande d'assistance juridique</h2>
|
| 99 |
+
|
| 100 |
+
<p>Bonjour,</p>
|
| 101 |
+
|
| 102 |
+
<p>Nous accusons réception de votre demande d'assistance juridique concernant :</p>
|
| 103 |
+
|
| 104 |
+
<div style="background-color: #f8f9fa; padding: 15px; border-left: 4px solid #2E86AB;">
|
| 105 |
+
<strong>Question initiale :</strong> {user_query}<br>
|
| 106 |
+
<strong>Juridiction concernée :</strong> {country}<br>
|
| 107 |
+
<strong>Votre email :</strong> {user_email}
|
| 108 |
+
</div>
|
| 109 |
+
|
| 110 |
+
<p>��� <strong>Notre équipe juridique a été notifiée</strong> et vous contactera dans les plus brefs délais.</p>
|
| 111 |
+
|
| 112 |
+
<h3>📞 Prochaines étapes :</h3>
|
| 113 |
+
<ul>
|
| 114 |
+
<li>Un avocat spécialisé vous contactera à l'adresse {user_email}</li>
|
| 115 |
+
<li>Préparez les documents relatifs à votre situation</li>
|
| 116 |
+
<li>Durée de réponse estimée : 24-48 heures</li>
|
| 117 |
+
</ul>
|
| 118 |
+
|
| 119 |
+
<p>Pour toute urgence, vous pouvez répondre directement à cet email.</p>
|
| 120 |
+
|
| 121 |
+
<hr>
|
| 122 |
+
<p style="color: #6c757d;">
|
| 123 |
+
<small>
|
| 124 |
+
ACFAI - Assistance Juridique Intelligente<br>
|
| 125 |
+
Email : {self.lawyer_email}<br>
|
| 126 |
+
Ceci est un message automatique, merci de ne pas y répondre directement.
|
| 127 |
+
</small>
|
| 128 |
+
</p>
|
| 129 |
+
</body>
|
| 130 |
+
</html>
|
| 131 |
+
"""
|
| 132 |
+
|
| 133 |
+
message.attach(MIMEText(body, "html")) # Correction: MIMEText
|
| 134 |
+
server.send_message(message)
|
| 135 |
+
return True
|
| 136 |
+
|
| 137 |
+
except Exception as e:
|
| 138 |
+
logger.error(f"Erreur envoi email utilisateur: {e}")
|
| 139 |
+
return False
|
| 140 |
+
|
| 141 |
+
def _send_lawyer_notification(self, server, user_email: str, user_query: str,
|
| 142 |
+
assistance_description: str, country: str) -> bool:
|
| 143 |
+
"""Envoie la notification à l'avocat"""
|
| 144 |
+
try:
|
| 145 |
+
message = MIMEMultipart() # Correction: MIMEMultipart
|
| 146 |
+
message["From"] = self.email_address
|
| 147 |
+
message["To"] = self.lawyer_email
|
| 148 |
+
message["Subject"] = f"🔔 Nouvelle demande d'assistance juridique - {country}"
|
| 149 |
+
|
| 150 |
+
body = f"""
|
| 151 |
+
<html>
|
| 152 |
+
<body>
|
| 153 |
+
<h2 style="color: #A23B72;">Nouvelle demande d'assistance juridique</h2>
|
| 154 |
+
|
| 155 |
+
<div style="background-color: #fff3cd; padding: 15px; border-left: 4px solid #ffc107;">
|
| 156 |
+
<h3>📋 Informations de la demande :</h3>
|
| 157 |
+
<p><strong>Utilisateur :</strong> {user_email}</p>
|
| 158 |
+
<p><strong>Pays/Juridiction :</strong> {country}</p>
|
| 159 |
+
<p><strong>Question initiale :</strong> {user_query}</p>
|
| 160 |
+
<p><strong>Description de l'assistance demandée :</strong><br>{assistance_description}</p>
|
| 161 |
+
</div>
|
| 162 |
+
|
| 163 |
+
<h3>🚀 Action requise :</h3>
|
| 164 |
+
<ul>
|
| 165 |
+
<li>Contacter l'utilisateur à : {user_email}</li>
|
| 166 |
+
<li>Spécialité requise : Droit {country}</li>
|
| 167 |
+
<li>Priorité : Normale</li>
|
| 168 |
+
</ul>
|
| 169 |
+
|
| 170 |
+
<hr>
|
| 171 |
+
<p style="color: #6c757d;">
|
| 172 |
+
<small>
|
| 173 |
+
Système Automatique ACFAI - {settings.CHAT_MODEL}<br>
|
| 174 |
+
Généré le : {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}
|
| 175 |
+
</small>
|
| 176 |
+
</p>
|
| 177 |
+
</body>
|
| 178 |
+
</html>
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
+
message.attach(MIMEText(body, "html")) # Correction: MIMEText
|
| 182 |
+
server.send_message(message)
|
| 183 |
+
return True
|
| 184 |
+
|
| 185 |
+
except Exception as e:
|
| 186 |
+
logger.error(f"Erreur envoi email avocat: {e}")
|
| 187 |
+
return False
|
core/graph_builder.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# [file name]: core/graph_builder.py
|
| 2 |
+
from langgraph.graph import StateGraph, START, END
|
| 3 |
+
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
| 4 |
+
import logging
|
| 5 |
+
from typing import Dict, List, Any
|
| 6 |
+
from langchain_core.runnables import RunnableConfig
|
| 7 |
+
|
| 8 |
+
from models.state_models import MultiCountryLegalState
|
| 9 |
+
from core.router import CountryRouter
|
| 10 |
+
from core.retriever import LegalRetriever
|
| 11 |
+
from core.conversation_repair import ConversationRepair
|
| 12 |
+
from core.human_approval_node import HumanApprovalNode
|
| 13 |
+
|
| 14 |
+
# Import modular components
|
| 15 |
+
from core.nodes.routing_nodes import RoutingNodes
|
| 16 |
+
from core.assistance.workflow_nodes import AssistanceWorkflowNodes
|
| 17 |
+
from core.nodes.retrieval_nodes import RetrievalNodes
|
| 18 |
+
from core.nodes.response_nodes import ResponseNodes
|
| 19 |
+
from core.nodes.helper_nodes import HelperNodes
|
| 20 |
+
from core.routing.routing_logic import RoutingLogic
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
class GraphBuilder:
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
router: CountryRouter,
|
| 28 |
+
llm,
|
| 29 |
+
checkpointer: AsyncPostgresSaver,
|
| 30 |
+
# Country retrievers as a dictionary for easy extension
|
| 31 |
+
country_retrievers: Dict[str, LegalRetriever] = None
|
| 32 |
+
):
|
| 33 |
+
self.router = router
|
| 34 |
+
self.llm = llm
|
| 35 |
+
self.checkpointer = checkpointer
|
| 36 |
+
|
| 37 |
+
# Initialize country retrievers - easily extensible!
|
| 38 |
+
self.country_retrievers = country_retrievers or {}
|
| 39 |
+
|
| 40 |
+
# Initialize modular components
|
| 41 |
+
self.conversation_repair = ConversationRepair()
|
| 42 |
+
self.human_approval = HumanApprovalNode()
|
| 43 |
+
self.routing_logic = RoutingLogic()
|
| 44 |
+
|
| 45 |
+
# Initialize node groups
|
| 46 |
+
self.routing_nodes = RoutingNodes(router, self.conversation_repair, llm)
|
| 47 |
+
self.assistance_nodes = AssistanceWorkflowNodes()
|
| 48 |
+
|
| 49 |
+
# Dynamic retrieval nodes based on available countries
|
| 50 |
+
self.retrieval_nodes = RetrievalNodes(self.country_retrievers)
|
| 51 |
+
|
| 52 |
+
self.response_nodes = ResponseNodes(llm)
|
| 53 |
+
self.helper_nodes = HelperNodes(llm)
|
| 54 |
+
|
| 55 |
+
logger.info(f"GraphBuilder initialized with countries: {list(self.country_retrievers.keys())}")
|
| 56 |
+
|
| 57 |
+
def add_country(self, country_code: str, retriever: LegalRetriever):
|
| 58 |
+
"""Dynamically add a new country to the system"""
|
| 59 |
+
self.country_retrievers[country_code] = retriever
|
| 60 |
+
self.retrieval_nodes = RetrievalNodes(self.country_retrievers) # Re-initialize
|
| 61 |
+
logger.info(f"Added country: {country_code}")
|
| 62 |
+
|
| 63 |
+
def build_graph(self) -> StateGraph:
|
| 64 |
+
"""Build simplified flow with all routing categories"""
|
| 65 |
+
workflow = StateGraph(MultiCountryLegalState)
|
| 66 |
+
|
| 67 |
+
# Core nodes
|
| 68 |
+
workflow.add_node("router", self.routing_nodes.router_node)
|
| 69 |
+
workflow.add_node("response", self.response_nodes.response_generation_node)
|
| 70 |
+
|
| 71 |
+
# Country retrieval nodes - dynamically created
|
| 72 |
+
country_nodes = {}
|
| 73 |
+
for country_code in self.country_retrievers.keys():
|
| 74 |
+
node_name = f"{country_code}_retrieval"
|
| 75 |
+
workflow.add_node(node_name, self._create_country_retrieval_node(country_code))
|
| 76 |
+
country_nodes[country_code] = node_name
|
| 77 |
+
|
| 78 |
+
# Handler nodes
|
| 79 |
+
workflow.add_node("greeting_handler", self.routing_nodes.greeting_small_talk_node)
|
| 80 |
+
workflow.add_node("repair_handler", self.routing_nodes.conversation_repair_node)
|
| 81 |
+
workflow.add_node("summary_handler", self.helper_nodes.conversation_summarization_node)
|
| 82 |
+
workflow.add_node("unclear_handler", self.helper_nodes.unclear_route_node)
|
| 83 |
+
workflow.add_node("out_of_scope_handler", self.helper_nodes.out_of_scope_node)
|
| 84 |
+
|
| 85 |
+
# Assistance nodes - Using wrapper methods to ensure correct signatures
|
| 86 |
+
workflow.add_node("assistance_collect_info", self._create_assistance_collect_wrapper())
|
| 87 |
+
workflow.add_node("assistance_confirm", self._create_assistance_confirm_wrapper())
|
| 88 |
+
workflow.add_node("human_approval", self.human_approval.process_approval)
|
| 89 |
+
workflow.add_node("process_assistance", self._create_process_assistance_node)
|
| 90 |
+
|
| 91 |
+
# Main flow
|
| 92 |
+
workflow.add_edge(START, "router")
|
| 93 |
+
|
| 94 |
+
# Router directly routes to appropriate nodes
|
| 95 |
+
workflow.add_conditional_edges(
|
| 96 |
+
"router",
|
| 97 |
+
self._route_after_router,
|
| 98 |
+
{
|
| 99 |
+
**country_nodes, # benin_retrieval, madagascar_retrieval, etc.
|
| 100 |
+
"greeting_small_talk": "greeting_handler",
|
| 101 |
+
"conversation_repair": "repair_handler",
|
| 102 |
+
"conversation_summarization": "summary_handler",
|
| 103 |
+
"unclear": "unclear_handler",
|
| 104 |
+
"out_of_scope": "out_of_scope_handler",
|
| 105 |
+
"assistance_request": "assistance_collect_info"
|
| 106 |
+
}
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
# All handlers go to response
|
| 110 |
+
workflow.add_edge("greeting_handler", "response")
|
| 111 |
+
workflow.add_edge("repair_handler", "response")
|
| 112 |
+
workflow.add_edge("summary_handler", "response")
|
| 113 |
+
workflow.add_edge("unclear_handler", "response")
|
| 114 |
+
workflow.add_edge("out_of_scope_handler", "response")
|
| 115 |
+
|
| 116 |
+
# Country nodes go to response
|
| 117 |
+
for country_code in self.country_retrievers.keys():
|
| 118 |
+
workflow.add_edge(f"{country_code}_retrieval", "response")
|
| 119 |
+
|
| 120 |
+
# Assistance sub-flow
|
| 121 |
+
workflow.add_conditional_edges(
|
| 122 |
+
"assistance_collect_info",
|
| 123 |
+
self.routing_logic.route_after_info_collection,
|
| 124 |
+
{
|
| 125 |
+
"need_email": "response", # Ask for email
|
| 126 |
+
"need_description": "response", # Ask for description
|
| 127 |
+
"ready_to_confirm": "assistance_confirm",
|
| 128 |
+
"cancelled": "response"
|
| 129 |
+
}
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# CRITICAL FIX: After response, only continue assistance if we have new user input
|
| 133 |
+
workflow.add_conditional_edges(
|
| 134 |
+
"response",
|
| 135 |
+
self._route_after_response,
|
| 136 |
+
{
|
| 137 |
+
"continue_assistance": "assistance_collect_info",
|
| 138 |
+
"end": END
|
| 139 |
+
}
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
workflow.add_conditional_edges(
|
| 143 |
+
"assistance_confirm",
|
| 144 |
+
self.routing_logic.route_after_confirmation,
|
| 145 |
+
{
|
| 146 |
+
"confirmed": "human_approval",
|
| 147 |
+
"cancelled": "response",
|
| 148 |
+
"needs_correction": "response"
|
| 149 |
+
}
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
workflow.add_conditional_edges(
|
| 153 |
+
"human_approval",
|
| 154 |
+
self.routing_logic.route_after_human_approval,
|
| 155 |
+
{
|
| 156 |
+
"approved": "process_assistance",
|
| 157 |
+
"rejected": "response",
|
| 158 |
+
"interrupt": "response"
|
| 159 |
+
}
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
workflow.add_edge("process_assistance", "response")
|
| 163 |
+
|
| 164 |
+
logger.info(f"Scalable graph built for {len(self.country_retrievers)} countries: {list(self.country_retrievers.keys())}")
|
| 165 |
+
return workflow
|
| 166 |
+
|
| 167 |
+
def _create_assistance_collect_wrapper(self):
|
| 168 |
+
"""Wrapper to ensure proper method signature for assistance collection"""
|
| 169 |
+
async def wrapper(state: MultiCountryLegalState, config: RunnableConfig) -> Dict[str, Any]:
|
| 170 |
+
result = await self.assistance_nodes.collect_assistance_info_node(state, config)
|
| 171 |
+
# Ensure supplemental_message is included if not present
|
| 172 |
+
if "supplemental_message" not in result:
|
| 173 |
+
result["supplemental_message"] = ""
|
| 174 |
+
return result
|
| 175 |
+
return wrapper
|
| 176 |
+
|
| 177 |
+
def _create_assistance_confirm_wrapper(self):
|
| 178 |
+
"""Wrapper to ensure proper method signature for assistance confirmation"""
|
| 179 |
+
async def wrapper(state: MultiCountryLegalState, config: RunnableConfig) -> Dict[str, Any]:
|
| 180 |
+
result = await self.assistance_nodes.confirm_assistance_send_node(state, config)
|
| 181 |
+
# Ensure supplemental_message is included if not present
|
| 182 |
+
if "supplemental_message" not in result:
|
| 183 |
+
result["supplemental_message"] = ""
|
| 184 |
+
return result
|
| 185 |
+
return wrapper
|
| 186 |
+
|
| 187 |
+
def _route_after_router(self, state: MultiCountryLegalState) -> str:
|
| 188 |
+
"""Route directly from router - single source of truth"""
|
| 189 |
+
router_decision = state.router_decision or "unclear"
|
| 190 |
+
logger.debug(f"Routing from router: {router_decision}")
|
| 191 |
+
return router_decision
|
| 192 |
+
|
| 193 |
+
def _route_after_response(self, state: MultiCountryLegalState) -> str:
|
| 194 |
+
"""Route after response - check if we should continue assistance workflow"""
|
| 195 |
+
# Check if we're in the middle of an assistance workflow
|
| 196 |
+
assistance_step = state.assistance_step
|
| 197 |
+
if assistance_step and assistance_step not in [None, "cancelled", "completed"]:
|
| 198 |
+
# CRITICAL FIX: Only continue if we have new user input to process
|
| 199 |
+
# This prevents infinite loops when no new user input is available
|
| 200 |
+
has_new_user_input = self._has_new_user_input(state)
|
| 201 |
+
|
| 202 |
+
if has_new_user_input:
|
| 203 |
+
logger.info(f"🔄 Continuing assistance workflow from response: {assistance_step}")
|
| 204 |
+
return "continue_assistance"
|
| 205 |
+
else:
|
| 206 |
+
logger.info("⏸️ No new user input - waiting for user response")
|
| 207 |
+
return "end"
|
| 208 |
+
|
| 209 |
+
# Normal end of conversation
|
| 210 |
+
logger.debug("✅ Ending conversation - no assistance workflow active")
|
| 211 |
+
return "end"
|
| 212 |
+
|
| 213 |
+
def _has_new_user_input(self, state: MultiCountryLegalState) -> bool:
|
| 214 |
+
"""Check if there's new user input to process in assistance workflow"""
|
| 215 |
+
if not state.messages:
|
| 216 |
+
return False
|
| 217 |
+
|
| 218 |
+
# Get the last message
|
| 219 |
+
last_message = state.messages[-1] if state.messages else None
|
| 220 |
+
|
| 221 |
+
# Check if the last message is from user and not already processed
|
| 222 |
+
if last_message and last_message.get("role") == "user":
|
| 223 |
+
# Check message metadata to see if it's been processed in current assistance step
|
| 224 |
+
message_meta = last_message.get("meta", {})
|
| 225 |
+
processed_in_step = message_meta.get("processed_in_assistance_step")
|
| 226 |
+
current_step = state.assistance_step
|
| 227 |
+
|
| 228 |
+
# If this message hasn't been processed in the current assistance step, it's new input
|
| 229 |
+
if processed_in_step != current_step:
|
| 230 |
+
logger.info(f"📥 New user input detected for assistance step: {current_step}")
|
| 231 |
+
return True
|
| 232 |
+
|
| 233 |
+
logger.info("📭 No new user input detected")
|
| 234 |
+
return False
|
| 235 |
+
|
| 236 |
+
def _create_country_retrieval_node(self, country_code: str):
|
| 237 |
+
"""Create a dynamic country retrieval node (closure factory)"""
|
| 238 |
+
async def country_retrieval_node(state: MultiCountryLegalState, config: RunnableConfig) -> Dict[str, Any]:
|
| 239 |
+
logger.info(f"Country retrieval for: {country_code}")
|
| 240 |
+
return await self.retrieval_nodes.country_retrieval_node(state, config, country_code)
|
| 241 |
+
return country_retrieval_node
|
| 242 |
+
|
| 243 |
+
async def _create_process_assistance_node(self, state: MultiCountryLegalState, config: RunnableConfig) -> Dict[str, Any]:
|
| 244 |
+
"""Process assistance after approval"""
|
| 245 |
+
logger.info("Processing assistance request")
|
| 246 |
+
|
| 247 |
+
# Mark assistance as completed with supplemental message
|
| 248 |
+
return {
|
| 249 |
+
"email_status": "sent",
|
| 250 |
+
"approval_status": "approved",
|
| 251 |
+
"assistance_step": "completed",
|
| 252 |
+
"messages": [],
|
| 253 |
+
# "supplemental_message": "Votre demande d'assistance a été traitée avec succès et envoyée à notre équipe juridique."
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
def debug_state(self, state: MultiCountryLegalState, step: str) -> None:
|
| 257 |
+
"""Debug state information"""
|
| 258 |
+
if logger.isEnabledFor(logging.DEBUG):
|
| 259 |
+
logger.debug(f"=== STATE DEBUG at {step} ===")
|
| 260 |
+
logger.debug(f"Router decision: {getattr(state, 'router_decision', 'None')}")
|
| 261 |
+
logger.debug(f"Assistance step: {getattr(state, 'assistance_step', 'None')}")
|
| 262 |
+
logger.debug(f"User email: {getattr(state, 'user_email', 'None')}")
|
| 263 |
+
logger.debug(f"Assistance description: {getattr(state, 'assistance_description', 'None')}")
|
| 264 |
+
logger.debug(f"Supplemental message: {getattr(state, 'supplemental_message', 'None')}")
|
| 265 |
+
logger.debug(f"Available countries: {list(self.country_retrievers.keys())}")
|
| 266 |
+
logger.debug("=== END STATE DEBUG ===")
|
core/human_approval_node.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# core/human_approval_node.py
|
| 2 |
+
import logging
|
| 3 |
+
from typing import Literal
|
| 4 |
+
from langchain_core.runnables import RunnableConfig
|
| 5 |
+
from langgraph.types import interrupt, Command
|
| 6 |
+
from models.state_models import MultiCountryLegalState
|
| 7 |
+
from core.assistance.email_service import AssistanceEmailService
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
class HumanApprovalNode:
|
| 13 |
+
def __init__(self):
|
| 14 |
+
self.email_service = AssistanceEmailService()
|
| 15 |
+
|
| 16 |
+
async def process_approval(
|
| 17 |
+
self,
|
| 18 |
+
state: MultiCountryLegalState,
|
| 19 |
+
config: RunnableConfig
|
| 20 |
+
) -> Command[Literal["response"]]:
|
| 21 |
+
"""Process human approval with interrupt"""
|
| 22 |
+
try:
|
| 23 |
+
# Validate required fields
|
| 24 |
+
if not state.user_email or not state.assistance_description:
|
| 25 |
+
logger.warning("Missing required fields for approval")
|
| 26 |
+
return Command(
|
| 27 |
+
goto="response",
|
| 28 |
+
update={
|
| 29 |
+
"messages": [{
|
| 30 |
+
"role": "assistant",
|
| 31 |
+
"content": "❌ Données incomplètes pour l'approbation.",
|
| 32 |
+
"meta": {}
|
| 33 |
+
}]
|
| 34 |
+
}
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
logger.info(f"🔒 Human approval node triggered for {state.user_email}")
|
| 38 |
+
|
| 39 |
+
# Prepare interrupt message
|
| 40 |
+
interrupt_message = self._format_approval_request(state)
|
| 41 |
+
|
| 42 |
+
# Trigger interrupt and wait for human input
|
| 43 |
+
moderator_input = interrupt({
|
| 44 |
+
"type": "human_approval",
|
| 45 |
+
"user_email": state.user_email,
|
| 46 |
+
"country": self._get_country_display(state),
|
| 47 |
+
"description": state.assistance_description,
|
| 48 |
+
"message": interrupt_message
|
| 49 |
+
})
|
| 50 |
+
|
| 51 |
+
logger.info(f"📥 Received moderator input: {moderator_input}")
|
| 52 |
+
|
| 53 |
+
# Parse moderator decision
|
| 54 |
+
decision = self._parse_decision(moderator_input)
|
| 55 |
+
|
| 56 |
+
# Handle approval
|
| 57 |
+
if decision["approved"]:
|
| 58 |
+
return await self._handle_approval(state, decision)
|
| 59 |
+
else:
|
| 60 |
+
return await self._handle_rejection(state, decision)
|
| 61 |
+
|
| 62 |
+
except Exception as e:
|
| 63 |
+
logger.error(f"Error in approval node: {str(e)}", exc_info=True)
|
| 64 |
+
return Command(
|
| 65 |
+
goto="response",
|
| 66 |
+
update={
|
| 67 |
+
"approval_status": "error",
|
| 68 |
+
"messages": [{
|
| 69 |
+
"role": "assistant",
|
| 70 |
+
"content": f"❌ Erreur lors de l'approbation: {str(e)}",
|
| 71 |
+
"meta": {}
|
| 72 |
+
}]
|
| 73 |
+
}
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
async def _handle_approval(
|
| 77 |
+
self,
|
| 78 |
+
state: MultiCountryLegalState,
|
| 79 |
+
decision: dict
|
| 80 |
+
) -> Command[Literal["response"]]:
|
| 81 |
+
"""Handle approved request (sends email and routes to response)"""
|
| 82 |
+
logger.info(f"✅ Request APPROVED for {state.user_email}")
|
| 83 |
+
|
| 84 |
+
# Send email
|
| 85 |
+
email_result = self.email_service.send_assistance_request(
|
| 86 |
+
user_email=state.user_email,
|
| 87 |
+
user_query=state.last_search_query or "Demande d'assistance",
|
| 88 |
+
assistance_description=state.assistance_description,
|
| 89 |
+
country=self._get_country_display(state)
|
| 90 |
+
)
|
| 91 |
+
logger.info(f"✅ Emails envoyés avec succès pour {state.user_email}")
|
| 92 |
+
|
| 93 |
+
# Build success message
|
| 94 |
+
if email_result.get("success"):
|
| 95 |
+
message_content = f"""✅ **DEMANDE APPROUVÉE ET ENVOYÉE**
|
| 96 |
+
|
| 97 |
+
📧 Un email de confirmation a été envoyé à: {state.user_email}
|
| 98 |
+
👨⚖️ Notre équipe juridique vous contactera sous 24-48 heures.
|
| 99 |
+
|
| 100 |
+
**Raison de l'approbation:** {decision['reason']}
|
| 101 |
+
**Approuvé par:** {decision['moderator_id']}
|
| 102 |
+
"""
|
| 103 |
+
else:
|
| 104 |
+
message_content = f"""⚠️ **DEMANDE APPROUVÉE MAIS ERREUR D'ENVOI**
|
| 105 |
+
|
| 106 |
+
La demande a été approuvée mais l'envoi d'email a échoué.
|
| 107 |
+
**Erreur:** {email_result.get('error', 'Unknown')}
|
| 108 |
+
|
| 109 |
+
Veuillez contacter directement: fitahiana@acfai.org
|
| 110 |
+
"""
|
| 111 |
+
|
| 112 |
+
return Command(
|
| 113 |
+
goto="response",
|
| 114 |
+
update={
|
| 115 |
+
"approval_status": "approved",
|
| 116 |
+
"approval_reason": decision["reason"],
|
| 117 |
+
"approved_by": decision["moderator_id"],
|
| 118 |
+
"approval_timestamp": datetime.now().isoformat(),
|
| 119 |
+
"email_status": "sent" if email_result.get("success") else "error",
|
| 120 |
+
"messages": [{
|
| 121 |
+
"role": "assistant",
|
| 122 |
+
"content": message_content,
|
| 123 |
+
"meta": {"approval": "approved"}
|
| 124 |
+
}]
|
| 125 |
+
}
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
async def _handle_rejection(
|
| 129 |
+
self,
|
| 130 |
+
state: MultiCountryLegalState,
|
| 131 |
+
decision: dict
|
| 132 |
+
) -> Command[Literal["response"]]: # Updated: Removed "process_assistance"
|
| 133 |
+
"""Handle rejected request"""
|
| 134 |
+
logger.info(f"❌ Request REJECTED for {state.user_email}")
|
| 135 |
+
|
| 136 |
+
message_content = f"""❌ **DEMANDE REFUSÉE**
|
| 137 |
+
|
| 138 |
+
Votre demande d'assistance n'a pas été approuvée.
|
| 139 |
+
|
| 140 |
+
**Raison:** {decision['reason']}
|
| 141 |
+
|
| 142 |
+
Si vous pensez qu'il s'agit d'une erreur, veuillez reformuler votre demande avec plus de détails.
|
| 143 |
+
"""
|
| 144 |
+
|
| 145 |
+
return Command(
|
| 146 |
+
goto="response",
|
| 147 |
+
update={
|
| 148 |
+
"approval_status": "rejected",
|
| 149 |
+
"approval_reason": decision["reason"],
|
| 150 |
+
"approved_by": decision["moderator_id"],
|
| 151 |
+
"approval_timestamp": datetime.now().isoformat(),
|
| 152 |
+
"messages": [{
|
| 153 |
+
"role": "assistant",
|
| 154 |
+
"content": message_content,
|
| 155 |
+
"meta": {"approval": "rejected"}
|
| 156 |
+
}]
|
| 157 |
+
}
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
def _format_approval_request(self, state: MultiCountryLegalState) -> str:
|
| 161 |
+
"""Format the approval request message"""
|
| 162 |
+
return f"""
|
| 163 |
+
🔒 **APPROBATION HUMAINE REQUISE**
|
| 164 |
+
|
| 165 |
+
📧 **Email:** {state.user_email}
|
| 166 |
+
🌍 **Pays:** {self._get_country_display(state)}
|
| 167 |
+
📝 **Description:** {state.assistance_description}
|
| 168 |
+
🔍 **Requête initiale:** {state.last_search_query or 'Non spécifiée'}
|
| 169 |
+
|
| 170 |
+
**Instructions:**
|
| 171 |
+
- Tapez "approve [raison]" pour approuver
|
| 172 |
+
- Tapez "reject [raison]" pour rejeter
|
| 173 |
+
|
| 174 |
+
**Exemples:**
|
| 175 |
+
- "approve Demande légitime"
|
| 176 |
+
- "reject Email invalide"
|
| 177 |
+
"""
|
| 178 |
+
|
| 179 |
+
def _parse_decision(self, user_input: str) -> dict:
|
| 180 |
+
"""Parse moderator decision from input"""
|
| 181 |
+
if not user_input or not isinstance(user_input, str):
|
| 182 |
+
return {
|
| 183 |
+
"approved": False,
|
| 184 |
+
"reason": "Input invalide",
|
| 185 |
+
"moderator_id": "system"
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
input_lower = user_input.lower().strip()
|
| 189 |
+
|
| 190 |
+
# Check for approval keywords
|
| 191 |
+
approve_keywords = ["approve", "approuver", "oui", "yes", "ok", "accept"]
|
| 192 |
+
is_approved = any(kw in input_lower for kw in approve_keywords)
|
| 193 |
+
|
| 194 |
+
# Extract reason (text after the decision keyword)
|
| 195 |
+
reason = user_input.strip()
|
| 196 |
+
for keyword in approve_keywords + ["reject", "rejeter", "non", "no"]:
|
| 197 |
+
if keyword in input_lower:
|
| 198 |
+
parts = user_input.split(keyword, 1)
|
| 199 |
+
if len(parts) > 1 and parts[1].strip():
|
| 200 |
+
reason = parts[1].strip()
|
| 201 |
+
break
|
| 202 |
+
|
| 203 |
+
if not reason or reason == user_input:
|
| 204 |
+
reason = "Approuvé par modérateur" if is_approved else "Refusé par modérateur"
|
| 205 |
+
|
| 206 |
+
return {
|
| 207 |
+
"approved": is_approved,
|
| 208 |
+
"reason": reason,
|
| 209 |
+
"moderator_id": "human_moderator"
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
def _get_country_display(self, state: MultiCountryLegalState) -> str:
|
| 213 |
+
"""Get country display name"""
|
| 214 |
+
country = state.country or state.legal_context.get("detected_country", "unknown")
|
| 215 |
+
if country == "unknown" and state.assistance_description:
|
| 216 |
+
country = MultiCountryLegalState.detect_country(state.assistance_description)
|
| 217 |
+
country_map = {
|
| 218 |
+
"benin": "Bénin",
|
| 219 |
+
"madagascar": "Madagascar"
|
| 220 |
+
}
|
| 221 |
+
logger.debug(f"Country from state: {state.country}, legal_context: {state.legal_context.get('detected_country')}, description: {country}")
|
| 222 |
+
return country_map.get(country, "Non spécifié")
|
core/nodes/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# [file name]: core/nodes/__init__.py
|
| 2 |
+
from .routing_nodes import RoutingNodes
|
| 3 |
+
from .retrieval_nodes import RetrievalNodes
|
| 4 |
+
from .response_nodes import ResponseNodes
|
| 5 |
+
from .helper_nodes import HelperNodes
|
| 6 |
+
|
| 7 |
+
# Remove AssistanceNodes from exports since it's moved to core/assistance/
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
"RoutingNodes",
|
| 11 |
+
"RetrievalNodes",
|
| 12 |
+
"ResponseNodes",
|
| 13 |
+
"HelperNodes"
|
| 14 |
+
]
|
core/nodes/base_node.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# [file name]: core/nodes/base_node.py
|
| 2 |
+
import logging
|
| 3 |
+
from typing import Dict, List, Optional, Any
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
from langchain_core.runnables import RunnableConfig
|
| 6 |
+
|
| 7 |
+
from models.state_models import MultiCountryLegalState
|
| 8 |
+
from utils.helpers import dict_to_message_obj
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
class BaseNode:
|
| 13 |
+
"""Base class with common utilities for all nodes"""
|
| 14 |
+
|
| 15 |
+
def _get_last_human_message(self, messages: List[Dict]) -> Optional[Dict]:
|
| 16 |
+
"""Get the last human message from conversation"""
|
| 17 |
+
if not messages:
|
| 18 |
+
return None
|
| 19 |
+
for msg in reversed(messages):
|
| 20 |
+
if msg.get("role", "").lower() in ("user", "human"):
|
| 21 |
+
return msg
|
| 22 |
+
return None
|
| 23 |
+
|
| 24 |
+
def _has_complete_response(self, messages: List[Dict]) -> bool:
|
| 25 |
+
"""Check if there's already an assistant response in recent messages"""
|
| 26 |
+
if not messages:
|
| 27 |
+
return False
|
| 28 |
+
for msg in reversed(messages):
|
| 29 |
+
if msg.get("role") == "assistant" and msg.get("content"):
|
| 30 |
+
return True
|
| 31 |
+
return False
|
| 32 |
+
|
| 33 |
+
def _create_error_message(self, error: str) -> Dict[str, Any]:
|
| 34 |
+
"""Create standardized error message"""
|
| 35 |
+
return {
|
| 36 |
+
"role": "assistant",
|
| 37 |
+
"content": f"Désolé, une erreur s'est produite lors du traitement de votre demande: {error}",
|
| 38 |
+
"meta": {
|
| 39 |
+
"is_error": True,
|
| 40 |
+
"timestamp": self._get_timestamp()
|
| 41 |
+
}
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
def _create_error_state(self, error: str) -> Dict[str, Any]:
|
| 45 |
+
"""Create error state with message"""
|
| 46 |
+
return {
|
| 47 |
+
"messages": [self._create_error_message(error)],
|
| 48 |
+
"search_results": f"Error: {error}"
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
def _get_timestamp(self) -> str:
|
| 52 |
+
"""Get current timestamp for message metadata"""
|
| 53 |
+
return datetime.now().isoformat()
|
| 54 |
+
|
| 55 |
+
def _update_legal_context(self, legal_context: Dict, country: str) -> Dict:
|
| 56 |
+
"""Update legal context with country information"""
|
| 57 |
+
updated = legal_context.copy() if legal_context else {}
|
| 58 |
+
|
| 59 |
+
if country in ["benin", "madagascar"]:
|
| 60 |
+
updated["detected_country"] = country
|
| 61 |
+
if country == "benin":
|
| 62 |
+
updated["jurisdiction"] = "Bénin"
|
| 63 |
+
elif country == "madagascar":
|
| 64 |
+
updated["jurisdiction"] = "Madagascar"
|
| 65 |
+
else:
|
| 66 |
+
updated["jurisdiction"] = "Unknown"
|
| 67 |
+
updated["detected_country"] = "unknown"
|
| 68 |
+
|
| 69 |
+
return updated
|
| 70 |
+
|
| 71 |
+
def _create_router_response(self, country: str, explanation: str, legal_context: Dict) -> Dict[str, Any]:
|
| 72 |
+
"""Create standardized router response"""
|
| 73 |
+
updated_context = self._update_legal_context(legal_context, country)
|
| 74 |
+
return {
|
| 75 |
+
"router_decision": country,
|
| 76 |
+
"route_explanation": explanation,
|
| 77 |
+
"legal_context": updated_context,
|
| 78 |
+
"primary_intent": country
|
| 79 |
+
}
|
core/nodes/helper_nodes.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# [file name]: core/nodes/helper_nodes.py
|
| 2 |
+
import logging
|
| 3 |
+
from typing import Dict, Any, List
|
| 4 |
+
from langchain_core.runnables import RunnableConfig
|
| 5 |
+
from langchain_core.messages import HumanMessage
|
| 6 |
+
|
| 7 |
+
from models.state_models import MultiCountryLegalState
|
| 8 |
+
from .base_node import BaseNode
|
| 9 |
+
from core.prompts.prompt_templates import PromptTemplates
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
class HelperNodes(BaseNode):
|
| 14 |
+
"""Helper nodes for unclear routes and summarization"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, llm):
|
| 17 |
+
self.llm = llm
|
| 18 |
+
self.prompts = PromptTemplates()
|
| 19 |
+
|
| 20 |
+
async def out_of_scope_node(self, state: MultiCountryLegalState, config: RunnableConfig) -> Dict[str, Any]:
|
| 21 |
+
"""Handle out-of-scope questions - redirect to legal domain"""
|
| 22 |
+
try:
|
| 23 |
+
logger.info("🚫 Out of scope question detected")
|
| 24 |
+
|
| 25 |
+
redirect_message = {
|
| 26 |
+
"role": "assistant",
|
| 27 |
+
"content": (
|
| 28 |
+
"Je suis un assistant juridique spécialisé dans le droit du Bénin et de Madagascar. "
|
| 29 |
+
"Je ne peux répondre qu'aux questions relatives au droit et aux procédures juridiques.\n\n"
|
| 30 |
+
"Comment puis-je vous aider avec vos questions juridiques ?"
|
| 31 |
+
),
|
| 32 |
+
"meta": {
|
| 33 |
+
"is_out_of_scope": True,
|
| 34 |
+
"timestamp": self._get_timestamp()
|
| 35 |
+
}
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
return {
|
| 39 |
+
"messages": [redirect_message],
|
| 40 |
+
"current_country": "out_of_scope",
|
| 41 |
+
"search_results": "Out of scope query - no legal search performed"
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
except Exception as e:
|
| 45 |
+
logger.error(f"Error in out_of_scope handler: {str(e)}")
|
| 46 |
+
return self._create_error_state(f"Error in out_of_scope: {str(e)}")
|
| 47 |
+
|
| 48 |
+
async def unclear_route_node(self, state: MultiCountryLegalState, config: RunnableConfig) -> Dict[str, Any]:
|
| 49 |
+
"""Handle unclear routing cases - for ambiguous legal queries"""
|
| 50 |
+
try:
|
| 51 |
+
s = state.model_dump()
|
| 52 |
+
route_explanation = s.get("route_explanation", "")
|
| 53 |
+
|
| 54 |
+
# This is now only for unclear LEGAL queries
|
| 55 |
+
clarification_msg = {
|
| 56 |
+
"role": "assistant",
|
| 57 |
+
"content": self.prompts.get_clarification_message(),
|
| 58 |
+
"meta": {
|
| 59 |
+
"requires_clarification": True,
|
| 60 |
+
"timestamp": self._get_timestamp()
|
| 61 |
+
}
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
return {
|
| 65 |
+
"messages": [clarification_msg],
|
| 66 |
+
"search_results": "Country clarification needed"
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
except Exception as e:
|
| 70 |
+
logger.error(f"Error in unclear route handling: {str(e)}")
|
| 71 |
+
return self._create_error_state(f"Error in unclear route: {str(e)}")
|
| 72 |
+
|
| 73 |
+
async def conversation_summarization_node(self, state: MultiCountryLegalState, config: RunnableConfig) -> Dict[str, Any]:
|
| 74 |
+
"""Generate summary of conversation history"""
|
| 75 |
+
try:
|
| 76 |
+
s = state.model_dump()
|
| 77 |
+
messages = s.get("messages", [])
|
| 78 |
+
|
| 79 |
+
logger.info(f"📋 Generating conversation summary for {len(messages)} messages")
|
| 80 |
+
|
| 81 |
+
summary = await self._generate_conversation_summary(messages)
|
| 82 |
+
|
| 83 |
+
return {
|
| 84 |
+
"messages": [{
|
| 85 |
+
"role": "assistant",
|
| 86 |
+
"content": summary,
|
| 87 |
+
"meta": {
|
| 88 |
+
"is_summary": True,
|
| 89 |
+
"conversation_length": len(messages),
|
| 90 |
+
"timestamp": self._get_timestamp()
|
| 91 |
+
}
|
| 92 |
+
}],
|
| 93 |
+
"search_results": "Conversation summary generated - no legal search performed"
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
except Exception as e:
|
| 97 |
+
logger.error(f"Error in conversation summarization: {str(e)}")
|
| 98 |
+
return self._create_error_state(f"Error in summarization: {str(e)}")
|
| 99 |
+
|
| 100 |
+
async def _generate_conversation_summary(self, messages: List[Dict]) -> str:
|
| 101 |
+
"""Use LLM to generate conversation summary"""
|
| 102 |
+
conversation_messages = [
|
| 103 |
+
msg for msg in messages
|
| 104 |
+
if msg.get("role") in ["user", "assistant"]
|
| 105 |
+
]
|
| 106 |
+
|
| 107 |
+
if len(conversation_messages) <= 2:
|
| 108 |
+
return "Notre conversation vient juste de commencer. Nous n'avons pas encore beaucoup échangé."
|
| 109 |
+
|
| 110 |
+
conversation_text = ""
|
| 111 |
+
for i, msg in enumerate(conversation_messages):
|
| 112 |
+
role = "Utilisateur" if msg.get("role") == "user" else "Assistant"
|
| 113 |
+
content = msg.get("content", "")
|
| 114 |
+
conversation_text += f"{role}: {content}\n\n"
|
| 115 |
+
|
| 116 |
+
summary_prompt = f"""
|
| 117 |
+
Vous êtes un assistant juridique. Résumez la conversation suivante entre l'utilisateur et vous-même.
|
| 118 |
+
|
| 119 |
+
**CONVERSATION:**
|
| 120 |
+
{conversation_text}
|
| 121 |
+
|
| 122 |
+
**INSTRUCTIONS:**
|
| 123 |
+
- Faites un résumé concis et clair
|
| 124 |
+
- Mettez en évidence les points juridiques principaux discutés
|
| 125 |
+
- Mentionnez les pays concernés (Bénin/Madagascar) si pertinents
|
| 126 |
+
- Gardez un ton professionnel mais accessible
|
| 127 |
+
- Maximum 5-7 phrases
|
| 128 |
+
|
| 129 |
+
**RÉSUMÉ:**
|
| 130 |
+
"""
|
| 131 |
+
|
| 132 |
+
try:
|
| 133 |
+
response = await self.llm.ainvoke([HumanMessage(content=summary_prompt)])
|
| 134 |
+
return response.content if hasattr(response, 'content') else str(response)
|
| 135 |
+
except Exception as e:
|
| 136 |
+
logger.error(f"LLM summarization failed: {e}")
|
| 137 |
+
user_messages = [m for m in conversation_messages if m.get("role") == "user"]
|
| 138 |
+
assistant_messages = [m for m in conversation_messages if m.get("role") == "assistant"]
|
| 139 |
+
|
| 140 |
+
return f"""**Résumé de notre conversation:**
|
| 141 |
+
|
| 142 |
+
- **Échanges totaux**: {len(conversation_messages)} messages
|
| 143 |
+
- **Questions de l'utilisateur**: {len(user_messages)}
|
| 144 |
+
- **Réponses fournies**: {len(assistant_messages)}
|
| 145 |
+
- **Dernier échange**: {conversation_messages[-1].get('content', '')[:100]}...
|
| 146 |
+
|
| 147 |
+
*Pour un résumé détaillé, veuillez reposer votre question.*"""
|
core/nodes/response_nodes.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# [file name]: core/nodes/response_nodes.py
|
| 2 |
+
import logging
|
| 3 |
+
import time
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
from typing import Dict, Any
|
| 6 |
+
from langchain_core.runnables import RunnableConfig
|
| 7 |
+
|
| 8 |
+
from models.state_models import MultiCountryLegalState
|
| 9 |
+
from utils.helpers import dict_to_message_obj, message_obj_to_dict
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
class ResponseNodes:
|
| 14 |
+
def __init__(self, llm):
|
| 15 |
+
self.llm = llm
|
| 16 |
+
|
| 17 |
+
async def response_generation_node(self, state: MultiCountryLegalState, config: RunnableConfig) -> Dict[str, Any]:
|
| 18 |
+
"""Generate appropriate responses based on current state"""
|
| 19 |
+
|
| 20 |
+
assistance_step = state.assistance_step
|
| 21 |
+
|
| 22 |
+
# Handle assistance workflow responses
|
| 23 |
+
if assistance_step == "collecting_email":
|
| 24 |
+
response_content = """
|
| 25 |
+
Je vois que vous souhaitez parler à un avocat. Pour vous aider, j'ai besoin de votre adresse email pour que notre équipe puisse vous contacter.
|
| 26 |
+
|
| 27 |
+
📧 **Veuillez me fournir votre adresse email :**
|
| 28 |
+
"""
|
| 29 |
+
return {
|
| 30 |
+
"messages": [{
|
| 31 |
+
"role": "assistant",
|
| 32 |
+
"content": response_content,
|
| 33 |
+
"meta": {"assistance_step": "collecting_email"}
|
| 34 |
+
}],
|
| 35 |
+
"supplemental_message": "" # Clear any previous supplemental messages
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
elif assistance_step == "collecting_description":
|
| 39 |
+
response_content = f"""
|
| 40 |
+
Merci ! Votre email ({state.user_email}) a été enregistré.
|
| 41 |
+
|
| 42 |
+
📝 **Veuillez maintenant décrire brièvement votre situation :**
|
| 43 |
+
- Quelle est votre question juridique ?
|
| 44 |
+
- De quel pays s'agit-il ?
|
| 45 |
+
- Quel type d'assistance recherchez-vous ?
|
| 46 |
+
|
| 47 |
+
Cette description aidera notre équipe à mieux vous orienter.
|
| 48 |
+
"""
|
| 49 |
+
return {
|
| 50 |
+
"messages": [{
|
| 51 |
+
"role": "assistant",
|
| 52 |
+
"content": response_content,
|
| 53 |
+
"meta": {"assistance_step": "collecting_description"}
|
| 54 |
+
}],
|
| 55 |
+
"supplemental_message": "" # Clear any previous supplemental messages
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
elif assistance_step == "confirming_send":
|
| 59 |
+
response_content = f"""
|
| 60 |
+
📋 **RÉCAPITULATIF DE VOTRE DEMANDE :**
|
| 61 |
+
|
| 62 |
+
📧 **Email :** {state.user_email}
|
| 63 |
+
📝 **Description :** {state.assistance_description}
|
| 64 |
+
|
| 65 |
+
✅ **Confirmez-vous l'envoi de cette demande à notre équipe juridique ?**
|
| 66 |
+
|
| 67 |
+
Répondez par :
|
| 68 |
+
- **"oui"** pour confirmer et envoyer
|
| 69 |
+
- **"non"** pour annuler et modifier
|
| 70 |
+
"""
|
| 71 |
+
return {
|
| 72 |
+
"messages": [{
|
| 73 |
+
"role": "assistant",
|
| 74 |
+
"content": response_content,
|
| 75 |
+
"meta": {"assistance_step": "confirming_send"}
|
| 76 |
+
}],
|
| 77 |
+
"supplemental_message": "" # Clear any previous supplemental messages
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
else:
|
| 81 |
+
# Default LLM response for non-assistance flows
|
| 82 |
+
return await self._generate_llm_response(state, config)
|
| 83 |
+
|
| 84 |
+
async def _generate_llm_response(self, state: MultiCountryLegalState, config: RunnableConfig) -> Dict[str, Any]:
|
| 85 |
+
"""Generate LLM-based response for normal conversation flows"""
|
| 86 |
+
try:
|
| 87 |
+
# Include supplemental message in the response if present
|
| 88 |
+
supplemental_message = state.supplemental_message or ""
|
| 89 |
+
|
| 90 |
+
# Synthesize response using LLM
|
| 91 |
+
response_content = await self._synthesize_response(state, supplemental_message)
|
| 92 |
+
|
| 93 |
+
return {
|
| 94 |
+
"messages": [{
|
| 95 |
+
"role": "assistant",
|
| 96 |
+
"content": response_content,
|
| 97 |
+
"meta": {
|
| 98 |
+
"timestamp": datetime.now().isoformat(),
|
| 99 |
+
"generated_by": "llm"
|
| 100 |
+
}
|
| 101 |
+
}],
|
| 102 |
+
"supplemental_message": "" # Clear after using
|
| 103 |
+
}
|
| 104 |
+
except Exception as e:
|
| 105 |
+
logger.error(f"Error generating LLM response: {str(e)}")
|
| 106 |
+
return {
|
| 107 |
+
"messages": [{
|
| 108 |
+
"role": "assistant",
|
| 109 |
+
"content": self._create_error_message(str(e)),
|
| 110 |
+
"meta": {"is_error": True}
|
| 111 |
+
}],
|
| 112 |
+
"supplemental_message": f"Erreur: {str(e)}"
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
async def _synthesize_response(self, state: MultiCountryLegalState, supplemental_message: str = "") -> str:
|
| 116 |
+
"""Synthesize final response based on graph execution"""
|
| 117 |
+
s = state.model_dump()
|
| 118 |
+
|
| 119 |
+
# Build context-aware system prompt
|
| 120 |
+
system_prompt = self._build_system_prompt(state, supplemental_message)
|
| 121 |
+
conversation_messages = self._build_conversation_messages(system_prompt, s.get("messages", []))
|
| 122 |
+
|
| 123 |
+
# Always use LLM to generate final response
|
| 124 |
+
logger.info("🧠 Generating final response with LLM")
|
| 125 |
+
ai_resp = await self.llm.ainvoke(conversation_messages)
|
| 126 |
+
|
| 127 |
+
return ai_resp.content if hasattr(ai_resp, 'content') else str(ai_resp)
|
| 128 |
+
|
| 129 |
+
def _build_system_prompt(self, state: MultiCountryLegalState, supplemental_message: str = "") -> str:
|
| 130 |
+
"""Build context-aware system prompt"""
|
| 131 |
+
s = state.model_dump()
|
| 132 |
+
|
| 133 |
+
base_prompt = """Vous êtes un assistant juridique expert spécialisé dans le droit du Bénin et de Madagascar.
|
| 134 |
+
|
| 135 |
+
TÂCHE: Fournir une réponse claire, précise et utile à l'utilisateur.
|
| 136 |
+
"""
|
| 137 |
+
|
| 138 |
+
# Add supplemental message if available
|
| 139 |
+
if supplemental_message:
|
| 140 |
+
base_prompt += f"\nMESSAGE IMPORTANT: {supplemental_message}\n"
|
| 141 |
+
|
| 142 |
+
# Add legal context if available
|
| 143 |
+
country_name = s.get("legal_context", {}).get("jurisdiction", "Unknown")
|
| 144 |
+
if country_name != "Unknown":
|
| 145 |
+
base_prompt += f"\nCONTEXTE JURIDIQUE: Vous répondez dans le cadre du droit {country_name}.\n"
|
| 146 |
+
|
| 147 |
+
# Add search results if available
|
| 148 |
+
search_results = s.get("search_results", "")
|
| 149 |
+
if search_results and "RECHERCHE JURIDIQUE" in search_results:
|
| 150 |
+
base_prompt += f"\nINFORMATIONS JURIDIQUES DISPONIBLES:\n{search_results}\n"
|
| 151 |
+
base_prompt += """
|
| 152 |
+
INSTRUCTIONS POUR LA RÉPONSE JURIDIQUE:
|
| 153 |
+
- Basez-vous sur les informations juridiques disponibles
|
| 154 |
+
- Citez les articles de loi pertinents si possible
|
| 155 |
+
- Soyez précis mais accessible aux non-juristes
|
| 156 |
+
- Indiquez si certaines informations manquent
|
| 157 |
+
"""
|
| 158 |
+
else:
|
| 159 |
+
base_prompt += """
|
| 160 |
+
INSTRUCTIONS GÉNÉRALES:
|
| 161 |
+
- Répondez de manière naturelle et utile
|
| 162 |
+
- Adaptez votre ton au contexte de la conversation
|
| 163 |
+
- Soyez empathique et professionnel
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
# Add assistance context if relevant
|
| 167 |
+
if s.get("assistance_requested"):
|
| 168 |
+
base_prompt += "\nCONTEXTE ASSISTANCE: L'utilisateur a demandé à parler à un avocat.\n"
|
| 169 |
+
|
| 170 |
+
if s.get("approval_status") == "rejected":
|
| 171 |
+
base_prompt += "\nCONTEXTE: La demande d'assistance a été rejetée. Expliquez poliment et proposez des alternatives.\n"
|
| 172 |
+
elif s.get("approval_status") == "approved":
|
| 173 |
+
base_prompt += "\nCONTEXTE: La demande d'assistance a été approuvée. Confirmez et donnez les prochaines étapes.\n"
|
| 174 |
+
|
| 175 |
+
return base_prompt
|
| 176 |
+
|
| 177 |
+
def _build_conversation_messages(self, system_prompt: str, messages: list) -> list:
|
| 178 |
+
"""Build conversation messages for LLM"""
|
| 179 |
+
from langchain_core.messages import SystemMessage
|
| 180 |
+
|
| 181 |
+
conversation_messages = [SystemMessage(content=system_prompt)]
|
| 182 |
+
|
| 183 |
+
# Include recent conversation history (last 6 messages)
|
| 184 |
+
recent_messages = messages[-6:] if len(messages) > 6 else messages
|
| 185 |
+
|
| 186 |
+
# Convert to message objects
|
| 187 |
+
conversation_messages.extend(dict_to_message_obj(m) for m in recent_messages)
|
| 188 |
+
|
| 189 |
+
return conversation_messages
|
| 190 |
+
|
| 191 |
+
async def human_approval_node(self, state: MultiCountryLegalState, config: RunnableConfig) -> Dict[str, Any]:
|
| 192 |
+
"""Handle human approval interrupts"""
|
| 193 |
+
logger.info("👨⚖️ Human approval node - triggering interrupt")
|
| 194 |
+
|
| 195 |
+
return {
|
| 196 |
+
"approval_status": "pending",
|
| 197 |
+
"messages": [{
|
| 198 |
+
"role": "assistant",
|
| 199 |
+
"content": "⏳ Votre demande d'assistance nécessite une approbation manuelle. Un modérateur va examiner votre demande.",
|
| 200 |
+
"meta": {"requires_approval": True}
|
| 201 |
+
}],
|
| 202 |
+
"supplemental_message": ""
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
async def process_assistance_node(self, state: MultiCountryLegalState, config: RunnableConfig) -> Dict[str, Any]:
|
| 206 |
+
"""Process assistance after approval - let LLM generate final message"""
|
| 207 |
+
logger.info("📧 Processing assistance request")
|
| 208 |
+
|
| 209 |
+
return {
|
| 210 |
+
"email_status": "sent",
|
| 211 |
+
"approval_status": "approved",
|
| 212 |
+
"assistance_step": "completed",
|
| 213 |
+
"messages": [], # Empty messages so LLM generates the final response
|
| 214 |
+
"supplemental_message": "Votre demande d'assistance a été traitée avec succès."
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
def _create_error_message(self, error: str) -> str:
|
| 218 |
+
"""Create error message"""
|
| 219 |
+
return f"❌ Désolé, une erreur s'est produite: {error}\n\nVeuillez réessayer ou reformuler votre demande."
|
core/nodes/retrieval_nodes.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# [file name]: core/nodes/retrieval_nodes.py
|
| 2 |
+
import logging
|
| 3 |
+
from typing import Dict, Any
|
| 4 |
+
from langchain_core.runnables import RunnableConfig
|
| 5 |
+
|
| 6 |
+
from models.state_models import MultiCountryLegalState
|
| 7 |
+
from core.retriever import LegalRetriever
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
class RetrievalNodes:
|
| 12 |
+
"""Scalable legal retrieval nodes for any number of countries"""
|
| 13 |
+
|
| 14 |
+
def __init__(self, country_retrievers: Dict[str, LegalRetriever]):
|
| 15 |
+
self.country_retrievers = country_retrievers
|
| 16 |
+
|
| 17 |
+
async def country_retrieval_node(self, state: MultiCountryLegalState, config: RunnableConfig, country_code: str) -> Dict[str, Any]:
|
| 18 |
+
"""Generic country retrieval for any country"""
|
| 19 |
+
try:
|
| 20 |
+
if country_code not in self.country_retrievers:
|
| 21 |
+
logger.error(f"❌ Country not configured: {country_code}")
|
| 22 |
+
return {
|
| 23 |
+
"search_results": f"Country {country_code} not available",
|
| 24 |
+
"detected_articles": [],
|
| 25 |
+
"supplemental_message": f"Pays {country_code} non configuré dans le système."
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
retriever = self.country_retrievers[country_code]
|
| 29 |
+
s = state.model_dump()
|
| 30 |
+
last_human = self._get_last_human_message(s.get("messages", []))
|
| 31 |
+
|
| 32 |
+
if not last_human:
|
| 33 |
+
return {
|
| 34 |
+
"search_results": f"No query for {country_code} retrieval",
|
| 35 |
+
"detected_articles": [],
|
| 36 |
+
"supplemental_message": "Aucune requête trouvée pour la recherche."
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
user_query = last_human.get("content", "").strip()
|
| 40 |
+
if not user_query:
|
| 41 |
+
return {
|
| 42 |
+
"search_results": f"Empty query for {country_code} retrieval",
|
| 43 |
+
"detected_articles": [],
|
| 44 |
+
"supplemental_message": "Requête vide pour la recherche."
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
logger.info(f"🌍 Performing {country_code} retrieval for: '{user_query[:50]}...'")
|
| 48 |
+
|
| 49 |
+
enhanced_docs, detected_articles, applied_filters, supplemental_message = await retriever.smart_legal_query(user_query, country_code)
|
| 50 |
+
|
| 51 |
+
search_results = retriever.format_search_results(
|
| 52 |
+
user_query, enhanced_docs, detected_articles, applied_filters, country_code, supplemental_message
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
logger.info(f"📚 Retrieved {len(enhanced_docs)} documents for {country_code}")
|
| 56 |
+
|
| 57 |
+
return {
|
| 58 |
+
"search_results": search_results,
|
| 59 |
+
"detected_articles": detected_articles,
|
| 60 |
+
"last_search_query": user_query,
|
| 61 |
+
"supplemental_message": supplemental_message, # Pass the supplemental message to state
|
| 62 |
+
# Store complex data in search_metadata instead of legal_context
|
| 63 |
+
"search_metadata": {
|
| 64 |
+
"applied_filters": applied_filters,
|
| 65 |
+
"documents_count": len(enhanced_docs),
|
| 66 |
+
"supplemental_message": supplemental_message
|
| 67 |
+
}
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
except Exception as e:
|
| 71 |
+
logger.error(f"Error in {country_code} retrieval: {str(e)}")
|
| 72 |
+
return {
|
| 73 |
+
"search_results": f"Erreur lors de la recherche {country_code}: {str(e)}",
|
| 74 |
+
"detected_articles": [],
|
| 75 |
+
"supplemental_message": f"Erreur lors de la recherche: {str(e)}"
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
def _get_last_human_message(self, messages: list) -> Dict[str, Any]:
|
| 79 |
+
"""Get the last human message"""
|
| 80 |
+
for msg in reversed(messages):
|
| 81 |
+
if msg.get("role") in ["user", "human"]:
|
| 82 |
+
return msg
|
| 83 |
+
return {}
|
core/nodes/routing_nodes.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# [file name]: core/nodes/routing_nodes.py
|
| 2 |
+
import logging
|
| 3 |
+
from typing import Dict, Any
|
| 4 |
+
from langchain_core.runnables import RunnableConfig
|
| 5 |
+
|
| 6 |
+
from models.state_models import MultiCountryLegalState
|
| 7 |
+
from core.router import CountryRouter
|
| 8 |
+
from .base_node import BaseNode
|
| 9 |
+
from core.prompts.prompt_templates import PromptTemplates
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
class RoutingNodes(BaseNode):
|
| 14 |
+
"""Router, greeting, and conversation repair nodes"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, router: CountryRouter, conversation_repair, llm):
|
| 17 |
+
self.router = router
|
| 18 |
+
self.conversation_repair = conversation_repair
|
| 19 |
+
self.llm = llm
|
| 20 |
+
self.prompts = PromptTemplates()
|
| 21 |
+
|
| 22 |
+
async def router_node(self, state: MultiCountryLegalState, config: RunnableConfig) -> Dict[str, Any]:
|
| 23 |
+
"""Enhanced router that detects primary intent with state awareness"""
|
| 24 |
+
try:
|
| 25 |
+
s = state.model_dump()
|
| 26 |
+
|
| 27 |
+
# CRITICAL: Check if we're continuing an assistance workflow
|
| 28 |
+
# This prevents the router from misclassifying continuation messages
|
| 29 |
+
assistance_step = s.get("assistance_step")
|
| 30 |
+
if assistance_step and assistance_step not in [None, "cancelled", "completed"]:
|
| 31 |
+
logger.info(f"⏩ Bypassing router - continuing assistance at step: {assistance_step}")
|
| 32 |
+
return {
|
| 33 |
+
"router_decision": "assistance_request",
|
| 34 |
+
"route_explanation": f"Continuing assistance workflow: {assistance_step}",
|
| 35 |
+
"assistance_step": assistance_step, # Ensure step persists
|
| 36 |
+
"assistance_requested": True
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
# Normal routing for new messages
|
| 40 |
+
return await self._perform_normal_routing(state, s)
|
| 41 |
+
|
| 42 |
+
except Exception as e:
|
| 43 |
+
logger.error(f"Router error: {str(e)}")
|
| 44 |
+
legal_context = state.legal_context if hasattr(state, 'legal_context') else {}
|
| 45 |
+
return self._create_router_response("unclear", f"Router error: {str(e)}", legal_context)
|
| 46 |
+
|
| 47 |
+
async def _perform_normal_routing(self, state: MultiCountryLegalState, state_dict: Dict) -> Dict[str, Any]:
|
| 48 |
+
"""Perform normal routing for new user queries"""
|
| 49 |
+
if not state_dict.get("messages"):
|
| 50 |
+
logger.warning("No messages in state for router")
|
| 51 |
+
return self._create_router_response("unclear", "No messages in state", state_dict.get("legal_context", {}))
|
| 52 |
+
|
| 53 |
+
last_human = self._get_last_human_message(state_dict.get("messages", []))
|
| 54 |
+
if not last_human:
|
| 55 |
+
logger.warning("No user query found in router")
|
| 56 |
+
return self._create_router_response("unclear", "No user query found", state_dict.get("legal_context", {}))
|
| 57 |
+
|
| 58 |
+
user_query = last_human.get("content", "").strip()
|
| 59 |
+
if not user_query:
|
| 60 |
+
logger.warning("Empty user query in router")
|
| 61 |
+
return self._create_router_response("unclear", "Empty user query", state_dict.get("legal_context", {}))
|
| 62 |
+
|
| 63 |
+
logger.info(f"🔀 Routing query: '{user_query[:50]}...'")
|
| 64 |
+
routing_result = await self.router.route_query(user_query, state_dict["messages"])
|
| 65 |
+
|
| 66 |
+
primary_intent = routing_result.country
|
| 67 |
+
logger.info(f"🎯 Router decision: {primary_intent} ({routing_result.confidence}) - {routing_result.method}")
|
| 68 |
+
|
| 69 |
+
updated_context = self._update_legal_context(state_dict["legal_context"], primary_intent)
|
| 70 |
+
|
| 71 |
+
response = {
|
| 72 |
+
"router_decision": primary_intent,
|
| 73 |
+
"route_explanation": f"{routing_result.method}: {routing_result.explanation}",
|
| 74 |
+
"legal_context": updated_context,
|
| 75 |
+
"primary_intent": primary_intent
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
# If this is an assistance request, initialize the workflow
|
| 79 |
+
if primary_intent == "assistance_request":
|
| 80 |
+
response.update({
|
| 81 |
+
"assistance_step": "collecting_email",
|
| 82 |
+
"assistance_requested": True
|
| 83 |
+
})
|
| 84 |
+
|
| 85 |
+
return response
|
| 86 |
+
|
| 87 |
+
async def greeting_small_talk_node(self, state: MultiCountryLegalState, config: RunnableConfig) -> Dict[str, Any]:
|
| 88 |
+
"""Handle greetings and small talk"""
|
| 89 |
+
try:
|
| 90 |
+
s = state.model_dump()
|
| 91 |
+
last_human = self._get_last_human_message(s.get("messages", []))
|
| 92 |
+
user_query = last_human.get("content", "").lower() if last_human else ""
|
| 93 |
+
|
| 94 |
+
logger.info(f"👋 Handling greeting/small_talk: '{user_query[:30]}...'")
|
| 95 |
+
|
| 96 |
+
greeting_response = self.prompts.generate_greeting_response(user_query)
|
| 97 |
+
|
| 98 |
+
return {
|
| 99 |
+
"messages": [{
|
| 100 |
+
"role": "assistant",
|
| 101 |
+
"content": greeting_response,
|
| 102 |
+
"meta": {
|
| 103 |
+
"is_greeting": True,
|
| 104 |
+
"timestamp": self._get_timestamp()
|
| 105 |
+
}
|
| 106 |
+
}],
|
| 107 |
+
"search_results": "Greeting handled - no legal search performed"
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
except Exception as e:
|
| 111 |
+
logger.error(f"Error in greeting node: {str(e)}")
|
| 112 |
+
return self._create_error_state(f"Error in greeting: {str(e)}")
|
| 113 |
+
|
| 114 |
+
async def conversation_repair_node(self, state: MultiCountryLegalState, config: RunnableConfig) -> Dict[str, Any]:
|
| 115 |
+
"""Unified repair handling with LLM"""
|
| 116 |
+
try:
|
| 117 |
+
s = state.model_dump()
|
| 118 |
+
last_human = self._get_last_human_message(s.get("messages", []))
|
| 119 |
+
user_query = last_human.get("content", "") if last_human else ""
|
| 120 |
+
|
| 121 |
+
logger.info(f"🔧 Handling repair request: '{user_query[:30]}...'")
|
| 122 |
+
|
| 123 |
+
repair_response = await self.conversation_repair.generate_repair_response(
|
| 124 |
+
user_query, s.get("messages", []), self.llm
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
return {
|
| 128 |
+
"messages": [{
|
| 129 |
+
"role": "assistant",
|
| 130 |
+
"content": repair_response,
|
| 131 |
+
"meta": {
|
| 132 |
+
"is_repair_response": True,
|
| 133 |
+
"timestamp": self._get_timestamp()
|
| 134 |
+
}
|
| 135 |
+
}],
|
| 136 |
+
"search_results": "Repair handled - no legal search performed"
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
except Exception as e:
|
| 140 |
+
logger.error(f"Error in repair node: {str(e)}")
|
| 141 |
+
return self._create_error_state(f"Error in repair: {str(e)}")
|
| 142 |
+
|
| 143 |
+
def _create_router_response(self, decision: str, explanation: str, legal_context: Dict) -> Dict[str, Any]:
|
| 144 |
+
"""Create a standardized router response"""
|
| 145 |
+
return {
|
| 146 |
+
"router_decision": decision,
|
| 147 |
+
"route_explanation": explanation,
|
| 148 |
+
"legal_context": legal_context,
|
| 149 |
+
"primary_intent": decision
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
def _get_last_human_message(self, messages: list) -> Dict[str, Any]:
|
| 153 |
+
"""Get the last human message from conversation history"""
|
| 154 |
+
for msg in reversed(messages):
|
| 155 |
+
if msg.get("role") in ["user", "human"]:
|
| 156 |
+
return msg
|
| 157 |
+
return {}
|
| 158 |
+
|
| 159 |
+
def _update_legal_context(self, legal_context: Dict, primary_intent: str) -> Dict:
|
| 160 |
+
"""Update legal context based on routing decision"""
|
| 161 |
+
updated_context = legal_context.copy()
|
| 162 |
+
|
| 163 |
+
# Map router decisions to detected_country
|
| 164 |
+
country_mapping = {
|
| 165 |
+
"benin": "benin",
|
| 166 |
+
"madagascar": "madagascar",
|
| 167 |
+
"assistance_request": updated_context.get("detected_country", "unknown"),
|
| 168 |
+
"greeting_small_talk": "unknown",
|
| 169 |
+
"conversation_repair": updated_context.get("detected_country", "unknown"),
|
| 170 |
+
"conversation_summarization": updated_context.get("detected_country", "unknown"),
|
| 171 |
+
"unclear": "unknown",
|
| 172 |
+
"out_of_scope": "unknown"
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
updated_context["detected_country"] = country_mapping.get(primary_intent, "unknown")
|
| 176 |
+
updated_context["primary_intent"] = primary_intent
|
| 177 |
+
|
| 178 |
+
return updated_context
|
| 179 |
+
|
| 180 |
+
def _get_timestamp(self) -> str:
|
| 181 |
+
"""Get current timestamp"""
|
| 182 |
+
from datetime import datetime
|
| 183 |
+
return datetime.now().isoformat()
|
| 184 |
+
|
| 185 |
+
def _create_error_state(self, error_message: str) -> Dict[str, Any]:
|
| 186 |
+
"""Create error state response"""
|
| 187 |
+
return {
|
| 188 |
+
"messages": [{
|
| 189 |
+
"role": "assistant",
|
| 190 |
+
"content": f"❌ Désolé, une erreur s'est produite. Veuillez réessayer.",
|
| 191 |
+
"meta": {"error": error_message}
|
| 192 |
+
}]
|
| 193 |
+
}
|
core/prompts/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .prompt_templates import PromptTemplates
|
| 2 |
+
|
| 3 |
+
__all__ = ["PromptTemplates"]
|
core/prompts/prompt_templates.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# [file name]: core/prompts/prompt_templates.py
|
| 2 |
+
class PromptTemplates:
|
| 3 |
+
"""All prompt templates used in the graph"""
|
| 4 |
+
|
| 5 |
+
@staticmethod
|
| 6 |
+
def get_email_request_message() -> str:
|
| 7 |
+
return """📧 **Demande d'assistance juridique**
|
| 8 |
+
|
| 9 |
+
Pour vous mettre en relation avec un avocat, j'ai besoin de votre adresse email.
|
| 10 |
+
|
| 11 |
+
**Votre email :**"""
|
| 12 |
+
|
| 13 |
+
@staticmethod
|
| 14 |
+
def get_description_prompt(email: str) -> str:
|
| 15 |
+
return f"""📝 **Description de votre besoin**
|
| 16 |
+
|
| 17 |
+
Merci ! Email enregistré : {email}
|
| 18 |
+
|
| 19 |
+
Maintenant, décrivez-moi **comment vous souhaitez être assisté(e)** :
|
| 20 |
+
|
| 21 |
+
Exemples :
|
| 22 |
+
• "Consultation téléphonique de 30 minutes sur le droit de la famille"
|
| 23 |
+
• "Avis écrit sur un contrat de travail"
|
| 24 |
+
• "Accompagnement pour une procédure de divorce"
|
| 25 |
+
• "Explication sur mes droits successoraux"
|
| 26 |
+
|
| 27 |
+
**Votre description :**"""
|
| 28 |
+
|
| 29 |
+
@staticmethod
|
| 30 |
+
def get_confirmation_prompt(data: dict) -> str:
|
| 31 |
+
email = data.get("email", "Non fourni")
|
| 32 |
+
description = data.get("description", "Non fournie")
|
| 33 |
+
|
| 34 |
+
return f"""✅ **Confirmation d'envoi**
|
| 35 |
+
|
| 36 |
+
Veuillez confirmer l'envoi de votre demande d'assistance :
|
| 37 |
+
|
| 38 |
+
📧 **Email** : {email}
|
| 39 |
+
📋 **Description** : {description}
|
| 40 |
+
|
| 41 |
+
**L'avocat vous contactera directement dans les 24-48 heures.**
|
| 42 |
+
|
| 43 |
+
🔔 **Confirmez-vous l'envoi ?** (répondez par OUI/NON)"""
|
| 44 |
+
|
| 45 |
+
@staticmethod
|
| 46 |
+
def get_missing_info_prompt(current_step: str, has_email: bool) -> str:
|
| 47 |
+
if current_step == "collecting_email":
|
| 48 |
+
return "📧 **Email manquant** : Pourriez-vous me donner votre adresse email ?"
|
| 49 |
+
else:
|
| 50 |
+
return "📝 **Description manquante** : Pourriez-vous décrire comment vous souhaitez être assisté(e) ?"
|
| 51 |
+
|
| 52 |
+
@staticmethod
|
| 53 |
+
def get_non_legal_response() -> str:
|
| 54 |
+
return """🔍 **Hors de mon domaine d'expertise**
|
| 55 |
+
|
| 56 |
+
Je suis un assistant juridique spécialisé pour le Bénin et Madagascar.
|
| 57 |
+
|
| 58 |
+
**Je peux vous aider avec :**
|
| 59 |
+
⚖️ **Questions juridiques** : lois, droits, procédures
|
| 60 |
+
📚 **Textes de loi** : articles, codes, décrets
|
| 61 |
+
🔧 **Assistance légale** : démarches, formalités
|
| 62 |
+
👨⚖️ **Connexion avocat** : assistance humaine
|
| 63 |
+
|
| 64 |
+
**Exemples de questions que je peux traiter :**
|
| 65 |
+
• "Procédure de divorce au Bénin"
|
| 66 |
+
• "Droits des enfants à Madagascar"
|
| 67 |
+
• "Articles sur le droit du travail"
|
| 68 |
+
• "Comment contacter un avocat ?"
|
| 69 |
+
|
| 70 |
+
Posez-moi une question juridique !"""
|
| 71 |
+
|
| 72 |
+
@staticmethod
|
| 73 |
+
def get_clarification_message() -> str:
|
| 74 |
+
return """Je ne peux pas déterminer de quel pays vous parlez. Pourriez-vous préciser si votre question concerne le droit du **Bénin** ou de **Madagascar** ?"""
|
| 75 |
+
|
| 76 |
+
@staticmethod
|
| 77 |
+
def generate_greeting_response(query: str) -> str:
|
| 78 |
+
"""Generate appropriate greeting responses"""
|
| 79 |
+
query_lower = query.lower()
|
| 80 |
+
|
| 81 |
+
if any(word in query_lower for word in ["bonjour", "hello", "hi"]):
|
| 82 |
+
return "👋 Bonjour ! Je suis votre assistant juridique spécialisé pour le Bénin et Madagascar. Comment puis-je vous aider aujourd'hui ?"
|
| 83 |
+
elif any(word in query_lower for word in ["salut", "coucou"]):
|
| 84 |
+
return "👋 Salut ! Je suis votre assistant juridique. Posez-moi vos questions sur le droit béninois ou malgache !"
|
| 85 |
+
elif any(word in query_lower for word in ["comment ça va", "ça va", "comment vas-tu"]):
|
| 86 |
+
return "😊 Je vais très bien, merci ! Je suis prêt à vous aider avec vos questions juridiques sur le Bénin ou Madagascar."
|
| 87 |
+
elif any(word in query_lower for word in ["merci", "thanks"]):
|
| 88 |
+
return "🤝 Je vous en prie ! N'hésitez pas si vous avez d'autres questions juridiques."
|
| 89 |
+
elif any(word in query_lower for word in ["au revoir", "bye", "à bientôt"]):
|
| 90 |
+
return "👋 Au revoir ! N'hésitez pas à revenir si vous avez besoin d'assistance juridique."
|
| 91 |
+
elif any(word in query_lower for word in ["qui es-tu", "ton nom", "te présenter"]):
|
| 92 |
+
return "⚖️ Je suis un assistant juridique IA spécialisé dans les droits du Bénin et de Madagascar. Je peux vous aider à trouver des informations sur les lois, articles, et procédures juridiques."
|
| 93 |
+
else:
|
| 94 |
+
return "👋 Bonjour ! Je suis votre assistant juridique. Posez-moi vos questions sur le droit béninois ou malgache !"
|
core/retriever.py
ADDED
|
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# core/retriever.py
|
| 2 |
+
import re
|
| 3 |
+
import logging
|
| 4 |
+
import asyncio
|
| 5 |
+
from typing import List, Dict, Any, Tuple
|
| 6 |
+
from langchain_core.documents import Document
|
| 7 |
+
from langchain_mongodb.vectorstores import MongoDBAtlasVectorSearch
|
| 8 |
+
|
| 9 |
+
from config.settings import settings
|
| 10 |
+
from config.constants import ARTICLE_PATTERNS, CATEGORY_KEYWORDS, DOCUMENT_TYPE_KEYWORDS
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
class LegalRetriever:
|
| 15 |
+
def __init__(self, vectorstore: MongoDBAtlasVectorSearch, collection):
|
| 16 |
+
self.vectorstore = vectorstore
|
| 17 |
+
self.collection = collection
|
| 18 |
+
|
| 19 |
+
async def smart_legal_query(self, user_query: str, country: str) -> Tuple[List[Document], List[str], Dict[str, Any], str]:
|
| 20 |
+
"""Perform smart legal search with automatic fallback and custom messages - ASYNC VERSION"""
|
| 21 |
+
try:
|
| 22 |
+
# Détection initiale du type de document
|
| 23 |
+
initial_doc_type = self._detect_document_type(user_query.lower())
|
| 24 |
+
pre_filter = self._build_pre_filters(user_query, country)
|
| 25 |
+
|
| 26 |
+
logger.info(f"📋 Filtre doc_type initial: {initial_doc_type}")
|
| 27 |
+
logger.info(f"🔍 Recherche {country} avec filtres: {pre_filter}")
|
| 28 |
+
|
| 29 |
+
# Première recherche
|
| 30 |
+
enhanced_docs, detected_articles, applied_filters = await self._perform_search_async(
|
| 31 |
+
user_query, country, pre_filter
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
message_supplementaire = ""
|
| 35 |
+
|
| 36 |
+
# Fallback automatique si aucun résultat pour case_study
|
| 37 |
+
if not enhanced_docs and initial_doc_type == "case_study":
|
| 38 |
+
logger.info("🔄 Fallback: Aucun case_study trouvé, recherche dans les articles")
|
| 39 |
+
|
| 40 |
+
# Create new filter for articles - DON'T rebuild, just modify
|
| 41 |
+
fallback_filter = pre_filter.copy() # Copy the original filter
|
| 42 |
+
fallback_filter["doc_type"] = "articles" # Force articles type
|
| 43 |
+
|
| 44 |
+
logger.info(f"🔄 Fallback filter: {fallback_filter}") # Log the fallback filter
|
| 45 |
+
|
| 46 |
+
enhanced_docs, detected_articles, applied_filters = await self._perform_search_async(
|
| 47 |
+
user_query, country, fallback_filter
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# Mark that fallback was used
|
| 51 |
+
applied_filters["original_search"] = "case_study"
|
| 52 |
+
applied_filters["fallback_to"] = "articles"
|
| 53 |
+
applied_filters["fallback_used"] = True
|
| 54 |
+
|
| 55 |
+
# Message personnalisé pour le fallback
|
| 56 |
+
if enhanced_docs:
|
| 57 |
+
message_supplementaire = (
|
| 58 |
+
"⚠️ Nous nous excusons, mais aucune décision de justice n'a été trouvée pour votre requête. "
|
| 59 |
+
"La base de données sera enrichie avec des décisions de justice prochainement. "
|
| 60 |
+
"En attendant, voici des articles de loi pertinents qui peuvent vous aider."
|
| 61 |
+
)
|
| 62 |
+
else:
|
| 63 |
+
# Check if it's a MongoDB error
|
| 64 |
+
if "mongodb_error" in applied_filters:
|
| 65 |
+
message_supplementaire = (
|
| 66 |
+
"⚠️ Nous nous excusons, mais une erreur technique s'est produite lors de la recherche. "
|
| 67 |
+
"Nous travaillons à résoudre ce problème. Veuillez réessayer dans quelques instants."
|
| 68 |
+
)
|
| 69 |
+
else:
|
| 70 |
+
message_supplementaire = (
|
| 71 |
+
"⚠️ Nous nous excusons, mais aucune décision de justice n'a été trouvée pour votre requête. "
|
| 72 |
+
"La base de données sera enrichie avec des décisions de justice prochainement. "
|
| 73 |
+
"De plus, aucun article de loi correspondant n'a été trouvé. "
|
| 74 |
+
"Essayez de reformuler votre question avec des termes plus généraux."
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
logger.info(f"🔍 Search completed: {len(enhanced_docs)} documents found")
|
| 78 |
+
logger.info(f"📢 Supplemental message: {message_supplementaire[:100] if message_supplementaire else 'None'}")
|
| 79 |
+
return enhanced_docs, detected_articles, applied_filters, message_supplementaire
|
| 80 |
+
|
| 81 |
+
except Exception as e:
|
| 82 |
+
logger.error(f"Error in smart_legal_query: {str(e)}")
|
| 83 |
+
# Return empty results on error
|
| 84 |
+
return [], [], {"error": str(e)}, f"Erreur lors de la recherche: {str(e)}"
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
async def _perform_search_async(self, user_query: str, country: str, pre_filter: Dict) -> Tuple[List[Document], List[str], Dict[str, Any]]:
|
| 88 |
+
"""Perform search with given filters - ASYNC VERSION"""
|
| 89 |
+
try:
|
| 90 |
+
detected_articles = self._detect_articles(user_query)
|
| 91 |
+
enhanced_query = self._enhance_query(user_query, detected_articles)
|
| 92 |
+
|
| 93 |
+
logger.info(f"🔢 Articles détectés: {detected_articles}")
|
| 94 |
+
logger.info(f"🔍 Requête enrichie: {enhanced_query[:100]}...")
|
| 95 |
+
|
| 96 |
+
# CRITICAL FIX: Run synchronous vectorstore operation in thread pool
|
| 97 |
+
docs = await asyncio.get_event_loop().run_in_executor(
|
| 98 |
+
None, # Use default thread pool
|
| 99 |
+
lambda: self.vectorstore.similarity_search(
|
| 100 |
+
enhanced_query,
|
| 101 |
+
k=settings.MAX_SEARCH_RESULTS,
|
| 102 |
+
pre_filter=pre_filter
|
| 103 |
+
)
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
logger.info(f"🎯 Vector search returned {len(docs)} raw documents")
|
| 107 |
+
|
| 108 |
+
if docs:
|
| 109 |
+
logger.info(f"📄 First result metadata: {docs[0].metadata}")
|
| 110 |
+
else:
|
| 111 |
+
logger.warning(f"⚠️ No documents found with filters: {pre_filter}")
|
| 112 |
+
await self._debug_search_issue(pre_filter)
|
| 113 |
+
|
| 114 |
+
enhanced_docs = self.enhance_with_article_context(docs)
|
| 115 |
+
return enhanced_docs, detected_articles, pre_filter
|
| 116 |
+
|
| 117 |
+
except Exception as e:
|
| 118 |
+
logger.error(f"Error in _perform_search_async: {str(e)}")
|
| 119 |
+
# Mark the filter with MongoDB error for better error handling
|
| 120 |
+
error_filter = pre_filter.copy()
|
| 121 |
+
error_filter["mongodb_error"] = str(e)
|
| 122 |
+
return [], [], error_filter
|
| 123 |
+
|
| 124 |
+
async def _debug_search_issue(self, pre_filter: Dict):
|
| 125 |
+
"""Debug why search returned no results"""
|
| 126 |
+
try:
|
| 127 |
+
# Check total document count
|
| 128 |
+
total_count = await asyncio.get_event_loop().run_in_executor(
|
| 129 |
+
None,
|
| 130 |
+
lambda: self.collection.count_documents({})
|
| 131 |
+
)
|
| 132 |
+
logger.info(f"🔢 Total documents in collection: {total_count}")
|
| 133 |
+
|
| 134 |
+
# Check documents matching country filter
|
| 135 |
+
country_count = await asyncio.get_event_loop().run_in_executor(
|
| 136 |
+
None,
|
| 137 |
+
lambda: self.collection.count_documents({"pays": pre_filter.get("pays")})
|
| 138 |
+
)
|
| 139 |
+
logger.info(f"🌍 Documents for country {pre_filter.get('pays')}: {country_count}")
|
| 140 |
+
|
| 141 |
+
# Check documents with doc_type
|
| 142 |
+
doc_type_count = await asyncio.get_event_loop().run_in_executor(
|
| 143 |
+
None,
|
| 144 |
+
lambda: self.collection.count_documents({
|
| 145 |
+
"pays": pre_filter.get("pays"),
|
| 146 |
+
"doc_type": pre_filter.get("doc_type")
|
| 147 |
+
})
|
| 148 |
+
)
|
| 149 |
+
logger.info(f"📋 Documents with doc_type {pre_filter.get('doc_type')}: {doc_type_count}")
|
| 150 |
+
|
| 151 |
+
# Check documents with embeddings
|
| 152 |
+
embedding_count = await asyncio.get_event_loop().run_in_executor(
|
| 153 |
+
None,
|
| 154 |
+
lambda: self.collection.count_documents({
|
| 155 |
+
"pays": pre_filter.get("pays"),
|
| 156 |
+
"embedding": {"$exists": True, "$ne": None}
|
| 157 |
+
})
|
| 158 |
+
)
|
| 159 |
+
logger.info(f"🎯 Documents with embeddings: {embedding_count}")
|
| 160 |
+
|
| 161 |
+
# Sample document check
|
| 162 |
+
sample_doc = await asyncio.get_event_loop().run_in_executor(
|
| 163 |
+
None,
|
| 164 |
+
lambda: self.collection.find_one({"pays": pre_filter.get("pays")})
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
if sample_doc:
|
| 168 |
+
logger.info(f"📄 Sample document keys: {list(sample_doc.keys())}")
|
| 169 |
+
logger.info(f"📄 Sample doc_type: {sample_doc.get('doc_type', 'NOT_SET')}")
|
| 170 |
+
else:
|
| 171 |
+
logger.error("❌ No sample document found!")
|
| 172 |
+
|
| 173 |
+
except Exception as e:
|
| 174 |
+
logger.error(f"Error in debug: {str(e)}")
|
| 175 |
+
|
| 176 |
+
def _build_pre_filters(self, query: str, country: str) -> Dict[str, Any]:
|
| 177 |
+
"""Build search filters based on query and country"""
|
| 178 |
+
# Filtre pays obligatoire - MAKE SURE EXACT MATCH
|
| 179 |
+
country_mapping = {
|
| 180 |
+
"benin": "Bénin",
|
| 181 |
+
"madagascar": "Madagascar"
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
pre_filter = {"pays": country_mapping.get(country.lower(), country)}
|
| 185 |
+
|
| 186 |
+
# Filtre doc_type pour différencier articles et études de cas
|
| 187 |
+
query_lower = query.lower()
|
| 188 |
+
detected_doc_type = self._detect_document_type(query_lower)
|
| 189 |
+
pre_filter["doc_type"] = detected_doc_type
|
| 190 |
+
|
| 191 |
+
logger.info(f"🏷️ Using country filter: {pre_filter['pays']}")
|
| 192 |
+
logger.info(f"📋 Using doc_type filter: {detected_doc_type}")
|
| 193 |
+
|
| 194 |
+
# Filtres par catégorie (optionnels)
|
| 195 |
+
logger.info("ℹ️ No category filter applied - using all available family law documents")
|
| 196 |
+
# for keyword, category in CATEGORY_KEYWORDS.items():
|
| 197 |
+
# if keyword in query_lower:
|
| 198 |
+
# pre_filter["categorie"] = category
|
| 199 |
+
# logger.info(f"🏷️ Filtre catégorie: {category}")
|
| 200 |
+
# break
|
| 201 |
+
|
| 202 |
+
return pre_filter
|
| 203 |
+
|
| 204 |
+
def _detect_document_type(self, query_lower: str) -> str:
|
| 205 |
+
"""Détecte le type de document basé sur les mots-clés de la requête"""
|
| 206 |
+
# Mots-clés pour les études de cas
|
| 207 |
+
case_study_indicators = [
|
| 208 |
+
"jurisprudence", "arrêt", "décision", "tribunal", "cours", "jugement",
|
| 209 |
+
"affaire", "procès", "litige", "contentieux", "précédent", "cas",
|
| 210 |
+
"cour d'appel", "cour suprême", "conseil d'état", "juridiction"
|
| 211 |
+
]
|
| 212 |
+
|
| 213 |
+
# Mots-clés pour les articles
|
| 214 |
+
articles_indicators = [
|
| 215 |
+
"article", "loi", "code", "décret", "texte", "disposition",
|
| 216 |
+
"règlement", "ordonnance", "prescription", "norme", "chapitre", "titre"
|
| 217 |
+
]
|
| 218 |
+
|
| 219 |
+
case_study_score = sum(1 for keyword in case_study_indicators if keyword in query_lower)
|
| 220 |
+
articles_score = sum(1 for keyword in articles_indicators if keyword in query_lower)
|
| 221 |
+
|
| 222 |
+
if case_study_score > articles_score and case_study_score > 0:
|
| 223 |
+
return "case_study"
|
| 224 |
+
elif articles_score > 0:
|
| 225 |
+
return "articles"
|
| 226 |
+
else:
|
| 227 |
+
# Par défaut, on cherche les articles de loi
|
| 228 |
+
return "articles"
|
| 229 |
+
|
| 230 |
+
def _detect_articles(self, query: str) -> List[str]:
|
| 231 |
+
"""Detect article references in query"""
|
| 232 |
+
detected_articles = []
|
| 233 |
+
for pattern in ARTICLE_PATTERNS:
|
| 234 |
+
matches = re.findall(pattern, query.lower())
|
| 235 |
+
for match in matches:
|
| 236 |
+
if isinstance(match, tuple):
|
| 237 |
+
nums = [n for n in match if n.isdigit()]
|
| 238 |
+
detected_articles.extend(nums)
|
| 239 |
+
else:
|
| 240 |
+
nums = re.findall(r"\d+", match)
|
| 241 |
+
detected_articles.extend(nums)
|
| 242 |
+
|
| 243 |
+
return sorted(list(set(detected_articles)))
|
| 244 |
+
|
| 245 |
+
def _enhance_query(self, query: str, detected_articles: List[str]) -> str:
|
| 246 |
+
"""Enhance query with article context"""
|
| 247 |
+
if detected_articles:
|
| 248 |
+
enhanced = f"article {' '.join(detected_articles)} {query}"
|
| 249 |
+
logger.info(f"🔢 Requête enrichie avec articles: {detected_articles}")
|
| 250 |
+
return enhanced
|
| 251 |
+
return query
|
| 252 |
+
|
| 253 |
+
def enhance_with_article_context(self, results: List[Document]) -> List[Document]:
|
| 254 |
+
"""Enhance search results with referenced article context"""
|
| 255 |
+
enhanced_results = []
|
| 256 |
+
for result in results:
|
| 257 |
+
enhanced_results.append(result)
|
| 258 |
+
|
| 259 |
+
# Pour les documents de type "articles", on peut ajouter les références
|
| 260 |
+
if result.metadata.get("doc_type") == "articles":
|
| 261 |
+
article_refs = result.metadata.get("article_references", [])
|
| 262 |
+
resolved_refs = result.metadata.get("resolved_references", {})
|
| 263 |
+
|
| 264 |
+
for article_num in article_refs[:3]:
|
| 265 |
+
if article_num in resolved_refs:
|
| 266 |
+
ref_doc = Document(
|
| 267 |
+
page_content=f"Article {article_num} (Référencé): {resolved_refs[article_num][:500]}...",
|
| 268 |
+
metadata={
|
| 269 |
+
**result.metadata,
|
| 270 |
+
"is_reference": True,
|
| 271 |
+
"referenced_article": article_num,
|
| 272 |
+
"doc_type": "article_reference"
|
| 273 |
+
},
|
| 274 |
+
)
|
| 275 |
+
enhanced_results.append(ref_doc)
|
| 276 |
+
|
| 277 |
+
return enhanced_results
|
| 278 |
+
def format_search_results(self, query: str, enhanced_docs: List[Document],
|
| 279 |
+
detected_articles: List[str], applied_filters: Dict[str, Any],
|
| 280 |
+
country: str, supplemental_message: str = "") -> str:
|
| 281 |
+
"""Format search results for system prompt"""
|
| 282 |
+
country_name = "Bénin" if country == "benin" else "Madagascar"
|
| 283 |
+
|
| 284 |
+
if not enhanced_docs:
|
| 285 |
+
doc_type = applied_filters.get("doc_type", "articles")
|
| 286 |
+
|
| 287 |
+
# Check if this was an error case
|
| 288 |
+
if "error" in applied_filters:
|
| 289 |
+
return f"""
|
| 290 |
+
**🚨 ERREUR DE RECHERCHE - {country_name.upper()}**
|
| 291 |
+
|
| 292 |
+
Une erreur s'est produite lors de la recherche: {applied_filters['error']}
|
| 293 |
+
|
| 294 |
+
**Informations de débogage:**
|
| 295 |
+
- **Requête**: "{query}"
|
| 296 |
+
- **Pays**: {country_name}
|
| 297 |
+
- **Type de document recherché**: {doc_type}
|
| 298 |
+
- **Filtres**: {applied_filters}
|
| 299 |
+
|
| 300 |
+
Veuillez réessayer ou contacter le support technique.
|
| 301 |
+
"""
|
| 302 |
+
|
| 303 |
+
if applied_filters.get("fallback_used"):
|
| 304 |
+
# Cas où le fallback a été utilisé mais n'a rien trouvé non plus
|
| 305 |
+
mongodb_error_note = ""
|
| 306 |
+
if "mongodb_error" in applied_filters:
|
| 307 |
+
mongodb_error_note = f"\n\n**⚠️ Erreur technique**: {applied_filters['mongodb_error'][:200]}..."
|
| 308 |
+
|
| 309 |
+
return f"""
|
| 310 |
+
**🔍 RECHERCHE JURIDIQUE - {country_name.upper()}**
|
| 311 |
+
|
| 312 |
+
{supplemental_message}
|
| 313 |
+
|
| 314 |
+
**💡 Informations :**
|
| 315 |
+
- Votre recherche portait sur des **décisions de justice (jurisprudence)**
|
| 316 |
+
- Aucune décision de justice n'a été trouvée dans la base de données
|
| 317 |
+
- Aucun article de loi correspondant n'a été trouvé non plus
|
| 318 |
+
{mongodb_error_note}
|
| 319 |
+
|
| 320 |
+
**Suggestion**: Essayez de reformuler votre requête avec des termes plus généraux.
|
| 321 |
+
|
| 322 |
+
**Recherche effectuée**:
|
| 323 |
+
- Type initial: {applied_filters.get('original_search', 'N/A')}
|
| 324 |
+
- Fallback vers: {applied_filters.get('fallback_to', 'N/A')}
|
| 325 |
+
- Pays: {country_name}
|
| 326 |
+
"""
|
| 327 |
+
else:
|
| 328 |
+
# Cas normal sans fallback
|
| 329 |
+
return f"""
|
| 330 |
+
**🔍 RECHERCHE JURIDIQUE - {country_name.upper()}**
|
| 331 |
+
|
| 332 |
+
Aucun document trouvé avec les critères suivants:
|
| 333 |
+
- **Type de document**: {doc_type}
|
| 334 |
+
- **Catégorie**: {applied_filters.get('categorie', 'Toutes')}
|
| 335 |
+
- **Requête**: "{query}"
|
| 336 |
+
|
| 337 |
+
**Suggestion**: Essayez avec des termes plus généraux ou vérifiez l'orthographe.
|
| 338 |
+
|
| 339 |
+
**Filtres appliqués**: {applied_filters}
|
| 340 |
+
"""
|
| 341 |
+
|
| 342 |
+
# Si des documents sont trouvés
|
| 343 |
+
doc_type = applied_filters.get("doc_type", "articles")
|
| 344 |
+
doc_type_fr = "articles de loi" if doc_type == "articles" else "études de cas/jurisprudence"
|
| 345 |
+
|
| 346 |
+
fallback_note = ""
|
| 347 |
+
if applied_filters.get("fallback_used"):
|
| 348 |
+
fallback_note = f"""
|
| 349 |
+
**💡 {supplemental_message}**
|
| 350 |
+
|
| 351 |
+
---
|
| 352 |
+
"""
|
| 353 |
+
|
| 354 |
+
search_results = f"""
|
| 355 |
+
**🔍 RECHERCHE JURIDIQUE - {country_name.upper()}**
|
| 356 |
+
**Type de documents**: {doc_type_fr}
|
| 357 |
+
**Requête**: "{query}"
|
| 358 |
+
**Juridiction**: {country_name}
|
| 359 |
+
**Articles détectés**: {', '.join(detected_articles) if detected_articles else 'Aucun'}
|
| 360 |
+
**Documents trouvés**: {len(enhanced_docs)}
|
| 361 |
+
|
| 362 |
+
{fallback_note}
|
| 363 |
+
"""
|
| 364 |
+
|
| 365 |
+
# Formatage des documents trouvés
|
| 366 |
+
main_docs = [doc for doc in enhanced_docs if not doc.metadata.get("is_reference", False)]
|
| 367 |
+
|
| 368 |
+
for i, doc in enumerate(main_docs[:5]):
|
| 369 |
+
doc_type = doc.metadata.get("doc_type", "inconnu")
|
| 370 |
+
source = doc.metadata.get('source', 'Non spécifié')
|
| 371 |
+
content = doc.page_content[:600]
|
| 372 |
+
|
| 373 |
+
search_results += f"""
|
| 374 |
+
**📄 DOCUMENT {i+1}** (Type: {doc_type})
|
| 375 |
+
- **Source**: {source}
|
| 376 |
+
- **Contenu**: {content}...
|
| 377 |
+
|
| 378 |
+
"""
|
| 379 |
+
|
| 380 |
+
return search_results
|
| 381 |
+
|
| 382 |
+
# BACKWARD COMPATIBILITY: Keep sync version for any remaining sync calls
|
| 383 |
+
def smart_legal_query_sync(self, user_query: str, country: str) -> Tuple[List[Document], List[str], Dict[str, Any], str]:
|
| 384 |
+
"""Synchronous version for backward compatibility"""
|
| 385 |
+
logger.warning("Using sync version of smart_legal_query - consider migrating to async")
|
| 386 |
+
return asyncio.run(self.smart_legal_query(user_query, country))
|
core/router.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# [file name]: core/router.py
|
| 2 |
+
import re
|
| 3 |
+
import logging
|
| 4 |
+
import json
|
| 5 |
+
from typing import Dict, List, Optional, Literal, Any
|
| 6 |
+
from langchain_openai import ChatOpenAI
|
| 7 |
+
from langchain_core.messages import SystemMessage, HumanMessage
|
| 8 |
+
|
| 9 |
+
from config.settings import settings
|
| 10 |
+
from models.state_models import RoutingResult
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
class CountryRouter:
|
| 15 |
+
def __init__(self):
|
| 16 |
+
self.llm = ChatOpenAI(
|
| 17 |
+
model=settings.CHAT_MODEL_2,
|
| 18 |
+
temperature=0.1,
|
| 19 |
+
max_tokens=200
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
async def route_query(self, query: str, conversation_history: List[Dict]) -> RoutingResult:
|
| 23 |
+
"""Unified LLM-powered routing"""
|
| 24 |
+
try:
|
| 25 |
+
# Build conversation context
|
| 26 |
+
context = self._build_conversation_context(conversation_history)
|
| 27 |
+
|
| 28 |
+
# LLM routing prompt
|
| 29 |
+
routing_prompt = self._build_routing_prompt(query, context)
|
| 30 |
+
|
| 31 |
+
logger.info(f"🔀 Routing query: '{query[:50]}...'")
|
| 32 |
+
|
| 33 |
+
# Call LLM for routing decision
|
| 34 |
+
response = await self.llm.ainvoke([SystemMessage(content=routing_prompt)])
|
| 35 |
+
routing_result = self._parse_routing_response(response.content)
|
| 36 |
+
|
| 37 |
+
logger.info(f"🎯 Router decision: {routing_result.country} ({routing_result.confidence})")
|
| 38 |
+
|
| 39 |
+
return routing_result
|
| 40 |
+
|
| 41 |
+
except Exception as e:
|
| 42 |
+
logger.error(f"Router error: {e}")
|
| 43 |
+
# Fallback to unclear
|
| 44 |
+
return RoutingResult(
|
| 45 |
+
country="unclear",
|
| 46 |
+
confidence="low",
|
| 47 |
+
method="error_fallback",
|
| 48 |
+
explanation=f"Router error: {str(e)}"
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
def _build_routing_prompt(self, query: str, context: str) -> str:
|
| 52 |
+
"""Build comprehensive routing prompt"""
|
| 53 |
+
return f"""
|
| 54 |
+
Vous êtes un routeur intelligent pour un assistant juridique spécialisé dans le droit du Bénin et de Madagascar.
|
| 55 |
+
|
| 56 |
+
**TÂCHE:** Analyser la requête utilisateur et déterminer la meilleure destination.
|
| 57 |
+
|
| 58 |
+
**DESTINATIONS POSSIBLES:**
|
| 59 |
+
- "benin": Questions juridiques concernant le Bénin (lois, procédures, droits)
|
| 60 |
+
- "madagascar": Questions juridiques concernant Madagascar (lois, procédures, droits)
|
| 61 |
+
- "assistance_request": Demande pour parler à un avocat humain
|
| 62 |
+
- "greeting_small_talk": Salutations, présentations, remerciements (politesse uniquement)
|
| 63 |
+
- "conversation_repair": Incompréhension, demande de clarification
|
| 64 |
+
- "conversation_summarization": Demande de résumé de la conversation
|
| 65 |
+
- "out_of_scope": Questions NON juridiques (café, météo, sports, recettes, etc.)
|
| 66 |
+
- "unclear": Intention juridique incertaine
|
| 67 |
+
|
| 68 |
+
**REQUÊTE:** "{query}"
|
| 69 |
+
|
| 70 |
+
**CONTEXTE DE CONVERSATION:**
|
| 71 |
+
{context}
|
| 72 |
+
|
| 73 |
+
**RÈGLES DE CLASSIFICATION:**
|
| 74 |
+
|
| 75 |
+
1. **greeting_small_talk** - UNIQUEMENT pour politesse basique:
|
| 76 |
+
- Salutations: "bonjour", "salut", "hello", "bonsoir", "au revoir"
|
| 77 |
+
- Présentations brèves: "je m'appelle X", "mon nom est X"
|
| 78 |
+
- Remerciements: "merci", "merci beaucoup"
|
| 79 |
+
- Politesses simples: "comment ça va", "ça va bien"
|
| 80 |
+
- Questions sur l'identité de l'assistant: "qui es-tu", "comment tu t'appelles"
|
| 81 |
+
|
| 82 |
+
2. **benin** - Pour questions juridiques sur le Bénin:
|
| 83 |
+
- Mentions explicites: "bénin", "benin", "béninois"
|
| 84 |
+
- Villes: "cotonou", "porto-novo"
|
| 85 |
+
- Lois/procédures béninoises
|
| 86 |
+
|
| 87 |
+
3. **madagascar** - Pour questions juridiques sur Madagascar:
|
| 88 |
+
- Mentions explicites: "madagascar", "malgache"
|
| 89 |
+
- Villes: "antananarivo", "toamasina"
|
| 90 |
+
- Lois/procédures malgaches
|
| 91 |
+
|
| 92 |
+
4. **assistance_request** - Demande d'aide humaine:
|
| 93 |
+
- "parler à un avocat"
|
| 94 |
+
- "contacter un avocat"
|
| 95 |
+
- "assistance téléphonique"
|
| 96 |
+
- "besoin d'aide juridique personnalisée"
|
| 97 |
+
|
| 98 |
+
5. **conversation_repair** - Problèmes de compréhension:
|
| 99 |
+
- "je n'ai pas compris"
|
| 100 |
+
- "répète s'il te plaît"
|
| 101 |
+
- "explique autrement"
|
| 102 |
+
- "qu'est-ce que tu veux dire"
|
| 103 |
+
|
| 104 |
+
6. **conversation_summarization** - Demande de résumé:
|
| 105 |
+
- "résume notre conversation"
|
| 106 |
+
- "récapitulatif"
|
| 107 |
+
- "qu'avons-nous dit"
|
| 108 |
+
|
| 109 |
+
7. **out_of_scope** - Questions clairement NON juridiques:
|
| 110 |
+
- Météo/Climat: "température à Douala", "il va pleuvoir?"
|
| 111 |
+
- Nourriture: "recette de ndolé", "fais-moi un café"
|
| 112 |
+
- Sport: "résultat du match", "qui a gagné?"
|
| 113 |
+
- Technologie: "comment réparer mon téléphone", "meilleur ordinateur"
|
| 114 |
+
- Divertissement: "raconte une blague", "parle-moi de musique"
|
| 115 |
+
- Santé non-juridique: "symptômes grippe", "remèdes traditionnels"
|
| 116 |
+
- **Règle clé**: AUCUN aspect juridique ou lien avec le droit
|
| 117 |
+
|
| 118 |
+
8. **unclear** - Questions juridiques MAIS pays/détails manquants:
|
| 119 |
+
- "J'ai un problème de divorce" (quel pays?)
|
| 120 |
+
- "Comment créer une entreprise" (Bénin ou Madagascar?)
|
| 121 |
+
- "Besoin d'aide juridique" (trop vague)
|
| 122 |
+
- "Question sur l'héritage" (juridiction non précisée)
|
| 123 |
+
- **Règle clé**: Intention juridique évidente MAIS manque de précision sur le pays ou les détails
|
| 124 |
+
|
| 125 |
+
**EXEMPLES COMPLETS:**
|
| 126 |
+
- "Bonjour" → {{"destination": "greeting_small_talk", "confidence": "high", "reasoning": "Salutation simple"}}
|
| 127 |
+
- "je m'appelle Thibaut" → {{"destination": "greeting_small_talk", "confidence": "high", "reasoning": "Présentation personnelle"}}
|
| 128 |
+
- "comment est-ce que je m'appelle" → {{"destination": "greeting_small_talk", "confidence": "high", "reasoning": "Question personnelle de rappel"}}
|
| 129 |
+
- "salut comment ça va" → {{"destination": "greeting_small_talk", "confidence": "high", "reasoning": "Salutation et politesse"}}
|
| 130 |
+
- "merci beaucoup" → {{"destination": "greeting_small_talk", "confidence": "high", "reasoning": "Remerciement"}}
|
| 131 |
+
- "qui es-tu" → {{"destination": "greeting_small_talk", "confidence": "high", "reasoning": "Question sur l'identité de l'assistant"}}
|
| 132 |
+
- "procedure divorce Bénin" → {{"destination": "benin", "confidence": "high", "reasoning": "Question juridique explicite sur le Bénin"}}
|
| 133 |
+
- "loi foncière Madagascar" → {{"destination": "madagascar", "confidence": "high", "reasoning": "Question juridique sur Madagascar"}}
|
| 134 |
+
- "Je veux parler à un avocat" → {{"destination": "assistance_request", "confidence": "high", "reasoning": "Demande explicite d'assistance humaine"}}
|
| 135 |
+
- "Je n'ai pas compris" → {{"destination": "conversation_repair", "confidence": "high", "reasoning": "Demande de clarification"}}
|
| 136 |
+
- "résume notre conversation" → {{"destination": "conversation_summarization", "confidence": "high", "reasoning": "Demande de résumé"}}
|
| 137 |
+
- "fais-moi un café" → {{"destination": "out_of_scope", "confidence": "high", "reasoning": "Demande sans rapport avec le droit"}}
|
| 138 |
+
- "quelle est la météo" → {{"destination": "out_of_scope", "confidence": "high", "reasoning": "Question météorologique, non juridique"}}
|
| 139 |
+
- "température à Douala" → {{"destination": "out_of_scope", "confidence": "high", "reasoning": "Question climatique, hors domaine juridique"}}
|
| 140 |
+
- "raconte-moi une blague" → {{"destination": "out_of_scope", "confidence": "high", "reasoning": "Demande de divertissement, non juridique"}}
|
| 141 |
+
- "J'ai un problème de divorce" → {{"destination": "unclear", "confidence": "medium", "reasoning": "Question juridique mais pays non précisé"}}
|
| 142 |
+
- "Comment créer une entreprise" → {{"destination": "unclear", "confidence": "medium", "reasoning": "Question juridique mais juridiction manquante"}}
|
| 143 |
+
|
| 144 |
+
**IMPORTANT:**
|
| 145 |
+
- **out_of_scope**: Questions SANS aucun aspect juridique (météo, sport, nourriture, etc.)
|
| 146 |
+
- **unclear**: Questions AVEC intention juridique MAIS manque de précision sur le pays
|
| 147 |
+
- Les présentations, salutations et remerciements sont "greeting_small_talk"
|
| 148 |
+
- Seules les questions JURIDIQUES avec pays identifié vont vers "benin" ou "madagascar"
|
| 149 |
+
|
| 150 |
+
**FORMAT DE RÉPONSE:**
|
| 151 |
+
Répondez UNIQUEMENT au format JSON valide:
|
| 152 |
+
{{
|
| 153 |
+
"destination": "benin|madagascar|assistance_request|greeting_small_talk|conversation_repair|conversation_summarization|unclear",
|
| 154 |
+
"confidence": "high|medium|low",
|
| 155 |
+
"reasoning": "explication brève et claire"
|
| 156 |
+
}}
|
| 157 |
+
|
| 158 |
+
**RÉPONSE:**
|
| 159 |
+
"""
|
| 160 |
+
|
| 161 |
+
def _parse_routing_response(self, response_text: str) -> RoutingResult:
|
| 162 |
+
"""Parse LLM routing response"""
|
| 163 |
+
try:
|
| 164 |
+
# Extract JSON from response
|
| 165 |
+
json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
|
| 166 |
+
if not json_match:
|
| 167 |
+
raise ValueError("No JSON found in response")
|
| 168 |
+
|
| 169 |
+
result = json.loads(json_match.group())
|
| 170 |
+
|
| 171 |
+
# Validate required fields
|
| 172 |
+
destination = result.get("destination", "unclear")
|
| 173 |
+
confidence = result.get("confidence", "low")
|
| 174 |
+
reasoning = result.get("reasoning", "No reasoning provided")
|
| 175 |
+
|
| 176 |
+
# Map destination to RoutingResult country field
|
| 177 |
+
valid_destinations = [
|
| 178 |
+
"benin", "madagascar", "unclear", "greeting_small_talk",
|
| 179 |
+
"conversation_repair", "assistance_request", "conversation_summarization",
|
| 180 |
+
"out_of_scope"
|
| 181 |
+
]
|
| 182 |
+
|
| 183 |
+
if destination not in valid_destinations:
|
| 184 |
+
logger.warning(f"Invalid destination from LLM: {destination}, defaulting to unclear")
|
| 185 |
+
destination = "unclear"
|
| 186 |
+
confidence = "low"
|
| 187 |
+
reasoning = f"Destination invalide: {destination}"
|
| 188 |
+
|
| 189 |
+
return RoutingResult(
|
| 190 |
+
country=destination,
|
| 191 |
+
confidence=confidence,
|
| 192 |
+
method="llm_routing",
|
| 193 |
+
explanation=reasoning
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
except Exception as e:
|
| 197 |
+
logger.error(f"Error parsing routing response: {e}")
|
| 198 |
+
logger.error(f"Raw response: {response_text}")
|
| 199 |
+
|
| 200 |
+
return RoutingResult(
|
| 201 |
+
country="unclear",
|
| 202 |
+
confidence="low",
|
| 203 |
+
method="parse_error",
|
| 204 |
+
explanation=f"Parse error: {str(e)}"
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
def _build_conversation_context(self, conversation_history: List[Dict]) -> str:
|
| 208 |
+
"""Build conversation context"""
|
| 209 |
+
if not conversation_history:
|
| 210 |
+
return "Aucun contexte de conversation"
|
| 211 |
+
|
| 212 |
+
# Get last 6 messages for context
|
| 213 |
+
recent_messages = conversation_history[-6:]
|
| 214 |
+
context_lines = []
|
| 215 |
+
|
| 216 |
+
for msg in recent_messages:
|
| 217 |
+
role = "Utilisateur" if msg.get("role") in ["user", "human"] else "Assistant"
|
| 218 |
+
content = msg.get("content", "")
|
| 219 |
+
context_lines.append(f"{role}: {content}")
|
| 220 |
+
|
| 221 |
+
return "\n".join(context_lines)
|
| 222 |
+
|
| 223 |
+
async def health_check(self) -> Dict[str, Any]:
|
| 224 |
+
"""Router health check"""
|
| 225 |
+
try:
|
| 226 |
+
# Test with a simple query
|
| 227 |
+
test_result = await self.route_query("test", [])
|
| 228 |
+
return {
|
| 229 |
+
"status": "healthy",
|
| 230 |
+
"llm_responding": True,
|
| 231 |
+
"last_test_result": test_result.model_dump()
|
| 232 |
+
}
|
| 233 |
+
except Exception as e:
|
| 234 |
+
return {
|
| 235 |
+
"status": "unhealthy",
|
| 236 |
+
"llm_responding": False,
|
| 237 |
+
"error": str(e)
|
| 238 |
+
}
|
core/routing/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .routing_logic import RoutingLogic
|
| 2 |
+
|
| 3 |
+
__all__ = ["RoutingLogic"]
|
core/routing/routing_logic.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# [file name]: core/routing/routing_logic.py
|
| 2 |
+
import logging
|
| 3 |
+
from typing import Literal
|
| 4 |
+
from models.state_models import MultiCountryLegalState
|
| 5 |
+
|
| 6 |
+
logger = logging.getLogger(__name__)
|
| 7 |
+
|
| 8 |
+
class RoutingLogic:
|
| 9 |
+
"""Centralized routing logic for graph edges"""
|
| 10 |
+
|
| 11 |
+
def route_after_info_collection(
|
| 12 |
+
self,
|
| 13 |
+
state: MultiCountryLegalState
|
| 14 |
+
) -> Literal["need_email", "need_description", "ready_to_confirm", "cancelled"]:
|
| 15 |
+
"""Route based on current assistance step and collected data"""
|
| 16 |
+
|
| 17 |
+
step = state.assistance_step
|
| 18 |
+
has_email = bool(state.user_email)
|
| 19 |
+
has_description = bool(state.assistance_description)
|
| 20 |
+
|
| 21 |
+
logger.info(f"📋 Assistance step: {step}")
|
| 22 |
+
logger.info(f" - Has email: {has_email} ({state.user_email})")
|
| 23 |
+
logger.info(f" - Has description: {has_description} ({state.assistance_description})")
|
| 24 |
+
|
| 25 |
+
# 🔥 NEW: Handle cancellation first
|
| 26 |
+
if step == "cancelled":
|
| 27 |
+
logger.info("🔄 Assistance workflow cancelled by user")
|
| 28 |
+
return "cancelled"
|
| 29 |
+
|
| 30 |
+
# Route based on current step progression
|
| 31 |
+
if step == "collecting_email":
|
| 32 |
+
if not has_email:
|
| 33 |
+
logger.info("→ Routing to: need_email (waiting for email)")
|
| 34 |
+
return "need_email"
|
| 35 |
+
else:
|
| 36 |
+
# Email collected, move to description
|
| 37 |
+
logger.info("→ Routing to: need_description (email collected)")
|
| 38 |
+
return "need_description"
|
| 39 |
+
|
| 40 |
+
elif step == "collecting_description":
|
| 41 |
+
if not has_description:
|
| 42 |
+
logger.info("→ Routing to: need_description (waiting for description)")
|
| 43 |
+
return "need_description"
|
| 44 |
+
else:
|
| 45 |
+
# Description collected, ready for confirmation
|
| 46 |
+
logger.info("→ Routing to: ready_to_confirm (both collected)")
|
| 47 |
+
return "ready_to_confirm"
|
| 48 |
+
|
| 49 |
+
elif step == "confirming_send":
|
| 50 |
+
# We're already in confirmation step - stay here until user confirms
|
| 51 |
+
logger.info("→ Routing to: ready_to_confirm (awaiting user confirmation)")
|
| 52 |
+
return "ready_to_confirm"
|
| 53 |
+
|
| 54 |
+
else:
|
| 55 |
+
# Default fallback logic
|
| 56 |
+
if not has_email:
|
| 57 |
+
logger.info("→ Routing to: need_email (default)")
|
| 58 |
+
return "need_email"
|
| 59 |
+
elif not has_description:
|
| 60 |
+
logger.info("→ Routing to: need_description (default)")
|
| 61 |
+
return "need_description"
|
| 62 |
+
else:
|
| 63 |
+
logger.info("→ Routing to: ready_to_confirm (default)")
|
| 64 |
+
return "ready_to_confirm"
|
| 65 |
+
|
| 66 |
+
def route_after_confirmation(
|
| 67 |
+
self,
|
| 68 |
+
state: MultiCountryLegalState
|
| 69 |
+
) -> Literal["confirmed", "cancelled", "needs_correction"]:
|
| 70 |
+
"""Route based on user's confirmation response and current step"""
|
| 71 |
+
|
| 72 |
+
step = state.assistance_step
|
| 73 |
+
last_message = self._get_last_user_message(state)
|
| 74 |
+
|
| 75 |
+
logger.info(f"📋 Confirmation step: {step}")
|
| 76 |
+
logger.info(f" - Last user message: '{last_message}'")
|
| 77 |
+
|
| 78 |
+
# 🔥 NEW: Handle cancellation from confirmation step
|
| 79 |
+
if step == "cancelled":
|
| 80 |
+
logger.info("→ Routing to: cancelled (workflow cancelled)")
|
| 81 |
+
return "cancelled"
|
| 82 |
+
|
| 83 |
+
elif step == "confirmed":
|
| 84 |
+
logger.info("→ Routing to: confirmed (human approval)")
|
| 85 |
+
return "confirmed"
|
| 86 |
+
|
| 87 |
+
elif step == "confirming_send":
|
| 88 |
+
# In confirmation step, check user response
|
| 89 |
+
user_response = last_message.lower().strip() if last_message else ""
|
| 90 |
+
|
| 91 |
+
if user_response in ["oui", "yes", "ok", "confirm", "confirmer", "c'est bon", "d'accord", "envoyer", "valider"]:
|
| 92 |
+
logger.info("→ Routing to: confirmed (user confirmed)")
|
| 93 |
+
return "confirmed"
|
| 94 |
+
|
| 95 |
+
elif user_response in ["non", "no", "cancel", "annuler", "pas maintenant", "arrêter", "stop", "je ne veux plus"]:
|
| 96 |
+
logger.info("→ Routing to: cancelled (user cancelled)")
|
| 97 |
+
return "cancelled"
|
| 98 |
+
|
| 99 |
+
else:
|
| 100 |
+
# User provided description or unclear response - go to response to ask again
|
| 101 |
+
logger.info("→ Routing to: needs_correction (need clarification)")
|
| 102 |
+
return "needs_correction"
|
| 103 |
+
|
| 104 |
+
else:
|
| 105 |
+
logger.info("→ Routing to: needs_correction (invalid state)")
|
| 106 |
+
return "needs_correction"
|
| 107 |
+
|
| 108 |
+
def route_after_human_approval(
|
| 109 |
+
self,
|
| 110 |
+
state: MultiCountryLegalState
|
| 111 |
+
) -> Literal["approved", "rejected", "interrupt"]:
|
| 112 |
+
"""Route based on human approval status"""
|
| 113 |
+
|
| 114 |
+
approval_status = state.approval_status
|
| 115 |
+
logger.info(f"📋 Approval status: {approval_status}")
|
| 116 |
+
|
| 117 |
+
if approval_status == "approved":
|
| 118 |
+
logger.info("→ Routing to: approved (process assistance)")
|
| 119 |
+
return "approved"
|
| 120 |
+
|
| 121 |
+
elif approval_status == "rejected":
|
| 122 |
+
logger.info("→ Routing to: rejected (response)")
|
| 123 |
+
return "rejected"
|
| 124 |
+
|
| 125 |
+
else:
|
| 126 |
+
# Still waiting for approval or error state
|
| 127 |
+
logger.info("→ Routing to: interrupt (waiting for decision)")
|
| 128 |
+
return "interrupt"
|
| 129 |
+
|
| 130 |
+
def _get_last_user_message(self, state: MultiCountryLegalState) -> str:
|
| 131 |
+
"""Extract the last user message from state"""
|
| 132 |
+
if not state.messages:
|
| 133 |
+
return ""
|
| 134 |
+
|
| 135 |
+
for msg in reversed(state.messages):
|
| 136 |
+
if hasattr(msg, 'role'):
|
| 137 |
+
role = msg.role
|
| 138 |
+
else:
|
| 139 |
+
role = msg.get('role', '')
|
| 140 |
+
|
| 141 |
+
if role in ['user', 'human']:
|
| 142 |
+
if hasattr(msg, 'content'):
|
| 143 |
+
return msg.content
|
| 144 |
+
else:
|
| 145 |
+
return msg.get('content', '')
|
| 146 |
+
|
| 147 |
+
return ""
|
| 148 |
+
|
| 149 |
+
def _looks_like_description(self, text: str) -> bool:
|
| 150 |
+
"""Check if text looks like a description rather than a confirmation"""
|
| 151 |
+
description_indicators = [
|
| 152 |
+
"j'ai besoin", "je veux", "je souhaite", "aide pour", "divorce",
|
| 153 |
+
"mariage", "héritage", "contrat", "travail", "familial", "bénin", "madagascar",
|
| 154 |
+
"problème", "situation", "question", "demande"
|
| 155 |
+
]
|
| 156 |
+
|
| 157 |
+
text_lower = text.lower()
|
| 158 |
+
return any(indicator in text_lower for indicator in description_indicators)
|
core/system_initializer.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# [file name]: core/system_initializer.py
|
| 2 |
+
import logging
|
| 3 |
+
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
| 4 |
+
|
| 5 |
+
from core.graph_builder import GraphBuilder
|
| 6 |
+
from core.chat_manager import LegalChatManager
|
| 7 |
+
from core.router import CountryRouter
|
| 8 |
+
from database.mongodb_client import MongoDBClient
|
| 9 |
+
from database.postgres_checkpointer import PostgresCheckpointer
|
| 10 |
+
from langchain_openai import ChatOpenAI
|
| 11 |
+
from config import settings # Make sure this import is correct
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
async def setup_system():
|
| 16 |
+
"""Initialize the legal assistant system for API use"""
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
# 1. Initialize MongoDB using your existing class
|
| 20 |
+
mongo_client = MongoDBClient()
|
| 21 |
+
if not mongo_client.connect():
|
| 22 |
+
raise Exception("MongoDB connection failed")
|
| 23 |
+
|
| 24 |
+
logger.info("✅ MongoDB connected successfully")
|
| 25 |
+
|
| 26 |
+
# 2. Use your existing vector stores directly from the client
|
| 27 |
+
vector_store_benin = mongo_client.benin_vectorstore
|
| 28 |
+
collection_benin = mongo_client.benin_collection
|
| 29 |
+
vector_store_madagascar = mongo_client.madagascar_vectorstore
|
| 30 |
+
collection_madagascar = mongo_client.madagascar_collection
|
| 31 |
+
|
| 32 |
+
# 3. Initialize retrievers
|
| 33 |
+
from core.retriever import LegalRetriever
|
| 34 |
+
benin_retriever = LegalRetriever(vector_store_benin, collection_benin)
|
| 35 |
+
madagascar_retriever = LegalRetriever(vector_store_madagascar, collection_madagascar)
|
| 36 |
+
|
| 37 |
+
country_retrievers = {
|
| 38 |
+
"benin": benin_retriever,
|
| 39 |
+
"madagascar": madagascar_retriever
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
# 4. Initialize LLM and router
|
| 43 |
+
llm = ChatOpenAI(
|
| 44 |
+
model="gpt-4o-mini",
|
| 45 |
+
temperature=0.1,
|
| 46 |
+
max_tokens=2000,
|
| 47 |
+
streaming=True
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
router = CountryRouter()
|
| 51 |
+
|
| 52 |
+
# 5. Initialize PostgreSQL checkpointer - FIXED DATABASE URL
|
| 53 |
+
# Check what database URL setting you have
|
| 54 |
+
database_url = getattr(settings, 'DATABASE_URL', None)
|
| 55 |
+
|
| 56 |
+
if not database_url:
|
| 57 |
+
# Try alternative setting names
|
| 58 |
+
database_url = getattr(settings, 'POSTGRES_URL', None) or \
|
| 59 |
+
getattr(settings, 'POSTGRESQL_URL', None) or \
|
| 60 |
+
getattr(settings, 'DB_URL', None)
|
| 61 |
+
|
| 62 |
+
if not database_url:
|
| 63 |
+
raise Exception("No database URL found in settings")
|
| 64 |
+
|
| 65 |
+
logger.info(f"🔗 Using database URL: {database_url.split('@')[-1] if '@' in database_url else 'local'}") # Log safely
|
| 66 |
+
|
| 67 |
+
postgres_checkpointer = PostgresCheckpointer(
|
| 68 |
+
database_url=database_url, # Use actual database URL
|
| 69 |
+
max_connections=10,
|
| 70 |
+
min_connections=2
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
if not await postgres_checkpointer.initialize():
|
| 74 |
+
raise Exception("PostgreSQL checkpointer initialization failed")
|
| 75 |
+
|
| 76 |
+
checkpointer = postgres_checkpointer.get_checkpointer()
|
| 77 |
+
logger.info("✅ PostgreSQL checkpointer initialized for API")
|
| 78 |
+
|
| 79 |
+
# 6. Build graph
|
| 80 |
+
graph_builder = GraphBuilder(
|
| 81 |
+
router=router,
|
| 82 |
+
llm=llm,
|
| 83 |
+
checkpointer=checkpointer,
|
| 84 |
+
country_retrievers=country_retrievers
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
workflow = graph_builder.build_graph()
|
| 88 |
+
app = workflow.compile(checkpointer=checkpointer)
|
| 89 |
+
|
| 90 |
+
# 7. Initialize chat manager
|
| 91 |
+
chat_manager = LegalChatManager(app, checkpointer)
|
| 92 |
+
|
| 93 |
+
logger.info("✅ API System initialized successfully")
|
| 94 |
+
|
| 95 |
+
return {
|
| 96 |
+
"chat_manager": chat_manager,
|
| 97 |
+
"graph": app,
|
| 98 |
+
"checkpointer": checkpointer
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
except Exception as e:
|
| 102 |
+
logger.error(f"❌ Failed to initialize system: {e}")
|
| 103 |
+
raise
|
database/__init__py
ADDED
|
File without changes
|
database/mongodb_client.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pymongo import MongoClient, ReadPreference
|
| 2 |
+
from pymongo.errors import ServerSelectionTimeoutError, ConnectionFailure
|
| 3 |
+
from langchain_mongodb.vectorstores import MongoDBAtlasVectorSearch
|
| 4 |
+
from langchain_openai import OpenAIEmbeddings
|
| 5 |
+
from typing import Dict
|
| 6 |
+
import logging
|
| 7 |
+
from config.settings import settings
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
class MongoDBClient:
|
| 12 |
+
def __init__(self):
|
| 13 |
+
self.client = None
|
| 14 |
+
self.db = None
|
| 15 |
+
self.benin_collection = None
|
| 16 |
+
self.madagascar_collection = None
|
| 17 |
+
self.benin_vectorstore = None
|
| 18 |
+
self.madagascar_vectorstore = None
|
| 19 |
+
self.embedding_model = None
|
| 20 |
+
|
| 21 |
+
def connect(self):
|
| 22 |
+
"""Connect to MongoDB and initialize collections"""
|
| 23 |
+
try:
|
| 24 |
+
# CRITICAL FIX: Add read preference to allow reading from secondary nodes
|
| 25 |
+
self.client = MongoClient(
|
| 26 |
+
settings.MONGO_URI,
|
| 27 |
+
|
| 28 |
+
# Allow reading from secondary nodes when primary is unavailable
|
| 29 |
+
read_preference=ReadPreference.SECONDARY_PREFERRED,
|
| 30 |
+
|
| 31 |
+
# Reduce timeouts to fail faster (instead of 30s)
|
| 32 |
+
serverSelectionTimeoutMS=10000, # 10 seconds
|
| 33 |
+
connectTimeoutMS=10000,
|
| 34 |
+
socketTimeoutMS=10000,
|
| 35 |
+
|
| 36 |
+
# Retry configuration
|
| 37 |
+
retryWrites=True,
|
| 38 |
+
retryReads=True,
|
| 39 |
+
|
| 40 |
+
# Connection pool settings
|
| 41 |
+
maxPoolSize=50,
|
| 42 |
+
minPoolSize=10,
|
| 43 |
+
|
| 44 |
+
# Write concern (for writes to still work)
|
| 45 |
+
w='majority',
|
| 46 |
+
journal=True
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
# Test the connection
|
| 50 |
+
self.client.admin.command('ping')
|
| 51 |
+
logger.info("✅ MongoDB connection test successful")
|
| 52 |
+
|
| 53 |
+
self.db = self.client[settings.DATABASE_NAME]
|
| 54 |
+
|
| 55 |
+
# Initialize collections
|
| 56 |
+
self.benin_collection = self.db[settings.BENIN_COLLECTION]
|
| 57 |
+
self.madagascar_collection = self.db[settings.MADAGASCAR_COLLECTION]
|
| 58 |
+
|
| 59 |
+
# Verify collections exist and have data
|
| 60 |
+
benin_count = self.benin_collection.count_documents({})
|
| 61 |
+
madagascar_count = self.madagascar_collection.count_documents({})
|
| 62 |
+
logger.info(f"📊 Bénin collection: {benin_count} documents")
|
| 63 |
+
logger.info(f"📊 Madagascar collection: {madagascar_count} documents")
|
| 64 |
+
|
| 65 |
+
# Initialize embedding model
|
| 66 |
+
self.embedding_model = OpenAIEmbeddings(
|
| 67 |
+
model=settings.EMBEDDING_MODEL,
|
| 68 |
+
openai_api_key=settings.OPENAI_API_KEY
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
# Initialize vector stores with read preference
|
| 72 |
+
self.benin_vectorstore = MongoDBAtlasVectorSearch(
|
| 73 |
+
collection=self.benin_collection,
|
| 74 |
+
embedding=self.embedding_model,
|
| 75 |
+
index_name=settings.VECTOR_INDEX_NAME,
|
| 76 |
+
text_key=settings.TEXT_KEY,
|
| 77 |
+
embedding_key=settings.EMBEDDING_KEY,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
self.madagascar_vectorstore = MongoDBAtlasVectorSearch(
|
| 81 |
+
collection=self.madagascar_collection,
|
| 82 |
+
embedding=self.embedding_model,
|
| 83 |
+
index_name=settings.VECTOR_INDEX_NAME,
|
| 84 |
+
text_key=settings.TEXT_KEY,
|
| 85 |
+
embedding_key=settings.EMBEDDING_KEY,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
print("✅ MongoDB connected successfully with SECONDARY_PREFERRED read preference")
|
| 89 |
+
return True
|
| 90 |
+
|
| 91 |
+
except (ServerSelectionTimeoutError, ConnectionFailure) as e:
|
| 92 |
+
logger.error(f"❌ MongoDB connection failed: {e}")
|
| 93 |
+
logger.error("🔍 Possible issues:")
|
| 94 |
+
logger.error(" 1. MongoDB Atlas cluster is paused")
|
| 95 |
+
logger.error(" 2. Network connectivity issues")
|
| 96 |
+
logger.error(" 3. IP address not whitelisted in Atlas")
|
| 97 |
+
logger.error(" 4. Cluster is undergoing maintenance")
|
| 98 |
+
print(f"❌ MongoDB connection failed: {e}")
|
| 99 |
+
return False
|
| 100 |
+
|
| 101 |
+
except Exception as e:
|
| 102 |
+
logger.error(f"❌ Unexpected error during MongoDB connection: {e}")
|
| 103 |
+
print(f"❌ MongoDB connection failed: {e}")
|
| 104 |
+
return False
|
| 105 |
+
|
| 106 |
+
def get_collection_stats(self) -> Dict:
|
| 107 |
+
"""Get statistics for both collections"""
|
| 108 |
+
if not self.client:
|
| 109 |
+
return {}
|
| 110 |
+
|
| 111 |
+
try:
|
| 112 |
+
benin_count = self.benin_collection.count_documents({})
|
| 113 |
+
madagascar_count = self.madagascar_collection.count_documents({})
|
| 114 |
+
|
| 115 |
+
# Sample document to check schema
|
| 116 |
+
benin_sample = self.benin_collection.find_one()
|
| 117 |
+
madagascar_sample = self.madagascar_collection.find_one()
|
| 118 |
+
|
| 119 |
+
# Check for documents by doc_type
|
| 120 |
+
benin_case_study_count = self.benin_collection.count_documents({"doc_type": "case_study"})
|
| 121 |
+
benin_articles_count = self.benin_collection.count_documents({"doc_type": "articles"})
|
| 122 |
+
madagascar_case_study_count = self.madagascar_collection.count_documents({"doc_type": "case_study"})
|
| 123 |
+
madagascar_articles_count = self.madagascar_collection.count_documents({"doc_type": "articles"})
|
| 124 |
+
|
| 125 |
+
return {
|
| 126 |
+
"benin": {
|
| 127 |
+
"total_documents": benin_count,
|
| 128 |
+
"case_study_count": benin_case_study_count,
|
| 129 |
+
"articles_count": benin_articles_count,
|
| 130 |
+
"has_embeddings": bool(benin_sample and 'vecteur_embedding' in benin_sample),
|
| 131 |
+
"sample_fields": list(benin_sample.keys()) if benin_sample else [],
|
| 132 |
+
"sample_doc_type": benin_sample.get('doc_type', 'NOT_SET') if benin_sample else None
|
| 133 |
+
},
|
| 134 |
+
"madagascar": {
|
| 135 |
+
"total_documents": madagascar_count,
|
| 136 |
+
"case_study_count": madagascar_case_study_count,
|
| 137 |
+
"articles_count": madagascar_articles_count,
|
| 138 |
+
"has_embeddings": bool(madagascar_sample and 'vecteur_embedding' in madagascar_sample),
|
| 139 |
+
"sample_fields": list(madagascar_sample.keys()) if madagascar_sample else [],
|
| 140 |
+
"sample_doc_type": madagascar_sample.get('doc_type', 'NOT_SET') if madagascar_sample else None
|
| 141 |
+
}
|
| 142 |
+
}
|
| 143 |
+
except Exception as e:
|
| 144 |
+
logger.error(f"Error getting collection stats: {e}")
|
| 145 |
+
print(f"Error getting collection stats: {e}")
|
| 146 |
+
return {}
|
| 147 |
+
|
| 148 |
+
def close(self):
|
| 149 |
+
"""Close MongoDB connection"""
|
| 150 |
+
if self.client:
|
| 151 |
+
self.client.close()
|
| 152 |
+
logger.info("✅ MongoDB connection closed")
|
| 153 |
+
print("✅ MongoDB connection closed")
|
database/postgres_checkpointer.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# database/postgres_checkpointer.py - CORRECT VERSION
|
| 2 |
+
from psycopg_pool import AsyncConnectionPool
|
| 3 |
+
from psycopg.rows import dict_row
|
| 4 |
+
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver # ✅ Correct import
|
| 5 |
+
from langgraph.checkpoint.memory import MemorySaver
|
| 6 |
+
import logging
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
class PostgresCheckpointer:
|
| 12 |
+
def __init__(self, database_url: str, max_connections: int = 10, min_connections: int = 2):
|
| 13 |
+
self.database_url = database_url
|
| 14 |
+
self.max_connections = max_connections
|
| 15 |
+
self.min_connections = min_connections
|
| 16 |
+
self.pool: Optional[AsyncConnectionPool] = None
|
| 17 |
+
self.checkpointer: Optional[AsyncPostgresSaver] = None # ✅ Correct type
|
| 18 |
+
self._is_initialized = False
|
| 19 |
+
|
| 20 |
+
async def initialize(self) -> bool:
|
| 21 |
+
"""Initialize PostgreSQL connection pool and checkpointer"""
|
| 22 |
+
try:
|
| 23 |
+
# Create async connection pool
|
| 24 |
+
self.pool = AsyncConnectionPool(
|
| 25 |
+
conninfo=self.database_url,
|
| 26 |
+
max_size=self.max_connections,
|
| 27 |
+
min_size=self.min_connections,
|
| 28 |
+
kwargs={"row_factory": dict_row, "autocommit": True},
|
| 29 |
+
open=False,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
await self.pool.open()
|
| 33 |
+
|
| 34 |
+
# ✅ CORRECT: Use AsyncPostgresSaver with AsyncConnectionPool
|
| 35 |
+
self.checkpointer = AsyncPostgresSaver(self.pool)
|
| 36 |
+
await self.checkpointer.setup() # ✅ Async setup method
|
| 37 |
+
|
| 38 |
+
self._is_initialized = True
|
| 39 |
+
logger.info("✅ PostgreSQL checkpointer initialized successfully with AsyncPostgresSaver")
|
| 40 |
+
return True
|
| 41 |
+
|
| 42 |
+
except Exception as e:
|
| 43 |
+
logger.error(f"❌ PostgreSQL initialization failed: {e}")
|
| 44 |
+
|
| 45 |
+
# Fallback to in-memory
|
| 46 |
+
try:
|
| 47 |
+
from langgraph.checkpoint.memory_aio import AsyncMemorySaver # ✅ Async memory saver
|
| 48 |
+
self.checkpointer = AsyncMemorySaver()
|
| 49 |
+
logger.warning("🔄 Falling back to async in-memory checkpointer")
|
| 50 |
+
self._is_initialized = True
|
| 51 |
+
return True
|
| 52 |
+
except ImportError:
|
| 53 |
+
# Fallback to sync MemorySaver if async not available
|
| 54 |
+
self.checkpointer = MemorySaver()
|
| 55 |
+
logger.warning("🔄 Falling back to sync in-memory checkpointer")
|
| 56 |
+
self._is_initialized = True
|
| 57 |
+
return True
|
| 58 |
+
except Exception as fallback_error:
|
| 59 |
+
logger.error(f"❌ Even fallback failed: {fallback_error}")
|
| 60 |
+
return False
|
| 61 |
+
|
| 62 |
+
async def close(self):
|
| 63 |
+
"""Close connections with proper cleanup"""
|
| 64 |
+
if self.pool:
|
| 65 |
+
await self.pool.close()
|
| 66 |
+
logger.info("✅ PostgreSQL connection pool closed")
|
| 67 |
+
|
| 68 |
+
self._is_initialized = False
|
| 69 |
+
|
| 70 |
+
async def health_check(self) -> dict:
|
| 71 |
+
"""Check the health of the PostgreSQL connection"""
|
| 72 |
+
if not self._is_initialized or not self.pool:
|
| 73 |
+
return {"status": "uninitialized", "healthy": False}
|
| 74 |
+
|
| 75 |
+
try:
|
| 76 |
+
async with self.pool.connection() as conn:
|
| 77 |
+
async with conn.cursor() as cur:
|
| 78 |
+
await cur.execute("SELECT 1")
|
| 79 |
+
result = await cur.fetchone()
|
| 80 |
+
|
| 81 |
+
return {
|
| 82 |
+
"status": "healthy",
|
| 83 |
+
"healthy": True,
|
| 84 |
+
"connection_count": self.pool.size if hasattr(self.pool, 'size') else "unknown"
|
| 85 |
+
}
|
| 86 |
+
except Exception as e:
|
| 87 |
+
return {"status": f"unhealthy: {str(e)}", "healthy": False}
|
| 88 |
+
|
| 89 |
+
def is_initialized(self) -> bool:
|
| 90 |
+
"""Check if checkpointer is properly initialized"""
|
| 91 |
+
return self._is_initialized and self.checkpointer is not None
|
| 92 |
+
|
| 93 |
+
def get_checkpointer(self) -> AsyncPostgresSaver:
|
| 94 |
+
"""Get the underlying checkpointer instance"""
|
| 95 |
+
if not self.is_initialized():
|
| 96 |
+
raise RuntimeError("Checkpointer not initialized. Call initialize() first.")
|
| 97 |
+
return self.checkpointer
|
generate_graph.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# generate_graph.py (example)
|
| 2 |
+
from graphviz import Digraph
|
| 3 |
+
|
| 4 |
+
def generate_graph():
|
| 5 |
+
dot = Digraph(comment='Legal RAG System Workflow')
|
| 6 |
+
dot.attr(rankdir='TB', size='8,5')
|
| 7 |
+
dot.attr('node', shape='box', style='filled', fillcolor='#e6f3ff')
|
| 8 |
+
|
| 9 |
+
# Core nodes
|
| 10 |
+
dot.node("START", fillcolor="#90ee90")
|
| 11 |
+
dot.node("ROUTER", label="router")
|
| 12 |
+
dot.node("RESPONSE", label="response", fillcolor="#98fb98")
|
| 13 |
+
|
| 14 |
+
# Country nodes
|
| 15 |
+
dot.node("BENIN_RETRIEVAL", label="benin_retrieval")
|
| 16 |
+
dot.node("MADAGASCAR_RETRIEVAL", label="madagascar_retrieval")
|
| 17 |
+
|
| 18 |
+
# Handler nodes
|
| 19 |
+
dot.node("GREETING", label="greeting_handler")
|
| 20 |
+
dot.node("REPAIR", label="repair_handler")
|
| 21 |
+
dot.node("SUMMARY", label="summary_handler")
|
| 22 |
+
dot.node("UNCLEAR", label="unclear_handler")
|
| 23 |
+
dot.node("OUT_OF_SCOPE", label="out_of_scope_handler")
|
| 24 |
+
|
| 25 |
+
# Assistance nodes
|
| 26 |
+
dot.node("ASSIST_COLLECT", label="assistance_collect_info")
|
| 27 |
+
dot.node("ASSIST_CONFIRM", label="assistance_confirm")
|
| 28 |
+
dot.node("HUMAN_APPROVAL", label="human_approval", fillcolor="#ffa07a")
|
| 29 |
+
|
| 30 |
+
# End node
|
| 31 |
+
dot.node("END", fillcolor="#ff9999")
|
| 32 |
+
|
| 33 |
+
# Edges
|
| 34 |
+
dot.edge("START", "ROUTER")
|
| 35 |
+
dot.edge("ROUTER", "BENIN_RETRIEVAL", label="benin")
|
| 36 |
+
dot.edge("ROUTER", "MADAGASCAR_RETRIEVAL", label="madagascar")
|
| 37 |
+
dot.edge("ROUTER", "GREETING", label="greeting_small_talk")
|
| 38 |
+
dot.edge("ROUTER", "REPAIR", label="conversation_repair")
|
| 39 |
+
dot.edge("ROUTER", "SUMMARY", label="conversation_summarization")
|
| 40 |
+
dot.edge("ROUTER", "UNCLEAR", label="unclear")
|
| 41 |
+
dot.edge("ROUTER", "OUT_OF_SCOPE", label="out_of_scope")
|
| 42 |
+
dot.edge("ROUTER", "ASSIST_COLLECT", label="assistance_request")
|
| 43 |
+
|
| 44 |
+
dot.edge("GREETING", "RESPONSE")
|
| 45 |
+
dot.edge("REPAIR", "RESPONSE")
|
| 46 |
+
dot.edge("SUMMARY", "RESPONSE")
|
| 47 |
+
dot.edge("UNCLEAR", "RESPONSE")
|
| 48 |
+
dot.edge("OUT_OF_SCOPE", "RESPONSE")
|
| 49 |
+
|
| 50 |
+
dot.edge("ASSIST_COLLECT", "RESPONSE", label="need_email/need_description")
|
| 51 |
+
dot.edge("ASSIST_COLLECT", "ASSIST_CONFIRM", label="ready_to_confirm")
|
| 52 |
+
dot.edge("ASSIST_COLLECT", "RESPONSE", label="cancelled")
|
| 53 |
+
|
| 54 |
+
dot.edge("ASSIST_CONFIRM", "HUMAN_APPROVAL", label="confirmed")
|
| 55 |
+
dot.edge("ASSIST_CONFIRM", "RESPONSE", label="needs_correction/cancelled")
|
| 56 |
+
|
| 57 |
+
dot.edge("HUMAN_APPROVAL", "RESPONSE")
|
| 58 |
+
|
| 59 |
+
dot.edge("RESPONSE", "ASSIST_COLLECT", label="continue_assistance")
|
| 60 |
+
dot.edge("RESPONSE", "END", label="end")
|
| 61 |
+
|
| 62 |
+
dot.render('legal_rag_workflow', format='png', cleanup=True)
|
| 63 |
+
print("Graph visualization generated: legal_rag_workflow.png")
|
| 64 |
+
|
| 65 |
+
if __name__ == "__main__":
|
| 66 |
+
generate_graph()
|
interfaces/__init__.py
ADDED
|
File without changes
|
interfaces/monitoring.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime
|
| 2 |
+
from typing import Dict, List
|
| 3 |
+
import logging
|
| 4 |
+
|
| 5 |
+
class LegalRAGMonitor:
|
| 6 |
+
"""Monitoring and error tracking for the legal RAG system"""
|
| 7 |
+
|
| 8 |
+
def __init__(self):
|
| 9 |
+
self.error_log = []
|
| 10 |
+
self.performance_metrics = {
|
| 11 |
+
"query_times": [],
|
| 12 |
+
"routing_accuracy": [],
|
| 13 |
+
"retrieval_success_rate": 0
|
| 14 |
+
}
|
| 15 |
+
self.alerts = []
|
| 16 |
+
|
| 17 |
+
def log_error(self, error_type: str, message: str, context: Dict = None):
|
| 18 |
+
"""Log errors for analysis"""
|
| 19 |
+
error_entry = {
|
| 20 |
+
"timestamp": datetime.now(),
|
| 21 |
+
"type": error_type,
|
| 22 |
+
"message": message,
|
| 23 |
+
"context": context or {}
|
| 24 |
+
}
|
| 25 |
+
self.error_log.append(error_entry)
|
| 26 |
+
logging.error(f"[{error_type}] {message}")
|
| 27 |
+
|
| 28 |
+
# Check for alert conditions
|
| 29 |
+
self._check_alerts(error_type, error_entry)
|
| 30 |
+
|
| 31 |
+
def track_query_performance(self, query_time: float, success: bool):
|
| 32 |
+
"""Track query performance metrics"""
|
| 33 |
+
self.performance_metrics["query_times"].append(query_time)
|
| 34 |
+
|
| 35 |
+
# Update success rate
|
| 36 |
+
current_rate = self.performance_metrics["retrieval_success_rate"]
|
| 37 |
+
total_queries = len(self.performance_metrics["query_times"])
|
| 38 |
+
|
| 39 |
+
if success:
|
| 40 |
+
self.performance_metrics["retrieval_success_rate"] = (
|
| 41 |
+
(current_rate * (total_queries - 1) + 1) / total_queries
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
def get_health_report(self) -> Dict:
|
| 45 |
+
"""Generate system health report"""
|
| 46 |
+
query_times = self.performance_metrics["query_times"]
|
| 47 |
+
|
| 48 |
+
return {
|
| 49 |
+
"error_count": len(self.error_log),
|
| 50 |
+
"recent_errors": self.error_log[-5:],
|
| 51 |
+
"avg_query_time": sum(query_times) / len(query_times) if query_times else 0,
|
| 52 |
+
"success_rate": self.performance_metrics["retrieval_success_rate"],
|
| 53 |
+
"total_queries": len(query_times),
|
| 54 |
+
"active_alerts": len(self.alerts)
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
def _check_alerts(self, error_type: str, error_entry: Dict):
|
| 58 |
+
"""Check if error should trigger an alert"""
|
| 59 |
+
# Example alert conditions
|
| 60 |
+
if error_type == "database_connection":
|
| 61 |
+
self.alerts.append({
|
| 62 |
+
"type": "critical",
|
| 63 |
+
"message": "Database connection failure",
|
| 64 |
+
"timestamp": datetime.now(),
|
| 65 |
+
"error": error_entry
|
| 66 |
+
})
|
| 67 |
+
|
| 68 |
+
# Clean old alerts (keep only last 24 hours)
|
| 69 |
+
cutoff_time = datetime.now().timestamp() - (24 * 3600)
|
| 70 |
+
self.alerts = [
|
| 71 |
+
alert for alert in self.alerts
|
| 72 |
+
if alert["timestamp"].timestamp() > cutoff_time
|
| 73 |
+
]
|
| 74 |
+
|
| 75 |
+
class AlertManager:
|
| 76 |
+
"""Manage system alerts and notifications"""
|
| 77 |
+
|
| 78 |
+
def __init__(self):
|
| 79 |
+
self.alerts = []
|
| 80 |
+
self.subscribers = []
|
| 81 |
+
|
| 82 |
+
def add_alert(self, alert_type: str, message: str, severity: str = "warning"):
|
| 83 |
+
"""Add a new alert"""
|
| 84 |
+
alert = {
|
| 85 |
+
"type": alert_type,
|
| 86 |
+
"message": message,
|
| 87 |
+
"severity": severity,
|
| 88 |
+
"timestamp": datetime.now(),
|
| 89 |
+
"acknowledged": False
|
| 90 |
+
}
|
| 91 |
+
self.alerts.append(alert)
|
| 92 |
+
self._notify_subscribers(alert)
|
| 93 |
+
|
| 94 |
+
def acknowledge_alert(self, alert_index: int):
|
| 95 |
+
"""Acknowledge an alert"""
|
| 96 |
+
if 0 <= alert_index < len(self.alerts):
|
| 97 |
+
self.alerts[alert_index]["acknowledged"] = True
|
| 98 |
+
|
| 99 |
+
def subscribe(self, callback):
|
| 100 |
+
"""Subscribe to alert notifications"""
|
| 101 |
+
self.subscribers.append(callback)
|
| 102 |
+
|
| 103 |
+
def _notify_subscribers(self, alert):
|
| 104 |
+
"""Notify all subscribers of a new alert"""
|
| 105 |
+
for subscriber in self.subscribers:
|
| 106 |
+
try:
|
| 107 |
+
subscriber(alert)
|
| 108 |
+
except Exception as e:
|
| 109 |
+
logging.error(f"Error notifying subscriber: {e}")
|
interfaces/web_interface.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, HTTPException, Depends
|
| 2 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
+
from pydantic import BaseModel
|
| 4 |
+
from typing import Dict, Optional
|
| 5 |
+
import uvicorn
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
|
| 8 |
+
from core.chat_manager import LegalChatManager
|
| 9 |
+
|
| 10 |
+
# Pydantic models for API
|
| 11 |
+
class ChatRequest(BaseModel):
|
| 12 |
+
query: str
|
| 13 |
+
session_id: Optional[str] = None
|
| 14 |
+
context: Optional[Dict] = None
|
| 15 |
+
|
| 16 |
+
class ChatResponse(BaseModel):
|
| 17 |
+
response: str
|
| 18 |
+
session_id: str
|
| 19 |
+
session_stats: Dict
|
| 20 |
+
error: Optional[str] = None
|
| 21 |
+
|
| 22 |
+
class HealthResponse(BaseModel):
|
| 23 |
+
status: str
|
| 24 |
+
stats: Dict
|
| 25 |
+
timestamp: str
|
| 26 |
+
|
| 27 |
+
class LegalRAGAPI:
|
| 28 |
+
def __init__(self, chat_manager: LegalChatManager):
|
| 29 |
+
self.app = FastAPI(title="Legal RAG API", version="1.0.0")
|
| 30 |
+
self.chat_manager = chat_manager
|
| 31 |
+
self._setup_middleware()
|
| 32 |
+
self._setup_routes()
|
| 33 |
+
|
| 34 |
+
def _setup_middleware(self):
|
| 35 |
+
"""Setup CORS and other middleware"""
|
| 36 |
+
self.app.add_middleware(
|
| 37 |
+
CORSMiddleware,
|
| 38 |
+
allow_origins=["*"],
|
| 39 |
+
allow_credentials=True,
|
| 40 |
+
allow_methods=["*"],
|
| 41 |
+
allow_headers=["*"],
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
def _setup_routes(self):
|
| 45 |
+
"""Setup API routes"""
|
| 46 |
+
|
| 47 |
+
@self.app.get("/")
|
| 48 |
+
async def root():
|
| 49 |
+
return {"message": "Legal RAG API is running"}
|
| 50 |
+
|
| 51 |
+
@self.app.post("/chat", response_model=ChatResponse)
|
| 52 |
+
async def chat_endpoint(request: ChatRequest):
|
| 53 |
+
try:
|
| 54 |
+
session_id = request.session_id or f"web_{datetime.now().timestamp()}"
|
| 55 |
+
|
| 56 |
+
response = await self.chat_manager.chat(
|
| 57 |
+
request.query,
|
| 58 |
+
session_id,
|
| 59 |
+
request.context
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
session_stats = self.chat_manager.get_session_stats(session_id)
|
| 63 |
+
|
| 64 |
+
return ChatResponse(
|
| 65 |
+
response=response,
|
| 66 |
+
session_id=session_id,
|
| 67 |
+
session_stats=session_stats,
|
| 68 |
+
error=None
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
except Exception as e:
|
| 72 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 73 |
+
|
| 74 |
+
@self.app.get("/health", response_model=HealthResponse)
|
| 75 |
+
async def health_check():
|
| 76 |
+
return HealthResponse(
|
| 77 |
+
status="healthy",
|
| 78 |
+
stats=self.chat_manager.get_global_stats(),
|
| 79 |
+
timestamp=datetime.now().isoformat()
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
@self.app.get("/sessions/{session_id}/history")
|
| 83 |
+
async def get_session_history(session_id: str):
|
| 84 |
+
try:
|
| 85 |
+
history = await self.chat_manager.get_conversation_history(session_id)
|
| 86 |
+
return {
|
| 87 |
+
"session_id": session_id,
|
| 88 |
+
"message_count": len(history),
|
| 89 |
+
"messages": history
|
| 90 |
+
}
|
| 91 |
+
except Exception as e:
|
| 92 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 93 |
+
|
| 94 |
+
def run(self, host: str = "0.0.0.0", port: int = 8000):
|
| 95 |
+
"""Run the API server"""
|
| 96 |
+
uvicorn.run(self.app, host=host, port=port)
|
main.py
ADDED
|
@@ -0,0 +1,629 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Scalable Multi-Country Legal RAG System
|
| 4 |
+
Supports dynamic addition of new countries with clean architecture
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import asyncio
|
| 8 |
+
import logging
|
| 9 |
+
import time
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
from typing import List, Dict, Any, Optional
|
| 12 |
+
|
| 13 |
+
from config.settings import settings
|
| 14 |
+
from database.mongodb_client import MongoDBClient
|
| 15 |
+
from database.postgres_checkpointer import PostgresCheckpointer
|
| 16 |
+
from core.router import CountryRouter
|
| 17 |
+
from core.retriever import LegalRetriever
|
| 18 |
+
from core.graph_builder import GraphBuilder
|
| 19 |
+
from core.chat_manager import LegalChatManager
|
| 20 |
+
from utils.logger import setup_logging
|
| 21 |
+
|
| 22 |
+
import uuid
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class MultiCountryLegalRAGSystem:
|
| 26 |
+
"""Scalable system class supporting dynamic country addition"""
|
| 27 |
+
|
| 28 |
+
def __init__(self):
|
| 29 |
+
self.mongo_client = MongoDBClient()
|
| 30 |
+
self.postgres_checkpointer = PostgresCheckpointer(
|
| 31 |
+
database_url=settings.DATABASE_URL,
|
| 32 |
+
max_connections=10,
|
| 33 |
+
min_connections=2
|
| 34 |
+
)
|
| 35 |
+
self.router = None
|
| 36 |
+
# Dynamic country retrievers dictionary - easily extensible!
|
| 37 |
+
self.country_retrievers = {}
|
| 38 |
+
self.llm = None
|
| 39 |
+
self.graph = None
|
| 40 |
+
self.chat_manager = None
|
| 41 |
+
self.initialized = False
|
| 42 |
+
|
| 43 |
+
async def initialize(self) -> bool:
|
| 44 |
+
"""Initialize the complete scalable system"""
|
| 45 |
+
try:
|
| 46 |
+
setup_logging()
|
| 47 |
+
settings.validate()
|
| 48 |
+
|
| 49 |
+
# Initialize databases
|
| 50 |
+
if not self.mongo_client.connect():
|
| 51 |
+
raise Exception("MongoDB connection failed")
|
| 52 |
+
|
| 53 |
+
if not await self.postgres_checkpointer.initialize():
|
| 54 |
+
logging.warning("PostgreSQL initialization failed")
|
| 55 |
+
|
| 56 |
+
# Initialize core components
|
| 57 |
+
self.router = CountryRouter()
|
| 58 |
+
|
| 59 |
+
# Initialize default countries - easily extensible!
|
| 60 |
+
self._initialize_default_countries()
|
| 61 |
+
|
| 62 |
+
# Initialize LLM
|
| 63 |
+
from langchain_openai import ChatOpenAI
|
| 64 |
+
self.llm = ChatOpenAI(
|
| 65 |
+
model=settings.CHAT_MODEL,
|
| 66 |
+
temperature=settings.CHAT_TEMPERATURE,
|
| 67 |
+
max_tokens=settings.CHAT_MAX_TOKENS
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# Build scalable graph with country dictionary
|
| 71 |
+
graph_builder = GraphBuilder(
|
| 72 |
+
router=self.router,
|
| 73 |
+
llm=self.llm,
|
| 74 |
+
checkpointer=self.postgres_checkpointer.get_checkpointer(),
|
| 75 |
+
country_retrievers=self.country_retrievers # Pass the dictionary
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
workflow = graph_builder.build_graph()
|
| 79 |
+
|
| 80 |
+
# Compile with interrupt support
|
| 81 |
+
self.graph = workflow.compile(
|
| 82 |
+
checkpointer=self.postgres_checkpointer.get_checkpointer(),
|
| 83 |
+
interrupt_before=["human_approval"]
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
# Initialize chat manager
|
| 87 |
+
self.chat_manager = LegalChatManager(
|
| 88 |
+
self.graph,
|
| 89 |
+
self.postgres_checkpointer.get_checkpointer()
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
await self._perform_health_check()
|
| 93 |
+
|
| 94 |
+
self.initialized = True
|
| 95 |
+
logging.info(f"✅ System initialized with {len(self.country_retrievers)} countries")
|
| 96 |
+
self._print_system_info()
|
| 97 |
+
|
| 98 |
+
return True
|
| 99 |
+
|
| 100 |
+
except Exception as e:
|
| 101 |
+
logging.error(f"❌ System initialization failed: {e}")
|
| 102 |
+
import traceback
|
| 103 |
+
traceback.print_exc()
|
| 104 |
+
return False
|
| 105 |
+
|
| 106 |
+
def _initialize_default_countries(self):
|
| 107 |
+
"""Initialize default countries - easily extensible!"""
|
| 108 |
+
# Benin
|
| 109 |
+
if hasattr(self.mongo_client, 'benin_vectorstore'):
|
| 110 |
+
self.country_retrievers["benin"] = LegalRetriever(
|
| 111 |
+
self.mongo_client.benin_vectorstore,
|
| 112 |
+
self.mongo_client.benin_collection
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
# Madagascar
|
| 116 |
+
if hasattr(self.mongo_client, 'madagascar_vectorstore'):
|
| 117 |
+
self.country_retrievers["madagascar"] = LegalRetriever(
|
| 118 |
+
self.mongo_client.madagascar_vectorstore,
|
| 119 |
+
self.mongo_client.madagascar_collection
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
logging.info(f"🌍 Initialized {len(self.country_retrievers)} default countries")
|
| 123 |
+
|
| 124 |
+
def add_country(self, country_code: str, vectorstore, collection) -> bool:
|
| 125 |
+
"""Dynamically add a new country to the running system"""
|
| 126 |
+
try:
|
| 127 |
+
if country_code in self.country_retrievers:
|
| 128 |
+
logging.warning(f"Country {country_code} already exists")
|
| 129 |
+
return False
|
| 130 |
+
|
| 131 |
+
new_retriever = LegalRetriever(vectorstore, collection)
|
| 132 |
+
self.country_retrievers[country_code] = new_retriever
|
| 133 |
+
|
| 134 |
+
# Rebuild graph if system is already initialized
|
| 135 |
+
if self.initialized:
|
| 136 |
+
graph_builder = GraphBuilder(
|
| 137 |
+
router=self.router,
|
| 138 |
+
llm=self.llm,
|
| 139 |
+
checkpointer=self.postgres_checkpointer.get_checkpointer(),
|
| 140 |
+
country_retrievers=self.country_retrievers
|
| 141 |
+
)
|
| 142 |
+
workflow = graph_builder.build_graph()
|
| 143 |
+
self.graph = workflow.compile(
|
| 144 |
+
checkpointer=self.postgres_checkpointer.get_checkpointer(),
|
| 145 |
+
interrupt_before=["human_approval"]
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
logging.info(f"🎉 Successfully added country: {country_code}")
|
| 149 |
+
return True
|
| 150 |
+
|
| 151 |
+
except Exception as e:
|
| 152 |
+
logging.error(f"❌ Failed to add country {country_code}: {e}")
|
| 153 |
+
return False
|
| 154 |
+
|
| 155 |
+
async def _perform_health_check(self):
|
| 156 |
+
"""Perform health check after initialization"""
|
| 157 |
+
try:
|
| 158 |
+
health_status = await self.health_check()
|
| 159 |
+
|
| 160 |
+
unhealthy_components = [k for k, v in health_status.get('components', {}).items() if not v]
|
| 161 |
+
if unhealthy_components:
|
| 162 |
+
logging.warning(f"⚠️ Unhealthy components: {unhealthy_components}")
|
| 163 |
+
|
| 164 |
+
except Exception as e:
|
| 165 |
+
logging.warning(f"⚠️ Health check failed: {e}")
|
| 166 |
+
|
| 167 |
+
async def health_check(self) -> Dict[str, Any]:
|
| 168 |
+
"""Comprehensive system health check"""
|
| 169 |
+
health_status = {
|
| 170 |
+
"system_initialized": self.initialized,
|
| 171 |
+
"mongodb_connected": self.mongo_client.client is not None,
|
| 172 |
+
"postgres_healthy": {},
|
| 173 |
+
"interrupt_enabled": True,
|
| 174 |
+
"available_countries": list(self.country_retrievers.keys()),
|
| 175 |
+
"components": {
|
| 176 |
+
"router": self.router is not None,
|
| 177 |
+
"llm": self.llm is not None,
|
| 178 |
+
"graph": self.graph is not None,
|
| 179 |
+
"chat_manager": self.chat_manager is not None,
|
| 180 |
+
"country_retrievers": len(self.country_retrievers) > 0
|
| 181 |
+
},
|
| 182 |
+
"timestamp": datetime.now().isoformat(),
|
| 183 |
+
"settings": {
|
| 184 |
+
"chat_model": settings.CHAT_MODEL,
|
| 185 |
+
"embedding_model": settings.EMBEDDING_MODEL,
|
| 186 |
+
"max_search_results": settings.MAX_SEARCH_RESULTS
|
| 187 |
+
}
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
# Test MongoDB connection
|
| 191 |
+
if health_status["mongodb_connected"]:
|
| 192 |
+
try:
|
| 193 |
+
self.mongo_client.client.admin.command('ping')
|
| 194 |
+
health_status["mongodb_ping"] = True
|
| 195 |
+
except Exception as e:
|
| 196 |
+
health_status["mongodb_ping"] = False
|
| 197 |
+
health_status["mongodb_error"] = str(e)
|
| 198 |
+
|
| 199 |
+
# Test PostgreSQL connection
|
| 200 |
+
if hasattr(self.postgres_checkpointer, 'health_check'):
|
| 201 |
+
postgres_health = await self.postgres_checkpointer.health_check()
|
| 202 |
+
health_status["postgres_healthy"] = postgres_health
|
| 203 |
+
|
| 204 |
+
return health_status
|
| 205 |
+
|
| 206 |
+
async def chat(self, message: str, session_id: str = None, context: dict = None) -> str:
|
| 207 |
+
"""Public chat interface"""
|
| 208 |
+
if not self.initialized:
|
| 209 |
+
raise RuntimeError("System not initialized. Call initialize() first.")
|
| 210 |
+
|
| 211 |
+
if not message or not message.strip():
|
| 212 |
+
raise ValueError("Message cannot be empty")
|
| 213 |
+
|
| 214 |
+
try:
|
| 215 |
+
# Prepare context
|
| 216 |
+
ctx = context or {}
|
| 217 |
+
ctx.setdefault("jurisdiction", "Unknown")
|
| 218 |
+
ctx.setdefault("user_type", "general")
|
| 219 |
+
ctx.setdefault("document_type", "legal")
|
| 220 |
+
ctx.setdefault("detected_country", "unknown")
|
| 221 |
+
|
| 222 |
+
session_id = session_id or f"cli_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
| 223 |
+
|
| 224 |
+
return await self.chat_manager.chat(message, session_id, ctx)
|
| 225 |
+
|
| 226 |
+
except Exception as e:
|
| 227 |
+
logging.error(f"❌ Chat error for session {session_id}: {e}")
|
| 228 |
+
return f"❌ Désolé, une erreur s'est produite lors du traitement de votre demande. Veuillez réessayer."
|
| 229 |
+
|
| 230 |
+
def get_session_info(self, session_id: str) -> Dict[str, Any]:
|
| 231 |
+
"""Get information about a specific session"""
|
| 232 |
+
if not self.initialized:
|
| 233 |
+
raise RuntimeError("System not initialized")
|
| 234 |
+
return self.chat_manager.get_session_stats(session_id)
|
| 235 |
+
|
| 236 |
+
def get_global_stats(self) -> Dict[str, Any]:
|
| 237 |
+
"""Get global system statistics"""
|
| 238 |
+
if not self.initialized:
|
| 239 |
+
raise RuntimeError("System not initialized")
|
| 240 |
+
return self.chat_manager.get_global_stats()
|
| 241 |
+
|
| 242 |
+
def get_available_countries(self) -> List[str]:
|
| 243 |
+
"""Get list of available countries"""
|
| 244 |
+
return list(self.country_retrievers.keys())
|
| 245 |
+
|
| 246 |
+
async def cleanup(self):
|
| 247 |
+
"""Cleanup resources"""
|
| 248 |
+
try:
|
| 249 |
+
if self.mongo_client:
|
| 250 |
+
self.mongo_client.close()
|
| 251 |
+
if self.postgres_checkpointer:
|
| 252 |
+
await self.postgres_checkpointer.close()
|
| 253 |
+
logging.info("✅ System cleanup completed")
|
| 254 |
+
except Exception as e:
|
| 255 |
+
logging.error(f"❌ Error during cleanup: {e}")
|
| 256 |
+
|
| 257 |
+
def _print_system_info(self):
|
| 258 |
+
"""Print system configuration information"""
|
| 259 |
+
countries = list(self.country_retrievers.keys())
|
| 260 |
+
print("\n" + "="*60)
|
| 261 |
+
print("🚀 SCALABLE MULTI-COUNTRY LEGAL RAG SYSTEM")
|
| 262 |
+
print("="*60)
|
| 263 |
+
print(f"🌍 Available Countries: {', '.join(countries) if countries else 'None'}")
|
| 264 |
+
print(f"🤖 AI Model: {settings.CHAT_MODEL}")
|
| 265 |
+
print(f"💾 Database: MongoDB + PostgreSQL")
|
| 266 |
+
print(f"🔍 Vector Search: {settings.EMBEDDING_MODEL}")
|
| 267 |
+
print(f"⏸️ Interrupt Support: ENABLED")
|
| 268 |
+
print(f"🌡️ Temperature: {settings.CHAT_TEMPERATURE}")
|
| 269 |
+
print(f"📝 Max Tokens: {settings.CHAT_MAX_TOKENS}")
|
| 270 |
+
print("="*60)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
class InterruptTester:
|
| 274 |
+
"""Specialized tester for human approval interrupts"""
|
| 275 |
+
|
| 276 |
+
def __init__(self, system: MultiCountryLegalRAGSystem):
|
| 277 |
+
self.system = system
|
| 278 |
+
self.test_results = []
|
| 279 |
+
|
| 280 |
+
async def test_assistance_workflow(self, test_name: str,
|
| 281 |
+
user_query: str,
|
| 282 |
+
user_email: str,
|
| 283 |
+
user_description: str,
|
| 284 |
+
moderator_response: str) -> Dict[str, Any]:
|
| 285 |
+
"""Test the complete assistance workflow with interrupt"""
|
| 286 |
+
print(f"\n🧪 Interrupt Test: {test_name}")
|
| 287 |
+
print(f"📝 User Query: {user_query}")
|
| 288 |
+
|
| 289 |
+
# session_id = f"test_{datetime.now().strftime('%H%M%S%f')}"
|
| 290 |
+
session_id = f"interactive_{uuid.uuid4().hex[:8]}"
|
| 291 |
+
current_response = ""
|
| 292 |
+
|
| 293 |
+
try:
|
| 294 |
+
# Step 1: Initial request
|
| 295 |
+
print("1️⃣ Step 1: Initial assistance request...")
|
| 296 |
+
current_response = await self.system.chat(user_query, session_id)
|
| 297 |
+
print(f"🤖 Response: {current_response[:150]}...")
|
| 298 |
+
|
| 299 |
+
# Step 2: Email collection
|
| 300 |
+
if user_email and any(keyword in current_response.lower() for keyword in ["email", "adresse", "@"]):
|
| 301 |
+
print(f"2️⃣ Step 2: Providing email: {user_email}")
|
| 302 |
+
current_response = await self.system.chat(user_email, session_id)
|
| 303 |
+
print(f"🤖 Response: {current_response[:150]}...")
|
| 304 |
+
|
| 305 |
+
# Step 3: Description collection
|
| 306 |
+
if user_description and any(keyword in current_response.lower() for keyword in ["description", "décrire", "besoin"]):
|
| 307 |
+
print(f"3️⃣ Step 3: Providing description: {user_description[:50]}...")
|
| 308 |
+
current_response = await self.system.chat(user_description, session_id)
|
| 309 |
+
print(f"🤖 Response: {current_response[:150]}...")
|
| 310 |
+
|
| 311 |
+
# Step 4: Confirmation
|
| 312 |
+
if any(keyword in current_response.lower() for keyword in ["confirmer", "confirmation", "oui/non"]):
|
| 313 |
+
print("4️⃣ Step 4: Confirming request...")
|
| 314 |
+
current_response = await self.system.chat("oui", session_id)
|
| 315 |
+
print(f"🤖 Response: {current_response[:150]}...")
|
| 316 |
+
|
| 317 |
+
# Step 5: Check for interrupt
|
| 318 |
+
interrupt_detected = self._check_for_interrupt(current_response, session_id)
|
| 319 |
+
|
| 320 |
+
if interrupt_detected:
|
| 321 |
+
print("⏸️ INTERRUPT DETECTED! Waiting for moderator...")
|
| 322 |
+
|
| 323 |
+
# Step 6: Moderator decision
|
| 324 |
+
print(f"👨⚖️ Moderator: {moderator_response}")
|
| 325 |
+
final_response = await self.system.chat(moderator_response, session_id)
|
| 326 |
+
print(f"✅ Final Response: {final_response[:200]}...")
|
| 327 |
+
|
| 328 |
+
result = {
|
| 329 |
+
"test_name": test_name,
|
| 330 |
+
"status": "PASS",
|
| 331 |
+
"interrupt_detected": True,
|
| 332 |
+
"moderator_decision": moderator_response,
|
| 333 |
+
"final_response": final_response,
|
| 334 |
+
"session_id": session_id
|
| 335 |
+
}
|
| 336 |
+
else:
|
| 337 |
+
print("⚠️ No interrupt detected in workflow")
|
| 338 |
+
result = {
|
| 339 |
+
"test_name": test_name,
|
| 340 |
+
"status": "FAIL",
|
| 341 |
+
"interrupt_detected": False,
|
| 342 |
+
"moderator_decision": None,
|
| 343 |
+
"final_response": current_response,
|
| 344 |
+
"error": "Interrupt not triggered",
|
| 345 |
+
"session_id": session_id
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
self.test_results.append(result)
|
| 349 |
+
return result
|
| 350 |
+
|
| 351 |
+
except Exception as e:
|
| 352 |
+
logging.error(f"❌ Test error: {e}")
|
| 353 |
+
error_result = {
|
| 354 |
+
"test_name": test_name,
|
| 355 |
+
"status": "ERROR",
|
| 356 |
+
"interrupt_detected": False,
|
| 357 |
+
"moderator_decision": None,
|
| 358 |
+
"final_response": current_response,
|
| 359 |
+
"error": str(e),
|
| 360 |
+
"session_id": session_id
|
| 361 |
+
}
|
| 362 |
+
self.test_results.append(error_result)
|
| 363 |
+
return error_result
|
| 364 |
+
|
| 365 |
+
def _check_for_interrupt(self, response: str, session_id: str) -> bool:
|
| 366 |
+
"""Enhanced interrupt detection"""
|
| 367 |
+
interrupt_indicators = [
|
| 368 |
+
"APPROBATION", "APPROVAL", "HUMAN", "MODERATOR",
|
| 369 |
+
"DÉCISION", "DECISION", "APPROUVER", "REJETER"
|
| 370 |
+
]
|
| 371 |
+
|
| 372 |
+
if any(indicator in response.upper() for indicator in interrupt_indicators):
|
| 373 |
+
return True
|
| 374 |
+
|
| 375 |
+
if (hasattr(self.system.chat_manager, 'pending_interrupts') and
|
| 376 |
+
session_id in self.system.chat_manager.pending_interrupts):
|
| 377 |
+
return True
|
| 378 |
+
|
| 379 |
+
return False
|
| 380 |
+
|
| 381 |
+
def print_summary(self):
|
| 382 |
+
"""Print test summary"""
|
| 383 |
+
print("\n" + "="*80)
|
| 384 |
+
print("📊 INTERRUPT TEST SUMMARY")
|
| 385 |
+
print("="*80)
|
| 386 |
+
|
| 387 |
+
total = len(self.test_results)
|
| 388 |
+
passed = len([r for r in self.test_results if r["status"] == "PASS"])
|
| 389 |
+
failed = len([r for r in self.test_results if r["status"] == "FAIL"])
|
| 390 |
+
errors = len([r for r in self.test_results if r["status"] == "ERROR"])
|
| 391 |
+
|
| 392 |
+
print(f"📈 Total Tests: {total}")
|
| 393 |
+
print(f"✅ Passed: {passed}")
|
| 394 |
+
print(f"❌ Failed: {failed}")
|
| 395 |
+
print(f"🚨 Errors: {errors}")
|
| 396 |
+
|
| 397 |
+
if passed > 0:
|
| 398 |
+
print(f"\n🎉 Successful Tests:")
|
| 399 |
+
for result in self.test_results:
|
| 400 |
+
if result["status"] == "PASS":
|
| 401 |
+
print(f" - {result['test_name']}")
|
| 402 |
+
|
| 403 |
+
if failed > 0 or errors > 0:
|
| 404 |
+
print(f"\n💥 Failed/Error Tests:")
|
| 405 |
+
for result in self.test_results:
|
| 406 |
+
if result["status"] in ["FAIL", "ERROR"]:
|
| 407 |
+
print(f" - {result['test_name']}: {result.get('error', 'Unknown error')}")
|
| 408 |
+
|
| 409 |
+
print("="*80)
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
async def run_interrupt_tests():
|
| 413 |
+
"""Run specialized tests for human approval interrupts"""
|
| 414 |
+
system = MultiCountryLegalRAGSystem()
|
| 415 |
+
tester = InterruptTester(system)
|
| 416 |
+
|
| 417 |
+
try:
|
| 418 |
+
print("🚀 Initializing system...")
|
| 419 |
+
success = await system.initialize()
|
| 420 |
+
if not success:
|
| 421 |
+
print("❌ System initialization failed")
|
| 422 |
+
return
|
| 423 |
+
|
| 424 |
+
print("\n🧪 STARTING INTERRUPT TESTS")
|
| 425 |
+
print("="*60)
|
| 426 |
+
|
| 427 |
+
test_scenarios = [
|
| 428 |
+
{
|
| 429 |
+
"name": "Complete Workflow - Approve",
|
| 430 |
+
"user_query": "Je veux parler a un avocat",
|
| 431 |
+
"user_email": "test@example.com",
|
| 432 |
+
"user_description": "Consultation pour divorce au Benin",
|
| 433 |
+
"moderator_response": "approve Demande legitime"
|
| 434 |
+
},
|
| 435 |
+
{
|
| 436 |
+
"name": "Complete Workflow - Reject",
|
| 437 |
+
"user_query": "Contactez-moi",
|
| 438 |
+
"user_email": "test2@example.com",
|
| 439 |
+
"user_description": "J'ai besoin d'aide",
|
| 440 |
+
"moderator_response": "reject Description trop vague"
|
| 441 |
+
}
|
| 442 |
+
]
|
| 443 |
+
|
| 444 |
+
for scenario in test_scenarios:
|
| 445 |
+
await tester.test_assistance_workflow(
|
| 446 |
+
scenario["name"],
|
| 447 |
+
scenario["user_query"],
|
| 448 |
+
scenario["user_email"],
|
| 449 |
+
scenario["user_description"],
|
| 450 |
+
scenario["moderator_response"]
|
| 451 |
+
)
|
| 452 |
+
await asyncio.sleep(1)
|
| 453 |
+
|
| 454 |
+
tester.print_summary()
|
| 455 |
+
|
| 456 |
+
except Exception as e:
|
| 457 |
+
logging.error(f"❌ Error during testing: {e}")
|
| 458 |
+
finally:
|
| 459 |
+
await system.cleanup()
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
async def interactive_mode():
|
| 463 |
+
"""Run interactive chat mode"""
|
| 464 |
+
system = MultiCountryLegalRAGSystem()
|
| 465 |
+
|
| 466 |
+
try:
|
| 467 |
+
print("🚀 Initializing system...")
|
| 468 |
+
success = await system.initialize()
|
| 469 |
+
if not success:
|
| 470 |
+
print("❌ System initialization failed")
|
| 471 |
+
return
|
| 472 |
+
|
| 473 |
+
print("\n🎯 INTERACTIVE MODE - SCALABLE SYSTEM")
|
| 474 |
+
print("="*60)
|
| 475 |
+
print("Commands:")
|
| 476 |
+
print(" 'quit' - Exit")
|
| 477 |
+
print(" 'stats' - Show statistics")
|
| 478 |
+
print(" 'health' - Health check")
|
| 479 |
+
print(" 'countries' - List available countries")
|
| 480 |
+
print(" 'session' - Session info")
|
| 481 |
+
print("="*60)
|
| 482 |
+
|
| 483 |
+
session_id = f"interactive_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
| 484 |
+
print(f"Session ID: {session_id}")
|
| 485 |
+
print(f"Available: {', '.join(system.get_available_countries())}\n")
|
| 486 |
+
|
| 487 |
+
while True:
|
| 488 |
+
try:
|
| 489 |
+
user_input = input("👤 You: ").strip()
|
| 490 |
+
|
| 491 |
+
if user_input.lower() in ['quit', 'exit', 'q']:
|
| 492 |
+
break
|
| 493 |
+
elif user_input.lower() == 'stats':
|
| 494 |
+
stats = system.get_global_stats()
|
| 495 |
+
print(f"\n📊 Statistics:")
|
| 496 |
+
print(f" Total Queries: {stats.get('total_queries', 0)}")
|
| 497 |
+
print(f" Active Sessions: {stats.get('active_sessions', 0)}")
|
| 498 |
+
print(f" Pending Interrupts: {stats.get('pending_interrupts', 0)}")
|
| 499 |
+
continue
|
| 500 |
+
elif user_input.lower() == 'health':
|
| 501 |
+
health = await system.health_check()
|
| 502 |
+
print(f"\n❤️ System Health:")
|
| 503 |
+
print(f" Status: {'✅ HEALTHY' if health['system_initialized'] else '❌ UNHEALTHY'}")
|
| 504 |
+
print(f" Countries: {len(health['available_countries'])} available")
|
| 505 |
+
print(f" MongoDB: {'✅ Connected' if health['mongodb_connected'] else '❌ Disconnected'}")
|
| 506 |
+
continue
|
| 507 |
+
elif user_input.lower() == 'countries':
|
| 508 |
+
countries = system.get_available_countries()
|
| 509 |
+
print(f"\n🌍 Available Countries: {', '.join(countries) if countries else 'None'}")
|
| 510 |
+
continue
|
| 511 |
+
elif user_input.lower() == 'session':
|
| 512 |
+
info = system.get_session_info(session_id)
|
| 513 |
+
print(f"\n📋 Session Info:")
|
| 514 |
+
print(f" Queries: {info.get('query_count', 0)}")
|
| 515 |
+
print(f" Avg Time: {info.get('average_processing_time', 0):.2f}s")
|
| 516 |
+
continue
|
| 517 |
+
elif not user_input:
|
| 518 |
+
continue
|
| 519 |
+
|
| 520 |
+
start_time = time.time()
|
| 521 |
+
response = await system.chat(user_input, session_id)
|
| 522 |
+
response_time = time.time() - start_time
|
| 523 |
+
|
| 524 |
+
print(f"🤖 Assistant ({response_time:.2f}s): {response}\n")
|
| 525 |
+
|
| 526 |
+
# Check for interrupt
|
| 527 |
+
if (hasattr(system.chat_manager, 'pending_interrupts') and
|
| 528 |
+
session_id in system.chat_manager.pending_interrupts):
|
| 529 |
+
print("⏸️ 💡 SYSTEM PAUSED - Next message treated as moderator decision\n")
|
| 530 |
+
|
| 531 |
+
except KeyboardInterrupt:
|
| 532 |
+
print("\n👋 Goodbye!")
|
| 533 |
+
break
|
| 534 |
+
except Exception as e:
|
| 535 |
+
print(f"❌ Error: {str(e)}\n")
|
| 536 |
+
|
| 537 |
+
finally:
|
| 538 |
+
await system.cleanup()
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
async def health_check_mode():
|
| 542 |
+
"""Run system health check only"""
|
| 543 |
+
system = MultiCountryLegalRAGSystem()
|
| 544 |
+
|
| 545 |
+
try:
|
| 546 |
+
print("🔍 Performing health check...")
|
| 547 |
+
success = await system.initialize()
|
| 548 |
+
|
| 549 |
+
if success:
|
| 550 |
+
health = await system.health_check()
|
| 551 |
+
print("\n" + "="*50)
|
| 552 |
+
print("📋 SYSTEM HEALTH REPORT")
|
| 553 |
+
print("="*50)
|
| 554 |
+
print(f"✅ System Initialized: {health['system_initialized']}")
|
| 555 |
+
print(f"🌍 Available Countries: {len(health['available_countries'])}")
|
| 556 |
+
print(f"💾 MongoDB: {'✅ Connected' if health['mongodb_connected'] else '❌ Disconnected'}")
|
| 557 |
+
print(f"⏸️ Interrupt Support: {'✅ Enabled' if health['interrupt_enabled'] else '❌ Disabled'}")
|
| 558 |
+
|
| 559 |
+
print(f"\n🔧 Components:")
|
| 560 |
+
for component, status in health['components'].items():
|
| 561 |
+
print(f" {component}: {'✅ OK' if status else '❌ Missing'}")
|
| 562 |
+
|
| 563 |
+
all_healthy = (health['system_initialized'] and
|
| 564 |
+
health['mongodb_connected'] and
|
| 565 |
+
all(health['components'].values()))
|
| 566 |
+
print(f"\n🎯 Overall Status: {'✅ HEALTHY' if all_healthy else '❌ UNHEALTHY'}")
|
| 567 |
+
|
| 568 |
+
else:
|
| 569 |
+
print("❌ System initialization failed")
|
| 570 |
+
|
| 571 |
+
finally:
|
| 572 |
+
await system.cleanup()
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
async def quick_test_mode():
|
| 576 |
+
"""Run a quick single test"""
|
| 577 |
+
system = MultiCountryLegalRAGSystem()
|
| 578 |
+
|
| 579 |
+
try:
|
| 580 |
+
print("🚀 Quick Test Mode")
|
| 581 |
+
print("Initializing system...")
|
| 582 |
+
success = await system.initialize()
|
| 583 |
+
if not success:
|
| 584 |
+
print("❌ System initialization failed")
|
| 585 |
+
return
|
| 586 |
+
|
| 587 |
+
test_query = "Bonjour, quelle est la procedure pour un divorce au Benin?"
|
| 588 |
+
session_id = "quick_test"
|
| 589 |
+
|
| 590 |
+
print(f"\n🧪 Testing: {test_query}")
|
| 591 |
+
start_time = time.time()
|
| 592 |
+
response = await system.chat(test_query, session_id)
|
| 593 |
+
response_time = time.time() - start_time
|
| 594 |
+
|
| 595 |
+
print(f"✅ Response ({response_time:.2f}s): {response}")
|
| 596 |
+
|
| 597 |
+
print(f"\n📊 System Info:")
|
| 598 |
+
print(f" Available Countries: {', '.join(system.get_available_countries())}")
|
| 599 |
+
|
| 600 |
+
except Exception as e:
|
| 601 |
+
print(f"❌ Quick test failed: {e}")
|
| 602 |
+
finally:
|
| 603 |
+
await system.cleanup()
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
if __name__ == "__main__":
|
| 607 |
+
import argparse
|
| 608 |
+
|
| 609 |
+
parser = argparse.ArgumentParser(
|
| 610 |
+
description="🚀 Scalable Multi-Country Legal RAG System"
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
parser.add_argument(
|
| 614 |
+
"--mode",
|
| 615 |
+
choices=["interactive", "health", "interrupt", "quick"],
|
| 616 |
+
default="interactive",
|
| 617 |
+
help="Run mode (default: interactive)"
|
| 618 |
+
)
|
| 619 |
+
|
| 620 |
+
args = parser.parse_args()
|
| 621 |
+
|
| 622 |
+
if args.mode == "interactive":
|
| 623 |
+
asyncio.run(interactive_mode())
|
| 624 |
+
elif args.mode == "health":
|
| 625 |
+
asyncio.run(health_check_mode())
|
| 626 |
+
elif args.mode == "interrupt":
|
| 627 |
+
asyncio.run(run_interrupt_tests())
|
| 628 |
+
elif args.mode == "quick":
|
| 629 |
+
asyncio.run(quick_test_mode())
|
models/__init__py
ADDED
|
File without changes
|
models/state_models.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# [file name]: models/state_models.py
|
| 2 |
+
from typing import List, Dict, Any, Optional, Annotated, Literal, Union
|
| 3 |
+
from pydantic import BaseModel, Field
|
| 4 |
+
import operator
|
| 5 |
+
|
| 6 |
+
class MultiCountryLegalState(BaseModel):
|
| 7 |
+
messages: Annotated[List[Dict[str, Any]], operator.add] = Field(default_factory=list)
|
| 8 |
+
legal_context: Dict[str, Any] = Field(
|
| 9 |
+
default_factory=lambda: {
|
| 10 |
+
"jurisdiction": "Unknown",
|
| 11 |
+
"user_type": "general",
|
| 12 |
+
"document_type": "legal",
|
| 13 |
+
"detected_country": "unknown"
|
| 14 |
+
}
|
| 15 |
+
)
|
| 16 |
+
# FIX: Make supplemental_message handle concurrent updates
|
| 17 |
+
supplemental_message: Optional[str] = Field(
|
| 18 |
+
default="",
|
| 19 |
+
description="Supplemental message to display to user (e.g., fallback messages, apologies)"
|
| 20 |
+
)
|
| 21 |
+
session_id: Optional[str] = None
|
| 22 |
+
last_search_query: Optional[str] = None
|
| 23 |
+
detected_articles: Annotated[List[str], operator.add] = Field(default_factory=list)
|
| 24 |
+
router_decision: Optional[str] = None
|
| 25 |
+
search_results: Optional[str] = None
|
| 26 |
+
route_explanation: Optional[str] = None
|
| 27 |
+
country: Optional[str] = Field(default=None)
|
| 28 |
+
|
| 29 |
+
# Assistance email fields
|
| 30 |
+
assistance_requested: bool = Field(default=False)
|
| 31 |
+
user_email: Optional[str] = None
|
| 32 |
+
assistance_description: Optional[str] = None
|
| 33 |
+
email_status: Optional[str] = None # "pending", "sent", "error"
|
| 34 |
+
assistance_step: Optional[str] = Field(default=None) # "collecting_email", "collecting_description", "confirming_send"
|
| 35 |
+
pending_assistance_data: Dict[str, Any] = Field(default_factory=dict)
|
| 36 |
+
|
| 37 |
+
# Conversation repair tracking
|
| 38 |
+
repair_type: Optional[str] = None
|
| 39 |
+
original_query: Optional[str] = None
|
| 40 |
+
misunderstanding_count: int = Field(default=0)
|
| 41 |
+
|
| 42 |
+
# Enhanced routing support
|
| 43 |
+
primary_intent: Optional[str] = Field(default=None)
|
| 44 |
+
|
| 45 |
+
# NEW: Human approval fields
|
| 46 |
+
approval_status: Optional[str] = Field(default=None) # "pending", "approved", "rejected"
|
| 47 |
+
approval_reason: Optional[str] = Field(default=None)
|
| 48 |
+
approved_by: Optional[str] = Field(default=None)
|
| 49 |
+
approval_timestamp: Optional[str] = Field(default=None)
|
| 50 |
+
|
| 51 |
+
# Conversation summary fields
|
| 52 |
+
summary_generated: bool = Field(default=False)
|
| 53 |
+
last_summary_timestamp: Optional[str] = Field(default=None)
|
| 54 |
+
|
| 55 |
+
# NEW: Search-related fields to prevent storing complex data in legal_context
|
| 56 |
+
search_metadata: Dict[str, Any] = Field(default_factory=dict)
|
| 57 |
+
|
| 58 |
+
@staticmethod
|
| 59 |
+
def detect_country(text: str) -> str:
|
| 60 |
+
"""
|
| 61 |
+
Detect country from text based on keywords.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
text: User input text to analyze
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
Country code: "benin", "madagascar", or "unknown"
|
| 68 |
+
"""
|
| 69 |
+
if not text:
|
| 70 |
+
return "unknown"
|
| 71 |
+
|
| 72 |
+
text_lower = text.lower()
|
| 73 |
+
|
| 74 |
+
# Benin keywords
|
| 75 |
+
benin_keywords = [
|
| 76 |
+
"bénin", "benin", "béninois", "béninoise",
|
| 77 |
+
"cotonou", "porto-novo", "porto novo",
|
| 78 |
+
"dahomey" # Historical name
|
| 79 |
+
]
|
| 80 |
+
|
| 81 |
+
# Madagascar keywords
|
| 82 |
+
madagascar_keywords = [
|
| 83 |
+
"madagascar", "malgache", "malagasy",
|
| 84 |
+
"antananarivo", "tananarive", "tana",
|
| 85 |
+
"toamasina", "tamatave"
|
| 86 |
+
]
|
| 87 |
+
|
| 88 |
+
# Check for country mentions
|
| 89 |
+
benin_score = sum(1 for keyword in benin_keywords if keyword in text_lower)
|
| 90 |
+
madagascar_score = sum(1 for keyword in madagascar_keywords if keyword in text_lower)
|
| 91 |
+
|
| 92 |
+
if benin_score > madagascar_score and benin_score > 0:
|
| 93 |
+
return "benin"
|
| 94 |
+
elif madagascar_score > benin_score and madagascar_score > 0:
|
| 95 |
+
return "madagascar"
|
| 96 |
+
|
| 97 |
+
return "unknown"
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class RoutingResult(BaseModel):
|
| 101 |
+
country: Literal["benin", "madagascar", "unclear", "greeting_small_talk",
|
| 102 |
+
"conversation_repair", "assistance_request", "conversation_summarization", "out_of_scope"]
|
| 103 |
+
confidence: Literal["high", "medium", "low"]
|
| 104 |
+
method: str
|
| 105 |
+
explanation: str
|
| 106 |
+
|
| 107 |
+
class SearchResult(BaseModel):
|
| 108 |
+
documents: List[Any]
|
| 109 |
+
detected_articles: List[str]
|
| 110 |
+
applied_filters: Dict[str, Any]
|
| 111 |
+
query: str
|
| 112 |
+
country: str
|
requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core dependencies
|
| 2 |
+
fastapi==0.104.1
|
| 3 |
+
uvicorn[standard]==0.24.0
|
| 4 |
+
langchain-openai==0.0.8
|
| 5 |
+
langchain-core==0.1.33
|
| 6 |
+
langgraph==0.0.52
|
| 7 |
+
langchain-mongodb==0.0.3
|
| 8 |
+
pymongo==4.6.1
|
| 9 |
+
openai==1.3.9
|
| 10 |
+
pydantic==2.5.0
|
| 11 |
+
python-dotenv==1.0.0
|
| 12 |
+
psycopg[binary]==3.1.13
|
| 13 |
+
langgraph-checkpoint-postgres==0.0.3
|
utils/__init__.py
ADDED
|
File without changes
|
utils/helpers.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Any
|
| 2 |
+
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage
|
| 3 |
+
|
| 4 |
+
def dict_to_message_obj(d: Dict[str, Any]) -> BaseMessage:
|
| 5 |
+
"""Convert dictionary to LangChain message object"""
|
| 6 |
+
role = d.get("role", "").lower()
|
| 7 |
+
content = d.get("content", "")
|
| 8 |
+
meta = d.get("meta", {}) or {}
|
| 9 |
+
|
| 10 |
+
if role in ("user", "human", "humanmessage"):
|
| 11 |
+
return HumanMessage(content=content, metadata=meta)
|
| 12 |
+
if role in ("assistant", "ai", "aimessage"):
|
| 13 |
+
return AIMessage(content=content, metadata=meta)
|
| 14 |
+
return SystemMessage(content=content, metadata=meta)
|
| 15 |
+
|
| 16 |
+
def message_obj_to_dict(msg: Any) -> Dict[str, Any]:
|
| 17 |
+
"""Convert LangChain message object to dictionary"""
|
| 18 |
+
content = getattr(msg, "content", str(msg))
|
| 19 |
+
meta = getattr(msg, "metadata", {}) or {}
|
| 20 |
+
|
| 21 |
+
if isinstance(msg, HumanMessage):
|
| 22 |
+
role = "user"
|
| 23 |
+
elif isinstance(msg, AIMessage):
|
| 24 |
+
role = "assistant"
|
| 25 |
+
elif isinstance(msg, SystemMessage):
|
| 26 |
+
role = "system"
|
| 27 |
+
else:
|
| 28 |
+
role = meta.get("role", "assistant")
|
| 29 |
+
|
| 30 |
+
return {"role": role, "content": content, "meta": meta}
|
| 31 |
+
|
| 32 |
+
def validate_country_code(country: str) -> str:
|
| 33 |
+
"""Validate and normalize country code"""
|
| 34 |
+
country = country.lower().strip()
|
| 35 |
+
if country in ["benin", "bj", "bénin"]:
|
| 36 |
+
return "benin"
|
| 37 |
+
elif country in ["madagascar", "mg", "madagasikara"]:
|
| 38 |
+
return "madagascar"
|
| 39 |
+
else:
|
| 40 |
+
return "unclear"
|
| 41 |
+
|
| 42 |
+
def format_legal_citation(article_number: str, law_title: str, country: str) -> str:
|
| 43 |
+
"""Format legal citation in standard format"""
|
| 44 |
+
country_formats = {
|
| 45 |
+
"benin": f"Article {article_number} du {law_title} (Bénin)",
|
| 46 |
+
"madagascar": f"Article {article_number} du {law_title} (Madagascar)"
|
| 47 |
+
}
|
| 48 |
+
return country_formats.get(country, f"Article {article_number} du {law_title}")
|
| 49 |
+
|
| 50 |
+
def safe_get(dictionary: Dict, key: str, default: Any = None) -> Any:
|
| 51 |
+
"""Safely get value from dictionary with default"""
|
| 52 |
+
if isinstance(dictionary, dict):
|
| 53 |
+
return dictionary.get(key, default)
|
| 54 |
+
return default
|
| 55 |
+
|
| 56 |
+
def truncate_text(text: str, max_length: int = 500) -> str:
|
| 57 |
+
"""Truncate text to specified length"""
|
| 58 |
+
if len(text) <= max_length:
|
| 59 |
+
return text
|
| 60 |
+
return text[:max_length] + "..."
|
| 61 |
+
|
| 62 |
+
def calculate_confidence_score(patterns_found: int, llm_confidence: str) -> float:
|
| 63 |
+
"""Calculate a numerical confidence score"""
|
| 64 |
+
pattern_score = min(patterns_found * 0.3, 0.6) # Max 0.6 from patterns
|
| 65 |
+
llm_scores = {"high": 0.8, "medium": 0.5, "low": 0.2}
|
| 66 |
+
llm_score = llm_scores.get(llm_confidence, 0.2)
|
| 67 |
+
|
| 68 |
+
return min(pattern_score + llm_score, 1.0)
|
utils/logger.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import sys
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
from typing import Dict, Any
|
| 5 |
+
|
| 6 |
+
def setup_logging(level=logging.INFO):
|
| 7 |
+
"""Setup comprehensive logging configuration"""
|
| 8 |
+
|
| 9 |
+
# Create formatter
|
| 10 |
+
formatter = logging.Formatter(
|
| 11 |
+
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
| 12 |
+
datefmt='%Y-%m-%d %H:%M:%S'
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
# Console handler
|
| 16 |
+
console_handler = logging.StreamHandler(sys.stdout)
|
| 17 |
+
console_handler.setLevel(level)
|
| 18 |
+
console_handler.setFormatter(formatter)
|
| 19 |
+
|
| 20 |
+
# File handler
|
| 21 |
+
file_handler = logging.FileHandler(f'legal_rag_{datetime.now().strftime("%Y%m%d")}.log')
|
| 22 |
+
file_handler.setLevel(logging.DEBUG)
|
| 23 |
+
file_handler.setFormatter(formatter)
|
| 24 |
+
|
| 25 |
+
# Configure root logger
|
| 26 |
+
logging.basicConfig(
|
| 27 |
+
level=level,
|
| 28 |
+
handlers=[console_handler, file_handler],
|
| 29 |
+
force=True
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
# Specific logger configurations
|
| 33 |
+
legal_logger = logging.getLogger('legal_rag')
|
| 34 |
+
legal_logger.setLevel(logging.DEBUG)
|
| 35 |
+
|
| 36 |
+
mongodb_logger = logging.getLogger('pymongo')
|
| 37 |
+
mongodb_logger.setLevel(logging.WARNING)
|
| 38 |
+
|
| 39 |
+
print("✅ Logging setup completed")
|
| 40 |
+
|
| 41 |
+
class PerformanceLogger:
|
| 42 |
+
"""Logger for performance monitoring"""
|
| 43 |
+
|
| 44 |
+
def __init__(self):
|
| 45 |
+
self.metrics = {
|
| 46 |
+
"query_times": [],
|
| 47 |
+
"routing_times": [],
|
| 48 |
+
"retrieval_times": [],
|
| 49 |
+
"generation_times": []
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
def log_query_time(self, session_id: str, duration: float):
|
| 53 |
+
"""Log query processing time"""
|
| 54 |
+
self.metrics["query_times"].append({
|
| 55 |
+
"session_id": session_id,
|
| 56 |
+
"duration": duration,
|
| 57 |
+
"timestamp": datetime.now()
|
| 58 |
+
})
|
| 59 |
+
logging.info(f"Query processed in {duration:.2f}s for session {session_id}")
|
| 60 |
+
|
| 61 |
+
def log_routing_decision(self, session_id: str, decision: str, confidence: str, method: str):
|
| 62 |
+
"""Log routing decisions"""
|
| 63 |
+
logging.debug(f"Routing: session={session_id}, decision={decision}, confidence={confidence}, method={method}")
|
| 64 |
+
|
| 65 |
+
def get_performance_report(self) -> Dict[str, Any]:
|
| 66 |
+
"""Generate performance report"""
|
| 67 |
+
query_times = [m["duration"] for m in self.metrics["query_times"]]
|
| 68 |
+
|
| 69 |
+
return {
|
| 70 |
+
"total_queries": len(query_times),
|
| 71 |
+
"average_query_time": sum(query_times) / len(query_times) if query_times else 0,
|
| 72 |
+
"max_query_time": max(query_times) if query_times else 0,
|
| 73 |
+
"min_query_time": min(query_times) if query_times else 0
|
| 74 |
+
}
|