SAAHMATHWORKS commited on
Commit
f37bf1d
·
1 Parent(s): 8badae9

ready for hugging face space

Browse files
api/config.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # api/config.py
2
+ import os
3
+ import logging
4
+ from typing import Optional
5
+
6
+ # Setup logging
7
+ logging.basicConfig(
8
+ level=logging.INFO,
9
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
10
+ )
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class Settings:
15
+ """Application settings"""
16
+ APP_TITLE: str = "Legal Assistant API"
17
+ APP_VERSION: str = "2.0.0"
18
+ APP_DESCRIPTION: str = "Multi-country legal RAG with streaming & human-in-the-loop"
19
+
20
+ # CORS
21
+ CORS_ORIGINS: list = ["*"]
22
+
23
+ # API Settings
24
+ STREAM_DELAY: float = 0.02 # Delay between tokens in streaming
25
+
26
+ # System
27
+ chat_manager: Optional[object] = None
28
+ graph: Optional[object] = None
29
+ system_initialized: bool = False
30
+
31
+
32
+ settings = Settings()
api/main.py CHANGED
@@ -1,704 +1,39 @@
1
- # api/main.py
2
- import sys
3
- from pathlib import Path
4
- sys.path.insert(0, str(Path(__file__).parent.parent))
5
-
6
- from typing import Optional, Any, Dict, AsyncGenerator
7
- from contextlib import asynccontextmanager
8
- from fastapi import FastAPI, Query, HTTPException, Body
9
- from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse
10
  from fastapi.middleware.cors import CORSMiddleware
11
- from pydantic import BaseModel
12
- import json
13
- from uuid import uuid4
14
- import logging
15
- import os
16
- import asyncio
17
- from datetime import datetime
18
-
19
- # Import your existing system
20
- from core.system_initializer import setup_system
21
-
22
- # Setup logging
23
- logging.basicConfig(level=logging.INFO)
24
- logger = logging.getLogger(__name__)
25
-
26
- # Global variables
27
- chat_manager = None
28
- graph = None
29
- system_initialized = False
30
-
31
-
32
- # ============================================================================
33
- # Pydantic Models
34
- # ============================================================================
35
- class ChatRequest(BaseModel):
36
- message: str
37
- session_id: Optional[str] = None
38
- stream: bool = False # Option to enable/disable streaming
39
-
40
- class ApprovalRequest(BaseModel):
41
- decision: str # "approve" or "reject"
42
- reason: Optional[str] = None
43
-
44
- class ChatResponse(BaseModel):
45
- response: str
46
- session_id: str
47
- has_interrupt: bool = False
48
- interrupt_type: Optional[str] = None
49
- interrupt_data: Optional[Dict] = None
50
-
51
-
52
- # ============================================================================
53
- # System Initialization
54
- # ============================================================================
55
- async def initialize_system():
56
- global chat_manager, graph, system_initialized
57
- try:
58
- system = await setup_system()
59
- chat_manager = system["chat_manager"]
60
- graph = system["graph"]
61
- system_initialized = True
62
- logger.info("✅ Legal assistant system initialized")
63
- except Exception as e:
64
- logger.error(f"❌ Failed to initialize system: {e}")
65
- system_initialized = False
66
-
67
- @asynccontextmanager
68
- async def lifespan(app: FastAPI):
69
- logger.info("🚀 Starting Legal Assistant API...")
70
- initialization_task = asyncio.create_task(initialize_system())
71
- yield
72
- logger.info("🛑 Shutting down...")
73
- initialization_task.cancel()
74
- try:
75
- await initialization_task
76
- except asyncio.CancelledError:
77
- pass
78
-
79
 
80
- # ============================================================================
81
- # FastAPI App
82
- # ============================================================================
83
  app = FastAPI(
84
- title="Legal Assistant API",
85
- version="2.0.0",
86
- description="Multi-country legal RAG with streaming & human-in-the-loop",
87
  lifespan=lifespan
88
  )
89
 
 
90
  app.add_middleware(
91
  CORSMiddleware,
92
- allow_origins=["*"],
93
  allow_credentials=True,
94
  allow_methods=["*"],
95
  allow_headers=["*"],
96
  )
97
 
 
 
 
 
 
98
 
99
- # ============================================================================
100
- # Streaming Helper
101
- # ============================================================================
102
- async def stream_chat_response(
103
- message: str,
104
- session_id: str
105
- ) -> AsyncGenerator[str, None]:
106
- """
107
- Stream chat responses token by token.
108
- If interrupt occurs, yields special interrupt event.
109
- """
110
- try:
111
- # Check for pending interrupt first
112
- if chat_manager.has_pending_interrupt(session_id):
113
- interrupt_info = chat_manager.pending_interrupts.get(session_id, {})
114
- interrupt_data_obj = interrupt_info.get("interrupt_data")
115
-
116
- if hasattr(interrupt_data_obj, 'value'):
117
- interrupt_value = interrupt_data_obj.value
118
- else:
119
- interrupt_value = {}
120
-
121
- # Yield interrupt event
122
- yield f"data: {json.dumps({'type': 'interrupt', 'data': interrupt_value})}\n\n"
123
- return
124
-
125
- # Buffer to collect response
126
- response_buffer = []
127
-
128
- # Process message
129
- response = await chat_manager.chat(
130
- message=message,
131
- session_id=session_id,
132
- interrupt_handler=None # Async mode
133
- )
134
-
135
- # Check if interrupt occurred during processing
136
- if chat_manager.has_pending_interrupt(session_id):
137
- interrupt_info = chat_manager.pending_interrupts.get(session_id, {})
138
- interrupt_data_obj = interrupt_info.get("interrupt_data")
139
-
140
- if hasattr(interrupt_data_obj, 'value'):
141
- interrupt_value = interrupt_data_obj.value
142
- else:
143
- interrupt_value = {}
144
-
145
- # Yield interrupt event
146
- yield f"data: {json.dumps({'type': 'interrupt', 'data': interrupt_value})}\n\n"
147
- return
148
-
149
- # Stream response token by token (simulate streaming)
150
- words = response.split()
151
- for i, word in enumerate(words):
152
- token = word + (" " if i < len(words) - 1 else "")
153
- yield f"data: {json.dumps({'type': 'token', 'content': token})}\n\n"
154
- await asyncio.sleep(0.02) # Small delay for visual effect
155
-
156
- # Send completion event
157
- yield f"data: {json.dumps({'type': 'done', 'session_id': session_id})}\n\n"
158
-
159
- except Exception as e:
160
- logger.exception(f"Error in streaming: {e}")
161
- yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n"
162
-
163
-
164
- # ============================================================================
165
- # Routes
166
- # ============================================================================
167
- @app.get("/", response_class=HTMLResponse)
168
- async def read_root():
169
- """Interactive chat interface with streaming support"""
170
- return """
171
- <html>
172
- <head>
173
- <title>Legal Assistant - Interactive Chat</title>
174
- <meta charset="UTF-8">
175
- <style>
176
- * { margin: 0; padding: 0; box-sizing: border-box; }
177
- body {
178
- font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
179
- background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
180
- min-height: 100vh;
181
- padding: 20px;
182
- }
183
- .container {
184
- max-width: 900px;
185
- margin: 0 auto;
186
- background: white;
187
- border-radius: 16px;
188
- box-shadow: 0 20px 60px rgba(0,0,0,0.3);
189
- overflow: hidden;
190
- }
191
- .header {
192
- background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
193
- color: white;
194
- padding: 30px;
195
- text-align: center;
196
- }
197
- .header h1 { margin-bottom: 10px; }
198
- .header p { opacity: 0.9; }
199
- .chat-container { padding: 20px; }
200
- #messages {
201
- height: 500px;
202
- overflow-y: auto;
203
- padding: 20px;
204
- background: #f8f9fa;
205
- border-radius: 12px;
206
- margin-bottom: 20px;
207
- }
208
- .message {
209
- margin: 15px 0;
210
- padding: 12px 18px;
211
- border-radius: 18px;
212
- max-width: 75%;
213
- animation: slideIn 0.3s ease;
214
- word-wrap: break-word;
215
- }
216
- @keyframes slideIn {
217
- from { opacity: 0; transform: translateY(10px); }
218
- to { opacity: 1; transform: translateY(0); }
219
- }
220
- .user-message {
221
- background: #667eea;
222
- color: white;
223
- margin-left: auto;
224
- border-bottom-right-radius: 4px;
225
- }
226
- .assistant-message {
227
- background: white;
228
- color: #333;
229
- border: 1px solid #e0e0e0;
230
- border-bottom-left-radius: 4px;
231
- }
232
- .system-message {
233
- background: #fff3cd;
234
- color: #856404;
235
- border: 1px solid #ffc107;
236
- text-align: center;
237
- margin: 10px auto;
238
- max-width: 90%;
239
- }
240
- .interrupt-panel {
241
- background: #fff3cd;
242
- border: 2px solid #ffc107;
243
- border-radius: 12px;
244
- padding: 20px;
245
- margin: 20px 0;
246
- animation: pulse 2s infinite;
247
- }
248
- @keyframes pulse {
249
- 0%, 100% { box-shadow: 0 0 0 0 rgba(255, 193, 7, 0.4); }
250
- 50% { box-shadow: 0 0 0 10px rgba(255, 193, 7, 0); }
251
- }
252
- .interrupt-panel h3 {
253
- color: #856404;
254
- margin-bottom: 15px;
255
- }
256
- .interrupt-details {
257
- background: white;
258
- padding: 15px;
259
- border-radius: 8px;
260
- margin: 15px 0;
261
- }
262
- .interrupt-details p {
263
- margin: 8px 0;
264
- color: #333;
265
- }
266
- .approval-buttons {
267
- display: flex;
268
- gap: 10px;
269
- margin-top: 15px;
270
- }
271
- .btn {
272
- flex: 1;
273
- padding: 12px 24px;
274
- border: none;
275
- border-radius: 8px;
276
- font-weight: bold;
277
- cursor: pointer;
278
- transition: all 0.3s;
279
- font-size: 14px;
280
- }
281
- .btn-approve {
282
- background: #4caf50;
283
- color: white;
284
- }
285
- .btn-approve:hover {
286
- background: #45a049;
287
- transform: translateY(-2px);
288
- box-shadow: 0 4px 12px rgba(76, 175, 80, 0.4);
289
- }
290
- .btn-reject {
291
- background: #f44336;
292
- color: white;
293
- }
294
- .btn-reject:hover {
295
- background: #da190b;
296
- transform: translateY(-2px);
297
- box-shadow: 0 4px 12px rgba(244, 67, 54, 0.4);
298
- }
299
- .input-area {
300
- display: flex;
301
- gap: 10px;
302
- padding: 10px;
303
- background: #f8f9fa;
304
- border-radius: 12px;
305
- }
306
- #message-input {
307
- flex: 1;
308
- padding: 12px 16px;
309
- border: 2px solid #e0e0e0;
310
- border-radius: 8px;
311
- font-size: 14px;
312
- transition: border 0.3s;
313
- }
314
- #message-input:focus {
315
- outline: none;
316
- border-color: #667eea;
317
- }
318
- #send-btn {
319
- padding: 12px 32px;
320
- background: #667eea;
321
- color: white;
322
- border: none;
323
- border-radius: 8px;
324
- font-weight: bold;
325
- cursor: pointer;
326
- transition: all 0.3s;
327
- }
328
- #send-btn:hover:not(:disabled) {
329
- background: #5568d3;
330
- transform: translateY(-2px);
331
- box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4);
332
- }
333
- #send-btn:disabled {
334
- background: #ccc;
335
- cursor: not-allowed;
336
- transform: none;
337
- }
338
- .status-bar {
339
- padding: 10px;
340
- background: #f8f9fa;
341
- border-radius: 8px;
342
- margin-top: 10px;
343
- display: flex;
344
- justify-content: space-between;
345
- font-size: 12px;
346
- color: #666;
347
- }
348
- .typing-indicator {
349
- display: none;
350
- padding: 10px;
351
- color: #666;
352
- font-style: italic;
353
- }
354
- .typing-indicator.active {
355
- display: block;
356
- }
357
- </style>
358
- </head>
359
- <body>
360
- <div class="container">
361
- <div class="header">
362
- <h1>🧑‍⚖️ Legal Assistant</h1>
363
- <p>Multi-country legal support for Benin & Madagascar</p>
364
- </div>
365
-
366
- <div class="chat-container">
367
- <div id="messages"></div>
368
- <div id="interrupt-zone"></div>
369
- <div class="typing-indicator" id="typing">Assistant is typing...</div>
370
-
371
- <div class="input-area">
372
- <input
373
- type="text"
374
- id="message-input"
375
- placeholder="Type your legal question..."
376
- autocomplete="off"
377
- />
378
- <button id="send-btn" onclick="sendMessage()">
379
- Send
380
- </button>
381
- </div>
382
-
383
- <div class="status-bar">
384
- <span>Session: <strong id="session-id">Not started</strong></span>
385
- <span id="status">System ready</span>
386
- </div>
387
- </div>
388
- </div>
389
-
390
- <script>
391
- let sessionId = null;
392
- let currentAssistantMessage = null;
393
-
394
- function addMessage(content, type) {
395
- const messagesDiv = document.getElementById('messages');
396
- const messageDiv = document.createElement('div');
397
- messageDiv.className = `message ${type}-message`;
398
- messageDiv.textContent = content;
399
- messagesDiv.appendChild(messageDiv);
400
- messagesDiv.scrollTop = messagesDiv.scrollHeight;
401
- return messageDiv;
402
- }
403
-
404
- function showInterrupt(data) {
405
- const interruptZone = document.getElementById('interrupt-zone');
406
- interruptZone.innerHTML = `
407
- <div class="interrupt-panel">
408
- <h3>🔒 Human Approval Required</h3>
409
- <div class="interrupt-details">
410
- <p><strong>📧 Email:</strong> ${data.user_email || 'N/A'}</p>
411
- <p><strong>🌍 Country:</strong> ${data.country || 'N/A'}</p>
412
- <p><strong>📝 Description:</strong> ${data.description || 'N/A'}</p>
413
- </div>
414
- <div class="approval-buttons">
415
- <button class="btn btn-approve" onclick="handleApproval('approve')">
416
- ✅ Approve Request
417
- </button>
418
- <button class="btn btn-reject" onclick="handleApproval('reject')">
419
- ❌ Reject Request
420
- </button>
421
- </div>
422
- </div>
423
- `;
424
- }
425
-
426
- function hideInterrupt() {
427
- document.getElementById('interrupt-zone').innerHTML = '';
428
- }
429
-
430
- async function handleApproval(decision) {
431
- const reason = decision === 'approve'
432
- ? 'Approved via web interface'
433
- : 'Rejected via web interface';
434
-
435
- try {
436
- addMessage(`${decision === 'approve' ? '✅' : '❌'} Decision: ${decision}`, 'system');
437
- hideInterrupt();
438
-
439
- const response = await fetch(`/sessions/${sessionId}/approve`, {
440
- method: 'POST',
441
- headers: { 'Content-Type': 'application/json' },
442
- body: JSON.stringify({ decision, reason })
443
- });
444
-
445
- const data = await response.json();
446
- addMessage(data.response, 'assistant');
447
-
448
- } catch (error) {
449
- addMessage('Error: ' + error.message, 'system');
450
- }
451
- }
452
-
453
- async function sendMessage() {
454
- const input = document.getElementById('message-input');
455
- const message = input.value.trim();
456
-
457
- if (!message) return;
458
-
459
- // Initialize session
460
- if (!sessionId) {
461
- sessionId = 'web_' + Date.now();
462
- document.getElementById('session-id').textContent = sessionId;
463
- }
464
-
465
- // Add user message
466
- addMessage(message, 'user');
467
- input.value = '';
468
-
469
- // Disable input
470
- const sendBtn = document.getElementById('send-btn');
471
- sendBtn.disabled = true;
472
- input.disabled = true;
473
-
474
- // Show typing indicator
475
- document.getElementById('typing').classList.add('active');
476
-
477
- try {
478
- // Use streaming endpoint
479
- const response = await fetch('/chat/stream', {
480
- method: 'POST',
481
- headers: { 'Content-Type': 'application/json' },
482
- body: JSON.stringify({
483
- message: message,
484
- session_id: sessionId,
485
- stream: true
486
- })
487
- });
488
-
489
- const reader = response.body.getReader();
490
- const decoder = new TextDecoder();
491
- currentAssistantMessage = addMessage('', 'assistant');
492
-
493
- while (true) {
494
- const { value, done } = await reader.read();
495
- if (done) break;
496
-
497
- const chunk = decoder.decode(value);
498
- const lines = chunk.split('\\n\\n');
499
-
500
- for (const line of lines) {
501
- if (line.startsWith('data: ')) {
502
- try {
503
- const data = JSON.parse(line.slice(6));
504
-
505
- if (data.type === 'token') {
506
- currentAssistantMessage.textContent += data.content;
507
- } else if (data.type === 'interrupt') {
508
- showInterrupt(data.data);
509
- } else if (data.type === 'done') {
510
- document.getElementById('status').textContent = 'Ready';
511
- } else if (data.type === 'error') {
512
- addMessage('Error: ' + data.message, 'system');
513
- }
514
- } catch (e) {
515
- console.error('Parse error:', e);
516
- }
517
- }
518
- }
519
- }
520
-
521
- } catch (error) {
522
- addMessage('Error: ' + error.message, 'system');
523
- } finally {
524
- sendBtn.disabled = false;
525
- input.disabled = false;
526
- input.focus();
527
- document.getElementById('typing').classList.remove('active');
528
- }
529
- }
530
-
531
- // Enter key to send
532
- document.getElementById('message-input').addEventListener('keypress', (e) => {
533
- if (e.key === 'Enter' && !e.shiftKey) {
534
- e.preventDefault();
535
- sendMessage();
536
- }
537
- });
538
-
539
- // Focus input on load
540
- document.getElementById('message-input').focus();
541
- </script>
542
- </body>
543
- </html>
544
- """
545
-
546
- @app.get("/health")
547
- async def health_check():
548
- """Health check endpoint"""
549
- return {
550
- "status": "healthy" if system_initialized else "starting",
551
- "system_initialized": system_initialized,
552
- "timestamp": datetime.now().isoformat()
553
- }
554
-
555
- @app.post("/chat")
556
- async def chat_simple(request: ChatRequest):
557
- """
558
- Simple non-streaming chat endpoint.
559
- Best for handling interrupts - returns complete response.
560
- """
561
- if not system_initialized or not chat_manager:
562
- raise HTTPException(status_code=503, detail="System initializing...")
563
-
564
- try:
565
- session_id = request.session_id or f"api_{uuid4()}"
566
-
567
- # Check for pending interrupt
568
- if chat_manager.has_pending_interrupt(session_id):
569
- interrupt_info = chat_manager.pending_interrupts.get(session_id, {})
570
- interrupt_data_obj = interrupt_info.get("interrupt_data")
571
-
572
- if hasattr(interrupt_data_obj, 'value'):
573
- interrupt_value = interrupt_data_obj.value
574
- else:
575
- interrupt_value = {}
576
-
577
- return ChatResponse(
578
- response="⏸️ Pending approval request. Please approve/reject first.",
579
- session_id=session_id,
580
- has_interrupt=True,
581
- interrupt_type="human_approval",
582
- interrupt_data=interrupt_value
583
- )
584
-
585
- # Process message
586
- response_text = await chat_manager.chat(
587
- message=request.message,
588
- session_id=session_id
589
- )
590
-
591
- # Check if interrupt occurred
592
- has_interrupt = chat_manager.has_pending_interrupt(session_id)
593
- interrupt_value = None
594
-
595
- if has_interrupt:
596
- interrupt_info = chat_manager.pending_interrupts.get(session_id, {})
597
- interrupt_data_obj = interrupt_info.get("interrupt_data")
598
-
599
- if hasattr(interrupt_data_obj, 'value'):
600
- interrupt_value = interrupt_data_obj.value
601
- elif isinstance(interrupt_data_obj, dict):
602
- interrupt_value = interrupt_data_obj.get("value", {})
603
-
604
- return ChatResponse(
605
- response=response_text,
606
- session_id=session_id,
607
- has_interrupt=has_interrupt,
608
- interrupt_type="human_approval" if has_interrupt else None,
609
- interrupt_data=interrupt_value
610
- )
611
-
612
- except Exception as e:
613
- logger.exception(f"Error in chat: {e}")
614
- raise HTTPException(status_code=500, detail=str(e))
615
-
616
- @app.post("/chat/stream")
617
- async def chat_stream(request: ChatRequest):
618
- """
619
- Streaming chat endpoint using Server-Sent Events (SSE).
620
- Streams tokens in real-time and handles interrupts.
621
- """
622
- if not system_initialized or not chat_manager:
623
- raise HTTPException(status_code=503, detail="System initializing...")
624
-
625
- session_id = request.session_id or f"api_{uuid4()}"
626
-
627
- return StreamingResponse(
628
- stream_chat_response(request.message, session_id),
629
- media_type="text/event-stream",
630
- headers={
631
- "Cache-Control": "no-cache",
632
- "Connection": "keep-alive",
633
- "X-Accel-Buffering": "no"
634
- }
635
- )
636
-
637
- @app.post("/sessions/{session_id}/approve")
638
- async def approve_assistance(session_id: str, request: ApprovalRequest):
639
- """Approve or reject assistance request"""
640
- if not chat_manager:
641
- raise HTTPException(status_code=503, detail="System not initialized")
642
-
643
- if not chat_manager.has_pending_interrupt(session_id):
644
- raise HTTPException(status_code=404, detail="No pending interrupt")
645
-
646
- try:
647
- decision_text = f"{request.decision} {request.reason or ''}"
648
- response_text = await chat_manager.chat(
649
- message=decision_text,
650
- session_id=session_id
651
- )
652
-
653
- return {
654
- "status": "success",
655
- "decision": request.decision,
656
- "response": response_text,
657
- "session_id": session_id
658
- }
659
- except Exception as e:
660
- logger.exception(f"Error handling approval: {e}")
661
- raise HTTPException(status_code=500, detail=str(e))
662
-
663
- @app.get("/sessions/{session_id}/status")
664
- async def get_session_status(session_id: str):
665
- """Get session status including interrupt state"""
666
- if not chat_manager:
667
- raise HTTPException(status_code=503, detail="System not initialized")
668
-
669
- has_interrupt = chat_manager.has_pending_interrupt(session_id)
670
-
671
- interrupt_data = None
672
- if has_interrupt:
673
- interrupt_info = chat_manager.pending_interrupts.get(session_id, {})
674
- interrupt_data_obj = interrupt_info.get("interrupt_data")
675
-
676
- if hasattr(interrupt_data_obj, 'value'):
677
- interrupt_data = interrupt_data_obj.value
678
-
679
- return {
680
- "session_id": session_id,
681
- "has_pending_interrupt": has_interrupt,
682
- "interrupt_data": interrupt_data
683
- }
684
 
685
- @app.get("/sessions/{session_id}/history")
686
- async def get_history(session_id: str):
687
- """Get conversation history"""
688
- if not chat_manager:
689
- raise HTTPException(status_code=503, detail="System not initialized")
690
-
691
- try:
692
- history = await chat_manager.get_conversation_history(session_id)
693
- return {
694
- "session_id": session_id,
695
- "history": [
696
- {
697
- "role": msg.type if hasattr(msg, 'type') else "unknown",
698
- "content": msg.content if hasattr(msg, 'content') else str(msg)
699
- }
700
- for msg in history
701
- ]
702
- }
703
- except Exception as e:
704
- raise HTTPException(status_code=500, detail=str(e))
 
1
+ # api/main.py - MAIN FILE (Clean & Simple)
2
+ from fastapi import FastAPI
 
 
 
 
 
 
 
3
  from fastapi.middleware.cors import CORSMiddleware
4
+ from api.config import settings
5
+ from api.utils import lifespan
6
+ from api.routes import (
7
+ home_router,
8
+ chat_router,
9
+ sessions_router,
10
+ health_router
11
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ # Create FastAPI app
 
 
14
  app = FastAPI(
15
+ title=settings.APP_TITLE,
16
+ version=settings.APP_VERSION,
17
+ description=settings.APP_DESCRIPTION,
18
  lifespan=lifespan
19
  )
20
 
21
+ # Add CORS middleware
22
  app.add_middleware(
23
  CORSMiddleware,
24
+ allow_origins=settings.CORS_ORIGINS,
25
  allow_credentials=True,
26
  allow_methods=["*"],
27
  allow_headers=["*"],
28
  )
29
 
30
+ # Include routers
31
+ app.include_router(home_router)
32
+ app.include_router(chat_router)
33
+ app.include_router(sessions_router)
34
+ app.include_router(health_router)
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ if __name__ == "__main__":
38
+ import uvicorn
39
+ uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
api/models/__init__.py CHANGED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # api/models/__init__.py
2
+ from .schemas import ChatRequest, ApprovalRequest, ChatResponse
3
+
4
+ __all__ = ["ChatRequest", "ApprovalRequest", "ChatResponse"]
api/models/schemas.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # api/models/schemas.py
2
+ from pydantic import BaseModel
3
+ from typing import Optional, Dict
4
+
5
+
6
+ class ChatRequest(BaseModel):
7
+ message: str
8
+ session_id: Optional[str] = None
9
+ stream: bool = False
10
+
11
+
12
+ class ApprovalRequest(BaseModel):
13
+ decision: str # "approve" or "reject"
14
+ reason: Optional[str] = None
15
+
16
+
17
+ class ChatResponse(BaseModel):
18
+ response: str
19
+ session_id: str
20
+ has_interrupt: bool = False
21
+ interrupt_type: Optional[str] = None
22
+ interrupt_data: Optional[Dict] = None
api/routes/__init__.py CHANGED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # api/routes/__init__.py
2
+ from .home import router as home_router
3
+ from .chat import router as chat_router
4
+ from .sessions import router as sessions_router
5
+ from .health import router as health_router
6
+
7
+ __all__ = ["home_router", "chat_router", "sessions_router", "health_router"]
api/routes/chat.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, HTTPException
2
+ from fastapi.responses import StreamingResponse
3
+ from api.models import ChatRequest, ChatResponse
4
+ from api.services import ChatService, StreamService
5
+ from api.config import settings
6
+ from uuid import uuid4
7
+
8
+ router = APIRouter(prefix="/chat", tags=["Chat"])
9
+
10
+
11
+ @router.post("", response_model=ChatResponse)
12
+ async def chat_simple(request: ChatRequest):
13
+ """
14
+ Simple non-streaming chat endpoint.
15
+ Best for handling interrupts - returns complete response.
16
+ """
17
+ if not settings.system_initialized or not settings.chat_manager:
18
+ raise HTTPException(status_code=503, detail="System initializing...")
19
+
20
+ try:
21
+ return await ChatService.process_message(
22
+ message=request.message,
23
+ session_id=request.session_id
24
+ )
25
+ except Exception as e:
26
+ raise HTTPException(status_code=500, detail=str(e))
27
+
28
+
29
+ @router.post("/stream")
30
+ async def chat_stream(request: ChatRequest):
31
+ """
32
+ Streaming chat endpoint using Server-Sent Events (SSE).
33
+ Streams tokens in real-time and handles interrupts.
34
+ """
35
+ if not settings.system_initialized or not settings.chat_manager:
36
+ raise HTTPException(status_code=503, detail="System initializing...")
37
+
38
+ session_id = request.session_id or f"api_{uuid4()}"
39
+
40
+ return StreamingResponse(
41
+ StreamService.stream_chat_response(request.message, session_id),
42
+ media_type="text/event-stream",
43
+ headers={
44
+ "Cache-Control": "no-cache",
45
+ "Connection": "keep-alive",
46
+ "X-Accel-Buffering": "no"
47
+ }
48
+ )
api/routes/health.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # api/routes/health.py
2
+ from fastapi import APIRouter
3
+ from datetime import datetime
4
+ from api.config import settings
5
+
6
+ router = APIRouter(tags=["Health"])
7
+
8
+
9
+ @router.get("/health")
10
+ async def health_check():
11
+ """Health check endpoint"""
12
+ return {
13
+ "status": "healthy" if settings.system_initialized else "starting",
14
+ "system_initialized": settings.system_initialized,
15
+ "timestamp": datetime.now().isoformat()
16
+ }
api/routes/home.py ADDED
@@ -0,0 +1,521 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter
2
+ from fastapi.responses import HTMLResponse
3
+
4
+ router = APIRouter(tags=["Home"])
5
+
6
+
7
+ @router.get("/", response_class=HTMLResponse)
8
+ async def root():
9
+ """Interactive chat interface - Main UI"""
10
+ return """
11
+ <html>
12
+ <head>
13
+ <title>Legal Assistant - Interactive Chat</title>
14
+ <meta charset="UTF-8">
15
+ <style>
16
+ * { margin: 0; padding: 0; box-sizing: border-box; }
17
+ body {
18
+ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
19
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
20
+ min-height: 100vh;
21
+ padding: 20px;
22
+ }
23
+ .container {
24
+ max-width: 900px;
25
+ margin: 0 auto;
26
+ background: white;
27
+ border-radius: 16px;
28
+ box-shadow: 0 20px 60px rgba(0,0,0,0.3);
29
+ overflow: hidden;
30
+ }
31
+ .header {
32
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
33
+ color: white;
34
+ padding: 30px;
35
+ text-align: center;
36
+ position: relative;
37
+ }
38
+ .header h1 { margin-bottom: 10px; }
39
+ .header p { opacity: 0.9; }
40
+ .api-link {
41
+ position: absolute;
42
+ top: 20px;
43
+ right: 20px;
44
+ padding: 8px 16px;
45
+ background: rgba(255,255,255,0.2);
46
+ color: white;
47
+ text-decoration: none;
48
+ border-radius: 5px;
49
+ font-size: 0.9em;
50
+ transition: all 0.3s;
51
+ }
52
+ .api-link:hover {
53
+ background: rgba(255,255,255,0.3);
54
+ transform: translateY(-2px);
55
+ }
56
+ .chat-container { padding: 20px; }
57
+ #messages {
58
+ height: 500px;
59
+ overflow-y: auto;
60
+ padding: 20px;
61
+ background: #f8f9fa;
62
+ border-radius: 12px;
63
+ margin-bottom: 20px;
64
+ }
65
+ .message {
66
+ margin: 15px 0;
67
+ padding: 12px 18px;
68
+ border-radius: 18px;
69
+ max-width: 75%;
70
+ animation: slideIn 0.3s ease;
71
+ word-wrap: break-word;
72
+ }
73
+ @keyframes slideIn {
74
+ from { opacity: 0; transform: translateY(10px); }
75
+ to { opacity: 1; transform: translateY(0); }
76
+ }
77
+ .user-message {
78
+ background: #667eea;
79
+ color: white;
80
+ margin-left: auto;
81
+ border-bottom-right-radius: 4px;
82
+ }
83
+ .assistant-message {
84
+ background: white;
85
+ color: #333;
86
+ border: 1px solid #e0e0e0;
87
+ border-bottom-left-radius: 4px;
88
+ }
89
+ .system-message {
90
+ background: #fff3cd;
91
+ color: #856404;
92
+ border: 1px solid #ffc107;
93
+ text-align: center;
94
+ margin: 10px auto;
95
+ max-width: 90%;
96
+ }
97
+ .interrupt-panel {
98
+ background: #fff3cd;
99
+ border: 2px solid #ffc107;
100
+ border-radius: 12px;
101
+ padding: 20px;
102
+ margin: 20px 0;
103
+ animation: pulse 2s infinite;
104
+ }
105
+ @keyframes pulse {
106
+ 0%, 100% { box-shadow: 0 0 0 0 rgba(255, 193, 7, 0.4); }
107
+ 50% { box-shadow: 0 0 0 10px rgba(255, 193, 7, 0); }
108
+ }
109
+ .interrupt-panel h3 {
110
+ color: #856404;
111
+ margin-bottom: 15px;
112
+ }
113
+ .interrupt-details {
114
+ background: white;
115
+ padding: 15px;
116
+ border-radius: 8px;
117
+ margin: 15px 0;
118
+ }
119
+ .interrupt-details p {
120
+ margin: 8px 0;
121
+ color: #333;
122
+ }
123
+ .approval-buttons {
124
+ display: flex;
125
+ gap: 10px;
126
+ margin-top: 15px;
127
+ }
128
+ .btn {
129
+ flex: 1;
130
+ padding: 12px 24px;
131
+ border: none;
132
+ border-radius: 8px;
133
+ font-weight: bold;
134
+ cursor: pointer;
135
+ transition: all 0.3s;
136
+ font-size: 14px;
137
+ }
138
+ .btn-approve {
139
+ background: #4caf50;
140
+ color: white;
141
+ }
142
+ .btn-approve:hover {
143
+ background: #45a049;
144
+ transform: translateY(-2px);
145
+ box-shadow: 0 4px 12px rgba(76, 175, 80, 0.4);
146
+ }
147
+ .btn-reject {
148
+ background: #f44336;
149
+ color: white;
150
+ }
151
+ .btn-reject:hover {
152
+ background: #da190b;
153
+ transform: translateY(-2px);
154
+ box-shadow: 0 4px 12px rgba(244, 67, 54, 0.4);
155
+ }
156
+ .input-area {
157
+ display: flex;
158
+ gap: 10px;
159
+ padding: 10px;
160
+ background: #f8f9fa;
161
+ border-radius: 12px;
162
+ }
163
+ #message-input {
164
+ flex: 1;
165
+ padding: 12px 16px;
166
+ border: 2px solid #e0e0e0;
167
+ border-radius: 8px;
168
+ font-size: 14px;
169
+ transition: border 0.3s;
170
+ }
171
+ #message-input:focus {
172
+ outline: none;
173
+ border-color: #667eea;
174
+ }
175
+ #send-btn {
176
+ padding: 12px 32px;
177
+ background: #667eea;
178
+ color: white;
179
+ border: none;
180
+ border-radius: 8px;
181
+ font-weight: bold;
182
+ cursor: pointer;
183
+ transition: all 0.3s;
184
+ }
185
+ #send-btn:hover:not(:disabled) {
186
+ background: #5568d3;
187
+ transform: translateY(-2px);
188
+ box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4);
189
+ }
190
+ #send-btn:disabled {
191
+ background: #ccc;
192
+ cursor: not-allowed;
193
+ transform: none;
194
+ }
195
+ .status-bar {
196
+ padding: 10px;
197
+ background: #f8f9fa;
198
+ border-radius: 8px;
199
+ margin-top: 10px;
200
+ display: flex;
201
+ justify-content: space-between;
202
+ font-size: 12px;
203
+ color: #666;
204
+ }
205
+ .typing-indicator {
206
+ display: none;
207
+ padding: 10px;
208
+ color: #666;
209
+ font-style: italic;
210
+ }
211
+ .typing-indicator.active {
212
+ display: block;
213
+ }
214
+ </style>
215
+ </head>
216
+ <body>
217
+ <div class="container">
218
+ <div class="header">
219
+ <a href="/docs" class="api-link">📚 API Docs</a>
220
+ <h1>🧑‍⚖️ Legal Assistant</h1>
221
+ <p>Multi-country legal support for Benin & Madagascar</p>
222
+ </div>
223
+
224
+ <div class="chat-container">
225
+ <div id="messages"></div>
226
+ <div id="interrupt-zone"></div>
227
+ <div class="typing-indicator" id="typing">Assistant is typing...</div>
228
+
229
+ <div class="input-area">
230
+ <input
231
+ type="text"
232
+ id="message-input"
233
+ placeholder="Type your legal question..."
234
+ autocomplete="off"
235
+ />
236
+ <button id="send-btn" onclick="sendMessage()">
237
+ Send
238
+ </button>
239
+ </div>
240
+
241
+ <div class="status-bar">
242
+ <span>Session: <strong id="session-id">Not started</strong></span>
243
+ <span id="status">System ready</span>
244
+ </div>
245
+ </div>
246
+ </div>
247
+
248
+ <script>
249
+ let sessionId = null;
250
+ let currentAssistantMessage = null;
251
+
252
+ function addMessage(content, type) {
253
+ const messagesDiv = document.getElementById('messages');
254
+ const messageDiv = document.createElement('div');
255
+ messageDiv.className = `message ${type}-message`;
256
+ messageDiv.textContent = content;
257
+ messagesDiv.appendChild(messageDiv);
258
+ messagesDiv.scrollTop = messagesDiv.scrollHeight;
259
+ return messageDiv;
260
+ }
261
+
262
+ function showInterrupt(data) {
263
+ const interruptZone = document.getElementById('interrupt-zone');
264
+ interruptZone.innerHTML = `
265
+ <div class="interrupt-panel">
266
+ <h3>🔒 Human Approval Required</h3>
267
+ <div class="interrupt-details">
268
+ <p><strong>📧 Email:</strong> ${data.user_email || 'N/A'}</p>
269
+ <p><strong>🌍 Country:</strong> ${data.country || 'N/A'}</p>
270
+ <p><strong>📝 Description:</strong> ${data.description || 'N/A'}</p>
271
+ </div>
272
+ <div class="approval-buttons">
273
+ <button class="btn btn-approve" onclick="handleApproval('approve')">
274
+ ✅ Approve Request
275
+ </button>
276
+ <button class="btn btn-reject" onclick="handleApproval('reject')">
277
+ ❌ Reject Request
278
+ </button>
279
+ </div>
280
+ </div>
281
+ `;
282
+ }
283
+
284
+ function hideInterrupt() {
285
+ document.getElementById('interrupt-zone').innerHTML = '';
286
+ }
287
+
288
+ async function handleApproval(decision) {
289
+ const reason = decision === 'approve'
290
+ ? 'Approved via web interface'
291
+ : 'Rejected via web interface';
292
+
293
+ try {
294
+ addMessage(`${decision === 'approve' ? '✅' : '❌'} Decision: ${decision}`, 'system');
295
+ hideInterrupt();
296
+
297
+ const response = await fetch(`/sessions/${sessionId}/approve`, {
298
+ method: 'POST',
299
+ headers: { 'Content-Type': 'application/json' },
300
+ body: JSON.stringify({ decision, reason })
301
+ });
302
+
303
+ const data = await response.json();
304
+ addMessage(data.response, 'assistant');
305
+
306
+ } catch (error) {
307
+ addMessage('Error: ' + error.message, 'system');
308
+ }
309
+ }
310
+
311
+ async function sendMessage() {
312
+ const input = document.getElementById('message-input');
313
+ const message = input.value.trim();
314
+
315
+ if (!message) return;
316
+
317
+ if (!sessionId) {
318
+ sessionId = 'web_' + Date.now();
319
+ document.getElementById('session-id').textContent = sessionId;
320
+ }
321
+
322
+ addMessage(message, 'user');
323
+ input.value = '';
324
+
325
+ const sendBtn = document.getElementById('send-btn');
326
+ sendBtn.disabled = true;
327
+ input.disabled = true;
328
+
329
+ document.getElementById('typing').classList.add('active');
330
+
331
+ try {
332
+ const response = await fetch('/chat/stream', {
333
+ method: 'POST',
334
+ headers: { 'Content-Type': 'application/json' },
335
+ body: JSON.stringify({
336
+ message: message,
337
+ session_id: sessionId,
338
+ stream: true
339
+ })
340
+ });
341
+
342
+ const reader = response.body.getReader();
343
+ const decoder = new TextDecoder();
344
+ currentAssistantMessage = addMessage('', 'assistant');
345
+
346
+ while (true) {
347
+ const { value, done } = await reader.read();
348
+ if (done) break;
349
+
350
+ const chunk = decoder.decode(value);
351
+ const lines = chunk.split('\\n\\n');
352
+
353
+ for (const line of lines) {
354
+ if (line.startsWith('data: ')) {
355
+ try {
356
+ const data = JSON.parse(line.slice(6));
357
+
358
+ if (data.type === 'token') {
359
+ currentAssistantMessage.textContent += data.content;
360
+ } else if (data.type === 'interrupt') {
361
+ showInterrupt(data.data);
362
+ } else if (data.type === 'done') {
363
+ document.getElementById('status').textContent = 'Ready';
364
+ } else if (data.type === 'error') {
365
+ addMessage('Error: ' + data.message, 'system');
366
+ }
367
+ } catch (e) {
368
+ console.error('Parse error:', e);
369
+ }
370
+ }
371
+ }
372
+ }
373
+
374
+ } catch (error) {
375
+ addMessage('Error: ' + error.message, 'system');
376
+ } finally {
377
+ sendBtn.disabled = false;
378
+ input.disabled = false;
379
+ input.focus();
380
+ document.getElementById('typing').classList.remove('active');
381
+ }
382
+ }
383
+
384
+ document.getElementById('message-input').addEventListener('keypress', (e) => {
385
+ if (e.key === 'Enter' && !e.shiftKey) {
386
+ e.preventDefault();
387
+ sendMessage();
388
+ }
389
+ });
390
+
391
+ document.getElementById('message-input').focus();
392
+ </script>
393
+ </body>
394
+ </html>
395
+ """
396
+
397
+
398
+ @router.get("/about", response_class=HTMLResponse)
399
+ async def about():
400
+ """About page with API information"""
401
+ return """
402
+ <!DOCTYPE html>
403
+ <html lang="en">
404
+ <head>
405
+ <meta charset="UTF-8">
406
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
407
+ <title>Legal Assistant API</title>
408
+ <style>
409
+ * { margin: 0; padding: 0; box-sizing: border-box; }
410
+ body {
411
+ font-family: 'Segoe UI', sans-serif;
412
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
413
+ min-height: 100vh;
414
+ display: flex;
415
+ align-items: center;
416
+ justify-content: center;
417
+ padding: 20px;
418
+ }
419
+ .container {
420
+ max-width: 900px;
421
+ width: 100%;
422
+ background: white;
423
+ border-radius: 20px;
424
+ box-shadow: 0 20px 60px rgba(0,0,0,0.3);
425
+ overflow: hidden;
426
+ }
427
+ .header {
428
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
429
+ color: white;
430
+ padding: 60px 40px;
431
+ text-align: center;
432
+ }
433
+ .header h1 { font-size: 2.5em; margin-bottom: 15px; }
434
+ .header p { font-size: 1.2em; opacity: 0.95; }
435
+ .version {
436
+ background: rgba(255,255,255,0.2);
437
+ padding: 5px 15px;
438
+ border-radius: 20px;
439
+ margin-top: 15px;
440
+ display: inline-block;
441
+ }
442
+ .content { padding: 40px; }
443
+ .features {
444
+ display: grid;
445
+ grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
446
+ gap: 20px;
447
+ margin-bottom: 40px;
448
+ }
449
+ .feature-card {
450
+ padding: 25px;
451
+ background: #f8f9fa;
452
+ border-radius: 12px;
453
+ border-left: 4px solid #667eea;
454
+ }
455
+ .feature-card h3 { color: #667eea; margin-bottom: 10px; }
456
+ .cta-section {
457
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
458
+ padding: 40px;
459
+ border-radius: 12px;
460
+ text-align: center;
461
+ }
462
+ .cta-section h2 { color: white; margin-bottom: 25px; }
463
+ .buttons { display: flex; gap: 15px; justify-content: center; flex-wrap: wrap; }
464
+ .btn {
465
+ padding: 15px 35px;
466
+ border-radius: 8px;
467
+ text-decoration: none;
468
+ font-weight: 600;
469
+ transition: all 0.3s;
470
+ }
471
+ .btn-primary { background: white; color: #667eea; }
472
+ .btn-primary:hover { transform: translateY(-3px); }
473
+ .footer {
474
+ padding: 30px;
475
+ background: #f8f9fa;
476
+ text-align: center;
477
+ color: #666;
478
+ }
479
+ </style>
480
+ </head>
481
+ <body>
482
+ <div class="container">
483
+ <div class="header">
484
+ <h1>⚖️ Legal Assistant API</h1>
485
+ <p>Multi-country legal support for Benin & Madagascar</p>
486
+ <span class="version">v2.0.0</span>
487
+ </div>
488
+ <div class="content">
489
+ <div class="features">
490
+ <div class="feature-card">
491
+ <h3>🤖 AI-Powered</h3>
492
+ <p>Advanced RAG technology</p>
493
+ </div>
494
+ <div class="feature-card">
495
+ <h3>🌍 Multi-Country</h3>
496
+ <p>Benin & Madagascar laws</p>
497
+ </div>
498
+ <div class="feature-card">
499
+ <h3>🔒 Human Approval</h3>
500
+ <p>Quality assurance</p>
501
+ </div>
502
+ <div class="feature-card">
503
+ <h3>⚡ Real-time</h3>
504
+ <p>Streaming responses</p>
505
+ </div>
506
+ </div>
507
+ <div class="cta-section">
508
+ <h2>Get Started</h2>
509
+ <div class="buttons">
510
+ <a href="/docs" class="btn btn-primary">📚 Swagger UI</a>
511
+ <a href="/redoc" class="btn btn-primary">📖 ReDoc</a>
512
+ </div>
513
+ </div>
514
+ </div>
515
+ <div class="footer">
516
+ <p>Built with FastAPI & LangGraph</p>
517
+ </div>
518
+ </div>
519
+ </body>
520
+ </html>
521
+ """
api/routes/sessions.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # api/routes/sessions.py
2
+ from fastapi import APIRouter, HTTPException
3
+ from api.models import ApprovalRequest
4
+ from api.services import ChatService
5
+
6
+ router = APIRouter(prefix="/sessions", tags=["Sessions"])
7
+
8
+
9
+ @router.post("/{session_id}/approve")
10
+ async def approve_assistance(session_id: str, request: ApprovalRequest):
11
+ """Approve or reject assistance request"""
12
+ try:
13
+ return await ChatService.approve_request(
14
+ session_id=session_id,
15
+ decision=request.decision,
16
+ reason=request.reason
17
+ )
18
+ except Exception as e:
19
+ raise HTTPException(status_code=500, detail=str(e))
20
+
21
+
22
+ @router.get("/{session_id}/status")
23
+ async def get_session_status(session_id: str):
24
+ """Get session status including interrupt state"""
25
+ try:
26
+ return await ChatService.get_session_status(session_id)
27
+ except Exception as e:
28
+ raise HTTPException(status_code=500, detail=str(e))
29
+
30
+
31
+ @router.get("/{session_id}/history")
32
+ async def get_history(session_id: str):
33
+ """Get conversation history"""
34
+ try:
35
+ return await ChatService.get_history(session_id)
36
+ except Exception as e:
37
+ raise HTTPException(status_code=500, detail=str(e))
38
+
39
+
40
+ @router.get("")
41
+ async def list_sessions():
42
+ """List all active sessions with their status"""
43
+ try:
44
+ return await ChatService.list_sessions()
45
+ except Exception as e:
46
+ raise HTTPException(status_code=500, detail=str(e))
47
+
48
+
49
+ @router.delete("/{session_id}")
50
+ async def clear_session(session_id: str):
51
+ """Clear a session (useful for testing)"""
52
+ try:
53
+ return await ChatService.clear_session(session_id)
54
+ except Exception as e:
55
+ raise HTTPException(status_code=500, detail=str(e))
56
+
api/services/__init__.py CHANGED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # api/services/__init__.py
2
+ # ============================================================================
3
+ from .chat_service import ChatService
4
+ from .stream_service import StreamService
5
+
6
+ __all__ = ["ChatService", "StreamService"]
api/services/chat_service.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # api/services/chat_service.py
2
+ from typing import Dict, Optional
3
+ from api.config import settings, logger
4
+ from api.models import ChatResponse
5
+ from uuid import uuid4
6
+
7
+
8
+ class ChatService:
9
+ """Service for handling chat operations"""
10
+
11
+ @staticmethod
12
+ async def process_message(
13
+ message: str,
14
+ session_id: Optional[str] = None
15
+ ) -> ChatResponse:
16
+ """Process a chat message"""
17
+ if not settings.system_initialized or not settings.chat_manager:
18
+ raise Exception("System not initialized")
19
+
20
+ session_id = session_id or f"api_{uuid4()}"
21
+
22
+ # Check for pending interrupt
23
+ if settings.chat_manager.has_pending_interrupt(session_id):
24
+ interrupt_info = settings.chat_manager.pending_interrupts.get(session_id, {})
25
+ interrupt_data_obj = interrupt_info.get("interrupt_data")
26
+
27
+ interrupt_value = (
28
+ interrupt_data_obj.value
29
+ if hasattr(interrupt_data_obj, 'value')
30
+ else {}
31
+ )
32
+
33
+ return ChatResponse(
34
+ response="⏸️ Pending approval request. Please approve/reject first.",
35
+ session_id=session_id,
36
+ has_interrupt=True,
37
+ interrupt_type="human_approval",
38
+ interrupt_data=interrupt_value
39
+ )
40
+
41
+ # Process message
42
+ response_text = await settings.chat_manager.chat(
43
+ message=message,
44
+ session_id=session_id
45
+ )
46
+
47
+ # Check if interrupt occurred
48
+ has_interrupt = settings.chat_manager.has_pending_interrupt(session_id)
49
+ interrupt_value = None
50
+
51
+ if has_interrupt:
52
+ interrupt_info = settings.chat_manager.pending_interrupts.get(session_id, {})
53
+ interrupt_data_obj = interrupt_info.get("interrupt_data")
54
+
55
+ if hasattr(interrupt_data_obj, 'value'):
56
+ interrupt_value = interrupt_data_obj.value
57
+ elif isinstance(interrupt_data_obj, dict):
58
+ interrupt_value = interrupt_data_obj.get("value", {})
59
+
60
+ return ChatResponse(
61
+ response=response_text,
62
+ session_id=session_id,
63
+ has_interrupt=has_interrupt,
64
+ interrupt_type="human_approval" if has_interrupt else None,
65
+ interrupt_data=interrupt_value
66
+ )
67
+
68
+ @staticmethod
69
+ async def approve_request(
70
+ session_id: str,
71
+ decision: str,
72
+ reason: Optional[str] = None
73
+ ) -> Dict:
74
+ """Approve or reject an assistance request"""
75
+ if not settings.chat_manager:
76
+ raise Exception("System not initialized")
77
+
78
+ if not settings.chat_manager.has_pending_interrupt(session_id):
79
+ raise Exception("No pending interrupt")
80
+
81
+ decision_text = f"{decision} {reason or ''}"
82
+ response_text = await settings.chat_manager.chat(
83
+ message=decision_text,
84
+ session_id=session_id
85
+ )
86
+
87
+ return {
88
+ "status": "success",
89
+ "decision": decision,
90
+ "response": response_text,
91
+ "session_id": session_id
92
+ }
93
+
94
+ @staticmethod
95
+ async def get_session_status(session_id: str) -> Dict:
96
+ """Get session status"""
97
+ if not settings.chat_manager:
98
+ raise Exception("System not initialized")
99
+
100
+ has_interrupt = settings.chat_manager.has_pending_interrupt(session_id)
101
+ interrupt_data = None
102
+
103
+ if has_interrupt:
104
+ interrupt_info = settings.chat_manager.pending_interrupts.get(session_id, {})
105
+ interrupt_data_obj = interrupt_info.get("interrupt_data")
106
+
107
+ if hasattr(interrupt_data_obj, 'value'):
108
+ interrupt_data = interrupt_data_obj.value
109
+
110
+ return {
111
+ "session_id": session_id,
112
+ "has_pending_interrupt": has_interrupt,
113
+ "interrupt_data": interrupt_data
114
+ }
115
+
116
+ @staticmethod
117
+ async def get_history(session_id: str) -> Dict:
118
+ """Get conversation history"""
119
+ if not settings.chat_manager:
120
+ raise Exception("System not initialized")
121
+
122
+ history = await settings.chat_manager.get_conversation_history(session_id)
123
+
124
+ return {
125
+ "session_id": session_id,
126
+ "history": [
127
+ {
128
+ "role": msg.type if hasattr(msg, 'type') else "unknown",
129
+ "content": msg.content if hasattr(msg, 'content') else str(msg)
130
+ }
131
+ for msg in history
132
+ ]
133
+ }
134
+
135
+ @staticmethod
136
+ async def list_sessions() -> Dict:
137
+ """List all active sessions"""
138
+ if not settings.chat_manager:
139
+ raise Exception("System not initialized")
140
+
141
+ sessions_info = []
142
+ for session_id in settings.chat_manager.pending_interrupts.keys():
143
+ interrupt_info = settings.chat_manager.pending_interrupts.get(session_id, {})
144
+ interrupt_data_obj = interrupt_info.get("interrupt_data")
145
+
146
+ data = (
147
+ interrupt_data_obj.value
148
+ if hasattr(interrupt_data_obj, 'value')
149
+ else {}
150
+ )
151
+
152
+ sessions_info.append({
153
+ "session_id": session_id,
154
+ "has_interrupt": True,
155
+ "user_email": data.get("user_email"),
156
+ "country": data.get("country"),
157
+ "description": data.get("description")
158
+ })
159
+
160
+ return {
161
+ "total_pending": len(sessions_info),
162
+ "sessions": sessions_info
163
+ }
164
+
165
+ @staticmethod
166
+ async def clear_session(session_id: str) -> Dict:
167
+ """Clear a session"""
168
+ if not settings.chat_manager:
169
+ raise Exception("System not initialized")
170
+
171
+ if session_id in settings.chat_manager.pending_interrupts:
172
+ del settings.chat_manager.pending_interrupts[session_id]
173
+
174
+ return {
175
+ "status": "success",
176
+ "message": f"Session {session_id} cleared"
177
+ }
178
+
api/services/stream_service.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # api/services/stream_service.py
2
+ import json
3
+ import asyncio
4
+ from typing import AsyncGenerator
5
+ from api.config import settings, logger
6
+
7
+
8
+ class StreamService:
9
+ """Service for handling streaming responses"""
10
+
11
+ @staticmethod
12
+ async def stream_chat_response(
13
+ message: str,
14
+ session_id: str
15
+ ) -> AsyncGenerator[str, None]:
16
+ """Stream chat responses token by token"""
17
+ try:
18
+ # Check for pending interrupt first
19
+ if settings.chat_manager.has_pending_interrupt(session_id):
20
+ interrupt_info = settings.chat_manager.pending_interrupts.get(session_id, {})
21
+ interrupt_data_obj = interrupt_info.get("interrupt_data")
22
+
23
+ interrupt_value = (
24
+ interrupt_data_obj.value
25
+ if hasattr(interrupt_data_obj, 'value')
26
+ else {}
27
+ )
28
+
29
+ yield f"data: {json.dumps({'type': 'interrupt', 'data': interrupt_value})}\n\n"
30
+ return
31
+
32
+ # Process message
33
+ response = await settings.chat_manager.chat(
34
+ message=message,
35
+ session_id=session_id,
36
+ interrupt_handler=None
37
+ )
38
+
39
+ # Check if interrupt occurred during processing
40
+ if settings.chat_manager.has_pending_interrupt(session_id):
41
+ interrupt_info = settings.chat_manager.pending_interrupts.get(session_id, {})
42
+ interrupt_data_obj = interrupt_info.get("interrupt_data")
43
+
44
+ interrupt_value = (
45
+ interrupt_data_obj.value
46
+ if hasattr(interrupt_data_obj, 'value')
47
+ else {}
48
+ )
49
+
50
+ yield f"data: {json.dumps({'type': 'interrupt', 'data': interrupt_value})}\n\n"
51
+ return
52
+
53
+ # Stream response token by token
54
+ words = response.split()
55
+ for i, word in enumerate(words):
56
+ token = word + (" " if i < len(words) - 1 else "")
57
+ yield f"data: {json.dumps({'type': 'token', 'content': token})}\n\n"
58
+ await asyncio.sleep(settings.STREAM_DELAY)
59
+
60
+ # Send completion event
61
+ yield f"data: {json.dumps({'type': 'done', 'session_id': session_id})}\n\n"
62
+
63
+ except Exception as e:
64
+ logger.exception(f"Error in streaming: {e}")
65
+ yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n"
66
+
api/utils/__init__.py CHANGED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # api/utils/__init__.py
2
+ from .startup import initialize_system, lifespan
3
+
4
+ __all__ = ["initialize_system", "lifespan"]
api/utils/startup.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # api/utils/startup.py
2
+ import sys
3
+ from pathlib import Path
4
+ sys.path.insert(0, str(Path(__file__).parent.parent.parent))
5
+
6
+ import asyncio
7
+ from contextlib import asynccontextmanager
8
+ from fastapi import FastAPI
9
+ from core.system_initializer import setup_system
10
+ from api.config import settings, logger
11
+
12
+
13
+ async def initialize_system():
14
+ """Initialize the legal assistant system"""
15
+ try:
16
+ system = await setup_system()
17
+ settings.chat_manager = system["chat_manager"]
18
+ settings.graph = system["graph"]
19
+ settings.system_initialized = True
20
+ logger.info("✅ Legal assistant system initialized")
21
+ except Exception as e:
22
+ logger.error(f"❌ Failed to initialize system: {e}")
23
+ settings.system_initialized = False
24
+
25
+
26
+ @asynccontextmanager
27
+ async def lifespan(app: FastAPI):
28
+ """Application lifespan manager"""
29
+ logger.info("🚀 Starting Legal Assistant API...")
30
+ initialization_task = asyncio.create_task(initialize_system())
31
+ yield
32
+ logger.info("🛑 Shutting down...")
33
+ initialization_task.cancel()
34
+ try:
35
+ await initialization_task
36
+ except asyncio.CancelledError:
37
+ pass
config/settings.py CHANGED
@@ -15,7 +15,7 @@ class Settings:
15
  NEON_END_POINT = os.getenv("NEON_END_POINT")
16
 
17
  # Database
18
- DATABASE_URL = NEON_END_POINT
19
 
20
  # Model Configurations
21
  EMBEDDING_MODEL = "text-embedding-ada-002"
 
15
  NEON_END_POINT = os.getenv("NEON_END_POINT")
16
 
17
  # Database
18
+ # DATABASE_URL = NEON_END_POINT
19
 
20
  # Model Configurations
21
  EMBEDDING_MODEL = "text-embedding-ada-002"
core/chat_manager.py CHANGED
@@ -1,5 +1,4 @@
1
  # [file name]: core/chat_manager.py
2
- # Add this as the FIRST lines of code (after docstrings)
3
  import sys
4
  from pathlib import Path
5
  sys.path.insert(0, str(Path(__file__).parent.parent))
@@ -7,10 +6,11 @@ sys.path.insert(0, str(Path(__file__).parent.parent))
7
  import asyncio
8
  import logging
9
  from datetime import datetime
10
- from typing import Dict, List, Optional
11
  from langchain_core.runnables import RunnableConfig
12
  from langchain_core.messages import BaseMessage
13
  from langgraph.types import Command
 
14
 
15
  from config.settings import settings
16
  from models.state_models import MultiCountryLegalState
@@ -19,6 +19,11 @@ from utils.helpers import dict_to_message_obj
19
  logger = logging.getLogger(__name__)
20
 
21
  class LegalChatManager:
 
 
 
 
 
22
  def __init__(self, graph, checkpointer):
23
  self.graph = graph
24
  self.checkpointer = checkpointer
@@ -32,9 +37,30 @@ class LegalChatManager:
32
  # Track pending interrupts by session
33
  self.pending_interrupts = {}
34
 
35
- async def chat(self, message: str, session_id: str,
36
- legal_context: Optional[Dict[str, str]] = None) -> str:
37
- """Process a chat message with session management and interrupt handling"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  if not self.graph:
39
  raise RuntimeError("System not initialized. Call setup_system() first.")
40
 
@@ -56,31 +82,89 @@ class LegalChatManager:
56
  # Track performance
57
  start_time = datetime.now()
58
 
59
- # Process through graph
60
- result = await self.graph.ainvoke(MultiCountryLegalState(**input_state), config)
61
 
62
- # Check for interrupt
63
- state_snapshot = await self.graph.aget_state(config)
64
- if state_snapshot and state_snapshot.next:
65
- # Graph is paused at an interrupt
66
- logger.info(f"⏸️ Graph interrupted at: {state_snapshot.next}")
67
- self.pending_interrupts[session_id] = {
68
- "type": "human_approval",
69
- "config": config,
70
- "created_at": datetime.now(),
71
- "paused_at": state_snapshot.next
72
- }
73
- return self._get_approval_prompt_message(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  # Track performance
76
  processing_time = (datetime.now() - start_time).total_seconds()
77
  self._update_session_stats(session_id, processing_time)
 
78
 
79
- # Extract and return response
80
- response = self._extract_response(result)
81
- self._update_routing_stats(response)
82
-
83
- return response
84
 
85
  except Exception as e:
86
  logger.exception(f"Chat error for session {session_id}")
@@ -88,32 +172,79 @@ class LegalChatManager:
88
  return f"Erreur lors du traitement: {str(e)}"
89
 
90
  async def _handle_pending_interrupt(self, session_id: str, message: str) -> str:
91
- """Handle user response to a pending interrupt using Command(resume=...)"""
 
 
 
 
 
 
 
 
 
 
92
  interrupt_data = self.pending_interrupts.get(session_id)
93
  if not interrupt_data:
94
  return "Erreur: Aucune interruption en attente."
95
 
96
  try:
97
- logger.info(f"📥 Resuming graph with moderator decision: {message}")
98
 
99
  config = interrupt_data["config"]
100
 
101
- # CRITICAL FIX: Use Command(resume=...) to properly resume from interrupt
102
- # This sends the user's message back to the interrupt() call
103
- result = await self.graph.ainvoke(
104
  Command(resume=message),
105
- config
106
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  # Clean up the pending interrupt
109
  del self.pending_interrupts[session_id]
110
 
111
- # Extract and return final response
112
- response = self._extract_response(result)
113
- self._update_routing_stats(response)
114
 
115
  logger.info(f"✅ Graph resumed successfully for session {session_id}")
116
- return response
117
 
118
  except Exception as e:
119
  logger.error(f"Error resuming from interrupt: {str(e)}")
@@ -122,26 +253,12 @@ class LegalChatManager:
122
  del self.pending_interrupts[session_id]
123
  return f"Erreur lors du traitement de la décision: {str(e)}"
124
 
125
- def _get_approval_prompt_message(self, state) -> str:
126
- """Generate message asking for human approval"""
127
- # Extract metadata from state
128
- if isinstance(state, MultiCountryLegalState):
129
- state_dict = state.model_dump()
130
- elif isinstance(state, dict):
131
- state_dict = state
132
- else:
133
- state_dict = {}
134
-
135
- user_email = state_dict.get("user_email", "Non spécifié")
136
- country = state_dict.get("legal_context", {}).get("detected_country", "Non spécifié")
137
- description = state_dict.get("assistance_description", "Non spécifié")
138
-
139
- return f"""
140
  🔒 **APPROBATION HUMAINE REQUISE**
141
 
142
- 📧 **Utilisateur**: {user_email}
143
- 🌍 **Pays**: {country}
144
- 📝 **Description**: {description}
145
 
146
  **Veuillez répondre avec:**
147
  - "approve [raison]" pour approuver la demande
@@ -154,10 +271,32 @@ class LegalChatManager:
154
  **Votre décision:**
155
  """
156
 
157
- # === EXISTING METHODS (unchanged) ===
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
  async def get_conversation_history(self, session_id: str) -> List[BaseMessage]:
160
- """Get conversation history for a session"""
 
 
 
 
 
 
 
 
161
  if not self.graph:
162
  return []
163
 
@@ -183,18 +322,27 @@ class LegalChatManager:
183
  logger.exception(f"Error getting conversation history for session {session_id}")
184
  return []
185
 
186
- def get_session_stats(self, session_id: str) -> Dict:
187
  """Get statistics for a specific session"""
188
  return self.active_sessions.get(session_id, {})
189
 
190
- def get_global_stats(self) -> Dict:
191
- """Get global system statistics"""
192
- return {
 
 
 
 
 
193
  "routing_stats": self.routing_stats,
194
  "active_sessions": len(self.active_sessions),
195
  "total_queries": self.routing_stats["total_queries"],
196
  "pending_interrupts": len(self.pending_interrupts)
197
  }
 
 
 
 
198
 
199
  def _initialize_session(self, session_id: str):
200
  """Initialize or update session tracking"""
@@ -212,9 +360,23 @@ class LegalChatManager:
212
  session_info["query_count"] += 1
213
  session_info["last_activity"] = datetime.now()
214
 
215
- def _prepare_input_state(self, message: str, session_id: str,
216
- legal_context: Optional[Dict[str, str]]) -> Dict:
217
- """Prepare input state for graph processing"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  ctx = legal_context or {
219
  "jurisdiction": "Unknown",
220
  "user_type": "general",
@@ -237,7 +399,15 @@ class LegalChatManager:
237
  }
238
 
239
  def _extract_response(self, result) -> str:
240
- """Extract response text from graph result"""
 
 
 
 
 
 
 
 
241
  if isinstance(result, MultiCountryLegalState):
242
  r = result.model_dump()
243
  elif isinstance(result, dict):
@@ -246,6 +416,7 @@ class LegalChatManager:
246
  r = {}
247
 
248
  msgs = r.get("messages", [])
 
249
  for m in reversed(msgs):
250
  if (m.get("role") or "").lower() in ("assistant", "ai"):
251
  return m.get("content", "")
@@ -278,7 +449,12 @@ class LegalChatManager:
278
  logger.error(f"Session {session_id}: {error}")
279
 
280
  def cleanup_inactive_sessions(self, max_age_hours: int = 24):
281
- """Clean up sessions that have been inactive for too long"""
 
 
 
 
 
282
  cutoff_time = datetime.now().timestamp() - (max_age_hours * 3600)
283
 
284
  inactive_sessions = [
@@ -291,4 +467,35 @@ class LegalChatManager:
291
  if session_id in self.pending_interrupts:
292
  del self.pending_interrupts[session_id]
293
  del self.active_sessions[session_id]
294
- logger.info(f"Cleaned up inactive session: {session_id}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # [file name]: core/chat_manager.py
 
2
  import sys
3
  from pathlib import Path
4
  sys.path.insert(0, str(Path(__file__).parent.parent))
 
6
  import asyncio
7
  import logging
8
  from datetime import datetime
9
+ from typing import Dict, List, Optional, Any, Callable
10
  from langchain_core.runnables import RunnableConfig
11
  from langchain_core.messages import BaseMessage
12
  from langgraph.types import Command
13
+ from langgraph.errors import GraphInterrupt
14
 
15
  from config.settings import settings
16
  from models.state_models import MultiCountryLegalState
 
19
  logger = logging.getLogger(__name__)
20
 
21
  class LegalChatManager:
22
+ """
23
+ Chat manager with full human-in-the-loop interrupt support.
24
+ Handles both synchronous (callback-based) and asynchronous (stored interrupt) modes.
25
+ """
26
+
27
  def __init__(self, graph, checkpointer):
28
  self.graph = graph
29
  self.checkpointer = checkpointer
 
37
  # Track pending interrupts by session
38
  self.pending_interrupts = {}
39
 
40
+ async def chat(
41
+ self,
42
+ message: str,
43
+ session_id: str,
44
+ legal_context: Optional[Dict[str, str]] = None,
45
+ interrupt_handler: Optional[Callable[[Dict], str]] = None
46
+ ) -> str:
47
+ """
48
+ Process a chat message with session management and interrupt handling.
49
+
50
+ Args:
51
+ message: User message to process
52
+ session_id: Unique session identifier for conversation tracking
53
+ legal_context: Optional legal context (jurisdiction, user type, etc.)
54
+ interrupt_handler: Optional callback for handling interrupts synchronously.
55
+ If provided, interrupts are handled immediately within this call.
56
+ If None, interrupts are stored for later resolution via subsequent calls.
57
+
58
+ Returns:
59
+ Assistant's response text
60
+
61
+ Raises:
62
+ RuntimeError: If system is not initialized
63
+ """
64
  if not self.graph:
65
  raise RuntimeError("System not initialized. Call setup_system() first.")
66
 
 
82
  # Track performance
83
  start_time = datetime.now()
84
 
85
+ # 🔥 CRITICAL: Use streaming to detect interrupts in real-time
86
+ final_response = None
87
 
88
+ async for chunk in self.graph.astream(
89
+ MultiCountryLegalState(**input_state),
90
+ config,
91
+ stream_mode="updates"
92
+ ):
93
+ # 🔥 Check if this chunk contains an interrupt
94
+ if "__interrupt__" in chunk:
95
+ interrupt_data = chunk["__interrupt__"]
96
+ logger.info(f"⏸️ Graph interrupted: {interrupt_data}")
97
+
98
+ # Extract interrupt info - handle tuple/Interrupt object format
99
+ # LangGraph returns interrupts as tuples containing Interrupt objects
100
+ if isinstance(interrupt_data, (list, tuple)):
101
+ interrupt_info = interrupt_data[0]
102
+ else:
103
+ interrupt_info = interrupt_data
104
+
105
+ # Handle Interrupt object (has .value attribute)
106
+ if hasattr(interrupt_info, 'value'):
107
+ interrupt_value = interrupt_info.value
108
+ elif isinstance(interrupt_info, dict):
109
+ interrupt_value = interrupt_info.get("value", interrupt_info)
110
+ else:
111
+ interrupt_value = {}
112
+
113
+ # Extract message from interrupt value
114
+ interrupt_message = interrupt_value.get("message", "") if isinstance(interrupt_value, dict) else ""
115
+
116
+ # 🔥 Two modes of operation:
117
+ # 1. Synchronous: If interrupt_handler provided, handle immediately
118
+ # 2. Asynchronous: Store interrupt and return, wait for next call
119
+
120
+ if interrupt_handler:
121
+ # SYNCHRONOUS MODE: Handle interrupt immediately
122
+ logger.info("📞 Calling synchronous interrupt handler")
123
+ moderator_response = interrupt_handler(interrupt_value)
124
+
125
+ # Resume immediately with the moderator's response
126
+ logger.info(f"🔄 Resuming graph with: {moderator_response}")
127
+ async for resume_chunk in self.graph.astream(
128
+ Command(resume=moderator_response),
129
+ config,
130
+ stream_mode="updates"
131
+ ):
132
+ # Continue processing resumed chunks
133
+ for node_name, node_output in resume_chunk.items():
134
+ if node_name != "__interrupt__":
135
+ logger.debug(f"📦 Resume chunk from {node_name}")
136
+
137
+ # After resume completes, get final state
138
+ state = await self.graph.aget_state(config)
139
+ final_response = self._extract_response(state.values)
140
+ break
141
+ else:
142
+ # ASYNCHRONOUS MODE: Store interrupt and return
143
+ logger.info("💾 Storing interrupt for later resolution")
144
+ self.pending_interrupts[session_id] = {
145
+ "type": "human_approval",
146
+ "config": config,
147
+ "created_at": datetime.now(),
148
+ "interrupt_data": interrupt_info
149
+ }
150
+ return interrupt_message or self._get_default_approval_prompt()
151
+
152
+ # Process normal chunks (non-interrupt)
153
+ for node_name, node_output in chunk.items():
154
+ if node_name != "__interrupt__":
155
+ logger.debug(f"📦 Chunk from {node_name}")
156
+
157
+ # If no interrupt occurred, get final state
158
+ if final_response is None:
159
+ state = await self.graph.aget_state(config)
160
+ final_response = self._extract_response(state.values)
161
 
162
  # Track performance
163
  processing_time = (datetime.now() - start_time).total_seconds()
164
  self._update_session_stats(session_id, processing_time)
165
+ self._update_routing_stats(final_response)
166
 
167
+ return final_response
 
 
 
 
168
 
169
  except Exception as e:
170
  logger.exception(f"Chat error for session {session_id}")
 
172
  return f"Erreur lors du traitement: {str(e)}"
173
 
174
  async def _handle_pending_interrupt(self, session_id: str, message: str) -> str:
175
+ """
176
+ Handle user response to a pending interrupt using Command(resume=...).
177
+ This is called when there's a stored interrupt waiting for resolution.
178
+
179
+ Args:
180
+ session_id: Session with pending interrupt
181
+ message: User's response (e.g., "approve" or "reject")
182
+
183
+ Returns:
184
+ Final response after resuming from interrupt
185
+ """
186
  interrupt_data = self.pending_interrupts.get(session_id)
187
  if not interrupt_data:
188
  return "Erreur: Aucune interruption en attente."
189
 
190
  try:
191
+ logger.info(f"🔥 Resuming graph with moderator decision: {message}")
192
 
193
  config = interrupt_data["config"]
194
 
195
+ # Use streaming to handle potential nested interrupts
196
+ final_response = None
197
+ async for chunk in self.graph.astream(
198
  Command(resume=message),
199
+ config,
200
+ stream_mode="updates"
201
+ ):
202
+ if "__interrupt__" in chunk:
203
+ # Another interrupt occurred during resume - store it
204
+ new_interrupt = chunk["__interrupt__"]
205
+
206
+ # Handle tuple/list format
207
+ if isinstance(new_interrupt, (list, tuple)):
208
+ interrupt_info = new_interrupt[0]
209
+ else:
210
+ interrupt_info = new_interrupt
211
+
212
+ # Extract value from Interrupt object
213
+ if hasattr(interrupt_info, 'value'):
214
+ interrupt_value = interrupt_info.value
215
+ elif isinstance(interrupt_info, dict):
216
+ interrupt_value = interrupt_info.get("value", interrupt_info)
217
+ else:
218
+ interrupt_value = {}
219
+
220
+ # Store the new interrupt
221
+ self.pending_interrupts[session_id] = {
222
+ "type": "human_approval",
223
+ "config": config,
224
+ "created_at": datetime.now(),
225
+ "interrupt_data": interrupt_info
226
+ }
227
+
228
+ interrupt_message = interrupt_value.get("message", "") if isinstance(interrupt_value, dict) else ""
229
+ return interrupt_message or self._get_default_approval_prompt()
230
+
231
+ # Process normal chunks
232
+ for node_name, node_output in chunk.items():
233
+ if node_name != "__interrupt__":
234
+ logger.debug(f"📦 Resume chunk from {node_name}")
235
+
236
+ # Get final state after successful resume
237
+ state = await self.graph.aget_state(config)
238
+ final_response = self._extract_response(state.values)
239
 
240
  # Clean up the pending interrupt
241
  del self.pending_interrupts[session_id]
242
 
243
+ # Update stats
244
+ self._update_routing_stats(final_response)
 
245
 
246
  logger.info(f"✅ Graph resumed successfully for session {session_id}")
247
+ return final_response
248
 
249
  except Exception as e:
250
  logger.error(f"Error resuming from interrupt: {str(e)}")
 
253
  del self.pending_interrupts[session_id]
254
  return f"Erreur lors du traitement de la décision: {str(e)}"
255
 
256
+ def _get_default_approval_prompt(self) -> str:
257
+ """Default approval prompt if interrupt message extraction fails"""
258
+ return """
 
 
 
 
 
 
 
 
 
 
 
 
259
  🔒 **APPROBATION HUMAINE REQUISE**
260
 
261
+ Une demande d'assistance juridique nécessite votre approbation.
 
 
262
 
263
  **Veuillez répondre avec:**
264
  - "approve [raison]" pour approuver la demande
 
271
  **Votre décision:**
272
  """
273
 
274
+ def get_checkpointer_info(self) -> Dict[str, Any]:
275
+ """Get information about the current checkpointer type"""
276
+ checkpointer_type = "unknown"
277
+ if hasattr(self.checkpointer, '__class__'):
278
+ class_name = self.checkpointer.__class__.__name__
279
+ if 'PostgresSaver' in class_name:
280
+ checkpointer_type = "postgres"
281
+ elif 'InMemorySaver' in class_name:
282
+ checkpointer_type = "memory"
283
+
284
+ return {
285
+ "type": checkpointer_type,
286
+ "persistent": checkpointer_type == "postgres",
287
+ "description": "Persistent storage" if checkpointer_type == "postgres" else "In-memory (volatile)"
288
+ }
289
 
290
  async def get_conversation_history(self, session_id: str) -> List[BaseMessage]:
291
+ """
292
+ Get conversation history for a session.
293
+
294
+ Args:
295
+ session_id: Session identifier
296
+
297
+ Returns:
298
+ List of message objects from conversation history
299
+ """
300
  if not self.graph:
301
  return []
302
 
 
322
  logger.exception(f"Error getting conversation history for session {session_id}")
323
  return []
324
 
325
+ def get_session_stats(self, session_id: str) -> Dict[str, Any]:
326
  """Get statistics for a specific session"""
327
  return self.active_sessions.get(session_id, {})
328
 
329
+ def get_global_stats(self) -> Dict[str, Any]:
330
+ """
331
+ Get global system statistics.
332
+
333
+ Returns:
334
+ Dictionary with routing stats, active sessions, and storage info
335
+ """
336
+ stats = {
337
  "routing_stats": self.routing_stats,
338
  "active_sessions": len(self.active_sessions),
339
  "total_queries": self.routing_stats["total_queries"],
340
  "pending_interrupts": len(self.pending_interrupts)
341
  }
342
+
343
+ # Add checkpointer info
344
+ stats.update(self.get_checkpointer_info())
345
+ return stats
346
 
347
  def _initialize_session(self, session_id: str):
348
  """Initialize or update session tracking"""
 
360
  session_info["query_count"] += 1
361
  session_info["last_activity"] = datetime.now()
362
 
363
+ def _prepare_input_state(
364
+ self,
365
+ message: str,
366
+ session_id: str,
367
+ legal_context: Optional[Dict[str, str]]
368
+ ) -> Dict[str, Any]:
369
+ """
370
+ Prepare input state for graph processing.
371
+
372
+ Args:
373
+ message: User message
374
+ session_id: Session identifier
375
+ legal_context: Optional legal context
376
+
377
+ Returns:
378
+ Dictionary with complete input state for graph
379
+ """
380
  ctx = legal_context or {
381
  "jurisdiction": "Unknown",
382
  "user_type": "general",
 
399
  }
400
 
401
  def _extract_response(self, result) -> str:
402
+ """
403
+ Extract response text from graph result.
404
+
405
+ Args:
406
+ result: Graph execution result (state or dict)
407
+
408
+ Returns:
409
+ Assistant's response text
410
+ """
411
  if isinstance(result, MultiCountryLegalState):
412
  r = result.model_dump()
413
  elif isinstance(result, dict):
 
416
  r = {}
417
 
418
  msgs = r.get("messages", [])
419
+ # Find the last assistant message
420
  for m in reversed(msgs):
421
  if (m.get("role") or "").lower() in ("assistant", "ai"):
422
  return m.get("content", "")
 
449
  logger.error(f"Session {session_id}: {error}")
450
 
451
  def cleanup_inactive_sessions(self, max_age_hours: int = 24):
452
+ """
453
+ Clean up sessions that have been inactive for too long.
454
+
455
+ Args:
456
+ max_age_hours: Maximum age in hours before cleanup
457
+ """
458
  cutoff_time = datetime.now().timestamp() - (max_age_hours * 3600)
459
 
460
  inactive_sessions = [
 
467
  if session_id in self.pending_interrupts:
468
  del self.pending_interrupts[session_id]
469
  del self.active_sessions[session_id]
470
+ logger.info(f"Cleaned up inactive session: {session_id}")
471
+
472
+ def has_pending_interrupt(self, session_id: str) -> bool:
473
+ """
474
+ Check if there's a pending interrupt for a session.
475
+
476
+ Args:
477
+ session_id: Session identifier
478
+
479
+ Returns:
480
+ True if session has pending interrupt, False otherwise
481
+ """
482
+ return session_id in self.pending_interrupts
483
+
484
+ def get_system_info(self) -> Dict[str, Any]:
485
+ """
486
+ Get comprehensive system information.
487
+
488
+ Returns:
489
+ Dictionary with system status, storage info, and statistics
490
+ """
491
+ return {
492
+ "system": {
493
+ "initialized": self.graph is not None,
494
+ "active_sessions": len(self.active_sessions),
495
+ "pending_interrupts": len(self.pending_interrupts),
496
+ "total_queries": self.routing_stats["total_queries"]
497
+ },
498
+ "storage": self.get_checkpointer_info(),
499
+ "routing": self.routing_stats,
500
+ "timestamp": datetime.now().isoformat()
501
+ }
core/graph_builder.py CHANGED
@@ -6,8 +6,9 @@ sys.path.insert(0, str(Path(__file__).parent.parent))
6
 
7
  from langgraph.graph import StateGraph, START, END
8
  from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
 
9
  import logging
10
- from typing import Dict, List, Any
11
  from langchain_core.runnables import RunnableConfig
12
 
13
  from models.state_models import MultiCountryLegalState
@@ -31,7 +32,7 @@ class GraphBuilder:
31
  self,
32
  router: CountryRouter,
33
  llm,
34
- checkpointer: AsyncPostgresSaver,
35
  # Country retrievers as a dictionary for easy extension
36
  country_retrievers: Dict[str, LegalRetriever] = None
37
  ):
@@ -57,7 +58,19 @@ class GraphBuilder:
57
  self.response_nodes = ResponseNodes(llm)
58
  self.helper_nodes = HelperNodes(llm)
59
 
60
- logger.info(f"GraphBuilder initialized with countries: {list(self.country_retrievers.keys())}")
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  def add_country(self, country_code: str, retriever: LegalRetriever):
63
  """Dynamically add a new country to the system"""
@@ -166,7 +179,8 @@ class GraphBuilder:
166
 
167
  workflow.add_edge("process_assistance", "response")
168
 
169
- logger.info(f"Scalable graph built for {len(self.country_retrievers)} countries: {list(self.country_retrievers.keys())}")
 
170
  return workflow
171
 
172
  def _create_assistance_collect_wrapper(self):
 
6
 
7
  from langgraph.graph import StateGraph, START, END
8
  from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
9
+ from langgraph.checkpoint.memory import InMemorySaver
10
  import logging
11
+ from typing import Dict, List, Any, Union
12
  from langchain_core.runnables import RunnableConfig
13
 
14
  from models.state_models import MultiCountryLegalState
 
32
  self,
33
  router: CountryRouter,
34
  llm,
35
+ checkpointer: Union[AsyncPostgresSaver, InMemorySaver],
36
  # Country retrievers as a dictionary for easy extension
37
  country_retrievers: Dict[str, LegalRetriever] = None
38
  ):
 
58
  self.response_nodes = ResponseNodes(llm)
59
  self.helper_nodes = HelperNodes(llm)
60
 
61
+ # Log checkpointer type
62
+ checkpointer_type = self._get_checkpointer_type()
63
+ logger.info(f"GraphBuilder initialized with {checkpointer_type} checkpointer and countries: {list(self.country_retrievers.keys())}")
64
+
65
+ def _get_checkpointer_type(self) -> str:
66
+ """Determine the type of checkpointer being used"""
67
+ if hasattr(self.checkpointer, '__class__'):
68
+ class_name = self.checkpointer.__class__.__name__
69
+ if 'PostgresSaver' in class_name:
70
+ return "PostgreSQL"
71
+ elif 'InMemorySaver' in class_name:
72
+ return "in-memory"
73
+ return "unknown"
74
 
75
  def add_country(self, country_code: str, retriever: LegalRetriever):
76
  """Dynamically add a new country to the system"""
 
179
 
180
  workflow.add_edge("process_assistance", "response")
181
 
182
+ checkpointer_type = self._get_checkpointer_type()
183
+ logger.info(f"Scalable graph built with {checkpointer_type} checkpointer for {len(self.country_retrievers)} countries: {list(self.country_retrievers.keys())}")
184
  return workflow
185
 
186
  def _create_assistance_collect_wrapper(self):
core/human_approval_node.py CHANGED
@@ -1,9 +1,4 @@
1
  # core/human_approval_node.py
2
- # Add this as the FIRST lines of code (after docstrings)
3
- import sys
4
- from pathlib import Path
5
- sys.path.insert(0, str(Path(__file__).parent.parent))
6
-
7
  import logging
8
  from typing import Literal
9
  from langchain_core.runnables import RunnableConfig
@@ -22,68 +17,59 @@ class HumanApprovalNode:
22
  self,
23
  state: MultiCountryLegalState,
24
  config: RunnableConfig
25
- ) -> Command[Literal["response"]]:
26
- """Process human approval with interrupt"""
27
- try:
28
- # Validate required fields
29
- if not state.user_email or not state.assistance_description:
30
- logger.warning("Missing required fields for approval")
31
- return Command(
32
- goto="response",
33
- update={
34
- "messages": [{
35
- "role": "assistant",
36
- "content": "❌ Données incomplètes pour l'approbation.",
37
- "meta": {}
38
- }]
39
- }
40
- )
41
-
42
- logger.info(f"🔒 Human approval node triggered for {state.user_email}")
43
-
44
- # Prepare interrupt message
45
- interrupt_message = self._format_approval_request(state)
46
-
47
- # Trigger interrupt and wait for human input
48
- moderator_input = interrupt({
49
- "type": "human_approval",
50
- "user_email": state.user_email,
51
- "country": self._get_country_display(state),
52
- "description": state.assistance_description,
53
- "message": interrupt_message
54
- })
55
-
56
- logger.info(f"📥 Received moderator input: {moderator_input}")
57
-
58
- # Parse moderator decision
59
- decision = self._parse_decision(moderator_input)
60
-
61
- # Handle approval
62
- if decision["approved"]:
63
- return await self._handle_approval(state, decision)
64
- else:
65
- return await self._handle_rejection(state, decision)
66
-
67
- except Exception as e:
68
- logger.error(f"Error in approval node: {str(e)}", exc_info=True)
69
  return Command(
70
  goto="response",
71
  update={
72
  "approval_status": "error",
 
73
  "messages": [{
74
  "role": "assistant",
75
- "content": f"❌ Erreur lors de l'approbation: {str(e)}",
76
  "meta": {}
77
  }]
78
  }
79
  )
80
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  async def _handle_approval(
82
  self,
83
  state: MultiCountryLegalState,
84
  decision: dict
85
- ) -> Command[Literal["response"]]:
86
- """Handle approved request (sends email and routes to response)"""
87
  logger.info(f"✅ Request APPROVED for {state.user_email}")
88
 
89
  # Send email
@@ -95,38 +81,15 @@ class HumanApprovalNode:
95
  )
96
  logger.info(f"✅ Emails envoyés avec succès pour {state.user_email}")
97
 
98
- # Build success message
99
- if email_result.get("success"):
100
- message_content = f"""✅ **DEMANDE APPROUVÉE ET ENVOYÉE**
101
-
102
- 📧 Un email de confirmation a été envoyé à: {state.user_email}
103
- 👨‍⚖️ Notre équipe juridique vous contactera sous 24-48 heures.
104
-
105
- **Raison de l'approbation:** {decision['reason']}
106
- **Approuvé par:** {decision['moderator_id']}
107
- """
108
- else:
109
- message_content = f"""⚠️ **DEMANDE APPROUVÉE MAIS ERREUR D'ENVOI**
110
-
111
- La demande a été approuvée mais l'envoi d'email a échoué.
112
- **Erreur:** {email_result.get('error', 'Unknown')}
113
-
114
- Veuillez contacter directement: fitahiana@acfai.org
115
- """
116
-
117
  return Command(
118
- goto="response",
119
  update={
120
  "approval_status": "approved",
121
  "approval_reason": decision["reason"],
122
  "approved_by": decision["moderator_id"],
123
  "approval_timestamp": datetime.now().isoformat(),
124
  "email_status": "sent" if email_result.get("success") else "error",
125
- "messages": [{
126
- "role": "assistant",
127
- "content": message_content,
128
- "meta": {"approval": "approved"}
129
- }]
130
  }
131
  )
132
 
@@ -134,8 +97,8 @@ Veuillez contacter directement: fitahiana@acfai.org
134
  self,
135
  state: MultiCountryLegalState,
136
  decision: dict
137
- ) -> Command[Literal["response"]]: # Updated: Removed "process_assistance"
138
- """Handle rejected request"""
139
  logger.info(f"❌ Request REJECTED for {state.user_email}")
140
 
141
  message_content = f"""❌ **DEMANDE REFUSÉE**
@@ -154,6 +117,7 @@ Si vous pensez qu'il s'agit d'une erreur, veuillez reformuler votre demande avec
154
  "approval_reason": decision["reason"],
155
  "approved_by": decision["moderator_id"],
156
  "approval_timestamp": datetime.now().isoformat(),
 
157
  "messages": [{
158
  "role": "assistant",
159
  "content": message_content,
 
1
  # core/human_approval_node.py
 
 
 
 
 
2
  import logging
3
  from typing import Literal
4
  from langchain_core.runnables import RunnableConfig
 
17
  self,
18
  state: MultiCountryLegalState,
19
  config: RunnableConfig
20
+ ) -> Command[Literal["process_assistance", "response"]]:
21
+ """Process human approval with interrupt - uses Command(goto=...) pattern"""
22
+
23
+ # Validate required fields BEFORE interrupt
24
+ if not state.user_email or not state.assistance_description:
25
+ logger.warning("Missing required fields for approval")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  return Command(
27
  goto="response",
28
  update={
29
  "approval_status": "error",
30
+ "approval_reason": "Données incomplètes",
31
  "messages": [{
32
  "role": "assistant",
33
+ "content": "❌ Données incomplètes pour l'approbation.",
34
  "meta": {}
35
  }]
36
  }
37
  )
38
+
39
+ logger.info(f"🔒 Human approval node triggered for {state.user_email}")
40
+
41
+ # Prepare interrupt message
42
+ interrupt_message = self._format_approval_request(state)
43
+
44
+ # 🔥 CRITICAL: DO NOT wrap interrupt() in try-except!
45
+ # Let GraphInterrupt propagate naturally to the graph executor
46
+ moderator_input = interrupt({
47
+ "type": "human_approval",
48
+ "user_email": state.user_email,
49
+ "country": self._get_country_display(state),
50
+ "description": state.assistance_description,
51
+ "message": interrupt_message
52
+ })
53
+
54
+ # 🎯 Code below ONLY executes after graph resumes from interrupt
55
+ logger.info(f"🔥 Received moderator input: {moderator_input}")
56
+
57
+ # Parse moderator decision
58
+ decision = self._parse_decision(moderator_input)
59
+
60
+ # Handle approval or rejection
61
+ if decision["approved"]:
62
+ return await self._handle_approval(state, decision)
63
+ else:
64
+ return await self._handle_rejection(state, decision)
65
+
66
+
67
  async def _handle_approval(
68
  self,
69
  state: MultiCountryLegalState,
70
  decision: dict
71
+ ) -> Command[Literal["process_assistance"]]:
72
+ """Handle approved request - routes to process_assistance"""
73
  logger.info(f"✅ Request APPROVED for {state.user_email}")
74
 
75
  # Send email
 
81
  )
82
  logger.info(f"✅ Emails envoyés avec succès pour {state.user_email}")
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  return Command(
85
+ goto="process_assistance",
86
  update={
87
  "approval_status": "approved",
88
  "approval_reason": decision["reason"],
89
  "approved_by": decision["moderator_id"],
90
  "approval_timestamp": datetime.now().isoformat(),
91
  "email_status": "sent" if email_result.get("success") else "error",
92
+ "email_result": email_result
 
 
 
 
93
  }
94
  )
95
 
 
97
  self,
98
  state: MultiCountryLegalState,
99
  decision: dict
100
+ ) -> Command[Literal["response"]]:
101
+ """Handle rejected request - routes directly to response"""
102
  logger.info(f"❌ Request REJECTED for {state.user_email}")
103
 
104
  message_content = f"""❌ **DEMANDE REFUSÉE**
 
117
  "approval_reason": decision["reason"],
118
  "approved_by": decision["moderator_id"],
119
  "approval_timestamp": datetime.now().isoformat(),
120
+ "assistance_step": "completed",
121
  "messages": [{
122
  "role": "assistant",
123
  "content": message_content,
core/system_initializer.py CHANGED
@@ -6,6 +6,7 @@ sys.path.insert(0, str(Path(__file__).parent.parent))
6
 
7
  import logging
8
  from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
 
9
 
10
  from core.graph_builder import GraphBuilder
11
  from core.chat_manager import LegalChatManager
@@ -13,13 +14,12 @@ from core.router import CountryRouter
13
  from database.mongodb_client import MongoDBClient
14
  from database.postgres_checkpointer import PostgresCheckpointer
15
  from langchain_openai import ChatOpenAI
16
- from config import settings # Make sure this import is correct
17
-
18
 
19
  logger = logging.getLogger(__name__)
20
 
21
  async def setup_system():
22
- """Initialize the legal assistant system for API use"""
23
 
24
  try:
25
  # 1. Initialize MongoDB using your existing class
@@ -55,32 +55,8 @@ async def setup_system():
55
 
56
  router = CountryRouter()
57
 
58
- # 5. Initialize PostgreSQL checkpointer - FIXED DATABASE URL
59
- # Check what database URL setting you have
60
- database_url = getattr(settings, 'DATABASE_URL', None)
61
-
62
- if not database_url:
63
- # Try alternative setting names
64
- database_url = getattr(settings, 'POSTGRES_URL', None) or \
65
- getattr(settings, 'POSTGRESQL_URL', None) or \
66
- getattr(settings, 'DB_URL', None)
67
-
68
- if not database_url:
69
- raise Exception("No database URL found in settings")
70
-
71
- logger.info(f"🔗 Using database URL: {database_url.split('@')[-1] if '@' in database_url else 'local'}") # Log safely
72
-
73
- postgres_checkpointer = PostgresCheckpointer(
74
- database_url=database_url, # Use actual database URL
75
- max_connections=10,
76
- min_connections=2
77
- )
78
-
79
- if not await postgres_checkpointer.initialize():
80
- raise Exception("PostgreSQL checkpointer initialization failed")
81
-
82
- checkpointer = postgres_checkpointer.get_checkpointer()
83
- logger.info("✅ PostgreSQL checkpointer initialized for API")
84
 
85
  # 6. Build graph
86
  graph_builder = GraphBuilder(
@@ -106,4 +82,60 @@ async def setup_system():
106
 
107
  except Exception as e:
108
  logger.error(f"❌ Failed to initialize system: {e}")
109
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  import logging
8
  from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
9
+ from langgraph.checkpoint.memory import InMemorySaver
10
 
11
  from core.graph_builder import GraphBuilder
12
  from core.chat_manager import LegalChatManager
 
14
  from database.mongodb_client import MongoDBClient
15
  from database.postgres_checkpointer import PostgresCheckpointer
16
  from langchain_openai import ChatOpenAI
17
+ from config import settings
 
18
 
19
  logger = logging.getLogger(__name__)
20
 
21
  async def setup_system():
22
+ """Initialize the legal assistant system with fallback to in-memory checkpointer"""
23
 
24
  try:
25
  # 1. Initialize MongoDB using your existing class
 
55
 
56
  router = CountryRouter()
57
 
58
+ # 5. Initialize checkpointer with fallback logic
59
+ checkpointer = await _initialize_checkpointer_with_fallback()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  # 6. Build graph
62
  graph_builder = GraphBuilder(
 
82
 
83
  except Exception as e:
84
  logger.error(f"❌ Failed to initialize system: {e}")
85
+ raise
86
+
87
+ async def _initialize_checkpointer_with_fallback():
88
+ """Initialize checkpointer with fallback to in-memory if PostgreSQL fails"""
89
+
90
+ # First, try to initialize PostgreSQL checkpointer
91
+ postgres_checkpointer = None
92
+ database_url = getattr(settings, 'DATABASE_URL', None)
93
+
94
+ if not database_url:
95
+ # Try alternative setting names
96
+ database_url = getattr(settings, 'POSTGRES_URL', None) or \
97
+ getattr(settings, 'POSTGRESQL_URL', None) or \
98
+ getattr(settings, 'DB_URL', None)
99
+
100
+ if database_url:
101
+ try:
102
+ logger.info(f"🔗 Attempting PostgreSQL connection: {database_url.split('@')[-1] if '@' in database_url else 'local'}")
103
+
104
+ postgres_checkpointer = PostgresCheckpointer(
105
+ database_url=database_url,
106
+ max_connections=10,
107
+ min_connections=2
108
+ )
109
+
110
+ if await postgres_checkpointer.initialize():
111
+ checkpointer = postgres_checkpointer.get_checkpointer()
112
+ logger.info("✅ PostgreSQL checkpointer initialized successfully")
113
+ return checkpointer
114
+ else:
115
+ logger.warning("❌ PostgreSQL checkpointer initialization failed, will fall back to in-memory")
116
+
117
+ except Exception as e:
118
+ logger.warning(f"❌ PostgreSQL connection failed: {e}, falling back to in-memory checkpointer")
119
+
120
+ else:
121
+ logger.warning("❌ No database URL found in settings, using in-memory checkpointer")
122
+
123
+ # Fall back to in-memory checkpointer
124
+ try:
125
+ checkpointer = InMemorySaver()
126
+ logger.info("✅ In-memory checkpointer initialized as fallback")
127
+ logger.warning("⚠️ Using in-memory checkpointer - conversation history will not persist across restarts")
128
+ return checkpointer
129
+
130
+ except Exception as e:
131
+ logger.error(f"❌ Even in-memory checkpointer failed: {e}")
132
+ raise Exception("Failed to initialize any checkpointer")
133
+
134
+ def get_checkpointer_type(checkpointer):
135
+ """Utility function to check what type of checkpointer is being used"""
136
+ if hasattr(checkpointer, '__class__'):
137
+ if 'PostgresSaver' in checkpointer.__class__.__name__:
138
+ return "postgres"
139
+ elif 'InMemorySaver' in checkpointer.__class__.__name__:
140
+ return "memory"
141
+ return "unknown"
main.py CHANGED
@@ -1,633 +1,483 @@
1
  #!/usr/bin/env python3
2
  """
3
- Scalable Multi-Country Legal RAG System
4
- Supports dynamic addition of new countries with clean architecture
5
  """
6
- # Add this as the FIRST lines of code (after docstrings)
7
  import sys
8
  from pathlib import Path
9
- sys.path.insert(0, str(Path(__file__).parent.parent))
10
 
11
  import asyncio
12
  import logging
13
  import time
14
  from datetime import datetime
15
- from typing import List, Dict, Any, Optional
16
-
17
- from MultiCountryRAG.config.settings import settings
18
- from MultiCountryRAG.database.mongodb_client import MongoDBClient
19
- from MultiCountryRAG.database.postgres_checkpointer import PostgresCheckpointer
20
- from MultiCountryRAG.core.router import CountryRouter
21
- from MultiCountryRAG.core.retriever import LegalRetriever
22
- from MultiCountryRAG.core.graph_builder import GraphBuilder
23
- from MultiCountryRAG.core.chat_manager import LegalChatManager
24
- from MultiCountryRAG.utils.logger import setup_logging
25
-
26
- import uuid
27
-
28
-
29
- class MultiCountryLegalRAGSystem:
30
- """Scalable system class supporting dynamic country addition"""
 
 
 
 
 
 
 
31
 
32
  def __init__(self):
33
- self.mongo_client = MongoDBClient()
34
- self.postgres_checkpointer = PostgresCheckpointer(
35
- database_url=settings.DATABASE_URL,
36
- max_connections=10,
37
- min_connections=2
38
- )
39
- self.router = None
40
- # Dynamic country retrievers dictionary - easily extensible!
41
- self.country_retrievers = {}
42
- self.llm = None
43
- self.graph = None
44
  self.chat_manager = None
45
- self.initialized = False
46
-
47
- async def initialize(self) -> bool:
48
- """Initialize the complete scalable system"""
 
 
49
  try:
50
- setup_logging()
51
- settings.validate()
52
-
53
- # Initialize databases
54
- if not self.mongo_client.connect():
55
- raise Exception("MongoDB connection failed")
56
-
57
- if not await self.postgres_checkpointer.initialize():
58
- logging.warning("PostgreSQL initialization failed")
59
 
60
- # Initialize core components
61
- self.router = CountryRouter()
62
 
63
- # Initialize default countries - easily extensible!
64
- self._initialize_default_countries()
 
65
 
66
- # Initialize LLM
67
- from langchain_openai import ChatOpenAI
68
- self.llm = ChatOpenAI(
69
- model=settings.CHAT_MODEL,
70
- temperature=settings.CHAT_TEMPERATURE,
71
- max_tokens=settings.CHAT_MAX_TOKENS
72
- )
73
-
74
- # Build scalable graph with country dictionary
75
- graph_builder = GraphBuilder(
76
- router=self.router,
77
- llm=self.llm,
78
- checkpointer=self.postgres_checkpointer.get_checkpointer(),
79
- country_retrievers=self.country_retrievers # Pass the dictionary
80
- )
81
-
82
- workflow = graph_builder.build_graph()
83
-
84
- # Compile with interrupt support
85
- self.graph = workflow.compile(
86
- checkpointer=self.postgres_checkpointer.get_checkpointer(),
87
- interrupt_before=["human_approval"]
88
- )
89
-
90
- # Initialize chat manager
91
- self.chat_manager = LegalChatManager(
92
- self.graph,
93
- self.postgres_checkpointer.get_checkpointer()
94
- )
95
-
96
- await self._perform_health_check()
97
-
98
- self.initialized = True
99
- logging.info(f"✅ System initialized with {len(self.country_retrievers)} countries")
100
- self._print_system_info()
101
 
102
  return True
103
 
104
  except Exception as e:
105
- logging.error(f" System initialization failed: {e}")
106
- import traceback
107
- traceback.print_exc()
108
  return False
109
-
110
- def _initialize_default_countries(self):
111
- """Initialize default countries - easily extensible!"""
112
- # Benin
113
- if hasattr(self.mongo_client, 'benin_vectorstore'):
114
- self.country_retrievers["benin"] = LegalRetriever(
115
- self.mongo_client.benin_vectorstore,
116
- self.mongo_client.benin_collection
117
- )
118
-
119
- # Madagascar
120
- if hasattr(self.mongo_client, 'madagascar_vectorstore'):
121
- self.country_retrievers["madagascar"] = LegalRetriever(
122
- self.mongo_client.madagascar_vectorstore,
123
- self.mongo_client.madagascar_collection
124
- )
125
 
126
- logging.info(f"🌍 Initialized {len(self.country_retrievers)} default countries")
127
-
128
- def add_country(self, country_code: str, vectorstore, collection) -> bool:
129
- """Dynamically add a new country to the running system"""
130
- try:
131
- if country_code in self.country_retrievers:
132
- logging.warning(f"Country {country_code} already exists")
133
- return False
134
-
135
- new_retriever = LegalRetriever(vectorstore, collection)
136
- self.country_retrievers[country_code] = new_retriever
137
-
138
- # Rebuild graph if system is already initialized
139
- if self.initialized:
140
- graph_builder = GraphBuilder(
141
- router=self.router,
142
- llm=self.llm,
143
- checkpointer=self.postgres_checkpointer.get_checkpointer(),
144
- country_retrievers=self.country_retrievers
145
- )
146
- workflow = graph_builder.build_graph()
147
- self.graph = workflow.compile(
148
- checkpointer=self.postgres_checkpointer.get_checkpointer(),
149
- interrupt_before=["human_approval"]
150
- )
151
-
152
- logging.info(f"🎉 Successfully added country: {country_code}")
153
- return True
154
 
155
- except Exception as e:
156
- logging.error(f"❌ Failed to add country {country_code}: {e}")
157
- return False
158
-
159
- async def _perform_health_check(self):
160
- """Perform health check after initialization"""
161
- try:
162
- health_status = await self.health_check()
163
-
164
- unhealthy_components = [k for k, v in health_status.get('components', {}).items() if not v]
165
- if unhealthy_components:
166
- logging.warning(f"⚠️ Unhealthy components: {unhealthy_components}")
167
-
168
- except Exception as e:
169
- logging.warning(f"⚠️ Health check failed: {e}")
170
-
171
- async def health_check(self) -> Dict[str, Any]:
172
- """Comprehensive system health check"""
173
- health_status = {
174
- "system_initialized": self.initialized,
175
- "mongodb_connected": self.mongo_client.client is not None,
176
- "postgres_healthy": {},
177
- "interrupt_enabled": True,
178
- "available_countries": list(self.country_retrievers.keys()),
179
- "components": {
180
- "router": self.router is not None,
181
- "llm": self.llm is not None,
182
- "graph": self.graph is not None,
183
- "chat_manager": self.chat_manager is not None,
184
- "country_retrievers": len(self.country_retrievers) > 0
185
- },
186
- "timestamp": datetime.now().isoformat(),
187
- "settings": {
188
- "chat_model": settings.CHAT_MODEL,
189
- "embedding_model": settings.EMBEDDING_MODEL,
190
- "max_search_results": settings.MAX_SEARCH_RESULTS
191
- }
192
- }
193
-
194
- # Test MongoDB connection
195
- if health_status["mongodb_connected"]:
196
- try:
197
- self.mongo_client.client.admin.command('ping')
198
- health_status["mongodb_ping"] = True
199
- except Exception as e:
200
- health_status["mongodb_ping"] = False
201
- health_status["mongodb_error"] = str(e)
202
 
203
- # Test PostgreSQL connection
204
- if hasattr(self.postgres_checkpointer, 'health_check'):
205
- postgres_health = await self.postgres_checkpointer.health_check()
206
- health_status["postgres_healthy"] = postgres_health
 
207
 
208
- return health_status
209
-
210
- async def chat(self, message: str, session_id: str = None, context: dict = None) -> str:
211
- """Public chat interface"""
212
- if not self.initialized:
213
- raise RuntimeError("System not initialized. Call initialize() first.")
214
-
215
- if not message or not message.strip():
216
- raise ValueError("Message cannot be empty")
217
 
218
- try:
219
- # Prepare context
220
- ctx = context or {}
221
- ctx.setdefault("jurisdiction", "Unknown")
222
- ctx.setdefault("user_type", "general")
223
- ctx.setdefault("document_type", "legal")
224
- ctx.setdefault("detected_country", "unknown")
225
 
226
- session_id = session_id or f"cli_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
 
 
227
 
228
- return await self.chat_manager.chat(message, session_id, ctx)
 
 
 
 
 
 
 
 
 
 
229
 
230
- except Exception as e:
231
- logging.error(f"❌ Chat error for session {session_id}: {e}")
232
- return f"❌ Désolé, une erreur s'est produite lors du traitement de votre demande. Veuillez réessayer."
233
-
234
- def get_session_info(self, session_id: str) -> Dict[str, Any]:
235
- """Get information about a specific session"""
236
- if not self.initialized:
237
- raise RuntimeError("System not initialized")
238
- return self.chat_manager.get_session_stats(session_id)
239
-
240
- def get_global_stats(self) -> Dict[str, Any]:
241
- """Get global system statistics"""
242
- if not self.initialized:
243
- raise RuntimeError("System not initialized")
244
- return self.chat_manager.get_global_stats()
245
-
246
- def get_available_countries(self) -> List[str]:
247
- """Get list of available countries"""
248
- return list(self.country_retrievers.keys())
249
-
250
- async def cleanup(self):
251
- """Cleanup resources"""
252
- try:
253
- if self.mongo_client:
254
- self.mongo_client.close()
255
- if self.postgres_checkpointer:
256
- await self.postgres_checkpointer.close()
257
- logging.info("✅ System cleanup completed")
258
- except Exception as e:
259
- logging.error(f"❌ Error during cleanup: {e}")
260
-
261
- def _print_system_info(self):
262
- """Print system configuration information"""
263
- countries = list(self.country_retrievers.keys())
264
- print("\n" + "="*60)
265
- print("🚀 SCALABLE MULTI-COUNTRY LEGAL RAG SYSTEM")
266
- print("="*60)
267
- print(f"🌍 Available Countries: {', '.join(countries) if countries else 'None'}")
268
- print(f"🤖 AI Model: {settings.CHAT_MODEL}")
269
- print(f"💾 Database: MongoDB + PostgreSQL")
270
- print(f"🔍 Vector Search: {settings.EMBEDDING_MODEL}")
271
- print(f"⏸️ Interrupt Support: ENABLED")
272
- print(f"🌡️ Temperature: {settings.CHAT_TEMPERATURE}")
273
- print(f"📝 Max Tokens: {settings.CHAT_MAX_TOKENS}")
274
- print("="*60)
275
-
276
-
277
- class InterruptTester:
278
- """Specialized tester for human approval interrupts"""
279
-
280
- def __init__(self, system: MultiCountryLegalRAGSystem):
281
- self.system = system
282
- self.test_results = []
283
-
284
- async def test_assistance_workflow(self, test_name: str,
285
- user_query: str,
286
- user_email: str,
287
- user_description: str,
288
- moderator_response: str) -> Dict[str, Any]:
289
- """Test the complete assistance workflow with interrupt"""
290
- print(f"\n🧪 Interrupt Test: {test_name}")
291
- print(f"📝 User Query: {user_query}")
292
 
293
- # session_id = f"test_{datetime.now().strftime('%H%M%S%f')}"
294
- session_id = f"interactive_{uuid.uuid4().hex[:8]}"
295
- current_response = ""
 
 
 
 
 
 
296
 
297
  try:
298
- # Step 1: Initial request
299
- print("1️⃣ Step 1: Initial assistance request...")
300
- current_response = await self.system.chat(user_query, session_id)
301
- print(f"🤖 Response: {current_response[:150]}...")
302
 
303
- # Step 2: Email collection
304
- if user_email and any(keyword in current_response.lower() for keyword in ["email", "adresse", "@"]):
305
- print(f"2️⃣ Step 2: Providing email: {user_email}")
306
- current_response = await self.system.chat(user_email, session_id)
307
- print(f"🤖 Response: {current_response[:150]}...")
 
308
 
309
- # Step 3: Description collection
310
- if user_description and any(keyword in current_response.lower() for keyword in ["description", "décrire", "besoin"]):
311
- print(f"3️⃣ Step 3: Providing description: {user_description[:50]}...")
312
- current_response = await self.system.chat(user_description, session_id)
313
- print(f"🤖 Response: {current_response[:150]}...")
314
 
315
- # Step 4: Confirmation
316
- if any(keyword in current_response.lower() for keyword in ["confirmer", "confirmation", "oui/non"]):
317
- print("4️⃣ Step 4: Confirming request...")
318
- current_response = await self.system.chat("oui", session_id)
319
- print(f"🤖 Response: {current_response[:150]}...")
320
 
321
- # Step 5: Check for interrupt
322
- interrupt_detected = self._check_for_interrupt(current_response, session_id)
 
 
 
 
323
 
324
- if interrupt_detected:
325
- print("⏸️ INTERRUPT DETECTED! Waiting for moderator...")
326
-
327
- # Step 6: Moderator decision
328
- print(f"👨‍⚖️ Moderator: {moderator_response}")
329
- final_response = await self.system.chat(moderator_response, session_id)
330
- print(f"✅ Final Response: {final_response[:200]}...")
331
-
332
- result = {
333
- "test_name": test_name,
334
- "status": "PASS",
335
- "interrupt_detected": True,
336
- "moderator_decision": moderator_response,
337
- "final_response": final_response,
338
- "session_id": session_id
339
- }
340
- else:
341
- print("⚠️ No interrupt detected in workflow")
342
- result = {
343
- "test_name": test_name,
344
- "status": "FAIL",
345
- "interrupt_detected": False,
346
- "moderator_decision": None,
347
- "final_response": current_response,
348
- "error": "Interrupt not triggered",
349
- "session_id": session_id
350
- }
351
 
352
- self.test_results.append(result)
353
- return result
354
 
 
 
355
  except Exception as e:
356
- logging.error(f" Test error: {e}")
357
- error_result = {
358
- "test_name": test_name,
359
- "status": "ERROR",
360
- "interrupt_detected": False,
361
- "moderator_decision": None,
362
- "final_response": current_response,
363
- "error": str(e),
364
- "session_id": session_id
365
- }
366
- self.test_results.append(error_result)
367
- return error_result
368
-
369
- def _check_for_interrupt(self, response: str, session_id: str) -> bool:
370
- """Enhanced interrupt detection"""
371
- interrupt_indicators = [
372
- "APPROBATION", "APPROVAL", "HUMAN", "MODERATOR",
373
- "DÉCISION", "DECISION", "APPROUVER", "REJETER"
374
- ]
375
-
376
- if any(indicator in response.upper() for indicator in interrupt_indicators):
377
- return True
378
-
379
- if (hasattr(self.system.chat_manager, 'pending_interrupts') and
380
- session_id in self.system.chat_manager.pending_interrupts):
381
- return True
382
-
383
- return False
384
 
385
- def print_summary(self):
386
- """Print test summary"""
387
- print("\n" + "="*80)
388
- print("📊 INTERRUPT TEST SUMMARY")
389
- print("="*80)
390
-
391
- total = len(self.test_results)
392
- passed = len([r for r in self.test_results if r["status"] == "PASS"])
393
- failed = len([r for r in self.test_results if r["status"] == "FAIL"])
394
- errors = len([r for r in self.test_results if r["status"] == "ERROR"])
395
-
396
- print(f"📈 Total Tests: {total}")
397
- print(f"✅ Passed: {passed}")
398
- print(f"❌ Failed: {failed}")
399
- print(f"🚨 Errors: {errors}")
400
-
401
- if passed > 0:
402
- print(f"\n🎉 Successful Tests:")
403
- for result in self.test_results:
404
- if result["status"] == "PASS":
405
- print(f" - {result['test_name']}")
406
 
407
- if failed > 0 or errors > 0:
408
- print(f"\n💥 Failed/Error Tests:")
409
- for result in self.test_results:
410
- if result["status"] in ["FAIL", "ERROR"]:
411
- print(f" - {result['test_name']}: {result.get('error', 'Unknown error')}")
412
 
413
- print("="*80)
414
-
415
-
416
- async def run_interrupt_tests():
417
- """Run specialized tests for human approval interrupts"""
418
- system = MultiCountryLegalRAGSystem()
419
- tester = InterruptTester(system)
 
 
 
 
 
 
 
 
420
 
421
- try:
422
- print("🚀 Initializing system...")
423
- success = await system.initialize()
424
- if not success:
425
- print("❌ System initialization failed")
426
- return
427
-
428
- print("\n🧪 STARTING INTERRUPT TESTS")
429
- print("="*60)
430
-
431
- test_scenarios = [
432
- {
433
- "name": "Complete Workflow - Approve",
434
- "user_query": "Je veux parler a un avocat",
435
- "user_email": "test@example.com",
436
- "user_description": "Consultation pour divorce au Benin",
437
- "moderator_response": "approve Demande legitime"
438
- },
439
- {
440
- "name": "Complete Workflow - Reject",
441
- "user_query": "Contactez-moi",
442
- "user_email": "test2@example.com",
443
- "user_description": "J'ai besoin d'aide",
444
- "moderator_response": "reject Description trop vague"
445
- }
446
- ]
447
-
448
- for scenario in test_scenarios:
449
- await tester.test_assistance_workflow(
450
- scenario["name"],
451
- scenario["user_query"],
452
- scenario["user_email"],
453
- scenario["user_description"],
454
- scenario["moderator_response"]
455
- )
456
- await asyncio.sleep(1)
457
-
458
- tester.print_summary()
459
-
460
- except Exception as e:
461
- logging.error(f"❌ Error during testing: {e}")
462
- finally:
463
- await system.cleanup()
464
-
465
-
466
- async def interactive_mode():
467
- """Run interactive chat mode"""
468
- system = MultiCountryLegalRAGSystem()
469
 
470
- try:
471
- print("🚀 Initializing system...")
472
- success = await system.initialize()
473
- if not success:
474
- print("❌ System initialization failed")
475
  return
476
-
477
- print("\n🎯 INTERACTIVE MODE - SCALABLE SYSTEM")
478
- print("="*60)
479
- print("Commands:")
480
- print(" 'quit' - Exit")
481
- print(" 'stats' - Show statistics")
482
- print(" 'health' - Health check")
483
- print(" 'countries' - List available countries")
484
- print(" 'session' - Session info")
485
- print("="*60)
486
-
487
- session_id = f"interactive_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
488
- print(f"Session ID: {session_id}")
489
- print(f"Available: {', '.join(system.get_available_countries())}\n")
490
 
491
- while True:
492
- try:
493
- user_input = input("👤 You: ").strip()
494
-
495
- if user_input.lower() in ['quit', 'exit', 'q']:
496
- break
497
- elif user_input.lower() == 'stats':
498
- stats = system.get_global_stats()
499
- print(f"\n📊 Statistics:")
500
- print(f" Total Queries: {stats.get('total_queries', 0)}")
501
- print(f" Active Sessions: {stats.get('active_sessions', 0)}")
502
- print(f" Pending Interrupts: {stats.get('pending_interrupts', 0)}")
503
- continue
504
- elif user_input.lower() == 'health':
505
- health = await system.health_check()
506
- print(f"\n❤️ System Health:")
507
- print(f" Status: {'✅ HEALTHY' if health['system_initialized'] else '❌ UNHEALTHY'}")
508
- print(f" Countries: {len(health['available_countries'])} available")
509
- print(f" MongoDB: {'✅ Connected' if health['mongodb_connected'] else '❌ Disconnected'}")
510
- continue
511
- elif user_input.lower() == 'countries':
512
- countries = system.get_available_countries()
513
- print(f"\n🌍 Available Countries: {', '.join(countries) if countries else 'None'}")
514
- continue
515
- elif user_input.lower() == 'session':
516
- info = system.get_session_info(session_id)
517
- print(f"\n📋 Session Info:")
518
- print(f" Queries: {info.get('query_count', 0)}")
519
- print(f" Avg Time: {info.get('average_processing_time', 0):.2f}s")
520
- continue
521
- elif not user_input:
522
- continue
523
-
524
- start_time = time.time()
525
- response = await system.chat(user_input, session_id)
526
- response_time = time.time() - start_time
527
-
528
- print(f"🤖 Assistant ({response_time:.2f}s): {response}\n")
529
-
530
- # Check for interrupt
531
- if (hasattr(system.chat_manager, 'pending_interrupts') and
532
- session_id in system.chat_manager.pending_interrupts):
533
- print("⏸️ 💡 SYSTEM PAUSED - Next message treated as moderator decision\n")
534
-
535
- except KeyboardInterrupt:
536
- print("\n👋 Goodbye!")
537
- break
538
- except Exception as e:
539
- print(f"❌ Error: {str(e)}\n")
540
-
541
- finally:
542
- await system.cleanup()
543
-
544
 
545
- async def health_check_mode():
546
- """Run system health check only"""
547
- system = MultiCountryLegalRAGSystem()
548
 
549
- try:
550
- print("🔍 Performing health check...")
551
- success = await system.initialize()
552
-
553
- if success:
554
- health = await system.health_check()
555
- print("\n" + "="*50)
556
- print("📋 SYSTEM HEALTH REPORT")
557
- print("="*50)
558
- print(f" System Initialized: {health['system_initialized']}")
559
- print(f"🌍 Available Countries: {len(health['available_countries'])}")
560
- print(f"💾 MongoDB: {'✅ Connected' if health['mongodb_connected'] else '❌ Disconnected'}")
561
- print(f"⏸️ Interrupt Support: {'✅ Enabled' if health['interrupt_enabled'] else '❌ Disabled'}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
562
 
563
- print(f"\n🔧 Components:")
564
- for component, status in health['components'].items():
565
- print(f" {component}: {'✅ OK' if status else '❌ Missing'}")
566
-
567
- all_healthy = (health['system_initialized'] and
568
- health['mongodb_connected'] and
569
- all(health['components'].values()))
570
- print(f"\n🎯 Overall Status: {'✅ HEALTHY' if all_healthy else '❌ UNHEALTHY'}")
571
 
572
- else:
573
- print("❌ System initialization failed")
574
 
575
- finally:
576
- await system.cleanup()
577
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
578
 
579
- async def quick_test_mode():
580
- """Run a quick single test"""
581
- system = MultiCountryLegalRAGSystem()
582
 
583
- try:
584
- print("🚀 Quick Test Mode")
585
- print("Initializing system...")
586
- success = await system.initialize()
587
- if not success:
588
- print("❌ System initialization failed")
589
- return
590
-
591
- test_query = "Bonjour, quelle est la procedure pour un divorce au Benin?"
592
- session_id = "quick_test"
593
-
594
- print(f"\n🧪 Testing: {test_query}")
595
- start_time = time.time()
596
- response = await system.chat(test_query, session_id)
597
- response_time = time.time() - start_time
598
-
599
- print(f" Response ({response_time:.2f}s): {response}")
600
-
601
- print(f"\n📊 System Info:")
602
- print(f" Available Countries: {', '.join(system.get_available_countries())}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
603
 
604
- except Exception as e:
605
- print(f" Quick test failed: {e}")
606
- finally:
607
- await system.cleanup()
608
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
609
 
610
- if __name__ == "__main__":
 
611
  import argparse
612
 
613
  parser = argparse.ArgumentParser(
614
- description="🚀 Scalable Multi-Country Legal RAG System"
 
 
 
 
 
 
 
 
615
  )
616
 
617
  parser.add_argument(
618
  "--mode",
619
- choices=["interactive", "health", "interrupt", "quick"],
620
  default="interactive",
621
- help="Run mode (default: interactive)"
 
 
 
 
 
 
622
  )
623
 
624
  args = parser.parse_args()
625
 
626
- if args.mode == "interactive":
627
- asyncio.run(interactive_mode())
628
- elif args.mode == "health":
629
- asyncio.run(health_check_mode())
630
- elif args.mode == "interrupt":
631
- asyncio.run(run_interrupt_tests())
632
- elif args.mode == "quick":
633
- asyncio.run(quick_test_mode())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  #!/usr/bin/env python3
2
  """
3
+ Multi-Country Legal RAG System - Interactive Testing Mode with Human-in-the-Loop Support
 
4
  """
 
5
  import sys
6
  from pathlib import Path
7
+ sys.path.insert(0, str(Path(__file__).parent))
8
 
9
  import asyncio
10
  import logging
11
  import time
12
  from datetime import datetime
13
+ from typing import Optional
14
+
15
+ from core.system_initializer import setup_system
16
+
17
+ # Setup comprehensive logging
18
+ logging.basicConfig(
19
+ level=logging.INFO,
20
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
21
+ handlers=[
22
+ logging.StreamHandler(sys.stdout),
23
+ logging.FileHandler('legal_rag_system.log', mode='a')
24
+ ]
25
+ )
26
+ logger = logging.getLogger(__name__)
27
+
28
+ class LegalRAGTester:
29
+ """
30
+ Interactive tester for the Legal RAG system with support for:
31
+ - Human-in-the-loop interrupts
32
+ - Session management
33
+ - Statistics tracking
34
+ - Assistance workflow testing
35
+ """
36
 
37
  def __init__(self):
38
+ self.system = None
 
 
 
 
 
 
 
 
 
 
39
  self.chat_manager = None
40
+ self.session_id = f"test_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
41
+ self.message_count = 0
42
+
43
+ async def initialize(self):
44
+ """Initialize the Legal RAG system"""
45
+ print("🚀 Initializing Legal RAG System...")
46
  try:
47
+ self.system = await setup_system()
48
+ self.chat_manager = self.system["chat_manager"]
 
 
 
 
 
 
 
49
 
50
+ print("✅ System initialized successfully!")
 
51
 
52
+ # Print system info
53
+ info = self.chat_manager.get_checkpointer_info()
54
+ stats = self.chat_manager.get_global_stats()
55
 
56
+ print(f"📊 Checkpointer: {info['type']} ({info['description']})")
57
+ print(f"💾 Persistent: {info['persistent']}")
58
+ print(f"🎯 Session ID: {self.session_id}")
59
+ print(f"🌍 Available countries: Bénin, Madagascar")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  return True
62
 
63
  except Exception as e:
64
+ logger.exception("Failed to initialize system")
65
+ print(f"❌ Initialization failed: {e}")
 
66
  return False
67
+
68
+ def _handle_interrupt(self, interrupt_value: dict) -> str:
69
+ """
70
+ Handle human-in-the-loop interrupts synchronously.
71
+ This is called when the graph needs human approval.
 
 
 
 
 
 
 
 
 
 
 
72
 
73
+ Args:
74
+ interrupt_value: Interrupt data containing message and context
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ Returns:
77
+ Moderator's decision (approve/reject)
78
+ """
79
+ print("\n" + "="*70)
80
+ print("🔒 HUMAN APPROVAL REQUIRED")
81
+ print("="*70)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
+ # Extract interrupt information
84
+ message = interrupt_value.get("message", "")
85
+ email = interrupt_value.get("user_email", "N/A")
86
+ country = interrupt_value.get("country", "N/A")
87
+ description = interrupt_value.get("description", "N/A")
88
 
89
+ # Display formatted approval request
90
+ print(message)
91
+ print()
 
 
 
 
 
 
92
 
93
+ # Get moderator decision with validation
94
+ while True:
95
+ moderator_input = input("🔐 Moderator Decision: ").strip()
 
 
 
 
96
 
97
+ if not moderator_input:
98
+ print("⚠️ Please provide a decision (approve/reject)")
99
+ continue
100
 
101
+ # Validate input
102
+ input_lower = moderator_input.lower()
103
+ if any(keyword in input_lower for keyword in ["approve", "approuver", "accept"]):
104
+ print("✅ Request APPROVED")
105
+ break
106
+ elif any(keyword in input_lower for keyword in ["reject", "rejeter", "refuse"]):
107
+ print("❌ Request REJECTED")
108
+ break
109
+ else:
110
+ print("⚠️ Invalid decision. Use 'approve [reason]' or 'reject [reason]'")
111
+ continue
112
 
113
+ print("="*70 + "\n")
114
+ return moderator_input
115
+
116
+ async def chat(self, message: str) -> Optional[str]:
117
+ """
118
+ Send a chat message and get response with interrupt handling.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
+ Args:
121
+ message: User message
122
+
123
+ Returns:
124
+ Assistant response or None on error
125
+ """
126
+ if not self.chat_manager:
127
+ print("❌ System not initialized. Please restart.")
128
+ return None
129
 
130
  try:
131
+ start_time = time.time()
132
+ self.message_count += 1
 
 
133
 
134
+ # Send message with interrupt handler
135
+ response = await self.chat_manager.chat(
136
+ message=message,
137
+ session_id=self.session_id,
138
+ interrupt_handler=self._handle_interrupt # Enable synchronous interrupt handling
139
+ )
140
 
141
+ response_time = time.time() - start_time
 
 
 
 
142
 
143
+ # Display response
144
+ print(f"\n🤖 Assistant ({response_time:.2f}s):")
 
 
 
145
 
146
+ # Format multi-line responses
147
+ for line in response.split('\n'):
148
+ if line.strip():
149
+ print(f" {line}")
150
+ else:
151
+ print()
152
 
153
+ # Check for pending interrupts (async mode fallback)
154
+ if self.chat_manager.has_pending_interrupt(self.session_id):
155
+ print("\n⏸️ 💡 System paused - waiting for moderator decision")
156
+ print(" Your next message will be treated as approval/rejection")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
+ return response
 
159
 
160
+ except KeyboardInterrupt:
161
+ raise # Re-raise to allow graceful shutdown
162
  except Exception as e:
163
+ logger.exception(f"Error processing message: {message}")
164
+ print(f"❌ Error: {str(e)}")
165
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
+ def display_stats(self):
168
+ """Display current system statistics"""
169
+ if not self.chat_manager:
170
+ print(" System not initialized")
171
+ return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
+ stats = self.chat_manager.get_global_stats()
174
+ routing = stats.get('routing_stats', {})
 
 
 
175
 
176
+ print("\n" + "="*70)
177
+ print("📊 SYSTEM STATISTICS")
178
+ print("="*70)
179
+ print(f"💬 Session Messages: {self.message_count}")
180
+ print(f"🔢 Total System Queries: {stats.get('total_queries', 0)}")
181
+ print(f"👥 Active Sessions: {stats.get('active_sessions', 0)}")
182
+ print(f"⏸️ Pending Interrupts: {stats.get('pending_interrupts', 0)}")
183
+ print()
184
+ print("🔀 Routing Statistics:")
185
+ print(f" 📍 Bénin queries: {routing.get('benin', 0)}")
186
+ print(f" 📍 Madagascar queries: {routing.get('madagascar', 0)}")
187
+ print(f" ❓ Unclear queries: {routing.get('unclear', 0)}")
188
+ print()
189
+ print(f"💾 Storage: {stats.get('type', 'unknown')} - {stats.get('description', '')}")
190
+ print("="*70 + "\n")
191
 
192
+ def display_help(self):
193
+ """Display help information"""
194
+ print("\n" + "="*70)
195
+ print("📋 AVAILABLE COMMANDS")
196
+ print("="*70)
197
+ print(" quit, exit, q - Exit the program")
198
+ print(" stats, s - Show system statistics")
199
+ print(" clear, cls - Clear the screen")
200
+ print(" help, h, ? - Show this help message")
201
+ print(" history - Show conversation history")
202
+ print(" reset - Reset current session")
203
+ print()
204
+ print("💡 TESTING TIPS:")
205
+ print(" - Legal queries: Ask about laws in Bénin or Madagascar")
206
+ print(" - Assistance: Say 'je veux parler à un avocat'")
207
+ print(" - Follow the prompts for email and description")
208
+ print(" - You'll be prompted for approval when needed")
209
+ print("="*70 + "\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
+ async def display_history(self):
212
+ """Display conversation history"""
213
+ if not self.chat_manager:
214
+ print("❌ System not initialized")
 
215
  return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
+ try:
218
+ history = await self.chat_manager.get_conversation_history(self.session_id)
219
+
220
+ if not history:
221
+ print("\n💬 No conversation history yet\n")
222
+ return
223
+
224
+ print("\n" + "="*70)
225
+ print("💬 CONVERSATION HISTORY")
226
+ print("="*70)
227
+
228
+ for i, msg in enumerate(history, 1):
229
+ role = "👤 User" if msg.type == "human" else "🤖 Assistant"
230
+ content = msg.content[:100] + "..." if len(msg.content) > 100 else msg.content
231
+ print(f"{i}. {role}: {content}")
232
+
233
+ print("="*70 + "\n")
234
+
235
+ except Exception as e:
236
+ logger.exception("Error displaying history")
237
+ print(f"❌ Error: {e}\n")
238
+
239
+ def reset_session(self):
240
+ """Reset the current session"""
241
+ old_session = self.session_id
242
+ self.session_id = f"test_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
243
+ self.message_count = 0
244
+ print(f"\n🔄 Session reset")
245
+ print(f" Old: {old_session}")
246
+ print(f" New: {self.session_id}\n")
247
+
248
+ def clear_screen(self):
249
+ """Clear the terminal screen"""
250
+ print("\033[H\033[J", end="")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
 
252
+ async def interactive_mode():
253
+ """Run the system in interactive mode with full interrupt support"""
254
+ tester = LegalRAGTester()
255
 
256
+ # Initialize system
257
+ if not await tester.initialize():
258
+ print("❌ Failed to initialize. Exiting.")
259
+ return
260
+
261
+ # Display welcome message
262
+ print("\n" + "="*70)
263
+ print("🎯 LEGAL RAG INTERACTIVE TESTING MODE")
264
+ print("="*70)
265
+ print("Type 'help' for available commands")
266
+ print("="*70)
267
+
268
+ # Quick test questions
269
+ test_questions = [
270
+ "Bonjour, comment ça va?",
271
+ "Quelle est la procédure de divorce au Bénin?",
272
+ "Je veux parler à un avocat spécialisé",
273
+ "Quels sont les droits des enfants à Madagascar?",
274
+ "Quelles sont les conditions pour se marier au Bénin?",
275
+ ]
276
+
277
+ print("\n💡 Quick test questions:")
278
+ for i, question in enumerate(test_questions, 1):
279
+ print(f" {i}. {question}")
280
+ print()
281
+
282
+ # Main interaction loop
283
+ while True:
284
+ try:
285
+ # Get user input
286
+ user_input = input("👤 You: ").strip()
287
 
288
+ if not user_input:
289
+ continue
 
 
 
 
 
 
290
 
291
+ # Handle commands
292
+ cmd_lower = user_input.lower()
293
 
294
+ # Exit commands
295
+ if cmd_lower in ['quit', 'exit', 'q']:
296
+ print("\n👋 Goodbye! Thank you for testing.")
297
+ break
298
+
299
+ # Stats command
300
+ elif cmd_lower in ['stats', 's']:
301
+ tester.display_stats()
302
+ continue
303
+
304
+ # Clear screen
305
+ elif cmd_lower in ['clear', 'cls']:
306
+ tester.clear_screen()
307
+ continue
308
+
309
+ # Help command
310
+ elif cmd_lower in ['help', 'h', '?']:
311
+ tester.display_help()
312
+ continue
313
+
314
+ # History command
315
+ elif cmd_lower == 'history':
316
+ await tester.display_history()
317
+ continue
318
+
319
+ # Reset session
320
+ elif cmd_lower == 'reset':
321
+ tester.reset_session()
322
+ continue
323
+
324
+ # Regular chat message
325
+ else:
326
+ await tester.chat(user_input)
327
+
328
+ except KeyboardInterrupt:
329
+ print("\n\n⚠️ Interrupted by user")
330
+ print("Type 'quit' to exit or continue chatting\n")
331
+ continue
332
+
333
+ except Exception as e:
334
+ logger.exception("Unexpected error in main loop")
335
+ print(f"\n❌ Unexpected error: {e}\n")
336
 
337
+ async def demo_mode():
338
+ """Run automated demo of the system capabilities"""
339
+ tester = LegalRAGTester()
340
 
341
+ print("🚀 Running Automated Demo...")
342
+
343
+ if not await tester.initialize():
344
+ print("❌ Failed to initialize. Exiting demo.")
345
+ return
346
+
347
+ demo_scenarios = [
348
+ {
349
+ "name": "Greeting",
350
+ "messages": ["Bonjour, je m'appelle Test"]
351
+ },
352
+ {
353
+ "name": "Legal Query - Bénin",
354
+ "messages": ["Quelle est la procédure de divorce au Bénin?"]
355
+ },
356
+ {
357
+ "name": "Legal Query - Madagascar",
358
+ "messages": ["Quels sont les droits de succession à Madagascar?"]
359
+ },
360
+ {
361
+ "name": "Conversation Repair",
362
+ "messages": ["Peux-tu répéter plus simplement?"]
363
+ },
364
+ {
365
+ "name": "Summary Request",
366
+ "messages": ["Résume notre conversation"]
367
+ }
368
+ ]
369
+
370
+ print("\n" + "="*70)
371
+ print("🧪 DEMO SCENARIOS")
372
+ print("="*70 + "\n")
373
+
374
+ for i, scenario in enumerate(demo_scenarios, 1):
375
+ print(f"\n{'='*70}")
376
+ print(f"📋 Scenario {i}/{len(demo_scenarios)}: {scenario['name']}")
377
+ print('='*70)
378
 
379
+ for message in scenario['messages']:
380
+ print(f"\n🧪 Testing: '{message}'")
381
+ await tester.chat(message)
382
+ await asyncio.sleep(1.5) # Pause between messages
383
+
384
+ print("\n" + "="*70)
385
+ print("✅ DEMO COMPLETED")
386
+ print("="*70)
387
+ tester.display_stats()
388
+
389
+ async def test_assistance_workflow():
390
+ """Test the complete assistance workflow with interrupts"""
391
+ tester = LegalRAGTester()
392
+
393
+ print("🧪 Testing Assistance Workflow with Human Approval...")
394
+
395
+ if not await tester.initialize():
396
+ print("❌ Failed to initialize. Exiting test.")
397
+ return
398
+
399
+ print("\n" + "="*70)
400
+ print("🔬 ASSISTANCE WORKFLOW TEST")
401
+ print("="*70 + "\n")
402
+
403
+ # Simulated assistance workflow
404
+ workflow_steps = [
405
+ "Je veux parler à un avocat",
406
+ "test@example.com",
407
+ "J'ai besoin d'aide pour un problème de divorce au Bénin",
408
+ "oui"
409
+ ]
410
+
411
+ step_names = [
412
+ "1. Initiate assistance request",
413
+ "2. Provide email",
414
+ "3. Describe situation",
415
+ "4. Confirm request"
416
+ ]
417
+
418
+ for step_name, message in zip(step_names, workflow_steps):
419
+ print(f"\n📍 {step_name}")
420
+ print(f" Message: '{message}'")
421
+ await tester.chat(message)
422
+ await asyncio.sleep(1)
423
+
424
+ print("\n" + "="*70)
425
+ print("✅ WORKFLOW TEST COMPLETED")
426
+ print("="*70)
427
+ tester.display_stats()
428
 
429
+ def main():
430
+ """Main entry point with argument parsing"""
431
  import argparse
432
 
433
  parser = argparse.ArgumentParser(
434
+ description="Legal RAG System - Interactive Testing with Human-in-the-Loop Support",
435
+ formatter_class=argparse.RawDescriptionHelpFormatter,
436
+ epilog="""
437
+ Examples:
438
+ python main.py # Interactive mode (default)
439
+ python main.py --mode demo # Automated demo
440
+ python main.py --mode test # Test assistance workflow
441
+ python main.py --debug # Enable debug logging
442
+ """
443
  )
444
 
445
  parser.add_argument(
446
  "--mode",
447
+ choices=["interactive", "demo", "test"],
448
  default="interactive",
449
+ help="Run mode: interactive (default), demo, or test"
450
+ )
451
+
452
+ parser.add_argument(
453
+ "--debug",
454
+ action="store_true",
455
+ help="Enable debug logging"
456
  )
457
 
458
  args = parser.parse_args()
459
 
460
+ # Adjust logging level if debug
461
+ if args.debug:
462
+ logging.getLogger().setLevel(logging.DEBUG)
463
+ print("🐛 Debug logging enabled\n")
464
+
465
+ # Run selected mode
466
+ try:
467
+ if args.mode == "interactive":
468
+ asyncio.run(interactive_mode())
469
+ elif args.mode == "demo":
470
+ asyncio.run(demo_mode())
471
+ elif args.mode == "test":
472
+ asyncio.run(test_assistance_workflow())
473
+
474
+ except KeyboardInterrupt:
475
+ print("\n\n👋 Program interrupted by user. Goodbye!")
476
+
477
+ except Exception as e:
478
+ logger.exception("Fatal error occurred")
479
+ print(f"\n❌ Fatal error: {e}")
480
+ sys.exit(1)
481
+
482
+ if __name__ == "__main__":
483
+ main()
requirements.txt CHANGED
@@ -6,6 +6,7 @@
6
  # Web Framework
7
  fastapi==0.118.2
8
  uvicorn[standard]==0.37.0
 
9
 
10
  # LangChain & LangGraph - Core only (dependencies will pull others)
11
  langgraph==0.6.8
 
6
  # Web Framework
7
  fastapi==0.118.2
8
  uvicorn[standard]==0.37.0
9
+ slowapi==0.1.9
10
 
11
  # LangChain & LangGraph - Core only (dependencies will pull others)
12
  langgraph==0.6.8