Spaces:
Sleeping
Sleeping
SAAHMATHWORKS
commited on
Commit
·
f37bf1d
1
Parent(s):
8badae9
ready for hugging face space
Browse files- api/config.py +32 -0
- api/main.py +24 -689
- api/models/__init__.py +4 -0
- api/models/schemas.py +22 -0
- api/routes/__init__.py +7 -0
- api/routes/chat.py +48 -0
- api/routes/health.py +16 -0
- api/routes/home.py +521 -0
- api/routes/sessions.py +56 -0
- api/services/__init__.py +6 -0
- api/services/chat_service.py +178 -0
- api/services/stream_service.py +66 -0
- api/utils/__init__.py +4 -0
- api/utils/startup.py +37 -0
- config/settings.py +1 -1
- core/chat_manager.py +272 -65
- core/graph_builder.py +18 -4
- core/human_approval_node.py +44 -80
- core/system_initializer.py +62 -30
- main.py +414 -564
- requirements.txt +1 -0
api/config.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# api/config.py
|
| 2 |
+
import os
|
| 3 |
+
import logging
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
# Setup logging
|
| 7 |
+
logging.basicConfig(
|
| 8 |
+
level=logging.INFO,
|
| 9 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 10 |
+
)
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Settings:
|
| 15 |
+
"""Application settings"""
|
| 16 |
+
APP_TITLE: str = "Legal Assistant API"
|
| 17 |
+
APP_VERSION: str = "2.0.0"
|
| 18 |
+
APP_DESCRIPTION: str = "Multi-country legal RAG with streaming & human-in-the-loop"
|
| 19 |
+
|
| 20 |
+
# CORS
|
| 21 |
+
CORS_ORIGINS: list = ["*"]
|
| 22 |
+
|
| 23 |
+
# API Settings
|
| 24 |
+
STREAM_DELAY: float = 0.02 # Delay between tokens in streaming
|
| 25 |
+
|
| 26 |
+
# System
|
| 27 |
+
chat_manager: Optional[object] = None
|
| 28 |
+
graph: Optional[object] = None
|
| 29 |
+
system_initialized: bool = False
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
settings = Settings()
|
api/main.py
CHANGED
|
@@ -1,704 +1,39 @@
|
|
| 1 |
-
# api/main.py
|
| 2 |
-
import
|
| 3 |
-
from pathlib import Path
|
| 4 |
-
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 5 |
-
|
| 6 |
-
from typing import Optional, Any, Dict, AsyncGenerator
|
| 7 |
-
from contextlib import asynccontextmanager
|
| 8 |
-
from fastapi import FastAPI, Query, HTTPException, Body
|
| 9 |
-
from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse
|
| 10 |
from fastapi.middleware.cors import CORSMiddleware
|
| 11 |
-
from
|
| 12 |
-
import
|
| 13 |
-
from
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
# Import your existing system
|
| 20 |
-
from core.system_initializer import setup_system
|
| 21 |
-
|
| 22 |
-
# Setup logging
|
| 23 |
-
logging.basicConfig(level=logging.INFO)
|
| 24 |
-
logger = logging.getLogger(__name__)
|
| 25 |
-
|
| 26 |
-
# Global variables
|
| 27 |
-
chat_manager = None
|
| 28 |
-
graph = None
|
| 29 |
-
system_initialized = False
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
# ============================================================================
|
| 33 |
-
# Pydantic Models
|
| 34 |
-
# ============================================================================
|
| 35 |
-
class ChatRequest(BaseModel):
|
| 36 |
-
message: str
|
| 37 |
-
session_id: Optional[str] = None
|
| 38 |
-
stream: bool = False # Option to enable/disable streaming
|
| 39 |
-
|
| 40 |
-
class ApprovalRequest(BaseModel):
|
| 41 |
-
decision: str # "approve" or "reject"
|
| 42 |
-
reason: Optional[str] = None
|
| 43 |
-
|
| 44 |
-
class ChatResponse(BaseModel):
|
| 45 |
-
response: str
|
| 46 |
-
session_id: str
|
| 47 |
-
has_interrupt: bool = False
|
| 48 |
-
interrupt_type: Optional[str] = None
|
| 49 |
-
interrupt_data: Optional[Dict] = None
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
# ============================================================================
|
| 53 |
-
# System Initialization
|
| 54 |
-
# ============================================================================
|
| 55 |
-
async def initialize_system():
|
| 56 |
-
global chat_manager, graph, system_initialized
|
| 57 |
-
try:
|
| 58 |
-
system = await setup_system()
|
| 59 |
-
chat_manager = system["chat_manager"]
|
| 60 |
-
graph = system["graph"]
|
| 61 |
-
system_initialized = True
|
| 62 |
-
logger.info("✅ Legal assistant system initialized")
|
| 63 |
-
except Exception as e:
|
| 64 |
-
logger.error(f"❌ Failed to initialize system: {e}")
|
| 65 |
-
system_initialized = False
|
| 66 |
-
|
| 67 |
-
@asynccontextmanager
|
| 68 |
-
async def lifespan(app: FastAPI):
|
| 69 |
-
logger.info("🚀 Starting Legal Assistant API...")
|
| 70 |
-
initialization_task = asyncio.create_task(initialize_system())
|
| 71 |
-
yield
|
| 72 |
-
logger.info("🛑 Shutting down...")
|
| 73 |
-
initialization_task.cancel()
|
| 74 |
-
try:
|
| 75 |
-
await initialization_task
|
| 76 |
-
except asyncio.CancelledError:
|
| 77 |
-
pass
|
| 78 |
-
|
| 79 |
|
| 80 |
-
#
|
| 81 |
-
# FastAPI App
|
| 82 |
-
# ============================================================================
|
| 83 |
app = FastAPI(
|
| 84 |
-
title=
|
| 85 |
-
version=
|
| 86 |
-
description=
|
| 87 |
lifespan=lifespan
|
| 88 |
)
|
| 89 |
|
|
|
|
| 90 |
app.add_middleware(
|
| 91 |
CORSMiddleware,
|
| 92 |
-
allow_origins=
|
| 93 |
allow_credentials=True,
|
| 94 |
allow_methods=["*"],
|
| 95 |
allow_headers=["*"],
|
| 96 |
)
|
| 97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
-
# ============================================================================
|
| 100 |
-
# Streaming Helper
|
| 101 |
-
# ============================================================================
|
| 102 |
-
async def stream_chat_response(
|
| 103 |
-
message: str,
|
| 104 |
-
session_id: str
|
| 105 |
-
) -> AsyncGenerator[str, None]:
|
| 106 |
-
"""
|
| 107 |
-
Stream chat responses token by token.
|
| 108 |
-
If interrupt occurs, yields special interrupt event.
|
| 109 |
-
"""
|
| 110 |
-
try:
|
| 111 |
-
# Check for pending interrupt first
|
| 112 |
-
if chat_manager.has_pending_interrupt(session_id):
|
| 113 |
-
interrupt_info = chat_manager.pending_interrupts.get(session_id, {})
|
| 114 |
-
interrupt_data_obj = interrupt_info.get("interrupt_data")
|
| 115 |
-
|
| 116 |
-
if hasattr(interrupt_data_obj, 'value'):
|
| 117 |
-
interrupt_value = interrupt_data_obj.value
|
| 118 |
-
else:
|
| 119 |
-
interrupt_value = {}
|
| 120 |
-
|
| 121 |
-
# Yield interrupt event
|
| 122 |
-
yield f"data: {json.dumps({'type': 'interrupt', 'data': interrupt_value})}\n\n"
|
| 123 |
-
return
|
| 124 |
-
|
| 125 |
-
# Buffer to collect response
|
| 126 |
-
response_buffer = []
|
| 127 |
-
|
| 128 |
-
# Process message
|
| 129 |
-
response = await chat_manager.chat(
|
| 130 |
-
message=message,
|
| 131 |
-
session_id=session_id,
|
| 132 |
-
interrupt_handler=None # Async mode
|
| 133 |
-
)
|
| 134 |
-
|
| 135 |
-
# Check if interrupt occurred during processing
|
| 136 |
-
if chat_manager.has_pending_interrupt(session_id):
|
| 137 |
-
interrupt_info = chat_manager.pending_interrupts.get(session_id, {})
|
| 138 |
-
interrupt_data_obj = interrupt_info.get("interrupt_data")
|
| 139 |
-
|
| 140 |
-
if hasattr(interrupt_data_obj, 'value'):
|
| 141 |
-
interrupt_value = interrupt_data_obj.value
|
| 142 |
-
else:
|
| 143 |
-
interrupt_value = {}
|
| 144 |
-
|
| 145 |
-
# Yield interrupt event
|
| 146 |
-
yield f"data: {json.dumps({'type': 'interrupt', 'data': interrupt_value})}\n\n"
|
| 147 |
-
return
|
| 148 |
-
|
| 149 |
-
# Stream response token by token (simulate streaming)
|
| 150 |
-
words = response.split()
|
| 151 |
-
for i, word in enumerate(words):
|
| 152 |
-
token = word + (" " if i < len(words) - 1 else "")
|
| 153 |
-
yield f"data: {json.dumps({'type': 'token', 'content': token})}\n\n"
|
| 154 |
-
await asyncio.sleep(0.02) # Small delay for visual effect
|
| 155 |
-
|
| 156 |
-
# Send completion event
|
| 157 |
-
yield f"data: {json.dumps({'type': 'done', 'session_id': session_id})}\n\n"
|
| 158 |
-
|
| 159 |
-
except Exception as e:
|
| 160 |
-
logger.exception(f"Error in streaming: {e}")
|
| 161 |
-
yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n"
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
# ============================================================================
|
| 165 |
-
# Routes
|
| 166 |
-
# ============================================================================
|
| 167 |
-
@app.get("/", response_class=HTMLResponse)
|
| 168 |
-
async def read_root():
|
| 169 |
-
"""Interactive chat interface with streaming support"""
|
| 170 |
-
return """
|
| 171 |
-
<html>
|
| 172 |
-
<head>
|
| 173 |
-
<title>Legal Assistant - Interactive Chat</title>
|
| 174 |
-
<meta charset="UTF-8">
|
| 175 |
-
<style>
|
| 176 |
-
* { margin: 0; padding: 0; box-sizing: border-box; }
|
| 177 |
-
body {
|
| 178 |
-
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
| 179 |
-
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 180 |
-
min-height: 100vh;
|
| 181 |
-
padding: 20px;
|
| 182 |
-
}
|
| 183 |
-
.container {
|
| 184 |
-
max-width: 900px;
|
| 185 |
-
margin: 0 auto;
|
| 186 |
-
background: white;
|
| 187 |
-
border-radius: 16px;
|
| 188 |
-
box-shadow: 0 20px 60px rgba(0,0,0,0.3);
|
| 189 |
-
overflow: hidden;
|
| 190 |
-
}
|
| 191 |
-
.header {
|
| 192 |
-
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 193 |
-
color: white;
|
| 194 |
-
padding: 30px;
|
| 195 |
-
text-align: center;
|
| 196 |
-
}
|
| 197 |
-
.header h1 { margin-bottom: 10px; }
|
| 198 |
-
.header p { opacity: 0.9; }
|
| 199 |
-
.chat-container { padding: 20px; }
|
| 200 |
-
#messages {
|
| 201 |
-
height: 500px;
|
| 202 |
-
overflow-y: auto;
|
| 203 |
-
padding: 20px;
|
| 204 |
-
background: #f8f9fa;
|
| 205 |
-
border-radius: 12px;
|
| 206 |
-
margin-bottom: 20px;
|
| 207 |
-
}
|
| 208 |
-
.message {
|
| 209 |
-
margin: 15px 0;
|
| 210 |
-
padding: 12px 18px;
|
| 211 |
-
border-radius: 18px;
|
| 212 |
-
max-width: 75%;
|
| 213 |
-
animation: slideIn 0.3s ease;
|
| 214 |
-
word-wrap: break-word;
|
| 215 |
-
}
|
| 216 |
-
@keyframes slideIn {
|
| 217 |
-
from { opacity: 0; transform: translateY(10px); }
|
| 218 |
-
to { opacity: 1; transform: translateY(0); }
|
| 219 |
-
}
|
| 220 |
-
.user-message {
|
| 221 |
-
background: #667eea;
|
| 222 |
-
color: white;
|
| 223 |
-
margin-left: auto;
|
| 224 |
-
border-bottom-right-radius: 4px;
|
| 225 |
-
}
|
| 226 |
-
.assistant-message {
|
| 227 |
-
background: white;
|
| 228 |
-
color: #333;
|
| 229 |
-
border: 1px solid #e0e0e0;
|
| 230 |
-
border-bottom-left-radius: 4px;
|
| 231 |
-
}
|
| 232 |
-
.system-message {
|
| 233 |
-
background: #fff3cd;
|
| 234 |
-
color: #856404;
|
| 235 |
-
border: 1px solid #ffc107;
|
| 236 |
-
text-align: center;
|
| 237 |
-
margin: 10px auto;
|
| 238 |
-
max-width: 90%;
|
| 239 |
-
}
|
| 240 |
-
.interrupt-panel {
|
| 241 |
-
background: #fff3cd;
|
| 242 |
-
border: 2px solid #ffc107;
|
| 243 |
-
border-radius: 12px;
|
| 244 |
-
padding: 20px;
|
| 245 |
-
margin: 20px 0;
|
| 246 |
-
animation: pulse 2s infinite;
|
| 247 |
-
}
|
| 248 |
-
@keyframes pulse {
|
| 249 |
-
0%, 100% { box-shadow: 0 0 0 0 rgba(255, 193, 7, 0.4); }
|
| 250 |
-
50% { box-shadow: 0 0 0 10px rgba(255, 193, 7, 0); }
|
| 251 |
-
}
|
| 252 |
-
.interrupt-panel h3 {
|
| 253 |
-
color: #856404;
|
| 254 |
-
margin-bottom: 15px;
|
| 255 |
-
}
|
| 256 |
-
.interrupt-details {
|
| 257 |
-
background: white;
|
| 258 |
-
padding: 15px;
|
| 259 |
-
border-radius: 8px;
|
| 260 |
-
margin: 15px 0;
|
| 261 |
-
}
|
| 262 |
-
.interrupt-details p {
|
| 263 |
-
margin: 8px 0;
|
| 264 |
-
color: #333;
|
| 265 |
-
}
|
| 266 |
-
.approval-buttons {
|
| 267 |
-
display: flex;
|
| 268 |
-
gap: 10px;
|
| 269 |
-
margin-top: 15px;
|
| 270 |
-
}
|
| 271 |
-
.btn {
|
| 272 |
-
flex: 1;
|
| 273 |
-
padding: 12px 24px;
|
| 274 |
-
border: none;
|
| 275 |
-
border-radius: 8px;
|
| 276 |
-
font-weight: bold;
|
| 277 |
-
cursor: pointer;
|
| 278 |
-
transition: all 0.3s;
|
| 279 |
-
font-size: 14px;
|
| 280 |
-
}
|
| 281 |
-
.btn-approve {
|
| 282 |
-
background: #4caf50;
|
| 283 |
-
color: white;
|
| 284 |
-
}
|
| 285 |
-
.btn-approve:hover {
|
| 286 |
-
background: #45a049;
|
| 287 |
-
transform: translateY(-2px);
|
| 288 |
-
box-shadow: 0 4px 12px rgba(76, 175, 80, 0.4);
|
| 289 |
-
}
|
| 290 |
-
.btn-reject {
|
| 291 |
-
background: #f44336;
|
| 292 |
-
color: white;
|
| 293 |
-
}
|
| 294 |
-
.btn-reject:hover {
|
| 295 |
-
background: #da190b;
|
| 296 |
-
transform: translateY(-2px);
|
| 297 |
-
box-shadow: 0 4px 12px rgba(244, 67, 54, 0.4);
|
| 298 |
-
}
|
| 299 |
-
.input-area {
|
| 300 |
-
display: flex;
|
| 301 |
-
gap: 10px;
|
| 302 |
-
padding: 10px;
|
| 303 |
-
background: #f8f9fa;
|
| 304 |
-
border-radius: 12px;
|
| 305 |
-
}
|
| 306 |
-
#message-input {
|
| 307 |
-
flex: 1;
|
| 308 |
-
padding: 12px 16px;
|
| 309 |
-
border: 2px solid #e0e0e0;
|
| 310 |
-
border-radius: 8px;
|
| 311 |
-
font-size: 14px;
|
| 312 |
-
transition: border 0.3s;
|
| 313 |
-
}
|
| 314 |
-
#message-input:focus {
|
| 315 |
-
outline: none;
|
| 316 |
-
border-color: #667eea;
|
| 317 |
-
}
|
| 318 |
-
#send-btn {
|
| 319 |
-
padding: 12px 32px;
|
| 320 |
-
background: #667eea;
|
| 321 |
-
color: white;
|
| 322 |
-
border: none;
|
| 323 |
-
border-radius: 8px;
|
| 324 |
-
font-weight: bold;
|
| 325 |
-
cursor: pointer;
|
| 326 |
-
transition: all 0.3s;
|
| 327 |
-
}
|
| 328 |
-
#send-btn:hover:not(:disabled) {
|
| 329 |
-
background: #5568d3;
|
| 330 |
-
transform: translateY(-2px);
|
| 331 |
-
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4);
|
| 332 |
-
}
|
| 333 |
-
#send-btn:disabled {
|
| 334 |
-
background: #ccc;
|
| 335 |
-
cursor: not-allowed;
|
| 336 |
-
transform: none;
|
| 337 |
-
}
|
| 338 |
-
.status-bar {
|
| 339 |
-
padding: 10px;
|
| 340 |
-
background: #f8f9fa;
|
| 341 |
-
border-radius: 8px;
|
| 342 |
-
margin-top: 10px;
|
| 343 |
-
display: flex;
|
| 344 |
-
justify-content: space-between;
|
| 345 |
-
font-size: 12px;
|
| 346 |
-
color: #666;
|
| 347 |
-
}
|
| 348 |
-
.typing-indicator {
|
| 349 |
-
display: none;
|
| 350 |
-
padding: 10px;
|
| 351 |
-
color: #666;
|
| 352 |
-
font-style: italic;
|
| 353 |
-
}
|
| 354 |
-
.typing-indicator.active {
|
| 355 |
-
display: block;
|
| 356 |
-
}
|
| 357 |
-
</style>
|
| 358 |
-
</head>
|
| 359 |
-
<body>
|
| 360 |
-
<div class="container">
|
| 361 |
-
<div class="header">
|
| 362 |
-
<h1>🧑⚖️ Legal Assistant</h1>
|
| 363 |
-
<p>Multi-country legal support for Benin & Madagascar</p>
|
| 364 |
-
</div>
|
| 365 |
-
|
| 366 |
-
<div class="chat-container">
|
| 367 |
-
<div id="messages"></div>
|
| 368 |
-
<div id="interrupt-zone"></div>
|
| 369 |
-
<div class="typing-indicator" id="typing">Assistant is typing...</div>
|
| 370 |
-
|
| 371 |
-
<div class="input-area">
|
| 372 |
-
<input
|
| 373 |
-
type="text"
|
| 374 |
-
id="message-input"
|
| 375 |
-
placeholder="Type your legal question..."
|
| 376 |
-
autocomplete="off"
|
| 377 |
-
/>
|
| 378 |
-
<button id="send-btn" onclick="sendMessage()">
|
| 379 |
-
Send
|
| 380 |
-
</button>
|
| 381 |
-
</div>
|
| 382 |
-
|
| 383 |
-
<div class="status-bar">
|
| 384 |
-
<span>Session: <strong id="session-id">Not started</strong></span>
|
| 385 |
-
<span id="status">System ready</span>
|
| 386 |
-
</div>
|
| 387 |
-
</div>
|
| 388 |
-
</div>
|
| 389 |
-
|
| 390 |
-
<script>
|
| 391 |
-
let sessionId = null;
|
| 392 |
-
let currentAssistantMessage = null;
|
| 393 |
-
|
| 394 |
-
function addMessage(content, type) {
|
| 395 |
-
const messagesDiv = document.getElementById('messages');
|
| 396 |
-
const messageDiv = document.createElement('div');
|
| 397 |
-
messageDiv.className = `message ${type}-message`;
|
| 398 |
-
messageDiv.textContent = content;
|
| 399 |
-
messagesDiv.appendChild(messageDiv);
|
| 400 |
-
messagesDiv.scrollTop = messagesDiv.scrollHeight;
|
| 401 |
-
return messageDiv;
|
| 402 |
-
}
|
| 403 |
-
|
| 404 |
-
function showInterrupt(data) {
|
| 405 |
-
const interruptZone = document.getElementById('interrupt-zone');
|
| 406 |
-
interruptZone.innerHTML = `
|
| 407 |
-
<div class="interrupt-panel">
|
| 408 |
-
<h3>🔒 Human Approval Required</h3>
|
| 409 |
-
<div class="interrupt-details">
|
| 410 |
-
<p><strong>📧 Email:</strong> ${data.user_email || 'N/A'}</p>
|
| 411 |
-
<p><strong>🌍 Country:</strong> ${data.country || 'N/A'}</p>
|
| 412 |
-
<p><strong>📝 Description:</strong> ${data.description || 'N/A'}</p>
|
| 413 |
-
</div>
|
| 414 |
-
<div class="approval-buttons">
|
| 415 |
-
<button class="btn btn-approve" onclick="handleApproval('approve')">
|
| 416 |
-
✅ Approve Request
|
| 417 |
-
</button>
|
| 418 |
-
<button class="btn btn-reject" onclick="handleApproval('reject')">
|
| 419 |
-
❌ Reject Request
|
| 420 |
-
</button>
|
| 421 |
-
</div>
|
| 422 |
-
</div>
|
| 423 |
-
`;
|
| 424 |
-
}
|
| 425 |
-
|
| 426 |
-
function hideInterrupt() {
|
| 427 |
-
document.getElementById('interrupt-zone').innerHTML = '';
|
| 428 |
-
}
|
| 429 |
-
|
| 430 |
-
async function handleApproval(decision) {
|
| 431 |
-
const reason = decision === 'approve'
|
| 432 |
-
? 'Approved via web interface'
|
| 433 |
-
: 'Rejected via web interface';
|
| 434 |
-
|
| 435 |
-
try {
|
| 436 |
-
addMessage(`${decision === 'approve' ? '✅' : '❌'} Decision: ${decision}`, 'system');
|
| 437 |
-
hideInterrupt();
|
| 438 |
-
|
| 439 |
-
const response = await fetch(`/sessions/${sessionId}/approve`, {
|
| 440 |
-
method: 'POST',
|
| 441 |
-
headers: { 'Content-Type': 'application/json' },
|
| 442 |
-
body: JSON.stringify({ decision, reason })
|
| 443 |
-
});
|
| 444 |
-
|
| 445 |
-
const data = await response.json();
|
| 446 |
-
addMessage(data.response, 'assistant');
|
| 447 |
-
|
| 448 |
-
} catch (error) {
|
| 449 |
-
addMessage('Error: ' + error.message, 'system');
|
| 450 |
-
}
|
| 451 |
-
}
|
| 452 |
-
|
| 453 |
-
async function sendMessage() {
|
| 454 |
-
const input = document.getElementById('message-input');
|
| 455 |
-
const message = input.value.trim();
|
| 456 |
-
|
| 457 |
-
if (!message) return;
|
| 458 |
-
|
| 459 |
-
// Initialize session
|
| 460 |
-
if (!sessionId) {
|
| 461 |
-
sessionId = 'web_' + Date.now();
|
| 462 |
-
document.getElementById('session-id').textContent = sessionId;
|
| 463 |
-
}
|
| 464 |
-
|
| 465 |
-
// Add user message
|
| 466 |
-
addMessage(message, 'user');
|
| 467 |
-
input.value = '';
|
| 468 |
-
|
| 469 |
-
// Disable input
|
| 470 |
-
const sendBtn = document.getElementById('send-btn');
|
| 471 |
-
sendBtn.disabled = true;
|
| 472 |
-
input.disabled = true;
|
| 473 |
-
|
| 474 |
-
// Show typing indicator
|
| 475 |
-
document.getElementById('typing').classList.add('active');
|
| 476 |
-
|
| 477 |
-
try {
|
| 478 |
-
// Use streaming endpoint
|
| 479 |
-
const response = await fetch('/chat/stream', {
|
| 480 |
-
method: 'POST',
|
| 481 |
-
headers: { 'Content-Type': 'application/json' },
|
| 482 |
-
body: JSON.stringify({
|
| 483 |
-
message: message,
|
| 484 |
-
session_id: sessionId,
|
| 485 |
-
stream: true
|
| 486 |
-
})
|
| 487 |
-
});
|
| 488 |
-
|
| 489 |
-
const reader = response.body.getReader();
|
| 490 |
-
const decoder = new TextDecoder();
|
| 491 |
-
currentAssistantMessage = addMessage('', 'assistant');
|
| 492 |
-
|
| 493 |
-
while (true) {
|
| 494 |
-
const { value, done } = await reader.read();
|
| 495 |
-
if (done) break;
|
| 496 |
-
|
| 497 |
-
const chunk = decoder.decode(value);
|
| 498 |
-
const lines = chunk.split('\\n\\n');
|
| 499 |
-
|
| 500 |
-
for (const line of lines) {
|
| 501 |
-
if (line.startsWith('data: ')) {
|
| 502 |
-
try {
|
| 503 |
-
const data = JSON.parse(line.slice(6));
|
| 504 |
-
|
| 505 |
-
if (data.type === 'token') {
|
| 506 |
-
currentAssistantMessage.textContent += data.content;
|
| 507 |
-
} else if (data.type === 'interrupt') {
|
| 508 |
-
showInterrupt(data.data);
|
| 509 |
-
} else if (data.type === 'done') {
|
| 510 |
-
document.getElementById('status').textContent = 'Ready';
|
| 511 |
-
} else if (data.type === 'error') {
|
| 512 |
-
addMessage('Error: ' + data.message, 'system');
|
| 513 |
-
}
|
| 514 |
-
} catch (e) {
|
| 515 |
-
console.error('Parse error:', e);
|
| 516 |
-
}
|
| 517 |
-
}
|
| 518 |
-
}
|
| 519 |
-
}
|
| 520 |
-
|
| 521 |
-
} catch (error) {
|
| 522 |
-
addMessage('Error: ' + error.message, 'system');
|
| 523 |
-
} finally {
|
| 524 |
-
sendBtn.disabled = false;
|
| 525 |
-
input.disabled = false;
|
| 526 |
-
input.focus();
|
| 527 |
-
document.getElementById('typing').classList.remove('active');
|
| 528 |
-
}
|
| 529 |
-
}
|
| 530 |
-
|
| 531 |
-
// Enter key to send
|
| 532 |
-
document.getElementById('message-input').addEventListener('keypress', (e) => {
|
| 533 |
-
if (e.key === 'Enter' && !e.shiftKey) {
|
| 534 |
-
e.preventDefault();
|
| 535 |
-
sendMessage();
|
| 536 |
-
}
|
| 537 |
-
});
|
| 538 |
-
|
| 539 |
-
// Focus input on load
|
| 540 |
-
document.getElementById('message-input').focus();
|
| 541 |
-
</script>
|
| 542 |
-
</body>
|
| 543 |
-
</html>
|
| 544 |
-
"""
|
| 545 |
-
|
| 546 |
-
@app.get("/health")
|
| 547 |
-
async def health_check():
|
| 548 |
-
"""Health check endpoint"""
|
| 549 |
-
return {
|
| 550 |
-
"status": "healthy" if system_initialized else "starting",
|
| 551 |
-
"system_initialized": system_initialized,
|
| 552 |
-
"timestamp": datetime.now().isoformat()
|
| 553 |
-
}
|
| 554 |
-
|
| 555 |
-
@app.post("/chat")
|
| 556 |
-
async def chat_simple(request: ChatRequest):
|
| 557 |
-
"""
|
| 558 |
-
Simple non-streaming chat endpoint.
|
| 559 |
-
Best for handling interrupts - returns complete response.
|
| 560 |
-
"""
|
| 561 |
-
if not system_initialized or not chat_manager:
|
| 562 |
-
raise HTTPException(status_code=503, detail="System initializing...")
|
| 563 |
-
|
| 564 |
-
try:
|
| 565 |
-
session_id = request.session_id or f"api_{uuid4()}"
|
| 566 |
-
|
| 567 |
-
# Check for pending interrupt
|
| 568 |
-
if chat_manager.has_pending_interrupt(session_id):
|
| 569 |
-
interrupt_info = chat_manager.pending_interrupts.get(session_id, {})
|
| 570 |
-
interrupt_data_obj = interrupt_info.get("interrupt_data")
|
| 571 |
-
|
| 572 |
-
if hasattr(interrupt_data_obj, 'value'):
|
| 573 |
-
interrupt_value = interrupt_data_obj.value
|
| 574 |
-
else:
|
| 575 |
-
interrupt_value = {}
|
| 576 |
-
|
| 577 |
-
return ChatResponse(
|
| 578 |
-
response="⏸️ Pending approval request. Please approve/reject first.",
|
| 579 |
-
session_id=session_id,
|
| 580 |
-
has_interrupt=True,
|
| 581 |
-
interrupt_type="human_approval",
|
| 582 |
-
interrupt_data=interrupt_value
|
| 583 |
-
)
|
| 584 |
-
|
| 585 |
-
# Process message
|
| 586 |
-
response_text = await chat_manager.chat(
|
| 587 |
-
message=request.message,
|
| 588 |
-
session_id=session_id
|
| 589 |
-
)
|
| 590 |
-
|
| 591 |
-
# Check if interrupt occurred
|
| 592 |
-
has_interrupt = chat_manager.has_pending_interrupt(session_id)
|
| 593 |
-
interrupt_value = None
|
| 594 |
-
|
| 595 |
-
if has_interrupt:
|
| 596 |
-
interrupt_info = chat_manager.pending_interrupts.get(session_id, {})
|
| 597 |
-
interrupt_data_obj = interrupt_info.get("interrupt_data")
|
| 598 |
-
|
| 599 |
-
if hasattr(interrupt_data_obj, 'value'):
|
| 600 |
-
interrupt_value = interrupt_data_obj.value
|
| 601 |
-
elif isinstance(interrupt_data_obj, dict):
|
| 602 |
-
interrupt_value = interrupt_data_obj.get("value", {})
|
| 603 |
-
|
| 604 |
-
return ChatResponse(
|
| 605 |
-
response=response_text,
|
| 606 |
-
session_id=session_id,
|
| 607 |
-
has_interrupt=has_interrupt,
|
| 608 |
-
interrupt_type="human_approval" if has_interrupt else None,
|
| 609 |
-
interrupt_data=interrupt_value
|
| 610 |
-
)
|
| 611 |
-
|
| 612 |
-
except Exception as e:
|
| 613 |
-
logger.exception(f"Error in chat: {e}")
|
| 614 |
-
raise HTTPException(status_code=500, detail=str(e))
|
| 615 |
-
|
| 616 |
-
@app.post("/chat/stream")
|
| 617 |
-
async def chat_stream(request: ChatRequest):
|
| 618 |
-
"""
|
| 619 |
-
Streaming chat endpoint using Server-Sent Events (SSE).
|
| 620 |
-
Streams tokens in real-time and handles interrupts.
|
| 621 |
-
"""
|
| 622 |
-
if not system_initialized or not chat_manager:
|
| 623 |
-
raise HTTPException(status_code=503, detail="System initializing...")
|
| 624 |
-
|
| 625 |
-
session_id = request.session_id or f"api_{uuid4()}"
|
| 626 |
-
|
| 627 |
-
return StreamingResponse(
|
| 628 |
-
stream_chat_response(request.message, session_id),
|
| 629 |
-
media_type="text/event-stream",
|
| 630 |
-
headers={
|
| 631 |
-
"Cache-Control": "no-cache",
|
| 632 |
-
"Connection": "keep-alive",
|
| 633 |
-
"X-Accel-Buffering": "no"
|
| 634 |
-
}
|
| 635 |
-
)
|
| 636 |
-
|
| 637 |
-
@app.post("/sessions/{session_id}/approve")
|
| 638 |
-
async def approve_assistance(session_id: str, request: ApprovalRequest):
|
| 639 |
-
"""Approve or reject assistance request"""
|
| 640 |
-
if not chat_manager:
|
| 641 |
-
raise HTTPException(status_code=503, detail="System not initialized")
|
| 642 |
-
|
| 643 |
-
if not chat_manager.has_pending_interrupt(session_id):
|
| 644 |
-
raise HTTPException(status_code=404, detail="No pending interrupt")
|
| 645 |
-
|
| 646 |
-
try:
|
| 647 |
-
decision_text = f"{request.decision} {request.reason or ''}"
|
| 648 |
-
response_text = await chat_manager.chat(
|
| 649 |
-
message=decision_text,
|
| 650 |
-
session_id=session_id
|
| 651 |
-
)
|
| 652 |
-
|
| 653 |
-
return {
|
| 654 |
-
"status": "success",
|
| 655 |
-
"decision": request.decision,
|
| 656 |
-
"response": response_text,
|
| 657 |
-
"session_id": session_id
|
| 658 |
-
}
|
| 659 |
-
except Exception as e:
|
| 660 |
-
logger.exception(f"Error handling approval: {e}")
|
| 661 |
-
raise HTTPException(status_code=500, detail=str(e))
|
| 662 |
-
|
| 663 |
-
@app.get("/sessions/{session_id}/status")
|
| 664 |
-
async def get_session_status(session_id: str):
|
| 665 |
-
"""Get session status including interrupt state"""
|
| 666 |
-
if not chat_manager:
|
| 667 |
-
raise HTTPException(status_code=503, detail="System not initialized")
|
| 668 |
-
|
| 669 |
-
has_interrupt = chat_manager.has_pending_interrupt(session_id)
|
| 670 |
-
|
| 671 |
-
interrupt_data = None
|
| 672 |
-
if has_interrupt:
|
| 673 |
-
interrupt_info = chat_manager.pending_interrupts.get(session_id, {})
|
| 674 |
-
interrupt_data_obj = interrupt_info.get("interrupt_data")
|
| 675 |
-
|
| 676 |
-
if hasattr(interrupt_data_obj, 'value'):
|
| 677 |
-
interrupt_data = interrupt_data_obj.value
|
| 678 |
-
|
| 679 |
-
return {
|
| 680 |
-
"session_id": session_id,
|
| 681 |
-
"has_pending_interrupt": has_interrupt,
|
| 682 |
-
"interrupt_data": interrupt_data
|
| 683 |
-
}
|
| 684 |
|
| 685 |
-
|
| 686 |
-
|
| 687 |
-
""
|
| 688 |
-
if not chat_manager:
|
| 689 |
-
raise HTTPException(status_code=503, detail="System not initialized")
|
| 690 |
-
|
| 691 |
-
try:
|
| 692 |
-
history = await chat_manager.get_conversation_history(session_id)
|
| 693 |
-
return {
|
| 694 |
-
"session_id": session_id,
|
| 695 |
-
"history": [
|
| 696 |
-
{
|
| 697 |
-
"role": msg.type if hasattr(msg, 'type') else "unknown",
|
| 698 |
-
"content": msg.content if hasattr(msg, 'content') else str(msg)
|
| 699 |
-
}
|
| 700 |
-
for msg in history
|
| 701 |
-
]
|
| 702 |
-
}
|
| 703 |
-
except Exception as e:
|
| 704 |
-
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
| 1 |
+
# api/main.py - MAIN FILE (Clean & Simple)
|
| 2 |
+
from fastapi import FastAPI
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from fastapi.middleware.cors import CORSMiddleware
|
| 4 |
+
from api.config import settings
|
| 5 |
+
from api.utils import lifespan
|
| 6 |
+
from api.routes import (
|
| 7 |
+
home_router,
|
| 8 |
+
chat_router,
|
| 9 |
+
sessions_router,
|
| 10 |
+
health_router
|
| 11 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
+
# Create FastAPI app
|
|
|
|
|
|
|
| 14 |
app = FastAPI(
|
| 15 |
+
title=settings.APP_TITLE,
|
| 16 |
+
version=settings.APP_VERSION,
|
| 17 |
+
description=settings.APP_DESCRIPTION,
|
| 18 |
lifespan=lifespan
|
| 19 |
)
|
| 20 |
|
| 21 |
+
# Add CORS middleware
|
| 22 |
app.add_middleware(
|
| 23 |
CORSMiddleware,
|
| 24 |
+
allow_origins=settings.CORS_ORIGINS,
|
| 25 |
allow_credentials=True,
|
| 26 |
allow_methods=["*"],
|
| 27 |
allow_headers=["*"],
|
| 28 |
)
|
| 29 |
|
| 30 |
+
# Include routers
|
| 31 |
+
app.include_router(home_router)
|
| 32 |
+
app.include_router(chat_router)
|
| 33 |
+
app.include_router(sessions_router)
|
| 34 |
+
app.include_router(health_router)
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
+
if __name__ == "__main__":
|
| 38 |
+
import uvicorn
|
| 39 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
api/models/__init__.py
CHANGED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# api/models/__init__.py
|
| 2 |
+
from .schemas import ChatRequest, ApprovalRequest, ChatResponse
|
| 3 |
+
|
| 4 |
+
__all__ = ["ChatRequest", "ApprovalRequest", "ChatResponse"]
|
api/models/schemas.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# api/models/schemas.py
|
| 2 |
+
from pydantic import BaseModel
|
| 3 |
+
from typing import Optional, Dict
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ChatRequest(BaseModel):
|
| 7 |
+
message: str
|
| 8 |
+
session_id: Optional[str] = None
|
| 9 |
+
stream: bool = False
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ApprovalRequest(BaseModel):
|
| 13 |
+
decision: str # "approve" or "reject"
|
| 14 |
+
reason: Optional[str] = None
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ChatResponse(BaseModel):
|
| 18 |
+
response: str
|
| 19 |
+
session_id: str
|
| 20 |
+
has_interrupt: bool = False
|
| 21 |
+
interrupt_type: Optional[str] = None
|
| 22 |
+
interrupt_data: Optional[Dict] = None
|
api/routes/__init__.py
CHANGED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# api/routes/__init__.py
|
| 2 |
+
from .home import router as home_router
|
| 3 |
+
from .chat import router as chat_router
|
| 4 |
+
from .sessions import router as sessions_router
|
| 5 |
+
from .health import router as health_router
|
| 6 |
+
|
| 7 |
+
__all__ = ["home_router", "chat_router", "sessions_router", "health_router"]
|
api/routes/chat.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, HTTPException
|
| 2 |
+
from fastapi.responses import StreamingResponse
|
| 3 |
+
from api.models import ChatRequest, ChatResponse
|
| 4 |
+
from api.services import ChatService, StreamService
|
| 5 |
+
from api.config import settings
|
| 6 |
+
from uuid import uuid4
|
| 7 |
+
|
| 8 |
+
router = APIRouter(prefix="/chat", tags=["Chat"])
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@router.post("", response_model=ChatResponse)
|
| 12 |
+
async def chat_simple(request: ChatRequest):
|
| 13 |
+
"""
|
| 14 |
+
Simple non-streaming chat endpoint.
|
| 15 |
+
Best for handling interrupts - returns complete response.
|
| 16 |
+
"""
|
| 17 |
+
if not settings.system_initialized or not settings.chat_manager:
|
| 18 |
+
raise HTTPException(status_code=503, detail="System initializing...")
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
return await ChatService.process_message(
|
| 22 |
+
message=request.message,
|
| 23 |
+
session_id=request.session_id
|
| 24 |
+
)
|
| 25 |
+
except Exception as e:
|
| 26 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@router.post("/stream")
|
| 30 |
+
async def chat_stream(request: ChatRequest):
|
| 31 |
+
"""
|
| 32 |
+
Streaming chat endpoint using Server-Sent Events (SSE).
|
| 33 |
+
Streams tokens in real-time and handles interrupts.
|
| 34 |
+
"""
|
| 35 |
+
if not settings.system_initialized or not settings.chat_manager:
|
| 36 |
+
raise HTTPException(status_code=503, detail="System initializing...")
|
| 37 |
+
|
| 38 |
+
session_id = request.session_id or f"api_{uuid4()}"
|
| 39 |
+
|
| 40 |
+
return StreamingResponse(
|
| 41 |
+
StreamService.stream_chat_response(request.message, session_id),
|
| 42 |
+
media_type="text/event-stream",
|
| 43 |
+
headers={
|
| 44 |
+
"Cache-Control": "no-cache",
|
| 45 |
+
"Connection": "keep-alive",
|
| 46 |
+
"X-Accel-Buffering": "no"
|
| 47 |
+
}
|
| 48 |
+
)
|
api/routes/health.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# api/routes/health.py
|
| 2 |
+
from fastapi import APIRouter
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
from api.config import settings
|
| 5 |
+
|
| 6 |
+
router = APIRouter(tags=["Health"])
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@router.get("/health")
|
| 10 |
+
async def health_check():
|
| 11 |
+
"""Health check endpoint"""
|
| 12 |
+
return {
|
| 13 |
+
"status": "healthy" if settings.system_initialized else "starting",
|
| 14 |
+
"system_initialized": settings.system_initialized,
|
| 15 |
+
"timestamp": datetime.now().isoformat()
|
| 16 |
+
}
|
api/routes/home.py
ADDED
|
@@ -0,0 +1,521 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter
|
| 2 |
+
from fastapi.responses import HTMLResponse
|
| 3 |
+
|
| 4 |
+
router = APIRouter(tags=["Home"])
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@router.get("/", response_class=HTMLResponse)
|
| 8 |
+
async def root():
|
| 9 |
+
"""Interactive chat interface - Main UI"""
|
| 10 |
+
return """
|
| 11 |
+
<html>
|
| 12 |
+
<head>
|
| 13 |
+
<title>Legal Assistant - Interactive Chat</title>
|
| 14 |
+
<meta charset="UTF-8">
|
| 15 |
+
<style>
|
| 16 |
+
* { margin: 0; padding: 0; box-sizing: border-box; }
|
| 17 |
+
body {
|
| 18 |
+
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
| 19 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 20 |
+
min-height: 100vh;
|
| 21 |
+
padding: 20px;
|
| 22 |
+
}
|
| 23 |
+
.container {
|
| 24 |
+
max-width: 900px;
|
| 25 |
+
margin: 0 auto;
|
| 26 |
+
background: white;
|
| 27 |
+
border-radius: 16px;
|
| 28 |
+
box-shadow: 0 20px 60px rgba(0,0,0,0.3);
|
| 29 |
+
overflow: hidden;
|
| 30 |
+
}
|
| 31 |
+
.header {
|
| 32 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 33 |
+
color: white;
|
| 34 |
+
padding: 30px;
|
| 35 |
+
text-align: center;
|
| 36 |
+
position: relative;
|
| 37 |
+
}
|
| 38 |
+
.header h1 { margin-bottom: 10px; }
|
| 39 |
+
.header p { opacity: 0.9; }
|
| 40 |
+
.api-link {
|
| 41 |
+
position: absolute;
|
| 42 |
+
top: 20px;
|
| 43 |
+
right: 20px;
|
| 44 |
+
padding: 8px 16px;
|
| 45 |
+
background: rgba(255,255,255,0.2);
|
| 46 |
+
color: white;
|
| 47 |
+
text-decoration: none;
|
| 48 |
+
border-radius: 5px;
|
| 49 |
+
font-size: 0.9em;
|
| 50 |
+
transition: all 0.3s;
|
| 51 |
+
}
|
| 52 |
+
.api-link:hover {
|
| 53 |
+
background: rgba(255,255,255,0.3);
|
| 54 |
+
transform: translateY(-2px);
|
| 55 |
+
}
|
| 56 |
+
.chat-container { padding: 20px; }
|
| 57 |
+
#messages {
|
| 58 |
+
height: 500px;
|
| 59 |
+
overflow-y: auto;
|
| 60 |
+
padding: 20px;
|
| 61 |
+
background: #f8f9fa;
|
| 62 |
+
border-radius: 12px;
|
| 63 |
+
margin-bottom: 20px;
|
| 64 |
+
}
|
| 65 |
+
.message {
|
| 66 |
+
margin: 15px 0;
|
| 67 |
+
padding: 12px 18px;
|
| 68 |
+
border-radius: 18px;
|
| 69 |
+
max-width: 75%;
|
| 70 |
+
animation: slideIn 0.3s ease;
|
| 71 |
+
word-wrap: break-word;
|
| 72 |
+
}
|
| 73 |
+
@keyframes slideIn {
|
| 74 |
+
from { opacity: 0; transform: translateY(10px); }
|
| 75 |
+
to { opacity: 1; transform: translateY(0); }
|
| 76 |
+
}
|
| 77 |
+
.user-message {
|
| 78 |
+
background: #667eea;
|
| 79 |
+
color: white;
|
| 80 |
+
margin-left: auto;
|
| 81 |
+
border-bottom-right-radius: 4px;
|
| 82 |
+
}
|
| 83 |
+
.assistant-message {
|
| 84 |
+
background: white;
|
| 85 |
+
color: #333;
|
| 86 |
+
border: 1px solid #e0e0e0;
|
| 87 |
+
border-bottom-left-radius: 4px;
|
| 88 |
+
}
|
| 89 |
+
.system-message {
|
| 90 |
+
background: #fff3cd;
|
| 91 |
+
color: #856404;
|
| 92 |
+
border: 1px solid #ffc107;
|
| 93 |
+
text-align: center;
|
| 94 |
+
margin: 10px auto;
|
| 95 |
+
max-width: 90%;
|
| 96 |
+
}
|
| 97 |
+
.interrupt-panel {
|
| 98 |
+
background: #fff3cd;
|
| 99 |
+
border: 2px solid #ffc107;
|
| 100 |
+
border-radius: 12px;
|
| 101 |
+
padding: 20px;
|
| 102 |
+
margin: 20px 0;
|
| 103 |
+
animation: pulse 2s infinite;
|
| 104 |
+
}
|
| 105 |
+
@keyframes pulse {
|
| 106 |
+
0%, 100% { box-shadow: 0 0 0 0 rgba(255, 193, 7, 0.4); }
|
| 107 |
+
50% { box-shadow: 0 0 0 10px rgba(255, 193, 7, 0); }
|
| 108 |
+
}
|
| 109 |
+
.interrupt-panel h3 {
|
| 110 |
+
color: #856404;
|
| 111 |
+
margin-bottom: 15px;
|
| 112 |
+
}
|
| 113 |
+
.interrupt-details {
|
| 114 |
+
background: white;
|
| 115 |
+
padding: 15px;
|
| 116 |
+
border-radius: 8px;
|
| 117 |
+
margin: 15px 0;
|
| 118 |
+
}
|
| 119 |
+
.interrupt-details p {
|
| 120 |
+
margin: 8px 0;
|
| 121 |
+
color: #333;
|
| 122 |
+
}
|
| 123 |
+
.approval-buttons {
|
| 124 |
+
display: flex;
|
| 125 |
+
gap: 10px;
|
| 126 |
+
margin-top: 15px;
|
| 127 |
+
}
|
| 128 |
+
.btn {
|
| 129 |
+
flex: 1;
|
| 130 |
+
padding: 12px 24px;
|
| 131 |
+
border: none;
|
| 132 |
+
border-radius: 8px;
|
| 133 |
+
font-weight: bold;
|
| 134 |
+
cursor: pointer;
|
| 135 |
+
transition: all 0.3s;
|
| 136 |
+
font-size: 14px;
|
| 137 |
+
}
|
| 138 |
+
.btn-approve {
|
| 139 |
+
background: #4caf50;
|
| 140 |
+
color: white;
|
| 141 |
+
}
|
| 142 |
+
.btn-approve:hover {
|
| 143 |
+
background: #45a049;
|
| 144 |
+
transform: translateY(-2px);
|
| 145 |
+
box-shadow: 0 4px 12px rgba(76, 175, 80, 0.4);
|
| 146 |
+
}
|
| 147 |
+
.btn-reject {
|
| 148 |
+
background: #f44336;
|
| 149 |
+
color: white;
|
| 150 |
+
}
|
| 151 |
+
.btn-reject:hover {
|
| 152 |
+
background: #da190b;
|
| 153 |
+
transform: translateY(-2px);
|
| 154 |
+
box-shadow: 0 4px 12px rgba(244, 67, 54, 0.4);
|
| 155 |
+
}
|
| 156 |
+
.input-area {
|
| 157 |
+
display: flex;
|
| 158 |
+
gap: 10px;
|
| 159 |
+
padding: 10px;
|
| 160 |
+
background: #f8f9fa;
|
| 161 |
+
border-radius: 12px;
|
| 162 |
+
}
|
| 163 |
+
#message-input {
|
| 164 |
+
flex: 1;
|
| 165 |
+
padding: 12px 16px;
|
| 166 |
+
border: 2px solid #e0e0e0;
|
| 167 |
+
border-radius: 8px;
|
| 168 |
+
font-size: 14px;
|
| 169 |
+
transition: border 0.3s;
|
| 170 |
+
}
|
| 171 |
+
#message-input:focus {
|
| 172 |
+
outline: none;
|
| 173 |
+
border-color: #667eea;
|
| 174 |
+
}
|
| 175 |
+
#send-btn {
|
| 176 |
+
padding: 12px 32px;
|
| 177 |
+
background: #667eea;
|
| 178 |
+
color: white;
|
| 179 |
+
border: none;
|
| 180 |
+
border-radius: 8px;
|
| 181 |
+
font-weight: bold;
|
| 182 |
+
cursor: pointer;
|
| 183 |
+
transition: all 0.3s;
|
| 184 |
+
}
|
| 185 |
+
#send-btn:hover:not(:disabled) {
|
| 186 |
+
background: #5568d3;
|
| 187 |
+
transform: translateY(-2px);
|
| 188 |
+
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4);
|
| 189 |
+
}
|
| 190 |
+
#send-btn:disabled {
|
| 191 |
+
background: #ccc;
|
| 192 |
+
cursor: not-allowed;
|
| 193 |
+
transform: none;
|
| 194 |
+
}
|
| 195 |
+
.status-bar {
|
| 196 |
+
padding: 10px;
|
| 197 |
+
background: #f8f9fa;
|
| 198 |
+
border-radius: 8px;
|
| 199 |
+
margin-top: 10px;
|
| 200 |
+
display: flex;
|
| 201 |
+
justify-content: space-between;
|
| 202 |
+
font-size: 12px;
|
| 203 |
+
color: #666;
|
| 204 |
+
}
|
| 205 |
+
.typing-indicator {
|
| 206 |
+
display: none;
|
| 207 |
+
padding: 10px;
|
| 208 |
+
color: #666;
|
| 209 |
+
font-style: italic;
|
| 210 |
+
}
|
| 211 |
+
.typing-indicator.active {
|
| 212 |
+
display: block;
|
| 213 |
+
}
|
| 214 |
+
</style>
|
| 215 |
+
</head>
|
| 216 |
+
<body>
|
| 217 |
+
<div class="container">
|
| 218 |
+
<div class="header">
|
| 219 |
+
<a href="/docs" class="api-link">📚 API Docs</a>
|
| 220 |
+
<h1>🧑⚖️ Legal Assistant</h1>
|
| 221 |
+
<p>Multi-country legal support for Benin & Madagascar</p>
|
| 222 |
+
</div>
|
| 223 |
+
|
| 224 |
+
<div class="chat-container">
|
| 225 |
+
<div id="messages"></div>
|
| 226 |
+
<div id="interrupt-zone"></div>
|
| 227 |
+
<div class="typing-indicator" id="typing">Assistant is typing...</div>
|
| 228 |
+
|
| 229 |
+
<div class="input-area">
|
| 230 |
+
<input
|
| 231 |
+
type="text"
|
| 232 |
+
id="message-input"
|
| 233 |
+
placeholder="Type your legal question..."
|
| 234 |
+
autocomplete="off"
|
| 235 |
+
/>
|
| 236 |
+
<button id="send-btn" onclick="sendMessage()">
|
| 237 |
+
Send
|
| 238 |
+
</button>
|
| 239 |
+
</div>
|
| 240 |
+
|
| 241 |
+
<div class="status-bar">
|
| 242 |
+
<span>Session: <strong id="session-id">Not started</strong></span>
|
| 243 |
+
<span id="status">System ready</span>
|
| 244 |
+
</div>
|
| 245 |
+
</div>
|
| 246 |
+
</div>
|
| 247 |
+
|
| 248 |
+
<script>
|
| 249 |
+
let sessionId = null;
|
| 250 |
+
let currentAssistantMessage = null;
|
| 251 |
+
|
| 252 |
+
function addMessage(content, type) {
|
| 253 |
+
const messagesDiv = document.getElementById('messages');
|
| 254 |
+
const messageDiv = document.createElement('div');
|
| 255 |
+
messageDiv.className = `message ${type}-message`;
|
| 256 |
+
messageDiv.textContent = content;
|
| 257 |
+
messagesDiv.appendChild(messageDiv);
|
| 258 |
+
messagesDiv.scrollTop = messagesDiv.scrollHeight;
|
| 259 |
+
return messageDiv;
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
function showInterrupt(data) {
|
| 263 |
+
const interruptZone = document.getElementById('interrupt-zone');
|
| 264 |
+
interruptZone.innerHTML = `
|
| 265 |
+
<div class="interrupt-panel">
|
| 266 |
+
<h3>🔒 Human Approval Required</h3>
|
| 267 |
+
<div class="interrupt-details">
|
| 268 |
+
<p><strong>📧 Email:</strong> ${data.user_email || 'N/A'}</p>
|
| 269 |
+
<p><strong>🌍 Country:</strong> ${data.country || 'N/A'}</p>
|
| 270 |
+
<p><strong>📝 Description:</strong> ${data.description || 'N/A'}</p>
|
| 271 |
+
</div>
|
| 272 |
+
<div class="approval-buttons">
|
| 273 |
+
<button class="btn btn-approve" onclick="handleApproval('approve')">
|
| 274 |
+
✅ Approve Request
|
| 275 |
+
</button>
|
| 276 |
+
<button class="btn btn-reject" onclick="handleApproval('reject')">
|
| 277 |
+
❌ Reject Request
|
| 278 |
+
</button>
|
| 279 |
+
</div>
|
| 280 |
+
</div>
|
| 281 |
+
`;
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
function hideInterrupt() {
|
| 285 |
+
document.getElementById('interrupt-zone').innerHTML = '';
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
async function handleApproval(decision) {
|
| 289 |
+
const reason = decision === 'approve'
|
| 290 |
+
? 'Approved via web interface'
|
| 291 |
+
: 'Rejected via web interface';
|
| 292 |
+
|
| 293 |
+
try {
|
| 294 |
+
addMessage(`${decision === 'approve' ? '✅' : '❌'} Decision: ${decision}`, 'system');
|
| 295 |
+
hideInterrupt();
|
| 296 |
+
|
| 297 |
+
const response = await fetch(`/sessions/${sessionId}/approve`, {
|
| 298 |
+
method: 'POST',
|
| 299 |
+
headers: { 'Content-Type': 'application/json' },
|
| 300 |
+
body: JSON.stringify({ decision, reason })
|
| 301 |
+
});
|
| 302 |
+
|
| 303 |
+
const data = await response.json();
|
| 304 |
+
addMessage(data.response, 'assistant');
|
| 305 |
+
|
| 306 |
+
} catch (error) {
|
| 307 |
+
addMessage('Error: ' + error.message, 'system');
|
| 308 |
+
}
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
async function sendMessage() {
|
| 312 |
+
const input = document.getElementById('message-input');
|
| 313 |
+
const message = input.value.trim();
|
| 314 |
+
|
| 315 |
+
if (!message) return;
|
| 316 |
+
|
| 317 |
+
if (!sessionId) {
|
| 318 |
+
sessionId = 'web_' + Date.now();
|
| 319 |
+
document.getElementById('session-id').textContent = sessionId;
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
addMessage(message, 'user');
|
| 323 |
+
input.value = '';
|
| 324 |
+
|
| 325 |
+
const sendBtn = document.getElementById('send-btn');
|
| 326 |
+
sendBtn.disabled = true;
|
| 327 |
+
input.disabled = true;
|
| 328 |
+
|
| 329 |
+
document.getElementById('typing').classList.add('active');
|
| 330 |
+
|
| 331 |
+
try {
|
| 332 |
+
const response = await fetch('/chat/stream', {
|
| 333 |
+
method: 'POST',
|
| 334 |
+
headers: { 'Content-Type': 'application/json' },
|
| 335 |
+
body: JSON.stringify({
|
| 336 |
+
message: message,
|
| 337 |
+
session_id: sessionId,
|
| 338 |
+
stream: true
|
| 339 |
+
})
|
| 340 |
+
});
|
| 341 |
+
|
| 342 |
+
const reader = response.body.getReader();
|
| 343 |
+
const decoder = new TextDecoder();
|
| 344 |
+
currentAssistantMessage = addMessage('', 'assistant');
|
| 345 |
+
|
| 346 |
+
while (true) {
|
| 347 |
+
const { value, done } = await reader.read();
|
| 348 |
+
if (done) break;
|
| 349 |
+
|
| 350 |
+
const chunk = decoder.decode(value);
|
| 351 |
+
const lines = chunk.split('\\n\\n');
|
| 352 |
+
|
| 353 |
+
for (const line of lines) {
|
| 354 |
+
if (line.startsWith('data: ')) {
|
| 355 |
+
try {
|
| 356 |
+
const data = JSON.parse(line.slice(6));
|
| 357 |
+
|
| 358 |
+
if (data.type === 'token') {
|
| 359 |
+
currentAssistantMessage.textContent += data.content;
|
| 360 |
+
} else if (data.type === 'interrupt') {
|
| 361 |
+
showInterrupt(data.data);
|
| 362 |
+
} else if (data.type === 'done') {
|
| 363 |
+
document.getElementById('status').textContent = 'Ready';
|
| 364 |
+
} else if (data.type === 'error') {
|
| 365 |
+
addMessage('Error: ' + data.message, 'system');
|
| 366 |
+
}
|
| 367 |
+
} catch (e) {
|
| 368 |
+
console.error('Parse error:', e);
|
| 369 |
+
}
|
| 370 |
+
}
|
| 371 |
+
}
|
| 372 |
+
}
|
| 373 |
+
|
| 374 |
+
} catch (error) {
|
| 375 |
+
addMessage('Error: ' + error.message, 'system');
|
| 376 |
+
} finally {
|
| 377 |
+
sendBtn.disabled = false;
|
| 378 |
+
input.disabled = false;
|
| 379 |
+
input.focus();
|
| 380 |
+
document.getElementById('typing').classList.remove('active');
|
| 381 |
+
}
|
| 382 |
+
}
|
| 383 |
+
|
| 384 |
+
document.getElementById('message-input').addEventListener('keypress', (e) => {
|
| 385 |
+
if (e.key === 'Enter' && !e.shiftKey) {
|
| 386 |
+
e.preventDefault();
|
| 387 |
+
sendMessage();
|
| 388 |
+
}
|
| 389 |
+
});
|
| 390 |
+
|
| 391 |
+
document.getElementById('message-input').focus();
|
| 392 |
+
</script>
|
| 393 |
+
</body>
|
| 394 |
+
</html>
|
| 395 |
+
"""
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
@router.get("/about", response_class=HTMLResponse)
|
| 399 |
+
async def about():
|
| 400 |
+
"""About page with API information"""
|
| 401 |
+
return """
|
| 402 |
+
<!DOCTYPE html>
|
| 403 |
+
<html lang="en">
|
| 404 |
+
<head>
|
| 405 |
+
<meta charset="UTF-8">
|
| 406 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 407 |
+
<title>Legal Assistant API</title>
|
| 408 |
+
<style>
|
| 409 |
+
* { margin: 0; padding: 0; box-sizing: border-box; }
|
| 410 |
+
body {
|
| 411 |
+
font-family: 'Segoe UI', sans-serif;
|
| 412 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 413 |
+
min-height: 100vh;
|
| 414 |
+
display: flex;
|
| 415 |
+
align-items: center;
|
| 416 |
+
justify-content: center;
|
| 417 |
+
padding: 20px;
|
| 418 |
+
}
|
| 419 |
+
.container {
|
| 420 |
+
max-width: 900px;
|
| 421 |
+
width: 100%;
|
| 422 |
+
background: white;
|
| 423 |
+
border-radius: 20px;
|
| 424 |
+
box-shadow: 0 20px 60px rgba(0,0,0,0.3);
|
| 425 |
+
overflow: hidden;
|
| 426 |
+
}
|
| 427 |
+
.header {
|
| 428 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 429 |
+
color: white;
|
| 430 |
+
padding: 60px 40px;
|
| 431 |
+
text-align: center;
|
| 432 |
+
}
|
| 433 |
+
.header h1 { font-size: 2.5em; margin-bottom: 15px; }
|
| 434 |
+
.header p { font-size: 1.2em; opacity: 0.95; }
|
| 435 |
+
.version {
|
| 436 |
+
background: rgba(255,255,255,0.2);
|
| 437 |
+
padding: 5px 15px;
|
| 438 |
+
border-radius: 20px;
|
| 439 |
+
margin-top: 15px;
|
| 440 |
+
display: inline-block;
|
| 441 |
+
}
|
| 442 |
+
.content { padding: 40px; }
|
| 443 |
+
.features {
|
| 444 |
+
display: grid;
|
| 445 |
+
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
|
| 446 |
+
gap: 20px;
|
| 447 |
+
margin-bottom: 40px;
|
| 448 |
+
}
|
| 449 |
+
.feature-card {
|
| 450 |
+
padding: 25px;
|
| 451 |
+
background: #f8f9fa;
|
| 452 |
+
border-radius: 12px;
|
| 453 |
+
border-left: 4px solid #667eea;
|
| 454 |
+
}
|
| 455 |
+
.feature-card h3 { color: #667eea; margin-bottom: 10px; }
|
| 456 |
+
.cta-section {
|
| 457 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 458 |
+
padding: 40px;
|
| 459 |
+
border-radius: 12px;
|
| 460 |
+
text-align: center;
|
| 461 |
+
}
|
| 462 |
+
.cta-section h2 { color: white; margin-bottom: 25px; }
|
| 463 |
+
.buttons { display: flex; gap: 15px; justify-content: center; flex-wrap: wrap; }
|
| 464 |
+
.btn {
|
| 465 |
+
padding: 15px 35px;
|
| 466 |
+
border-radius: 8px;
|
| 467 |
+
text-decoration: none;
|
| 468 |
+
font-weight: 600;
|
| 469 |
+
transition: all 0.3s;
|
| 470 |
+
}
|
| 471 |
+
.btn-primary { background: white; color: #667eea; }
|
| 472 |
+
.btn-primary:hover { transform: translateY(-3px); }
|
| 473 |
+
.footer {
|
| 474 |
+
padding: 30px;
|
| 475 |
+
background: #f8f9fa;
|
| 476 |
+
text-align: center;
|
| 477 |
+
color: #666;
|
| 478 |
+
}
|
| 479 |
+
</style>
|
| 480 |
+
</head>
|
| 481 |
+
<body>
|
| 482 |
+
<div class="container">
|
| 483 |
+
<div class="header">
|
| 484 |
+
<h1>⚖️ Legal Assistant API</h1>
|
| 485 |
+
<p>Multi-country legal support for Benin & Madagascar</p>
|
| 486 |
+
<span class="version">v2.0.0</span>
|
| 487 |
+
</div>
|
| 488 |
+
<div class="content">
|
| 489 |
+
<div class="features">
|
| 490 |
+
<div class="feature-card">
|
| 491 |
+
<h3>🤖 AI-Powered</h3>
|
| 492 |
+
<p>Advanced RAG technology</p>
|
| 493 |
+
</div>
|
| 494 |
+
<div class="feature-card">
|
| 495 |
+
<h3>🌍 Multi-Country</h3>
|
| 496 |
+
<p>Benin & Madagascar laws</p>
|
| 497 |
+
</div>
|
| 498 |
+
<div class="feature-card">
|
| 499 |
+
<h3>🔒 Human Approval</h3>
|
| 500 |
+
<p>Quality assurance</p>
|
| 501 |
+
</div>
|
| 502 |
+
<div class="feature-card">
|
| 503 |
+
<h3>⚡ Real-time</h3>
|
| 504 |
+
<p>Streaming responses</p>
|
| 505 |
+
</div>
|
| 506 |
+
</div>
|
| 507 |
+
<div class="cta-section">
|
| 508 |
+
<h2>Get Started</h2>
|
| 509 |
+
<div class="buttons">
|
| 510 |
+
<a href="/docs" class="btn btn-primary">📚 Swagger UI</a>
|
| 511 |
+
<a href="/redoc" class="btn btn-primary">📖 ReDoc</a>
|
| 512 |
+
</div>
|
| 513 |
+
</div>
|
| 514 |
+
</div>
|
| 515 |
+
<div class="footer">
|
| 516 |
+
<p>Built with FastAPI & LangGraph</p>
|
| 517 |
+
</div>
|
| 518 |
+
</div>
|
| 519 |
+
</body>
|
| 520 |
+
</html>
|
| 521 |
+
"""
|
api/routes/sessions.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# api/routes/sessions.py
|
| 2 |
+
from fastapi import APIRouter, HTTPException
|
| 3 |
+
from api.models import ApprovalRequest
|
| 4 |
+
from api.services import ChatService
|
| 5 |
+
|
| 6 |
+
router = APIRouter(prefix="/sessions", tags=["Sessions"])
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@router.post("/{session_id}/approve")
|
| 10 |
+
async def approve_assistance(session_id: str, request: ApprovalRequest):
|
| 11 |
+
"""Approve or reject assistance request"""
|
| 12 |
+
try:
|
| 13 |
+
return await ChatService.approve_request(
|
| 14 |
+
session_id=session_id,
|
| 15 |
+
decision=request.decision,
|
| 16 |
+
reason=request.reason
|
| 17 |
+
)
|
| 18 |
+
except Exception as e:
|
| 19 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@router.get("/{session_id}/status")
|
| 23 |
+
async def get_session_status(session_id: str):
|
| 24 |
+
"""Get session status including interrupt state"""
|
| 25 |
+
try:
|
| 26 |
+
return await ChatService.get_session_status(session_id)
|
| 27 |
+
except Exception as e:
|
| 28 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@router.get("/{session_id}/history")
|
| 32 |
+
async def get_history(session_id: str):
|
| 33 |
+
"""Get conversation history"""
|
| 34 |
+
try:
|
| 35 |
+
return await ChatService.get_history(session_id)
|
| 36 |
+
except Exception as e:
|
| 37 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@router.get("")
|
| 41 |
+
async def list_sessions():
|
| 42 |
+
"""List all active sessions with their status"""
|
| 43 |
+
try:
|
| 44 |
+
return await ChatService.list_sessions()
|
| 45 |
+
except Exception as e:
|
| 46 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@router.delete("/{session_id}")
|
| 50 |
+
async def clear_session(session_id: str):
|
| 51 |
+
"""Clear a session (useful for testing)"""
|
| 52 |
+
try:
|
| 53 |
+
return await ChatService.clear_session(session_id)
|
| 54 |
+
except Exception as e:
|
| 55 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 56 |
+
|
api/services/__init__.py
CHANGED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# api/services/__init__.py
|
| 2 |
+
# ============================================================================
|
| 3 |
+
from .chat_service import ChatService
|
| 4 |
+
from .stream_service import StreamService
|
| 5 |
+
|
| 6 |
+
__all__ = ["ChatService", "StreamService"]
|
api/services/chat_service.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# api/services/chat_service.py
|
| 2 |
+
from typing import Dict, Optional
|
| 3 |
+
from api.config import settings, logger
|
| 4 |
+
from api.models import ChatResponse
|
| 5 |
+
from uuid import uuid4
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ChatService:
|
| 9 |
+
"""Service for handling chat operations"""
|
| 10 |
+
|
| 11 |
+
@staticmethod
|
| 12 |
+
async def process_message(
|
| 13 |
+
message: str,
|
| 14 |
+
session_id: Optional[str] = None
|
| 15 |
+
) -> ChatResponse:
|
| 16 |
+
"""Process a chat message"""
|
| 17 |
+
if not settings.system_initialized or not settings.chat_manager:
|
| 18 |
+
raise Exception("System not initialized")
|
| 19 |
+
|
| 20 |
+
session_id = session_id or f"api_{uuid4()}"
|
| 21 |
+
|
| 22 |
+
# Check for pending interrupt
|
| 23 |
+
if settings.chat_manager.has_pending_interrupt(session_id):
|
| 24 |
+
interrupt_info = settings.chat_manager.pending_interrupts.get(session_id, {})
|
| 25 |
+
interrupt_data_obj = interrupt_info.get("interrupt_data")
|
| 26 |
+
|
| 27 |
+
interrupt_value = (
|
| 28 |
+
interrupt_data_obj.value
|
| 29 |
+
if hasattr(interrupt_data_obj, 'value')
|
| 30 |
+
else {}
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
return ChatResponse(
|
| 34 |
+
response="⏸️ Pending approval request. Please approve/reject first.",
|
| 35 |
+
session_id=session_id,
|
| 36 |
+
has_interrupt=True,
|
| 37 |
+
interrupt_type="human_approval",
|
| 38 |
+
interrupt_data=interrupt_value
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
# Process message
|
| 42 |
+
response_text = await settings.chat_manager.chat(
|
| 43 |
+
message=message,
|
| 44 |
+
session_id=session_id
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
# Check if interrupt occurred
|
| 48 |
+
has_interrupt = settings.chat_manager.has_pending_interrupt(session_id)
|
| 49 |
+
interrupt_value = None
|
| 50 |
+
|
| 51 |
+
if has_interrupt:
|
| 52 |
+
interrupt_info = settings.chat_manager.pending_interrupts.get(session_id, {})
|
| 53 |
+
interrupt_data_obj = interrupt_info.get("interrupt_data")
|
| 54 |
+
|
| 55 |
+
if hasattr(interrupt_data_obj, 'value'):
|
| 56 |
+
interrupt_value = interrupt_data_obj.value
|
| 57 |
+
elif isinstance(interrupt_data_obj, dict):
|
| 58 |
+
interrupt_value = interrupt_data_obj.get("value", {})
|
| 59 |
+
|
| 60 |
+
return ChatResponse(
|
| 61 |
+
response=response_text,
|
| 62 |
+
session_id=session_id,
|
| 63 |
+
has_interrupt=has_interrupt,
|
| 64 |
+
interrupt_type="human_approval" if has_interrupt else None,
|
| 65 |
+
interrupt_data=interrupt_value
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
@staticmethod
|
| 69 |
+
async def approve_request(
|
| 70 |
+
session_id: str,
|
| 71 |
+
decision: str,
|
| 72 |
+
reason: Optional[str] = None
|
| 73 |
+
) -> Dict:
|
| 74 |
+
"""Approve or reject an assistance request"""
|
| 75 |
+
if not settings.chat_manager:
|
| 76 |
+
raise Exception("System not initialized")
|
| 77 |
+
|
| 78 |
+
if not settings.chat_manager.has_pending_interrupt(session_id):
|
| 79 |
+
raise Exception("No pending interrupt")
|
| 80 |
+
|
| 81 |
+
decision_text = f"{decision} {reason or ''}"
|
| 82 |
+
response_text = await settings.chat_manager.chat(
|
| 83 |
+
message=decision_text,
|
| 84 |
+
session_id=session_id
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
return {
|
| 88 |
+
"status": "success",
|
| 89 |
+
"decision": decision,
|
| 90 |
+
"response": response_text,
|
| 91 |
+
"session_id": session_id
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
@staticmethod
|
| 95 |
+
async def get_session_status(session_id: str) -> Dict:
|
| 96 |
+
"""Get session status"""
|
| 97 |
+
if not settings.chat_manager:
|
| 98 |
+
raise Exception("System not initialized")
|
| 99 |
+
|
| 100 |
+
has_interrupt = settings.chat_manager.has_pending_interrupt(session_id)
|
| 101 |
+
interrupt_data = None
|
| 102 |
+
|
| 103 |
+
if has_interrupt:
|
| 104 |
+
interrupt_info = settings.chat_manager.pending_interrupts.get(session_id, {})
|
| 105 |
+
interrupt_data_obj = interrupt_info.get("interrupt_data")
|
| 106 |
+
|
| 107 |
+
if hasattr(interrupt_data_obj, 'value'):
|
| 108 |
+
interrupt_data = interrupt_data_obj.value
|
| 109 |
+
|
| 110 |
+
return {
|
| 111 |
+
"session_id": session_id,
|
| 112 |
+
"has_pending_interrupt": has_interrupt,
|
| 113 |
+
"interrupt_data": interrupt_data
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
@staticmethod
|
| 117 |
+
async def get_history(session_id: str) -> Dict:
|
| 118 |
+
"""Get conversation history"""
|
| 119 |
+
if not settings.chat_manager:
|
| 120 |
+
raise Exception("System not initialized")
|
| 121 |
+
|
| 122 |
+
history = await settings.chat_manager.get_conversation_history(session_id)
|
| 123 |
+
|
| 124 |
+
return {
|
| 125 |
+
"session_id": session_id,
|
| 126 |
+
"history": [
|
| 127 |
+
{
|
| 128 |
+
"role": msg.type if hasattr(msg, 'type') else "unknown",
|
| 129 |
+
"content": msg.content if hasattr(msg, 'content') else str(msg)
|
| 130 |
+
}
|
| 131 |
+
for msg in history
|
| 132 |
+
]
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
@staticmethod
|
| 136 |
+
async def list_sessions() -> Dict:
|
| 137 |
+
"""List all active sessions"""
|
| 138 |
+
if not settings.chat_manager:
|
| 139 |
+
raise Exception("System not initialized")
|
| 140 |
+
|
| 141 |
+
sessions_info = []
|
| 142 |
+
for session_id in settings.chat_manager.pending_interrupts.keys():
|
| 143 |
+
interrupt_info = settings.chat_manager.pending_interrupts.get(session_id, {})
|
| 144 |
+
interrupt_data_obj = interrupt_info.get("interrupt_data")
|
| 145 |
+
|
| 146 |
+
data = (
|
| 147 |
+
interrupt_data_obj.value
|
| 148 |
+
if hasattr(interrupt_data_obj, 'value')
|
| 149 |
+
else {}
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
sessions_info.append({
|
| 153 |
+
"session_id": session_id,
|
| 154 |
+
"has_interrupt": True,
|
| 155 |
+
"user_email": data.get("user_email"),
|
| 156 |
+
"country": data.get("country"),
|
| 157 |
+
"description": data.get("description")
|
| 158 |
+
})
|
| 159 |
+
|
| 160 |
+
return {
|
| 161 |
+
"total_pending": len(sessions_info),
|
| 162 |
+
"sessions": sessions_info
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
@staticmethod
|
| 166 |
+
async def clear_session(session_id: str) -> Dict:
|
| 167 |
+
"""Clear a session"""
|
| 168 |
+
if not settings.chat_manager:
|
| 169 |
+
raise Exception("System not initialized")
|
| 170 |
+
|
| 171 |
+
if session_id in settings.chat_manager.pending_interrupts:
|
| 172 |
+
del settings.chat_manager.pending_interrupts[session_id]
|
| 173 |
+
|
| 174 |
+
return {
|
| 175 |
+
"status": "success",
|
| 176 |
+
"message": f"Session {session_id} cleared"
|
| 177 |
+
}
|
| 178 |
+
|
api/services/stream_service.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# api/services/stream_service.py
|
| 2 |
+
import json
|
| 3 |
+
import asyncio
|
| 4 |
+
from typing import AsyncGenerator
|
| 5 |
+
from api.config import settings, logger
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class StreamService:
|
| 9 |
+
"""Service for handling streaming responses"""
|
| 10 |
+
|
| 11 |
+
@staticmethod
|
| 12 |
+
async def stream_chat_response(
|
| 13 |
+
message: str,
|
| 14 |
+
session_id: str
|
| 15 |
+
) -> AsyncGenerator[str, None]:
|
| 16 |
+
"""Stream chat responses token by token"""
|
| 17 |
+
try:
|
| 18 |
+
# Check for pending interrupt first
|
| 19 |
+
if settings.chat_manager.has_pending_interrupt(session_id):
|
| 20 |
+
interrupt_info = settings.chat_manager.pending_interrupts.get(session_id, {})
|
| 21 |
+
interrupt_data_obj = interrupt_info.get("interrupt_data")
|
| 22 |
+
|
| 23 |
+
interrupt_value = (
|
| 24 |
+
interrupt_data_obj.value
|
| 25 |
+
if hasattr(interrupt_data_obj, 'value')
|
| 26 |
+
else {}
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
yield f"data: {json.dumps({'type': 'interrupt', 'data': interrupt_value})}\n\n"
|
| 30 |
+
return
|
| 31 |
+
|
| 32 |
+
# Process message
|
| 33 |
+
response = await settings.chat_manager.chat(
|
| 34 |
+
message=message,
|
| 35 |
+
session_id=session_id,
|
| 36 |
+
interrupt_handler=None
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
# Check if interrupt occurred during processing
|
| 40 |
+
if settings.chat_manager.has_pending_interrupt(session_id):
|
| 41 |
+
interrupt_info = settings.chat_manager.pending_interrupts.get(session_id, {})
|
| 42 |
+
interrupt_data_obj = interrupt_info.get("interrupt_data")
|
| 43 |
+
|
| 44 |
+
interrupt_value = (
|
| 45 |
+
interrupt_data_obj.value
|
| 46 |
+
if hasattr(interrupt_data_obj, 'value')
|
| 47 |
+
else {}
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
yield f"data: {json.dumps({'type': 'interrupt', 'data': interrupt_value})}\n\n"
|
| 51 |
+
return
|
| 52 |
+
|
| 53 |
+
# Stream response token by token
|
| 54 |
+
words = response.split()
|
| 55 |
+
for i, word in enumerate(words):
|
| 56 |
+
token = word + (" " if i < len(words) - 1 else "")
|
| 57 |
+
yield f"data: {json.dumps({'type': 'token', 'content': token})}\n\n"
|
| 58 |
+
await asyncio.sleep(settings.STREAM_DELAY)
|
| 59 |
+
|
| 60 |
+
# Send completion event
|
| 61 |
+
yield f"data: {json.dumps({'type': 'done', 'session_id': session_id})}\n\n"
|
| 62 |
+
|
| 63 |
+
except Exception as e:
|
| 64 |
+
logger.exception(f"Error in streaming: {e}")
|
| 65 |
+
yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n"
|
| 66 |
+
|
api/utils/__init__.py
CHANGED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# api/utils/__init__.py
|
| 2 |
+
from .startup import initialize_system, lifespan
|
| 3 |
+
|
| 4 |
+
__all__ = ["initialize_system", "lifespan"]
|
api/utils/startup.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# api/utils/startup.py
|
| 2 |
+
import sys
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
| 5 |
+
|
| 6 |
+
import asyncio
|
| 7 |
+
from contextlib import asynccontextmanager
|
| 8 |
+
from fastapi import FastAPI
|
| 9 |
+
from core.system_initializer import setup_system
|
| 10 |
+
from api.config import settings, logger
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
async def initialize_system():
|
| 14 |
+
"""Initialize the legal assistant system"""
|
| 15 |
+
try:
|
| 16 |
+
system = await setup_system()
|
| 17 |
+
settings.chat_manager = system["chat_manager"]
|
| 18 |
+
settings.graph = system["graph"]
|
| 19 |
+
settings.system_initialized = True
|
| 20 |
+
logger.info("✅ Legal assistant system initialized")
|
| 21 |
+
except Exception as e:
|
| 22 |
+
logger.error(f"❌ Failed to initialize system: {e}")
|
| 23 |
+
settings.system_initialized = False
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@asynccontextmanager
|
| 27 |
+
async def lifespan(app: FastAPI):
|
| 28 |
+
"""Application lifespan manager"""
|
| 29 |
+
logger.info("🚀 Starting Legal Assistant API...")
|
| 30 |
+
initialization_task = asyncio.create_task(initialize_system())
|
| 31 |
+
yield
|
| 32 |
+
logger.info("🛑 Shutting down...")
|
| 33 |
+
initialization_task.cancel()
|
| 34 |
+
try:
|
| 35 |
+
await initialization_task
|
| 36 |
+
except asyncio.CancelledError:
|
| 37 |
+
pass
|
config/settings.py
CHANGED
|
@@ -15,7 +15,7 @@ class Settings:
|
|
| 15 |
NEON_END_POINT = os.getenv("NEON_END_POINT")
|
| 16 |
|
| 17 |
# Database
|
| 18 |
-
DATABASE_URL = NEON_END_POINT
|
| 19 |
|
| 20 |
# Model Configurations
|
| 21 |
EMBEDDING_MODEL = "text-embedding-ada-002"
|
|
|
|
| 15 |
NEON_END_POINT = os.getenv("NEON_END_POINT")
|
| 16 |
|
| 17 |
# Database
|
| 18 |
+
# DATABASE_URL = NEON_END_POINT
|
| 19 |
|
| 20 |
# Model Configurations
|
| 21 |
EMBEDDING_MODEL = "text-embedding-ada-002"
|
core/chat_manager.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
# [file name]: core/chat_manager.py
|
| 2 |
-
# Add this as the FIRST lines of code (after docstrings)
|
| 3 |
import sys
|
| 4 |
from pathlib import Path
|
| 5 |
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
@@ -7,10 +6,11 @@ sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
| 7 |
import asyncio
|
| 8 |
import logging
|
| 9 |
from datetime import datetime
|
| 10 |
-
from typing import Dict, List, Optional
|
| 11 |
from langchain_core.runnables import RunnableConfig
|
| 12 |
from langchain_core.messages import BaseMessage
|
| 13 |
from langgraph.types import Command
|
|
|
|
| 14 |
|
| 15 |
from config.settings import settings
|
| 16 |
from models.state_models import MultiCountryLegalState
|
|
@@ -19,6 +19,11 @@ from utils.helpers import dict_to_message_obj
|
|
| 19 |
logger = logging.getLogger(__name__)
|
| 20 |
|
| 21 |
class LegalChatManager:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
def __init__(self, graph, checkpointer):
|
| 23 |
self.graph = graph
|
| 24 |
self.checkpointer = checkpointer
|
|
@@ -32,9 +37,30 @@ class LegalChatManager:
|
|
| 32 |
# Track pending interrupts by session
|
| 33 |
self.pending_interrupts = {}
|
| 34 |
|
| 35 |
-
async def chat(
|
| 36 |
-
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
if not self.graph:
|
| 39 |
raise RuntimeError("System not initialized. Call setup_system() first.")
|
| 40 |
|
|
@@ -56,31 +82,89 @@ class LegalChatManager:
|
|
| 56 |
# Track performance
|
| 57 |
start_time = datetime.now()
|
| 58 |
|
| 59 |
-
#
|
| 60 |
-
|
| 61 |
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
"
|
| 70 |
-
"
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
# Track performance
|
| 76 |
processing_time = (datetime.now() - start_time).total_seconds()
|
| 77 |
self._update_session_stats(session_id, processing_time)
|
|
|
|
| 78 |
|
| 79 |
-
|
| 80 |
-
response = self._extract_response(result)
|
| 81 |
-
self._update_routing_stats(response)
|
| 82 |
-
|
| 83 |
-
return response
|
| 84 |
|
| 85 |
except Exception as e:
|
| 86 |
logger.exception(f"Chat error for session {session_id}")
|
|
@@ -88,32 +172,79 @@ class LegalChatManager:
|
|
| 88 |
return f"Erreur lors du traitement: {str(e)}"
|
| 89 |
|
| 90 |
async def _handle_pending_interrupt(self, session_id: str, message: str) -> str:
|
| 91 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
interrupt_data = self.pending_interrupts.get(session_id)
|
| 93 |
if not interrupt_data:
|
| 94 |
return "Erreur: Aucune interruption en attente."
|
| 95 |
|
| 96 |
try:
|
| 97 |
-
logger.info(f"
|
| 98 |
|
| 99 |
config = interrupt_data["config"]
|
| 100 |
|
| 101 |
-
#
|
| 102 |
-
|
| 103 |
-
|
| 104 |
Command(resume=message),
|
| 105 |
-
config
|
| 106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
# Clean up the pending interrupt
|
| 109 |
del self.pending_interrupts[session_id]
|
| 110 |
|
| 111 |
-
#
|
| 112 |
-
|
| 113 |
-
self._update_routing_stats(response)
|
| 114 |
|
| 115 |
logger.info(f"✅ Graph resumed successfully for session {session_id}")
|
| 116 |
-
return
|
| 117 |
|
| 118 |
except Exception as e:
|
| 119 |
logger.error(f"Error resuming from interrupt: {str(e)}")
|
|
@@ -122,26 +253,12 @@ class LegalChatManager:
|
|
| 122 |
del self.pending_interrupts[session_id]
|
| 123 |
return f"Erreur lors du traitement de la décision: {str(e)}"
|
| 124 |
|
| 125 |
-
def
|
| 126 |
-
"""
|
| 127 |
-
|
| 128 |
-
if isinstance(state, MultiCountryLegalState):
|
| 129 |
-
state_dict = state.model_dump()
|
| 130 |
-
elif isinstance(state, dict):
|
| 131 |
-
state_dict = state
|
| 132 |
-
else:
|
| 133 |
-
state_dict = {}
|
| 134 |
-
|
| 135 |
-
user_email = state_dict.get("user_email", "Non spécifié")
|
| 136 |
-
country = state_dict.get("legal_context", {}).get("detected_country", "Non spécifié")
|
| 137 |
-
description = state_dict.get("assistance_description", "Non spécifié")
|
| 138 |
-
|
| 139 |
-
return f"""
|
| 140 |
🔒 **APPROBATION HUMAINE REQUISE**
|
| 141 |
|
| 142 |
-
|
| 143 |
-
🌍 **Pays**: {country}
|
| 144 |
-
📝 **Description**: {description}
|
| 145 |
|
| 146 |
**Veuillez répondre avec:**
|
| 147 |
- "approve [raison]" pour approuver la demande
|
|
@@ -154,10 +271,32 @@ class LegalChatManager:
|
|
| 154 |
**Votre décision:**
|
| 155 |
"""
|
| 156 |
|
| 157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
|
| 159 |
async def get_conversation_history(self, session_id: str) -> List[BaseMessage]:
|
| 160 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
if not self.graph:
|
| 162 |
return []
|
| 163 |
|
|
@@ -183,18 +322,27 @@ class LegalChatManager:
|
|
| 183 |
logger.exception(f"Error getting conversation history for session {session_id}")
|
| 184 |
return []
|
| 185 |
|
| 186 |
-
def get_session_stats(self, session_id: str) -> Dict:
|
| 187 |
"""Get statistics for a specific session"""
|
| 188 |
return self.active_sessions.get(session_id, {})
|
| 189 |
|
| 190 |
-
def get_global_stats(self) -> Dict:
|
| 191 |
-
"""
|
| 192 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
"routing_stats": self.routing_stats,
|
| 194 |
"active_sessions": len(self.active_sessions),
|
| 195 |
"total_queries": self.routing_stats["total_queries"],
|
| 196 |
"pending_interrupts": len(self.pending_interrupts)
|
| 197 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
|
| 199 |
def _initialize_session(self, session_id: str):
|
| 200 |
"""Initialize or update session tracking"""
|
|
@@ -212,9 +360,23 @@ class LegalChatManager:
|
|
| 212 |
session_info["query_count"] += 1
|
| 213 |
session_info["last_activity"] = datetime.now()
|
| 214 |
|
| 215 |
-
def _prepare_input_state(
|
| 216 |
-
|
| 217 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
ctx = legal_context or {
|
| 219 |
"jurisdiction": "Unknown",
|
| 220 |
"user_type": "general",
|
|
@@ -237,7 +399,15 @@ class LegalChatManager:
|
|
| 237 |
}
|
| 238 |
|
| 239 |
def _extract_response(self, result) -> str:
|
| 240 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
if isinstance(result, MultiCountryLegalState):
|
| 242 |
r = result.model_dump()
|
| 243 |
elif isinstance(result, dict):
|
|
@@ -246,6 +416,7 @@ class LegalChatManager:
|
|
| 246 |
r = {}
|
| 247 |
|
| 248 |
msgs = r.get("messages", [])
|
|
|
|
| 249 |
for m in reversed(msgs):
|
| 250 |
if (m.get("role") or "").lower() in ("assistant", "ai"):
|
| 251 |
return m.get("content", "")
|
|
@@ -278,7 +449,12 @@ class LegalChatManager:
|
|
| 278 |
logger.error(f"Session {session_id}: {error}")
|
| 279 |
|
| 280 |
def cleanup_inactive_sessions(self, max_age_hours: int = 24):
|
| 281 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
cutoff_time = datetime.now().timestamp() - (max_age_hours * 3600)
|
| 283 |
|
| 284 |
inactive_sessions = [
|
|
@@ -291,4 +467,35 @@ class LegalChatManager:
|
|
| 291 |
if session_id in self.pending_interrupts:
|
| 292 |
del self.pending_interrupts[session_id]
|
| 293 |
del self.active_sessions[session_id]
|
| 294 |
-
logger.info(f"Cleaned up inactive session: {session_id}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# [file name]: core/chat_manager.py
|
|
|
|
| 2 |
import sys
|
| 3 |
from pathlib import Path
|
| 4 |
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
|
|
| 6 |
import asyncio
|
| 7 |
import logging
|
| 8 |
from datetime import datetime
|
| 9 |
+
from typing import Dict, List, Optional, Any, Callable
|
| 10 |
from langchain_core.runnables import RunnableConfig
|
| 11 |
from langchain_core.messages import BaseMessage
|
| 12 |
from langgraph.types import Command
|
| 13 |
+
from langgraph.errors import GraphInterrupt
|
| 14 |
|
| 15 |
from config.settings import settings
|
| 16 |
from models.state_models import MultiCountryLegalState
|
|
|
|
| 19 |
logger = logging.getLogger(__name__)
|
| 20 |
|
| 21 |
class LegalChatManager:
|
| 22 |
+
"""
|
| 23 |
+
Chat manager with full human-in-the-loop interrupt support.
|
| 24 |
+
Handles both synchronous (callback-based) and asynchronous (stored interrupt) modes.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
def __init__(self, graph, checkpointer):
|
| 28 |
self.graph = graph
|
| 29 |
self.checkpointer = checkpointer
|
|
|
|
| 37 |
# Track pending interrupts by session
|
| 38 |
self.pending_interrupts = {}
|
| 39 |
|
| 40 |
+
async def chat(
|
| 41 |
+
self,
|
| 42 |
+
message: str,
|
| 43 |
+
session_id: str,
|
| 44 |
+
legal_context: Optional[Dict[str, str]] = None,
|
| 45 |
+
interrupt_handler: Optional[Callable[[Dict], str]] = None
|
| 46 |
+
) -> str:
|
| 47 |
+
"""
|
| 48 |
+
Process a chat message with session management and interrupt handling.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
message: User message to process
|
| 52 |
+
session_id: Unique session identifier for conversation tracking
|
| 53 |
+
legal_context: Optional legal context (jurisdiction, user type, etc.)
|
| 54 |
+
interrupt_handler: Optional callback for handling interrupts synchronously.
|
| 55 |
+
If provided, interrupts are handled immediately within this call.
|
| 56 |
+
If None, interrupts are stored for later resolution via subsequent calls.
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
Assistant's response text
|
| 60 |
+
|
| 61 |
+
Raises:
|
| 62 |
+
RuntimeError: If system is not initialized
|
| 63 |
+
"""
|
| 64 |
if not self.graph:
|
| 65 |
raise RuntimeError("System not initialized. Call setup_system() first.")
|
| 66 |
|
|
|
|
| 82 |
# Track performance
|
| 83 |
start_time = datetime.now()
|
| 84 |
|
| 85 |
+
# 🔥 CRITICAL: Use streaming to detect interrupts in real-time
|
| 86 |
+
final_response = None
|
| 87 |
|
| 88 |
+
async for chunk in self.graph.astream(
|
| 89 |
+
MultiCountryLegalState(**input_state),
|
| 90 |
+
config,
|
| 91 |
+
stream_mode="updates"
|
| 92 |
+
):
|
| 93 |
+
# 🔥 Check if this chunk contains an interrupt
|
| 94 |
+
if "__interrupt__" in chunk:
|
| 95 |
+
interrupt_data = chunk["__interrupt__"]
|
| 96 |
+
logger.info(f"⏸️ Graph interrupted: {interrupt_data}")
|
| 97 |
+
|
| 98 |
+
# Extract interrupt info - handle tuple/Interrupt object format
|
| 99 |
+
# LangGraph returns interrupts as tuples containing Interrupt objects
|
| 100 |
+
if isinstance(interrupt_data, (list, tuple)):
|
| 101 |
+
interrupt_info = interrupt_data[0]
|
| 102 |
+
else:
|
| 103 |
+
interrupt_info = interrupt_data
|
| 104 |
+
|
| 105 |
+
# Handle Interrupt object (has .value attribute)
|
| 106 |
+
if hasattr(interrupt_info, 'value'):
|
| 107 |
+
interrupt_value = interrupt_info.value
|
| 108 |
+
elif isinstance(interrupt_info, dict):
|
| 109 |
+
interrupt_value = interrupt_info.get("value", interrupt_info)
|
| 110 |
+
else:
|
| 111 |
+
interrupt_value = {}
|
| 112 |
+
|
| 113 |
+
# Extract message from interrupt value
|
| 114 |
+
interrupt_message = interrupt_value.get("message", "") if isinstance(interrupt_value, dict) else ""
|
| 115 |
+
|
| 116 |
+
# 🔥 Two modes of operation:
|
| 117 |
+
# 1. Synchronous: If interrupt_handler provided, handle immediately
|
| 118 |
+
# 2. Asynchronous: Store interrupt and return, wait for next call
|
| 119 |
+
|
| 120 |
+
if interrupt_handler:
|
| 121 |
+
# SYNCHRONOUS MODE: Handle interrupt immediately
|
| 122 |
+
logger.info("📞 Calling synchronous interrupt handler")
|
| 123 |
+
moderator_response = interrupt_handler(interrupt_value)
|
| 124 |
+
|
| 125 |
+
# Resume immediately with the moderator's response
|
| 126 |
+
logger.info(f"🔄 Resuming graph with: {moderator_response}")
|
| 127 |
+
async for resume_chunk in self.graph.astream(
|
| 128 |
+
Command(resume=moderator_response),
|
| 129 |
+
config,
|
| 130 |
+
stream_mode="updates"
|
| 131 |
+
):
|
| 132 |
+
# Continue processing resumed chunks
|
| 133 |
+
for node_name, node_output in resume_chunk.items():
|
| 134 |
+
if node_name != "__interrupt__":
|
| 135 |
+
logger.debug(f"📦 Resume chunk from {node_name}")
|
| 136 |
+
|
| 137 |
+
# After resume completes, get final state
|
| 138 |
+
state = await self.graph.aget_state(config)
|
| 139 |
+
final_response = self._extract_response(state.values)
|
| 140 |
+
break
|
| 141 |
+
else:
|
| 142 |
+
# ASYNCHRONOUS MODE: Store interrupt and return
|
| 143 |
+
logger.info("💾 Storing interrupt for later resolution")
|
| 144 |
+
self.pending_interrupts[session_id] = {
|
| 145 |
+
"type": "human_approval",
|
| 146 |
+
"config": config,
|
| 147 |
+
"created_at": datetime.now(),
|
| 148 |
+
"interrupt_data": interrupt_info
|
| 149 |
+
}
|
| 150 |
+
return interrupt_message or self._get_default_approval_prompt()
|
| 151 |
+
|
| 152 |
+
# Process normal chunks (non-interrupt)
|
| 153 |
+
for node_name, node_output in chunk.items():
|
| 154 |
+
if node_name != "__interrupt__":
|
| 155 |
+
logger.debug(f"📦 Chunk from {node_name}")
|
| 156 |
+
|
| 157 |
+
# If no interrupt occurred, get final state
|
| 158 |
+
if final_response is None:
|
| 159 |
+
state = await self.graph.aget_state(config)
|
| 160 |
+
final_response = self._extract_response(state.values)
|
| 161 |
|
| 162 |
# Track performance
|
| 163 |
processing_time = (datetime.now() - start_time).total_seconds()
|
| 164 |
self._update_session_stats(session_id, processing_time)
|
| 165 |
+
self._update_routing_stats(final_response)
|
| 166 |
|
| 167 |
+
return final_response
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
|
| 169 |
except Exception as e:
|
| 170 |
logger.exception(f"Chat error for session {session_id}")
|
|
|
|
| 172 |
return f"Erreur lors du traitement: {str(e)}"
|
| 173 |
|
| 174 |
async def _handle_pending_interrupt(self, session_id: str, message: str) -> str:
|
| 175 |
+
"""
|
| 176 |
+
Handle user response to a pending interrupt using Command(resume=...).
|
| 177 |
+
This is called when there's a stored interrupt waiting for resolution.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
session_id: Session with pending interrupt
|
| 181 |
+
message: User's response (e.g., "approve" or "reject")
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
Final response after resuming from interrupt
|
| 185 |
+
"""
|
| 186 |
interrupt_data = self.pending_interrupts.get(session_id)
|
| 187 |
if not interrupt_data:
|
| 188 |
return "Erreur: Aucune interruption en attente."
|
| 189 |
|
| 190 |
try:
|
| 191 |
+
logger.info(f"🔥 Resuming graph with moderator decision: {message}")
|
| 192 |
|
| 193 |
config = interrupt_data["config"]
|
| 194 |
|
| 195 |
+
# Use streaming to handle potential nested interrupts
|
| 196 |
+
final_response = None
|
| 197 |
+
async for chunk in self.graph.astream(
|
| 198 |
Command(resume=message),
|
| 199 |
+
config,
|
| 200 |
+
stream_mode="updates"
|
| 201 |
+
):
|
| 202 |
+
if "__interrupt__" in chunk:
|
| 203 |
+
# Another interrupt occurred during resume - store it
|
| 204 |
+
new_interrupt = chunk["__interrupt__"]
|
| 205 |
+
|
| 206 |
+
# Handle tuple/list format
|
| 207 |
+
if isinstance(new_interrupt, (list, tuple)):
|
| 208 |
+
interrupt_info = new_interrupt[0]
|
| 209 |
+
else:
|
| 210 |
+
interrupt_info = new_interrupt
|
| 211 |
+
|
| 212 |
+
# Extract value from Interrupt object
|
| 213 |
+
if hasattr(interrupt_info, 'value'):
|
| 214 |
+
interrupt_value = interrupt_info.value
|
| 215 |
+
elif isinstance(interrupt_info, dict):
|
| 216 |
+
interrupt_value = interrupt_info.get("value", interrupt_info)
|
| 217 |
+
else:
|
| 218 |
+
interrupt_value = {}
|
| 219 |
+
|
| 220 |
+
# Store the new interrupt
|
| 221 |
+
self.pending_interrupts[session_id] = {
|
| 222 |
+
"type": "human_approval",
|
| 223 |
+
"config": config,
|
| 224 |
+
"created_at": datetime.now(),
|
| 225 |
+
"interrupt_data": interrupt_info
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
interrupt_message = interrupt_value.get("message", "") if isinstance(interrupt_value, dict) else ""
|
| 229 |
+
return interrupt_message or self._get_default_approval_prompt()
|
| 230 |
+
|
| 231 |
+
# Process normal chunks
|
| 232 |
+
for node_name, node_output in chunk.items():
|
| 233 |
+
if node_name != "__interrupt__":
|
| 234 |
+
logger.debug(f"📦 Resume chunk from {node_name}")
|
| 235 |
+
|
| 236 |
+
# Get final state after successful resume
|
| 237 |
+
state = await self.graph.aget_state(config)
|
| 238 |
+
final_response = self._extract_response(state.values)
|
| 239 |
|
| 240 |
# Clean up the pending interrupt
|
| 241 |
del self.pending_interrupts[session_id]
|
| 242 |
|
| 243 |
+
# Update stats
|
| 244 |
+
self._update_routing_stats(final_response)
|
|
|
|
| 245 |
|
| 246 |
logger.info(f"✅ Graph resumed successfully for session {session_id}")
|
| 247 |
+
return final_response
|
| 248 |
|
| 249 |
except Exception as e:
|
| 250 |
logger.error(f"Error resuming from interrupt: {str(e)}")
|
|
|
|
| 253 |
del self.pending_interrupts[session_id]
|
| 254 |
return f"Erreur lors du traitement de la décision: {str(e)}"
|
| 255 |
|
| 256 |
+
def _get_default_approval_prompt(self) -> str:
|
| 257 |
+
"""Default approval prompt if interrupt message extraction fails"""
|
| 258 |
+
return """
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
🔒 **APPROBATION HUMAINE REQUISE**
|
| 260 |
|
| 261 |
+
Une demande d'assistance juridique nécessite votre approbation.
|
|
|
|
|
|
|
| 262 |
|
| 263 |
**Veuillez répondre avec:**
|
| 264 |
- "approve [raison]" pour approuver la demande
|
|
|
|
| 271 |
**Votre décision:**
|
| 272 |
"""
|
| 273 |
|
| 274 |
+
def get_checkpointer_info(self) -> Dict[str, Any]:
|
| 275 |
+
"""Get information about the current checkpointer type"""
|
| 276 |
+
checkpointer_type = "unknown"
|
| 277 |
+
if hasattr(self.checkpointer, '__class__'):
|
| 278 |
+
class_name = self.checkpointer.__class__.__name__
|
| 279 |
+
if 'PostgresSaver' in class_name:
|
| 280 |
+
checkpointer_type = "postgres"
|
| 281 |
+
elif 'InMemorySaver' in class_name:
|
| 282 |
+
checkpointer_type = "memory"
|
| 283 |
+
|
| 284 |
+
return {
|
| 285 |
+
"type": checkpointer_type,
|
| 286 |
+
"persistent": checkpointer_type == "postgres",
|
| 287 |
+
"description": "Persistent storage" if checkpointer_type == "postgres" else "In-memory (volatile)"
|
| 288 |
+
}
|
| 289 |
|
| 290 |
async def get_conversation_history(self, session_id: str) -> List[BaseMessage]:
|
| 291 |
+
"""
|
| 292 |
+
Get conversation history for a session.
|
| 293 |
+
|
| 294 |
+
Args:
|
| 295 |
+
session_id: Session identifier
|
| 296 |
+
|
| 297 |
+
Returns:
|
| 298 |
+
List of message objects from conversation history
|
| 299 |
+
"""
|
| 300 |
if not self.graph:
|
| 301 |
return []
|
| 302 |
|
|
|
|
| 322 |
logger.exception(f"Error getting conversation history for session {session_id}")
|
| 323 |
return []
|
| 324 |
|
| 325 |
+
def get_session_stats(self, session_id: str) -> Dict[str, Any]:
|
| 326 |
"""Get statistics for a specific session"""
|
| 327 |
return self.active_sessions.get(session_id, {})
|
| 328 |
|
| 329 |
+
def get_global_stats(self) -> Dict[str, Any]:
|
| 330 |
+
"""
|
| 331 |
+
Get global system statistics.
|
| 332 |
+
|
| 333 |
+
Returns:
|
| 334 |
+
Dictionary with routing stats, active sessions, and storage info
|
| 335 |
+
"""
|
| 336 |
+
stats = {
|
| 337 |
"routing_stats": self.routing_stats,
|
| 338 |
"active_sessions": len(self.active_sessions),
|
| 339 |
"total_queries": self.routing_stats["total_queries"],
|
| 340 |
"pending_interrupts": len(self.pending_interrupts)
|
| 341 |
}
|
| 342 |
+
|
| 343 |
+
# Add checkpointer info
|
| 344 |
+
stats.update(self.get_checkpointer_info())
|
| 345 |
+
return stats
|
| 346 |
|
| 347 |
def _initialize_session(self, session_id: str):
|
| 348 |
"""Initialize or update session tracking"""
|
|
|
|
| 360 |
session_info["query_count"] += 1
|
| 361 |
session_info["last_activity"] = datetime.now()
|
| 362 |
|
| 363 |
+
def _prepare_input_state(
|
| 364 |
+
self,
|
| 365 |
+
message: str,
|
| 366 |
+
session_id: str,
|
| 367 |
+
legal_context: Optional[Dict[str, str]]
|
| 368 |
+
) -> Dict[str, Any]:
|
| 369 |
+
"""
|
| 370 |
+
Prepare input state for graph processing.
|
| 371 |
+
|
| 372 |
+
Args:
|
| 373 |
+
message: User message
|
| 374 |
+
session_id: Session identifier
|
| 375 |
+
legal_context: Optional legal context
|
| 376 |
+
|
| 377 |
+
Returns:
|
| 378 |
+
Dictionary with complete input state for graph
|
| 379 |
+
"""
|
| 380 |
ctx = legal_context or {
|
| 381 |
"jurisdiction": "Unknown",
|
| 382 |
"user_type": "general",
|
|
|
|
| 399 |
}
|
| 400 |
|
| 401 |
def _extract_response(self, result) -> str:
|
| 402 |
+
"""
|
| 403 |
+
Extract response text from graph result.
|
| 404 |
+
|
| 405 |
+
Args:
|
| 406 |
+
result: Graph execution result (state or dict)
|
| 407 |
+
|
| 408 |
+
Returns:
|
| 409 |
+
Assistant's response text
|
| 410 |
+
"""
|
| 411 |
if isinstance(result, MultiCountryLegalState):
|
| 412 |
r = result.model_dump()
|
| 413 |
elif isinstance(result, dict):
|
|
|
|
| 416 |
r = {}
|
| 417 |
|
| 418 |
msgs = r.get("messages", [])
|
| 419 |
+
# Find the last assistant message
|
| 420 |
for m in reversed(msgs):
|
| 421 |
if (m.get("role") or "").lower() in ("assistant", "ai"):
|
| 422 |
return m.get("content", "")
|
|
|
|
| 449 |
logger.error(f"Session {session_id}: {error}")
|
| 450 |
|
| 451 |
def cleanup_inactive_sessions(self, max_age_hours: int = 24):
|
| 452 |
+
"""
|
| 453 |
+
Clean up sessions that have been inactive for too long.
|
| 454 |
+
|
| 455 |
+
Args:
|
| 456 |
+
max_age_hours: Maximum age in hours before cleanup
|
| 457 |
+
"""
|
| 458 |
cutoff_time = datetime.now().timestamp() - (max_age_hours * 3600)
|
| 459 |
|
| 460 |
inactive_sessions = [
|
|
|
|
| 467 |
if session_id in self.pending_interrupts:
|
| 468 |
del self.pending_interrupts[session_id]
|
| 469 |
del self.active_sessions[session_id]
|
| 470 |
+
logger.info(f"Cleaned up inactive session: {session_id}")
|
| 471 |
+
|
| 472 |
+
def has_pending_interrupt(self, session_id: str) -> bool:
|
| 473 |
+
"""
|
| 474 |
+
Check if there's a pending interrupt for a session.
|
| 475 |
+
|
| 476 |
+
Args:
|
| 477 |
+
session_id: Session identifier
|
| 478 |
+
|
| 479 |
+
Returns:
|
| 480 |
+
True if session has pending interrupt, False otherwise
|
| 481 |
+
"""
|
| 482 |
+
return session_id in self.pending_interrupts
|
| 483 |
+
|
| 484 |
+
def get_system_info(self) -> Dict[str, Any]:
|
| 485 |
+
"""
|
| 486 |
+
Get comprehensive system information.
|
| 487 |
+
|
| 488 |
+
Returns:
|
| 489 |
+
Dictionary with system status, storage info, and statistics
|
| 490 |
+
"""
|
| 491 |
+
return {
|
| 492 |
+
"system": {
|
| 493 |
+
"initialized": self.graph is not None,
|
| 494 |
+
"active_sessions": len(self.active_sessions),
|
| 495 |
+
"pending_interrupts": len(self.pending_interrupts),
|
| 496 |
+
"total_queries": self.routing_stats["total_queries"]
|
| 497 |
+
},
|
| 498 |
+
"storage": self.get_checkpointer_info(),
|
| 499 |
+
"routing": self.routing_stats,
|
| 500 |
+
"timestamp": datetime.now().isoformat()
|
| 501 |
+
}
|
core/graph_builder.py
CHANGED
|
@@ -6,8 +6,9 @@ sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
| 6 |
|
| 7 |
from langgraph.graph import StateGraph, START, END
|
| 8 |
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
|
|
|
| 9 |
import logging
|
| 10 |
-
from typing import Dict, List, Any
|
| 11 |
from langchain_core.runnables import RunnableConfig
|
| 12 |
|
| 13 |
from models.state_models import MultiCountryLegalState
|
|
@@ -31,7 +32,7 @@ class GraphBuilder:
|
|
| 31 |
self,
|
| 32 |
router: CountryRouter,
|
| 33 |
llm,
|
| 34 |
-
checkpointer: AsyncPostgresSaver,
|
| 35 |
# Country retrievers as a dictionary for easy extension
|
| 36 |
country_retrievers: Dict[str, LegalRetriever] = None
|
| 37 |
):
|
|
@@ -57,7 +58,19 @@ class GraphBuilder:
|
|
| 57 |
self.response_nodes = ResponseNodes(llm)
|
| 58 |
self.helper_nodes = HelperNodes(llm)
|
| 59 |
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
def add_country(self, country_code: str, retriever: LegalRetriever):
|
| 63 |
"""Dynamically add a new country to the system"""
|
|
@@ -166,7 +179,8 @@ class GraphBuilder:
|
|
| 166 |
|
| 167 |
workflow.add_edge("process_assistance", "response")
|
| 168 |
|
| 169 |
-
|
|
|
|
| 170 |
return workflow
|
| 171 |
|
| 172 |
def _create_assistance_collect_wrapper(self):
|
|
|
|
| 6 |
|
| 7 |
from langgraph.graph import StateGraph, START, END
|
| 8 |
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
| 9 |
+
from langgraph.checkpoint.memory import InMemorySaver
|
| 10 |
import logging
|
| 11 |
+
from typing import Dict, List, Any, Union
|
| 12 |
from langchain_core.runnables import RunnableConfig
|
| 13 |
|
| 14 |
from models.state_models import MultiCountryLegalState
|
|
|
|
| 32 |
self,
|
| 33 |
router: CountryRouter,
|
| 34 |
llm,
|
| 35 |
+
checkpointer: Union[AsyncPostgresSaver, InMemorySaver],
|
| 36 |
# Country retrievers as a dictionary for easy extension
|
| 37 |
country_retrievers: Dict[str, LegalRetriever] = None
|
| 38 |
):
|
|
|
|
| 58 |
self.response_nodes = ResponseNodes(llm)
|
| 59 |
self.helper_nodes = HelperNodes(llm)
|
| 60 |
|
| 61 |
+
# Log checkpointer type
|
| 62 |
+
checkpointer_type = self._get_checkpointer_type()
|
| 63 |
+
logger.info(f"GraphBuilder initialized with {checkpointer_type} checkpointer and countries: {list(self.country_retrievers.keys())}")
|
| 64 |
+
|
| 65 |
+
def _get_checkpointer_type(self) -> str:
|
| 66 |
+
"""Determine the type of checkpointer being used"""
|
| 67 |
+
if hasattr(self.checkpointer, '__class__'):
|
| 68 |
+
class_name = self.checkpointer.__class__.__name__
|
| 69 |
+
if 'PostgresSaver' in class_name:
|
| 70 |
+
return "PostgreSQL"
|
| 71 |
+
elif 'InMemorySaver' in class_name:
|
| 72 |
+
return "in-memory"
|
| 73 |
+
return "unknown"
|
| 74 |
|
| 75 |
def add_country(self, country_code: str, retriever: LegalRetriever):
|
| 76 |
"""Dynamically add a new country to the system"""
|
|
|
|
| 179 |
|
| 180 |
workflow.add_edge("process_assistance", "response")
|
| 181 |
|
| 182 |
+
checkpointer_type = self._get_checkpointer_type()
|
| 183 |
+
logger.info(f"Scalable graph built with {checkpointer_type} checkpointer for {len(self.country_retrievers)} countries: {list(self.country_retrievers.keys())}")
|
| 184 |
return workflow
|
| 185 |
|
| 186 |
def _create_assistance_collect_wrapper(self):
|
core/human_approval_node.py
CHANGED
|
@@ -1,9 +1,4 @@
|
|
| 1 |
# core/human_approval_node.py
|
| 2 |
-
# Add this as the FIRST lines of code (after docstrings)
|
| 3 |
-
import sys
|
| 4 |
-
from pathlib import Path
|
| 5 |
-
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 6 |
-
|
| 7 |
import logging
|
| 8 |
from typing import Literal
|
| 9 |
from langchain_core.runnables import RunnableConfig
|
|
@@ -22,68 +17,59 @@ class HumanApprovalNode:
|
|
| 22 |
self,
|
| 23 |
state: MultiCountryLegalState,
|
| 24 |
config: RunnableConfig
|
| 25 |
-
) -> Command[Literal["response"]]:
|
| 26 |
-
"""Process human approval with interrupt"""
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
return Command(
|
| 32 |
-
goto="response",
|
| 33 |
-
update={
|
| 34 |
-
"messages": [{
|
| 35 |
-
"role": "assistant",
|
| 36 |
-
"content": "❌ Données incomplètes pour l'approbation.",
|
| 37 |
-
"meta": {}
|
| 38 |
-
}]
|
| 39 |
-
}
|
| 40 |
-
)
|
| 41 |
-
|
| 42 |
-
logger.info(f"🔒 Human approval node triggered for {state.user_email}")
|
| 43 |
-
|
| 44 |
-
# Prepare interrupt message
|
| 45 |
-
interrupt_message = self._format_approval_request(state)
|
| 46 |
-
|
| 47 |
-
# Trigger interrupt and wait for human input
|
| 48 |
-
moderator_input = interrupt({
|
| 49 |
-
"type": "human_approval",
|
| 50 |
-
"user_email": state.user_email,
|
| 51 |
-
"country": self._get_country_display(state),
|
| 52 |
-
"description": state.assistance_description,
|
| 53 |
-
"message": interrupt_message
|
| 54 |
-
})
|
| 55 |
-
|
| 56 |
-
logger.info(f"📥 Received moderator input: {moderator_input}")
|
| 57 |
-
|
| 58 |
-
# Parse moderator decision
|
| 59 |
-
decision = self._parse_decision(moderator_input)
|
| 60 |
-
|
| 61 |
-
# Handle approval
|
| 62 |
-
if decision["approved"]:
|
| 63 |
-
return await self._handle_approval(state, decision)
|
| 64 |
-
else:
|
| 65 |
-
return await self._handle_rejection(state, decision)
|
| 66 |
-
|
| 67 |
-
except Exception as e:
|
| 68 |
-
logger.error(f"Error in approval node: {str(e)}", exc_info=True)
|
| 69 |
return Command(
|
| 70 |
goto="response",
|
| 71 |
update={
|
| 72 |
"approval_status": "error",
|
|
|
|
| 73 |
"messages": [{
|
| 74 |
"role": "assistant",
|
| 75 |
-
"content":
|
| 76 |
"meta": {}
|
| 77 |
}]
|
| 78 |
}
|
| 79 |
)
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
async def _handle_approval(
|
| 82 |
self,
|
| 83 |
state: MultiCountryLegalState,
|
| 84 |
decision: dict
|
| 85 |
-
) -> Command[Literal["
|
| 86 |
-
"""Handle approved request
|
| 87 |
logger.info(f"✅ Request APPROVED for {state.user_email}")
|
| 88 |
|
| 89 |
# Send email
|
|
@@ -95,38 +81,15 @@ class HumanApprovalNode:
|
|
| 95 |
)
|
| 96 |
logger.info(f"✅ Emails envoyés avec succès pour {state.user_email}")
|
| 97 |
|
| 98 |
-
# Build success message
|
| 99 |
-
if email_result.get("success"):
|
| 100 |
-
message_content = f"""✅ **DEMANDE APPROUVÉE ET ENVOYÉE**
|
| 101 |
-
|
| 102 |
-
📧 Un email de confirmation a été envoyé à: {state.user_email}
|
| 103 |
-
👨⚖️ Notre équipe juridique vous contactera sous 24-48 heures.
|
| 104 |
-
|
| 105 |
-
**Raison de l'approbation:** {decision['reason']}
|
| 106 |
-
**Approuvé par:** {decision['moderator_id']}
|
| 107 |
-
"""
|
| 108 |
-
else:
|
| 109 |
-
message_content = f"""⚠️ **DEMANDE APPROUVÉE MAIS ERREUR D'ENVOI**
|
| 110 |
-
|
| 111 |
-
La demande a été approuvée mais l'envoi d'email a échoué.
|
| 112 |
-
**Erreur:** {email_result.get('error', 'Unknown')}
|
| 113 |
-
|
| 114 |
-
Veuillez contacter directement: fitahiana@acfai.org
|
| 115 |
-
"""
|
| 116 |
-
|
| 117 |
return Command(
|
| 118 |
-
goto="
|
| 119 |
update={
|
| 120 |
"approval_status": "approved",
|
| 121 |
"approval_reason": decision["reason"],
|
| 122 |
"approved_by": decision["moderator_id"],
|
| 123 |
"approval_timestamp": datetime.now().isoformat(),
|
| 124 |
"email_status": "sent" if email_result.get("success") else "error",
|
| 125 |
-
"
|
| 126 |
-
"role": "assistant",
|
| 127 |
-
"content": message_content,
|
| 128 |
-
"meta": {"approval": "approved"}
|
| 129 |
-
}]
|
| 130 |
}
|
| 131 |
)
|
| 132 |
|
|
@@ -134,8 +97,8 @@ Veuillez contacter directement: fitahiana@acfai.org
|
|
| 134 |
self,
|
| 135 |
state: MultiCountryLegalState,
|
| 136 |
decision: dict
|
| 137 |
-
) -> Command[Literal["response"]]:
|
| 138 |
-
"""Handle rejected request"""
|
| 139 |
logger.info(f"❌ Request REJECTED for {state.user_email}")
|
| 140 |
|
| 141 |
message_content = f"""❌ **DEMANDE REFUSÉE**
|
|
@@ -154,6 +117,7 @@ Si vous pensez qu'il s'agit d'une erreur, veuillez reformuler votre demande avec
|
|
| 154 |
"approval_reason": decision["reason"],
|
| 155 |
"approved_by": decision["moderator_id"],
|
| 156 |
"approval_timestamp": datetime.now().isoformat(),
|
|
|
|
| 157 |
"messages": [{
|
| 158 |
"role": "assistant",
|
| 159 |
"content": message_content,
|
|
|
|
| 1 |
# core/human_approval_node.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import logging
|
| 3 |
from typing import Literal
|
| 4 |
from langchain_core.runnables import RunnableConfig
|
|
|
|
| 17 |
self,
|
| 18 |
state: MultiCountryLegalState,
|
| 19 |
config: RunnableConfig
|
| 20 |
+
) -> Command[Literal["process_assistance", "response"]]:
|
| 21 |
+
"""Process human approval with interrupt - uses Command(goto=...) pattern"""
|
| 22 |
+
|
| 23 |
+
# Validate required fields BEFORE interrupt
|
| 24 |
+
if not state.user_email or not state.assistance_description:
|
| 25 |
+
logger.warning("Missing required fields for approval")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
return Command(
|
| 27 |
goto="response",
|
| 28 |
update={
|
| 29 |
"approval_status": "error",
|
| 30 |
+
"approval_reason": "Données incomplètes",
|
| 31 |
"messages": [{
|
| 32 |
"role": "assistant",
|
| 33 |
+
"content": "❌ Données incomplètes pour l'approbation.",
|
| 34 |
"meta": {}
|
| 35 |
}]
|
| 36 |
}
|
| 37 |
)
|
| 38 |
+
|
| 39 |
+
logger.info(f"🔒 Human approval node triggered for {state.user_email}")
|
| 40 |
+
|
| 41 |
+
# Prepare interrupt message
|
| 42 |
+
interrupt_message = self._format_approval_request(state)
|
| 43 |
+
|
| 44 |
+
# 🔥 CRITICAL: DO NOT wrap interrupt() in try-except!
|
| 45 |
+
# Let GraphInterrupt propagate naturally to the graph executor
|
| 46 |
+
moderator_input = interrupt({
|
| 47 |
+
"type": "human_approval",
|
| 48 |
+
"user_email": state.user_email,
|
| 49 |
+
"country": self._get_country_display(state),
|
| 50 |
+
"description": state.assistance_description,
|
| 51 |
+
"message": interrupt_message
|
| 52 |
+
})
|
| 53 |
+
|
| 54 |
+
# 🎯 Code below ONLY executes after graph resumes from interrupt
|
| 55 |
+
logger.info(f"🔥 Received moderator input: {moderator_input}")
|
| 56 |
+
|
| 57 |
+
# Parse moderator decision
|
| 58 |
+
decision = self._parse_decision(moderator_input)
|
| 59 |
+
|
| 60 |
+
# Handle approval or rejection
|
| 61 |
+
if decision["approved"]:
|
| 62 |
+
return await self._handle_approval(state, decision)
|
| 63 |
+
else:
|
| 64 |
+
return await self._handle_rejection(state, decision)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
async def _handle_approval(
|
| 68 |
self,
|
| 69 |
state: MultiCountryLegalState,
|
| 70 |
decision: dict
|
| 71 |
+
) -> Command[Literal["process_assistance"]]:
|
| 72 |
+
"""Handle approved request - routes to process_assistance"""
|
| 73 |
logger.info(f"✅ Request APPROVED for {state.user_email}")
|
| 74 |
|
| 75 |
# Send email
|
|
|
|
| 81 |
)
|
| 82 |
logger.info(f"✅ Emails envoyés avec succès pour {state.user_email}")
|
| 83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
return Command(
|
| 85 |
+
goto="process_assistance",
|
| 86 |
update={
|
| 87 |
"approval_status": "approved",
|
| 88 |
"approval_reason": decision["reason"],
|
| 89 |
"approved_by": decision["moderator_id"],
|
| 90 |
"approval_timestamp": datetime.now().isoformat(),
|
| 91 |
"email_status": "sent" if email_result.get("success") else "error",
|
| 92 |
+
"email_result": email_result
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
}
|
| 94 |
)
|
| 95 |
|
|
|
|
| 97 |
self,
|
| 98 |
state: MultiCountryLegalState,
|
| 99 |
decision: dict
|
| 100 |
+
) -> Command[Literal["response"]]:
|
| 101 |
+
"""Handle rejected request - routes directly to response"""
|
| 102 |
logger.info(f"❌ Request REJECTED for {state.user_email}")
|
| 103 |
|
| 104 |
message_content = f"""❌ **DEMANDE REFUSÉE**
|
|
|
|
| 117 |
"approval_reason": decision["reason"],
|
| 118 |
"approved_by": decision["moderator_id"],
|
| 119 |
"approval_timestamp": datetime.now().isoformat(),
|
| 120 |
+
"assistance_step": "completed",
|
| 121 |
"messages": [{
|
| 122 |
"role": "assistant",
|
| 123 |
"content": message_content,
|
core/system_initializer.py
CHANGED
|
@@ -6,6 +6,7 @@ sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
| 6 |
|
| 7 |
import logging
|
| 8 |
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
|
|
|
| 9 |
|
| 10 |
from core.graph_builder import GraphBuilder
|
| 11 |
from core.chat_manager import LegalChatManager
|
|
@@ -13,13 +14,12 @@ from core.router import CountryRouter
|
|
| 13 |
from database.mongodb_client import MongoDBClient
|
| 14 |
from database.postgres_checkpointer import PostgresCheckpointer
|
| 15 |
from langchain_openai import ChatOpenAI
|
| 16 |
-
from config import settings
|
| 17 |
-
|
| 18 |
|
| 19 |
logger = logging.getLogger(__name__)
|
| 20 |
|
| 21 |
async def setup_system():
|
| 22 |
-
"""Initialize the legal assistant system
|
| 23 |
|
| 24 |
try:
|
| 25 |
# 1. Initialize MongoDB using your existing class
|
|
@@ -55,32 +55,8 @@ async def setup_system():
|
|
| 55 |
|
| 56 |
router = CountryRouter()
|
| 57 |
|
| 58 |
-
# 5. Initialize
|
| 59 |
-
|
| 60 |
-
database_url = getattr(settings, 'DATABASE_URL', None)
|
| 61 |
-
|
| 62 |
-
if not database_url:
|
| 63 |
-
# Try alternative setting names
|
| 64 |
-
database_url = getattr(settings, 'POSTGRES_URL', None) or \
|
| 65 |
-
getattr(settings, 'POSTGRESQL_URL', None) or \
|
| 66 |
-
getattr(settings, 'DB_URL', None)
|
| 67 |
-
|
| 68 |
-
if not database_url:
|
| 69 |
-
raise Exception("No database URL found in settings")
|
| 70 |
-
|
| 71 |
-
logger.info(f"🔗 Using database URL: {database_url.split('@')[-1] if '@' in database_url else 'local'}") # Log safely
|
| 72 |
-
|
| 73 |
-
postgres_checkpointer = PostgresCheckpointer(
|
| 74 |
-
database_url=database_url, # Use actual database URL
|
| 75 |
-
max_connections=10,
|
| 76 |
-
min_connections=2
|
| 77 |
-
)
|
| 78 |
-
|
| 79 |
-
if not await postgres_checkpointer.initialize():
|
| 80 |
-
raise Exception("PostgreSQL checkpointer initialization failed")
|
| 81 |
-
|
| 82 |
-
checkpointer = postgres_checkpointer.get_checkpointer()
|
| 83 |
-
logger.info("✅ PostgreSQL checkpointer initialized for API")
|
| 84 |
|
| 85 |
# 6. Build graph
|
| 86 |
graph_builder = GraphBuilder(
|
|
@@ -106,4 +82,60 @@ async def setup_system():
|
|
| 106 |
|
| 107 |
except Exception as e:
|
| 108 |
logger.error(f"❌ Failed to initialize system: {e}")
|
| 109 |
-
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
import logging
|
| 8 |
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
| 9 |
+
from langgraph.checkpoint.memory import InMemorySaver
|
| 10 |
|
| 11 |
from core.graph_builder import GraphBuilder
|
| 12 |
from core.chat_manager import LegalChatManager
|
|
|
|
| 14 |
from database.mongodb_client import MongoDBClient
|
| 15 |
from database.postgres_checkpointer import PostgresCheckpointer
|
| 16 |
from langchain_openai import ChatOpenAI
|
| 17 |
+
from config import settings
|
|
|
|
| 18 |
|
| 19 |
logger = logging.getLogger(__name__)
|
| 20 |
|
| 21 |
async def setup_system():
|
| 22 |
+
"""Initialize the legal assistant system with fallback to in-memory checkpointer"""
|
| 23 |
|
| 24 |
try:
|
| 25 |
# 1. Initialize MongoDB using your existing class
|
|
|
|
| 55 |
|
| 56 |
router = CountryRouter()
|
| 57 |
|
| 58 |
+
# 5. Initialize checkpointer with fallback logic
|
| 59 |
+
checkpointer = await _initialize_checkpointer_with_fallback()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
# 6. Build graph
|
| 62 |
graph_builder = GraphBuilder(
|
|
|
|
| 82 |
|
| 83 |
except Exception as e:
|
| 84 |
logger.error(f"❌ Failed to initialize system: {e}")
|
| 85 |
+
raise
|
| 86 |
+
|
| 87 |
+
async def _initialize_checkpointer_with_fallback():
|
| 88 |
+
"""Initialize checkpointer with fallback to in-memory if PostgreSQL fails"""
|
| 89 |
+
|
| 90 |
+
# First, try to initialize PostgreSQL checkpointer
|
| 91 |
+
postgres_checkpointer = None
|
| 92 |
+
database_url = getattr(settings, 'DATABASE_URL', None)
|
| 93 |
+
|
| 94 |
+
if not database_url:
|
| 95 |
+
# Try alternative setting names
|
| 96 |
+
database_url = getattr(settings, 'POSTGRES_URL', None) or \
|
| 97 |
+
getattr(settings, 'POSTGRESQL_URL', None) or \
|
| 98 |
+
getattr(settings, 'DB_URL', None)
|
| 99 |
+
|
| 100 |
+
if database_url:
|
| 101 |
+
try:
|
| 102 |
+
logger.info(f"🔗 Attempting PostgreSQL connection: {database_url.split('@')[-1] if '@' in database_url else 'local'}")
|
| 103 |
+
|
| 104 |
+
postgres_checkpointer = PostgresCheckpointer(
|
| 105 |
+
database_url=database_url,
|
| 106 |
+
max_connections=10,
|
| 107 |
+
min_connections=2
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
if await postgres_checkpointer.initialize():
|
| 111 |
+
checkpointer = postgres_checkpointer.get_checkpointer()
|
| 112 |
+
logger.info("✅ PostgreSQL checkpointer initialized successfully")
|
| 113 |
+
return checkpointer
|
| 114 |
+
else:
|
| 115 |
+
logger.warning("❌ PostgreSQL checkpointer initialization failed, will fall back to in-memory")
|
| 116 |
+
|
| 117 |
+
except Exception as e:
|
| 118 |
+
logger.warning(f"❌ PostgreSQL connection failed: {e}, falling back to in-memory checkpointer")
|
| 119 |
+
|
| 120 |
+
else:
|
| 121 |
+
logger.warning("❌ No database URL found in settings, using in-memory checkpointer")
|
| 122 |
+
|
| 123 |
+
# Fall back to in-memory checkpointer
|
| 124 |
+
try:
|
| 125 |
+
checkpointer = InMemorySaver()
|
| 126 |
+
logger.info("✅ In-memory checkpointer initialized as fallback")
|
| 127 |
+
logger.warning("⚠️ Using in-memory checkpointer - conversation history will not persist across restarts")
|
| 128 |
+
return checkpointer
|
| 129 |
+
|
| 130 |
+
except Exception as e:
|
| 131 |
+
logger.error(f"❌ Even in-memory checkpointer failed: {e}")
|
| 132 |
+
raise Exception("Failed to initialize any checkpointer")
|
| 133 |
+
|
| 134 |
+
def get_checkpointer_type(checkpointer):
|
| 135 |
+
"""Utility function to check what type of checkpointer is being used"""
|
| 136 |
+
if hasattr(checkpointer, '__class__'):
|
| 137 |
+
if 'PostgresSaver' in checkpointer.__class__.__name__:
|
| 138 |
+
return "postgres"
|
| 139 |
+
elif 'InMemorySaver' in checkpointer.__class__.__name__:
|
| 140 |
+
return "memory"
|
| 141 |
+
return "unknown"
|
main.py
CHANGED
|
@@ -1,633 +1,483 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
-
|
| 4 |
-
Supports dynamic addition of new countries with clean architecture
|
| 5 |
"""
|
| 6 |
-
# Add this as the FIRST lines of code (after docstrings)
|
| 7 |
import sys
|
| 8 |
from pathlib import Path
|
| 9 |
-
sys.path.insert(0, str(Path(__file__).parent
|
| 10 |
|
| 11 |
import asyncio
|
| 12 |
import logging
|
| 13 |
import time
|
| 14 |
from datetime import datetime
|
| 15 |
-
from typing import
|
| 16 |
-
|
| 17 |
-
from
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
def __init__(self):
|
| 33 |
-
self.
|
| 34 |
-
self.postgres_checkpointer = PostgresCheckpointer(
|
| 35 |
-
database_url=settings.DATABASE_URL,
|
| 36 |
-
max_connections=10,
|
| 37 |
-
min_connections=2
|
| 38 |
-
)
|
| 39 |
-
self.router = None
|
| 40 |
-
# Dynamic country retrievers dictionary - easily extensible!
|
| 41 |
-
self.country_retrievers = {}
|
| 42 |
-
self.llm = None
|
| 43 |
-
self.graph = None
|
| 44 |
self.chat_manager = None
|
| 45 |
-
self.
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
|
|
|
|
|
|
| 49 |
try:
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
# Initialize databases
|
| 54 |
-
if not self.mongo_client.connect():
|
| 55 |
-
raise Exception("MongoDB connection failed")
|
| 56 |
-
|
| 57 |
-
if not await self.postgres_checkpointer.initialize():
|
| 58 |
-
logging.warning("PostgreSQL initialization failed")
|
| 59 |
|
| 60 |
-
|
| 61 |
-
self.router = CountryRouter()
|
| 62 |
|
| 63 |
-
#
|
| 64 |
-
self.
|
|
|
|
| 65 |
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
self.
|
| 69 |
-
|
| 70 |
-
temperature=settings.CHAT_TEMPERATURE,
|
| 71 |
-
max_tokens=settings.CHAT_MAX_TOKENS
|
| 72 |
-
)
|
| 73 |
-
|
| 74 |
-
# Build scalable graph with country dictionary
|
| 75 |
-
graph_builder = GraphBuilder(
|
| 76 |
-
router=self.router,
|
| 77 |
-
llm=self.llm,
|
| 78 |
-
checkpointer=self.postgres_checkpointer.get_checkpointer(),
|
| 79 |
-
country_retrievers=self.country_retrievers # Pass the dictionary
|
| 80 |
-
)
|
| 81 |
-
|
| 82 |
-
workflow = graph_builder.build_graph()
|
| 83 |
-
|
| 84 |
-
# Compile with interrupt support
|
| 85 |
-
self.graph = workflow.compile(
|
| 86 |
-
checkpointer=self.postgres_checkpointer.get_checkpointer(),
|
| 87 |
-
interrupt_before=["human_approval"]
|
| 88 |
-
)
|
| 89 |
-
|
| 90 |
-
# Initialize chat manager
|
| 91 |
-
self.chat_manager = LegalChatManager(
|
| 92 |
-
self.graph,
|
| 93 |
-
self.postgres_checkpointer.get_checkpointer()
|
| 94 |
-
)
|
| 95 |
-
|
| 96 |
-
await self._perform_health_check()
|
| 97 |
-
|
| 98 |
-
self.initialized = True
|
| 99 |
-
logging.info(f"✅ System initialized with {len(self.country_retrievers)} countries")
|
| 100 |
-
self._print_system_info()
|
| 101 |
|
| 102 |
return True
|
| 103 |
|
| 104 |
except Exception as e:
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
traceback.print_exc()
|
| 108 |
return False
|
| 109 |
-
|
| 110 |
-
def
|
| 111 |
-
"""
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
self.country_retrievers["benin"] = LegalRetriever(
|
| 115 |
-
self.mongo_client.benin_vectorstore,
|
| 116 |
-
self.mongo_client.benin_collection
|
| 117 |
-
)
|
| 118 |
-
|
| 119 |
-
# Madagascar
|
| 120 |
-
if hasattr(self.mongo_client, 'madagascar_vectorstore'):
|
| 121 |
-
self.country_retrievers["madagascar"] = LegalRetriever(
|
| 122 |
-
self.mongo_client.madagascar_vectorstore,
|
| 123 |
-
self.mongo_client.madagascar_collection
|
| 124 |
-
)
|
| 125 |
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
def add_country(self, country_code: str, vectorstore, collection) -> bool:
|
| 129 |
-
"""Dynamically add a new country to the running system"""
|
| 130 |
-
try:
|
| 131 |
-
if country_code in self.country_retrievers:
|
| 132 |
-
logging.warning(f"Country {country_code} already exists")
|
| 133 |
-
return False
|
| 134 |
-
|
| 135 |
-
new_retriever = LegalRetriever(vectorstore, collection)
|
| 136 |
-
self.country_retrievers[country_code] = new_retriever
|
| 137 |
-
|
| 138 |
-
# Rebuild graph if system is already initialized
|
| 139 |
-
if self.initialized:
|
| 140 |
-
graph_builder = GraphBuilder(
|
| 141 |
-
router=self.router,
|
| 142 |
-
llm=self.llm,
|
| 143 |
-
checkpointer=self.postgres_checkpointer.get_checkpointer(),
|
| 144 |
-
country_retrievers=self.country_retrievers
|
| 145 |
-
)
|
| 146 |
-
workflow = graph_builder.build_graph()
|
| 147 |
-
self.graph = workflow.compile(
|
| 148 |
-
checkpointer=self.postgres_checkpointer.get_checkpointer(),
|
| 149 |
-
interrupt_before=["human_approval"]
|
| 150 |
-
)
|
| 151 |
-
|
| 152 |
-
logging.info(f"🎉 Successfully added country: {country_code}")
|
| 153 |
-
return True
|
| 154 |
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
""
|
| 161 |
-
try:
|
| 162 |
-
health_status = await self.health_check()
|
| 163 |
-
|
| 164 |
-
unhealthy_components = [k for k, v in health_status.get('components', {}).items() if not v]
|
| 165 |
-
if unhealthy_components:
|
| 166 |
-
logging.warning(f"⚠️ Unhealthy components: {unhealthy_components}")
|
| 167 |
-
|
| 168 |
-
except Exception as e:
|
| 169 |
-
logging.warning(f"⚠️ Health check failed: {e}")
|
| 170 |
-
|
| 171 |
-
async def health_check(self) -> Dict[str, Any]:
|
| 172 |
-
"""Comprehensive system health check"""
|
| 173 |
-
health_status = {
|
| 174 |
-
"system_initialized": self.initialized,
|
| 175 |
-
"mongodb_connected": self.mongo_client.client is not None,
|
| 176 |
-
"postgres_healthy": {},
|
| 177 |
-
"interrupt_enabled": True,
|
| 178 |
-
"available_countries": list(self.country_retrievers.keys()),
|
| 179 |
-
"components": {
|
| 180 |
-
"router": self.router is not None,
|
| 181 |
-
"llm": self.llm is not None,
|
| 182 |
-
"graph": self.graph is not None,
|
| 183 |
-
"chat_manager": self.chat_manager is not None,
|
| 184 |
-
"country_retrievers": len(self.country_retrievers) > 0
|
| 185 |
-
},
|
| 186 |
-
"timestamp": datetime.now().isoformat(),
|
| 187 |
-
"settings": {
|
| 188 |
-
"chat_model": settings.CHAT_MODEL,
|
| 189 |
-
"embedding_model": settings.EMBEDDING_MODEL,
|
| 190 |
-
"max_search_results": settings.MAX_SEARCH_RESULTS
|
| 191 |
-
}
|
| 192 |
-
}
|
| 193 |
-
|
| 194 |
-
# Test MongoDB connection
|
| 195 |
-
if health_status["mongodb_connected"]:
|
| 196 |
-
try:
|
| 197 |
-
self.mongo_client.client.admin.command('ping')
|
| 198 |
-
health_status["mongodb_ping"] = True
|
| 199 |
-
except Exception as e:
|
| 200 |
-
health_status["mongodb_ping"] = False
|
| 201 |
-
health_status["mongodb_error"] = str(e)
|
| 202 |
|
| 203 |
-
#
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
|
|
|
| 207 |
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
"""Public chat interface"""
|
| 212 |
-
if not self.initialized:
|
| 213 |
-
raise RuntimeError("System not initialized. Call initialize() first.")
|
| 214 |
-
|
| 215 |
-
if not message or not message.strip():
|
| 216 |
-
raise ValueError("Message cannot be empty")
|
| 217 |
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
ctx.setdefault("jurisdiction", "Unknown")
|
| 222 |
-
ctx.setdefault("user_type", "general")
|
| 223 |
-
ctx.setdefault("document_type", "legal")
|
| 224 |
-
ctx.setdefault("detected_country", "unknown")
|
| 225 |
|
| 226 |
-
|
|
|
|
|
|
|
| 227 |
|
| 228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
if not self.initialized:
|
| 237 |
-
raise RuntimeError("System not initialized")
|
| 238 |
-
return self.chat_manager.get_session_stats(session_id)
|
| 239 |
-
|
| 240 |
-
def get_global_stats(self) -> Dict[str, Any]:
|
| 241 |
-
"""Get global system statistics"""
|
| 242 |
-
if not self.initialized:
|
| 243 |
-
raise RuntimeError("System not initialized")
|
| 244 |
-
return self.chat_manager.get_global_stats()
|
| 245 |
-
|
| 246 |
-
def get_available_countries(self) -> List[str]:
|
| 247 |
-
"""Get list of available countries"""
|
| 248 |
-
return list(self.country_retrievers.keys())
|
| 249 |
-
|
| 250 |
-
async def cleanup(self):
|
| 251 |
-
"""Cleanup resources"""
|
| 252 |
-
try:
|
| 253 |
-
if self.mongo_client:
|
| 254 |
-
self.mongo_client.close()
|
| 255 |
-
if self.postgres_checkpointer:
|
| 256 |
-
await self.postgres_checkpointer.close()
|
| 257 |
-
logging.info("✅ System cleanup completed")
|
| 258 |
-
except Exception as e:
|
| 259 |
-
logging.error(f"❌ Error during cleanup: {e}")
|
| 260 |
-
|
| 261 |
-
def _print_system_info(self):
|
| 262 |
-
"""Print system configuration information"""
|
| 263 |
-
countries = list(self.country_retrievers.keys())
|
| 264 |
-
print("\n" + "="*60)
|
| 265 |
-
print("🚀 SCALABLE MULTI-COUNTRY LEGAL RAG SYSTEM")
|
| 266 |
-
print("="*60)
|
| 267 |
-
print(f"🌍 Available Countries: {', '.join(countries) if countries else 'None'}")
|
| 268 |
-
print(f"🤖 AI Model: {settings.CHAT_MODEL}")
|
| 269 |
-
print(f"💾 Database: MongoDB + PostgreSQL")
|
| 270 |
-
print(f"🔍 Vector Search: {settings.EMBEDDING_MODEL}")
|
| 271 |
-
print(f"⏸️ Interrupt Support: ENABLED")
|
| 272 |
-
print(f"🌡️ Temperature: {settings.CHAT_TEMPERATURE}")
|
| 273 |
-
print(f"📝 Max Tokens: {settings.CHAT_MAX_TOKENS}")
|
| 274 |
-
print("="*60)
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
class InterruptTester:
|
| 278 |
-
"""Specialized tester for human approval interrupts"""
|
| 279 |
-
|
| 280 |
-
def __init__(self, system: MultiCountryLegalRAGSystem):
|
| 281 |
-
self.system = system
|
| 282 |
-
self.test_results = []
|
| 283 |
-
|
| 284 |
-
async def test_assistance_workflow(self, test_name: str,
|
| 285 |
-
user_query: str,
|
| 286 |
-
user_email: str,
|
| 287 |
-
user_description: str,
|
| 288 |
-
moderator_response: str) -> Dict[str, Any]:
|
| 289 |
-
"""Test the complete assistance workflow with interrupt"""
|
| 290 |
-
print(f"\n🧪 Interrupt Test: {test_name}")
|
| 291 |
-
print(f"📝 User Query: {user_query}")
|
| 292 |
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 296 |
|
| 297 |
try:
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
current_response = await self.system.chat(user_query, session_id)
|
| 301 |
-
print(f"🤖 Response: {current_response[:150]}...")
|
| 302 |
|
| 303 |
-
#
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
|
|
|
| 308 |
|
| 309 |
-
|
| 310 |
-
if user_description and any(keyword in current_response.lower() for keyword in ["description", "décrire", "besoin"]):
|
| 311 |
-
print(f"3️⃣ Step 3: Providing description: {user_description[:50]}...")
|
| 312 |
-
current_response = await self.system.chat(user_description, session_id)
|
| 313 |
-
print(f"🤖 Response: {current_response[:150]}...")
|
| 314 |
|
| 315 |
-
#
|
| 316 |
-
|
| 317 |
-
print("4️⃣ Step 4: Confirming request...")
|
| 318 |
-
current_response = await self.system.chat("oui", session_id)
|
| 319 |
-
print(f"🤖 Response: {current_response[:150]}...")
|
| 320 |
|
| 321 |
-
#
|
| 322 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 323 |
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
print(f"👨⚖️ Moderator: {moderator_response}")
|
| 329 |
-
final_response = await self.system.chat(moderator_response, session_id)
|
| 330 |
-
print(f"✅ Final Response: {final_response[:200]}...")
|
| 331 |
-
|
| 332 |
-
result = {
|
| 333 |
-
"test_name": test_name,
|
| 334 |
-
"status": "PASS",
|
| 335 |
-
"interrupt_detected": True,
|
| 336 |
-
"moderator_decision": moderator_response,
|
| 337 |
-
"final_response": final_response,
|
| 338 |
-
"session_id": session_id
|
| 339 |
-
}
|
| 340 |
-
else:
|
| 341 |
-
print("⚠️ No interrupt detected in workflow")
|
| 342 |
-
result = {
|
| 343 |
-
"test_name": test_name,
|
| 344 |
-
"status": "FAIL",
|
| 345 |
-
"interrupt_detected": False,
|
| 346 |
-
"moderator_decision": None,
|
| 347 |
-
"final_response": current_response,
|
| 348 |
-
"error": "Interrupt not triggered",
|
| 349 |
-
"session_id": session_id
|
| 350 |
-
}
|
| 351 |
|
| 352 |
-
|
| 353 |
-
return result
|
| 354 |
|
|
|
|
|
|
|
| 355 |
except Exception as e:
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
"status": "ERROR",
|
| 360 |
-
"interrupt_detected": False,
|
| 361 |
-
"moderator_decision": None,
|
| 362 |
-
"final_response": current_response,
|
| 363 |
-
"error": str(e),
|
| 364 |
-
"session_id": session_id
|
| 365 |
-
}
|
| 366 |
-
self.test_results.append(error_result)
|
| 367 |
-
return error_result
|
| 368 |
-
|
| 369 |
-
def _check_for_interrupt(self, response: str, session_id: str) -> bool:
|
| 370 |
-
"""Enhanced interrupt detection"""
|
| 371 |
-
interrupt_indicators = [
|
| 372 |
-
"APPROBATION", "APPROVAL", "HUMAN", "MODERATOR",
|
| 373 |
-
"DÉCISION", "DECISION", "APPROUVER", "REJETER"
|
| 374 |
-
]
|
| 375 |
-
|
| 376 |
-
if any(indicator in response.upper() for indicator in interrupt_indicators):
|
| 377 |
-
return True
|
| 378 |
-
|
| 379 |
-
if (hasattr(self.system.chat_manager, 'pending_interrupts') and
|
| 380 |
-
session_id in self.system.chat_manager.pending_interrupts):
|
| 381 |
-
return True
|
| 382 |
-
|
| 383 |
-
return False
|
| 384 |
|
| 385 |
-
def
|
| 386 |
-
"""
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
total = len(self.test_results)
|
| 392 |
-
passed = len([r for r in self.test_results if r["status"] == "PASS"])
|
| 393 |
-
failed = len([r for r in self.test_results if r["status"] == "FAIL"])
|
| 394 |
-
errors = len([r for r in self.test_results if r["status"] == "ERROR"])
|
| 395 |
-
|
| 396 |
-
print(f"📈 Total Tests: {total}")
|
| 397 |
-
print(f"✅ Passed: {passed}")
|
| 398 |
-
print(f"❌ Failed: {failed}")
|
| 399 |
-
print(f"🚨 Errors: {errors}")
|
| 400 |
-
|
| 401 |
-
if passed > 0:
|
| 402 |
-
print(f"\n🎉 Successful Tests:")
|
| 403 |
-
for result in self.test_results:
|
| 404 |
-
if result["status"] == "PASS":
|
| 405 |
-
print(f" - {result['test_name']}")
|
| 406 |
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
for result in self.test_results:
|
| 410 |
-
if result["status"] in ["FAIL", "ERROR"]:
|
| 411 |
-
print(f" - {result['test_name']}: {result.get('error', 'Unknown error')}")
|
| 412 |
|
| 413 |
-
print("="*
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 420 |
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
print("
|
| 429 |
-
print("
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
{
|
| 440 |
-
"name": "Complete Workflow - Reject",
|
| 441 |
-
"user_query": "Contactez-moi",
|
| 442 |
-
"user_email": "test2@example.com",
|
| 443 |
-
"user_description": "J'ai besoin d'aide",
|
| 444 |
-
"moderator_response": "reject Description trop vague"
|
| 445 |
-
}
|
| 446 |
-
]
|
| 447 |
-
|
| 448 |
-
for scenario in test_scenarios:
|
| 449 |
-
await tester.test_assistance_workflow(
|
| 450 |
-
scenario["name"],
|
| 451 |
-
scenario["user_query"],
|
| 452 |
-
scenario["user_email"],
|
| 453 |
-
scenario["user_description"],
|
| 454 |
-
scenario["moderator_response"]
|
| 455 |
-
)
|
| 456 |
-
await asyncio.sleep(1)
|
| 457 |
-
|
| 458 |
-
tester.print_summary()
|
| 459 |
-
|
| 460 |
-
except Exception as e:
|
| 461 |
-
logging.error(f"❌ Error during testing: {e}")
|
| 462 |
-
finally:
|
| 463 |
-
await system.cleanup()
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
async def interactive_mode():
|
| 467 |
-
"""Run interactive chat mode"""
|
| 468 |
-
system = MultiCountryLegalRAGSystem()
|
| 469 |
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
print("❌ System initialization failed")
|
| 475 |
return
|
| 476 |
-
|
| 477 |
-
print("\n🎯 INTERACTIVE MODE - SCALABLE SYSTEM")
|
| 478 |
-
print("="*60)
|
| 479 |
-
print("Commands:")
|
| 480 |
-
print(" 'quit' - Exit")
|
| 481 |
-
print(" 'stats' - Show statistics")
|
| 482 |
-
print(" 'health' - Health check")
|
| 483 |
-
print(" 'countries' - List available countries")
|
| 484 |
-
print(" 'session' - Session info")
|
| 485 |
-
print("="*60)
|
| 486 |
-
|
| 487 |
-
session_id = f"interactive_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
| 488 |
-
print(f"Session ID: {session_id}")
|
| 489 |
-
print(f"Available: {', '.join(system.get_available_countries())}\n")
|
| 490 |
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
response = await system.chat(user_input, session_id)
|
| 526 |
-
response_time = time.time() - start_time
|
| 527 |
-
|
| 528 |
-
print(f"🤖 Assistant ({response_time:.2f}s): {response}\n")
|
| 529 |
-
|
| 530 |
-
# Check for interrupt
|
| 531 |
-
if (hasattr(system.chat_manager, 'pending_interrupts') and
|
| 532 |
-
session_id in system.chat_manager.pending_interrupts):
|
| 533 |
-
print("⏸️ 💡 SYSTEM PAUSED - Next message treated as moderator decision\n")
|
| 534 |
-
|
| 535 |
-
except KeyboardInterrupt:
|
| 536 |
-
print("\n👋 Goodbye!")
|
| 537 |
-
break
|
| 538 |
-
except Exception as e:
|
| 539 |
-
print(f"❌ Error: {str(e)}\n")
|
| 540 |
-
|
| 541 |
-
finally:
|
| 542 |
-
await system.cleanup()
|
| 543 |
-
|
| 544 |
|
| 545 |
-
async def
|
| 546 |
-
"""Run system
|
| 547 |
-
|
| 548 |
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 562 |
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
print(f" {component}: {'✅ OK' if status else '❌ Missing'}")
|
| 566 |
-
|
| 567 |
-
all_healthy = (health['system_initialized'] and
|
| 568 |
-
health['mongodb_connected'] and
|
| 569 |
-
all(health['components'].values()))
|
| 570 |
-
print(f"\n🎯 Overall Status: {'✅ HEALTHY' if all_healthy else '❌ UNHEALTHY'}")
|
| 571 |
|
| 572 |
-
|
| 573 |
-
|
| 574 |
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 578 |
|
| 579 |
-
async def
|
| 580 |
-
"""Run
|
| 581 |
-
|
| 582 |
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 603 |
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
|
| 608 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 609 |
|
| 610 |
-
|
|
|
|
| 611 |
import argparse
|
| 612 |
|
| 613 |
parser = argparse.ArgumentParser(
|
| 614 |
-
description="
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 615 |
)
|
| 616 |
|
| 617 |
parser.add_argument(
|
| 618 |
"--mode",
|
| 619 |
-
choices=["interactive", "
|
| 620 |
default="interactive",
|
| 621 |
-
help="Run mode (default
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 622 |
)
|
| 623 |
|
| 624 |
args = parser.parse_args()
|
| 625 |
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
|
| 631 |
-
|
| 632 |
-
|
| 633 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
+
Multi-Country Legal RAG System - Interactive Testing Mode with Human-in-the-Loop Support
|
|
|
|
| 4 |
"""
|
|
|
|
| 5 |
import sys
|
| 6 |
from pathlib import Path
|
| 7 |
+
sys.path.insert(0, str(Path(__file__).parent))
|
| 8 |
|
| 9 |
import asyncio
|
| 10 |
import logging
|
| 11 |
import time
|
| 12 |
from datetime import datetime
|
| 13 |
+
from typing import Optional
|
| 14 |
+
|
| 15 |
+
from core.system_initializer import setup_system
|
| 16 |
+
|
| 17 |
+
# Setup comprehensive logging
|
| 18 |
+
logging.basicConfig(
|
| 19 |
+
level=logging.INFO,
|
| 20 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
| 21 |
+
handlers=[
|
| 22 |
+
logging.StreamHandler(sys.stdout),
|
| 23 |
+
logging.FileHandler('legal_rag_system.log', mode='a')
|
| 24 |
+
]
|
| 25 |
+
)
|
| 26 |
+
logger = logging.getLogger(__name__)
|
| 27 |
+
|
| 28 |
+
class LegalRAGTester:
|
| 29 |
+
"""
|
| 30 |
+
Interactive tester for the Legal RAG system with support for:
|
| 31 |
+
- Human-in-the-loop interrupts
|
| 32 |
+
- Session management
|
| 33 |
+
- Statistics tracking
|
| 34 |
+
- Assistance workflow testing
|
| 35 |
+
"""
|
| 36 |
|
| 37 |
def __init__(self):
|
| 38 |
+
self.system = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
self.chat_manager = None
|
| 40 |
+
self.session_id = f"test_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
| 41 |
+
self.message_count = 0
|
| 42 |
+
|
| 43 |
+
async def initialize(self):
|
| 44 |
+
"""Initialize the Legal RAG system"""
|
| 45 |
+
print("🚀 Initializing Legal RAG System...")
|
| 46 |
try:
|
| 47 |
+
self.system = await setup_system()
|
| 48 |
+
self.chat_manager = self.system["chat_manager"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
+
print("✅ System initialized successfully!")
|
|
|
|
| 51 |
|
| 52 |
+
# Print system info
|
| 53 |
+
info = self.chat_manager.get_checkpointer_info()
|
| 54 |
+
stats = self.chat_manager.get_global_stats()
|
| 55 |
|
| 56 |
+
print(f"📊 Checkpointer: {info['type']} ({info['description']})")
|
| 57 |
+
print(f"💾 Persistent: {info['persistent']}")
|
| 58 |
+
print(f"🎯 Session ID: {self.session_id}")
|
| 59 |
+
print(f"🌍 Available countries: Bénin, Madagascar")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
return True
|
| 62 |
|
| 63 |
except Exception as e:
|
| 64 |
+
logger.exception("Failed to initialize system")
|
| 65 |
+
print(f"❌ Initialization failed: {e}")
|
|
|
|
| 66 |
return False
|
| 67 |
+
|
| 68 |
+
def _handle_interrupt(self, interrupt_value: dict) -> str:
|
| 69 |
+
"""
|
| 70 |
+
Handle human-in-the-loop interrupts synchronously.
|
| 71 |
+
This is called when the graph needs human approval.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
+
Args:
|
| 74 |
+
interrupt_value: Interrupt data containing message and context
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
+
Returns:
|
| 77 |
+
Moderator's decision (approve/reject)
|
| 78 |
+
"""
|
| 79 |
+
print("\n" + "="*70)
|
| 80 |
+
print("🔒 HUMAN APPROVAL REQUIRED")
|
| 81 |
+
print("="*70)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
+
# Extract interrupt information
|
| 84 |
+
message = interrupt_value.get("message", "")
|
| 85 |
+
email = interrupt_value.get("user_email", "N/A")
|
| 86 |
+
country = interrupt_value.get("country", "N/A")
|
| 87 |
+
description = interrupt_value.get("description", "N/A")
|
| 88 |
|
| 89 |
+
# Display formatted approval request
|
| 90 |
+
print(message)
|
| 91 |
+
print()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
+
# Get moderator decision with validation
|
| 94 |
+
while True:
|
| 95 |
+
moderator_input = input("🔐 Moderator Decision: ").strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
|
| 97 |
+
if not moderator_input:
|
| 98 |
+
print("⚠️ Please provide a decision (approve/reject)")
|
| 99 |
+
continue
|
| 100 |
|
| 101 |
+
# Validate input
|
| 102 |
+
input_lower = moderator_input.lower()
|
| 103 |
+
if any(keyword in input_lower for keyword in ["approve", "approuver", "accept"]):
|
| 104 |
+
print("✅ Request APPROVED")
|
| 105 |
+
break
|
| 106 |
+
elif any(keyword in input_lower for keyword in ["reject", "rejeter", "refuse"]):
|
| 107 |
+
print("❌ Request REJECTED")
|
| 108 |
+
break
|
| 109 |
+
else:
|
| 110 |
+
print("⚠️ Invalid decision. Use 'approve [reason]' or 'reject [reason]'")
|
| 111 |
+
continue
|
| 112 |
|
| 113 |
+
print("="*70 + "\n")
|
| 114 |
+
return moderator_input
|
| 115 |
+
|
| 116 |
+
async def chat(self, message: str) -> Optional[str]:
|
| 117 |
+
"""
|
| 118 |
+
Send a chat message and get response with interrupt handling.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
+
Args:
|
| 121 |
+
message: User message
|
| 122 |
+
|
| 123 |
+
Returns:
|
| 124 |
+
Assistant response or None on error
|
| 125 |
+
"""
|
| 126 |
+
if not self.chat_manager:
|
| 127 |
+
print("❌ System not initialized. Please restart.")
|
| 128 |
+
return None
|
| 129 |
|
| 130 |
try:
|
| 131 |
+
start_time = time.time()
|
| 132 |
+
self.message_count += 1
|
|
|
|
|
|
|
| 133 |
|
| 134 |
+
# Send message with interrupt handler
|
| 135 |
+
response = await self.chat_manager.chat(
|
| 136 |
+
message=message,
|
| 137 |
+
session_id=self.session_id,
|
| 138 |
+
interrupt_handler=self._handle_interrupt # Enable synchronous interrupt handling
|
| 139 |
+
)
|
| 140 |
|
| 141 |
+
response_time = time.time() - start_time
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
+
# Display response
|
| 144 |
+
print(f"\n🤖 Assistant ({response_time:.2f}s):")
|
|
|
|
|
|
|
|
|
|
| 145 |
|
| 146 |
+
# Format multi-line responses
|
| 147 |
+
for line in response.split('\n'):
|
| 148 |
+
if line.strip():
|
| 149 |
+
print(f" {line}")
|
| 150 |
+
else:
|
| 151 |
+
print()
|
| 152 |
|
| 153 |
+
# Check for pending interrupts (async mode fallback)
|
| 154 |
+
if self.chat_manager.has_pending_interrupt(self.session_id):
|
| 155 |
+
print("\n⏸️ 💡 System paused - waiting for moderator decision")
|
| 156 |
+
print(" Your next message will be treated as approval/rejection")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
|
| 158 |
+
return response
|
|
|
|
| 159 |
|
| 160 |
+
except KeyboardInterrupt:
|
| 161 |
+
raise # Re-raise to allow graceful shutdown
|
| 162 |
except Exception as e:
|
| 163 |
+
logger.exception(f"Error processing message: {message}")
|
| 164 |
+
print(f"❌ Error: {str(e)}")
|
| 165 |
+
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
|
| 167 |
+
def display_stats(self):
|
| 168 |
+
"""Display current system statistics"""
|
| 169 |
+
if not self.chat_manager:
|
| 170 |
+
print("❌ System not initialized")
|
| 171 |
+
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
|
| 173 |
+
stats = self.chat_manager.get_global_stats()
|
| 174 |
+
routing = stats.get('routing_stats', {})
|
|
|
|
|
|
|
|
|
|
| 175 |
|
| 176 |
+
print("\n" + "="*70)
|
| 177 |
+
print("📊 SYSTEM STATISTICS")
|
| 178 |
+
print("="*70)
|
| 179 |
+
print(f"💬 Session Messages: {self.message_count}")
|
| 180 |
+
print(f"🔢 Total System Queries: {stats.get('total_queries', 0)}")
|
| 181 |
+
print(f"👥 Active Sessions: {stats.get('active_sessions', 0)}")
|
| 182 |
+
print(f"⏸️ Pending Interrupts: {stats.get('pending_interrupts', 0)}")
|
| 183 |
+
print()
|
| 184 |
+
print("🔀 Routing Statistics:")
|
| 185 |
+
print(f" 📍 Bénin queries: {routing.get('benin', 0)}")
|
| 186 |
+
print(f" 📍 Madagascar queries: {routing.get('madagascar', 0)}")
|
| 187 |
+
print(f" ❓ Unclear queries: {routing.get('unclear', 0)}")
|
| 188 |
+
print()
|
| 189 |
+
print(f"💾 Storage: {stats.get('type', 'unknown')} - {stats.get('description', '')}")
|
| 190 |
+
print("="*70 + "\n")
|
| 191 |
|
| 192 |
+
def display_help(self):
|
| 193 |
+
"""Display help information"""
|
| 194 |
+
print("\n" + "="*70)
|
| 195 |
+
print("📋 AVAILABLE COMMANDS")
|
| 196 |
+
print("="*70)
|
| 197 |
+
print(" quit, exit, q - Exit the program")
|
| 198 |
+
print(" stats, s - Show system statistics")
|
| 199 |
+
print(" clear, cls - Clear the screen")
|
| 200 |
+
print(" help, h, ? - Show this help message")
|
| 201 |
+
print(" history - Show conversation history")
|
| 202 |
+
print(" reset - Reset current session")
|
| 203 |
+
print()
|
| 204 |
+
print("💡 TESTING TIPS:")
|
| 205 |
+
print(" - Legal queries: Ask about laws in Bénin or Madagascar")
|
| 206 |
+
print(" - Assistance: Say 'je veux parler à un avocat'")
|
| 207 |
+
print(" - Follow the prompts for email and description")
|
| 208 |
+
print(" - You'll be prompted for approval when needed")
|
| 209 |
+
print("="*70 + "\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
|
| 211 |
+
async def display_history(self):
|
| 212 |
+
"""Display conversation history"""
|
| 213 |
+
if not self.chat_manager:
|
| 214 |
+
print("❌ System not initialized")
|
|
|
|
| 215 |
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
|
| 217 |
+
try:
|
| 218 |
+
history = await self.chat_manager.get_conversation_history(self.session_id)
|
| 219 |
+
|
| 220 |
+
if not history:
|
| 221 |
+
print("\n💬 No conversation history yet\n")
|
| 222 |
+
return
|
| 223 |
+
|
| 224 |
+
print("\n" + "="*70)
|
| 225 |
+
print("💬 CONVERSATION HISTORY")
|
| 226 |
+
print("="*70)
|
| 227 |
+
|
| 228 |
+
for i, msg in enumerate(history, 1):
|
| 229 |
+
role = "👤 User" if msg.type == "human" else "🤖 Assistant"
|
| 230 |
+
content = msg.content[:100] + "..." if len(msg.content) > 100 else msg.content
|
| 231 |
+
print(f"{i}. {role}: {content}")
|
| 232 |
+
|
| 233 |
+
print("="*70 + "\n")
|
| 234 |
+
|
| 235 |
+
except Exception as e:
|
| 236 |
+
logger.exception("Error displaying history")
|
| 237 |
+
print(f"❌ Error: {e}\n")
|
| 238 |
+
|
| 239 |
+
def reset_session(self):
|
| 240 |
+
"""Reset the current session"""
|
| 241 |
+
old_session = self.session_id
|
| 242 |
+
self.session_id = f"test_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
| 243 |
+
self.message_count = 0
|
| 244 |
+
print(f"\n🔄 Session reset")
|
| 245 |
+
print(f" Old: {old_session}")
|
| 246 |
+
print(f" New: {self.session_id}\n")
|
| 247 |
+
|
| 248 |
+
def clear_screen(self):
|
| 249 |
+
"""Clear the terminal screen"""
|
| 250 |
+
print("\033[H\033[J", end="")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
|
| 252 |
+
async def interactive_mode():
|
| 253 |
+
"""Run the system in interactive mode with full interrupt support"""
|
| 254 |
+
tester = LegalRAGTester()
|
| 255 |
|
| 256 |
+
# Initialize system
|
| 257 |
+
if not await tester.initialize():
|
| 258 |
+
print("❌ Failed to initialize. Exiting.")
|
| 259 |
+
return
|
| 260 |
+
|
| 261 |
+
# Display welcome message
|
| 262 |
+
print("\n" + "="*70)
|
| 263 |
+
print("🎯 LEGAL RAG INTERACTIVE TESTING MODE")
|
| 264 |
+
print("="*70)
|
| 265 |
+
print("Type 'help' for available commands")
|
| 266 |
+
print("="*70)
|
| 267 |
+
|
| 268 |
+
# Quick test questions
|
| 269 |
+
test_questions = [
|
| 270 |
+
"Bonjour, comment ça va?",
|
| 271 |
+
"Quelle est la procédure de divorce au Bénin?",
|
| 272 |
+
"Je veux parler à un avocat spécialisé",
|
| 273 |
+
"Quels sont les droits des enfants à Madagascar?",
|
| 274 |
+
"Quelles sont les conditions pour se marier au Bénin?",
|
| 275 |
+
]
|
| 276 |
+
|
| 277 |
+
print("\n💡 Quick test questions:")
|
| 278 |
+
for i, question in enumerate(test_questions, 1):
|
| 279 |
+
print(f" {i}. {question}")
|
| 280 |
+
print()
|
| 281 |
+
|
| 282 |
+
# Main interaction loop
|
| 283 |
+
while True:
|
| 284 |
+
try:
|
| 285 |
+
# Get user input
|
| 286 |
+
user_input = input("👤 You: ").strip()
|
| 287 |
|
| 288 |
+
if not user_input:
|
| 289 |
+
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
|
| 291 |
+
# Handle commands
|
| 292 |
+
cmd_lower = user_input.lower()
|
| 293 |
|
| 294 |
+
# Exit commands
|
| 295 |
+
if cmd_lower in ['quit', 'exit', 'q']:
|
| 296 |
+
print("\n👋 Goodbye! Thank you for testing.")
|
| 297 |
+
break
|
| 298 |
+
|
| 299 |
+
# Stats command
|
| 300 |
+
elif cmd_lower in ['stats', 's']:
|
| 301 |
+
tester.display_stats()
|
| 302 |
+
continue
|
| 303 |
+
|
| 304 |
+
# Clear screen
|
| 305 |
+
elif cmd_lower in ['clear', 'cls']:
|
| 306 |
+
tester.clear_screen()
|
| 307 |
+
continue
|
| 308 |
+
|
| 309 |
+
# Help command
|
| 310 |
+
elif cmd_lower in ['help', 'h', '?']:
|
| 311 |
+
tester.display_help()
|
| 312 |
+
continue
|
| 313 |
+
|
| 314 |
+
# History command
|
| 315 |
+
elif cmd_lower == 'history':
|
| 316 |
+
await tester.display_history()
|
| 317 |
+
continue
|
| 318 |
+
|
| 319 |
+
# Reset session
|
| 320 |
+
elif cmd_lower == 'reset':
|
| 321 |
+
tester.reset_session()
|
| 322 |
+
continue
|
| 323 |
+
|
| 324 |
+
# Regular chat message
|
| 325 |
+
else:
|
| 326 |
+
await tester.chat(user_input)
|
| 327 |
+
|
| 328 |
+
except KeyboardInterrupt:
|
| 329 |
+
print("\n\n⚠️ Interrupted by user")
|
| 330 |
+
print("Type 'quit' to exit or continue chatting\n")
|
| 331 |
+
continue
|
| 332 |
+
|
| 333 |
+
except Exception as e:
|
| 334 |
+
logger.exception("Unexpected error in main loop")
|
| 335 |
+
print(f"\n❌ Unexpected error: {e}\n")
|
| 336 |
|
| 337 |
+
async def demo_mode():
|
| 338 |
+
"""Run automated demo of the system capabilities"""
|
| 339 |
+
tester = LegalRAGTester()
|
| 340 |
|
| 341 |
+
print("🚀 Running Automated Demo...")
|
| 342 |
+
|
| 343 |
+
if not await tester.initialize():
|
| 344 |
+
print("❌ Failed to initialize. Exiting demo.")
|
| 345 |
+
return
|
| 346 |
+
|
| 347 |
+
demo_scenarios = [
|
| 348 |
+
{
|
| 349 |
+
"name": "Greeting",
|
| 350 |
+
"messages": ["Bonjour, je m'appelle Test"]
|
| 351 |
+
},
|
| 352 |
+
{
|
| 353 |
+
"name": "Legal Query - Bénin",
|
| 354 |
+
"messages": ["Quelle est la procédure de divorce au Bénin?"]
|
| 355 |
+
},
|
| 356 |
+
{
|
| 357 |
+
"name": "Legal Query - Madagascar",
|
| 358 |
+
"messages": ["Quels sont les droits de succession à Madagascar?"]
|
| 359 |
+
},
|
| 360 |
+
{
|
| 361 |
+
"name": "Conversation Repair",
|
| 362 |
+
"messages": ["Peux-tu répéter plus simplement?"]
|
| 363 |
+
},
|
| 364 |
+
{
|
| 365 |
+
"name": "Summary Request",
|
| 366 |
+
"messages": ["Résume notre conversation"]
|
| 367 |
+
}
|
| 368 |
+
]
|
| 369 |
+
|
| 370 |
+
print("\n" + "="*70)
|
| 371 |
+
print("🧪 DEMO SCENARIOS")
|
| 372 |
+
print("="*70 + "\n")
|
| 373 |
+
|
| 374 |
+
for i, scenario in enumerate(demo_scenarios, 1):
|
| 375 |
+
print(f"\n{'='*70}")
|
| 376 |
+
print(f"📋 Scenario {i}/{len(demo_scenarios)}: {scenario['name']}")
|
| 377 |
+
print('='*70)
|
| 378 |
|
| 379 |
+
for message in scenario['messages']:
|
| 380 |
+
print(f"\n🧪 Testing: '{message}'")
|
| 381 |
+
await tester.chat(message)
|
| 382 |
+
await asyncio.sleep(1.5) # Pause between messages
|
| 383 |
+
|
| 384 |
+
print("\n" + "="*70)
|
| 385 |
+
print("✅ DEMO COMPLETED")
|
| 386 |
+
print("="*70)
|
| 387 |
+
tester.display_stats()
|
| 388 |
+
|
| 389 |
+
async def test_assistance_workflow():
|
| 390 |
+
"""Test the complete assistance workflow with interrupts"""
|
| 391 |
+
tester = LegalRAGTester()
|
| 392 |
+
|
| 393 |
+
print("🧪 Testing Assistance Workflow with Human Approval...")
|
| 394 |
+
|
| 395 |
+
if not await tester.initialize():
|
| 396 |
+
print("❌ Failed to initialize. Exiting test.")
|
| 397 |
+
return
|
| 398 |
+
|
| 399 |
+
print("\n" + "="*70)
|
| 400 |
+
print("🔬 ASSISTANCE WORKFLOW TEST")
|
| 401 |
+
print("="*70 + "\n")
|
| 402 |
+
|
| 403 |
+
# Simulated assistance workflow
|
| 404 |
+
workflow_steps = [
|
| 405 |
+
"Je veux parler à un avocat",
|
| 406 |
+
"test@example.com",
|
| 407 |
+
"J'ai besoin d'aide pour un problème de divorce au Bénin",
|
| 408 |
+
"oui"
|
| 409 |
+
]
|
| 410 |
+
|
| 411 |
+
step_names = [
|
| 412 |
+
"1. Initiate assistance request",
|
| 413 |
+
"2. Provide email",
|
| 414 |
+
"3. Describe situation",
|
| 415 |
+
"4. Confirm request"
|
| 416 |
+
]
|
| 417 |
+
|
| 418 |
+
for step_name, message in zip(step_names, workflow_steps):
|
| 419 |
+
print(f"\n📍 {step_name}")
|
| 420 |
+
print(f" Message: '{message}'")
|
| 421 |
+
await tester.chat(message)
|
| 422 |
+
await asyncio.sleep(1)
|
| 423 |
+
|
| 424 |
+
print("\n" + "="*70)
|
| 425 |
+
print("✅ WORKFLOW TEST COMPLETED")
|
| 426 |
+
print("="*70)
|
| 427 |
+
tester.display_stats()
|
| 428 |
|
| 429 |
+
def main():
|
| 430 |
+
"""Main entry point with argument parsing"""
|
| 431 |
import argparse
|
| 432 |
|
| 433 |
parser = argparse.ArgumentParser(
|
| 434 |
+
description="Legal RAG System - Interactive Testing with Human-in-the-Loop Support",
|
| 435 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 436 |
+
epilog="""
|
| 437 |
+
Examples:
|
| 438 |
+
python main.py # Interactive mode (default)
|
| 439 |
+
python main.py --mode demo # Automated demo
|
| 440 |
+
python main.py --mode test # Test assistance workflow
|
| 441 |
+
python main.py --debug # Enable debug logging
|
| 442 |
+
"""
|
| 443 |
)
|
| 444 |
|
| 445 |
parser.add_argument(
|
| 446 |
"--mode",
|
| 447 |
+
choices=["interactive", "demo", "test"],
|
| 448 |
default="interactive",
|
| 449 |
+
help="Run mode: interactive (default), demo, or test"
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
parser.add_argument(
|
| 453 |
+
"--debug",
|
| 454 |
+
action="store_true",
|
| 455 |
+
help="Enable debug logging"
|
| 456 |
)
|
| 457 |
|
| 458 |
args = parser.parse_args()
|
| 459 |
|
| 460 |
+
# Adjust logging level if debug
|
| 461 |
+
if args.debug:
|
| 462 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
| 463 |
+
print("🐛 Debug logging enabled\n")
|
| 464 |
+
|
| 465 |
+
# Run selected mode
|
| 466 |
+
try:
|
| 467 |
+
if args.mode == "interactive":
|
| 468 |
+
asyncio.run(interactive_mode())
|
| 469 |
+
elif args.mode == "demo":
|
| 470 |
+
asyncio.run(demo_mode())
|
| 471 |
+
elif args.mode == "test":
|
| 472 |
+
asyncio.run(test_assistance_workflow())
|
| 473 |
+
|
| 474 |
+
except KeyboardInterrupt:
|
| 475 |
+
print("\n\n👋 Program interrupted by user. Goodbye!")
|
| 476 |
+
|
| 477 |
+
except Exception as e:
|
| 478 |
+
logger.exception("Fatal error occurred")
|
| 479 |
+
print(f"\n❌ Fatal error: {e}")
|
| 480 |
+
sys.exit(1)
|
| 481 |
+
|
| 482 |
+
if __name__ == "__main__":
|
| 483 |
+
main()
|
requirements.txt
CHANGED
|
@@ -6,6 +6,7 @@
|
|
| 6 |
# Web Framework
|
| 7 |
fastapi==0.118.2
|
| 8 |
uvicorn[standard]==0.37.0
|
|
|
|
| 9 |
|
| 10 |
# LangChain & LangGraph - Core only (dependencies will pull others)
|
| 11 |
langgraph==0.6.8
|
|
|
|
| 6 |
# Web Framework
|
| 7 |
fastapi==0.118.2
|
| 8 |
uvicorn[standard]==0.37.0
|
| 9 |
+
slowapi==0.1.9
|
| 10 |
|
| 11 |
# LangChain & LangGraph - Core only (dependencies will pull others)
|
| 12 |
langgraph==0.6.8
|