kn29 commited on
Commit
5cd1b9f
·
verified ·
1 Parent(s): 4c6aa01

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +230 -887
app.py CHANGED
@@ -1,19 +1,22 @@
1
- from fastapi import FastAPI, HTTPException
2
- from fastapi.middleware.cors import CORSMiddleware
3
- from pydantic import BaseModel, Field
4
- import pymongo
5
- import os
6
- import numpy as np
7
- from datetime import datetime, timedelta
8
- import logging
9
- import traceback
10
- from typing import Dict, Any, Optional, List
11
  import asyncio
 
 
 
 
12
  import threading
13
  import time
 
14
  from collections import defaultdict
15
  from contextlib import asynccontextmanager
16
- import sys
 
 
 
 
 
 
 
 
17
 
18
  try:
19
  import faiss
@@ -21,6 +24,13 @@ try:
21
  except ImportError:
22
  FAISS_AVAILABLE = False
23
 
 
 
 
 
 
 
 
24
  # Configure comprehensive logging
25
  logging.basicConfig(
26
  level=logging.INFO,
@@ -32,32 +42,30 @@ logging.basicConfig(
32
  )
33
  logger = logging.getLogger(__name__)
34
 
35
- # Global state
36
  MONGO_CLIENT = None
37
  DB = None
38
- RAG_INITIALIZED = False
39
- RAG_MODULE = None
40
  APP_STATE = {
41
  "startup_time": None,
42
  "mongodb_connected": False,
43
- "rag_ready": False,
44
  "active_sessions": 0,
45
  "total_queries": 0,
46
  "errors": []
47
  }
48
 
49
- # Configuration - Session memory management
50
- CLEANUP_INTERVAL = 1800 # Run cleanup every 30 minutes (1800 seconds)
51
- STORE_TTL = 1800 # Sessions expire after 30 minutes of inactivity (1800 seconds)
 
 
 
 
52
 
53
- # You can adjust these values:
54
- # STORE_TTL = 900 # 15 minutes
55
- # STORE_TTL = 3600 # 1 hour
56
- # STORE_TTL = 7200 # 2 hours
57
-
58
- # Request/Response models with validation
59
  class ChatRequest(BaseModel):
60
- message: str = Field(..., min_length=1, max_length=5000, description="User's query message")
61
 
62
  class ChatResponse(BaseModel):
63
  success: bool
@@ -71,7 +79,7 @@ class ChatResponse(BaseModel):
71
  error_details: Optional[str] = None
72
 
73
  class InitRequest(BaseModel):
74
- force_reload: bool = Field(default=False, description="Force reload session even if already loaded")
75
 
76
  class InitResponse(BaseModel):
77
  success: bool
@@ -85,996 +93,331 @@ class InitResponse(BaseModel):
85
  class HealthResponse(BaseModel):
86
  status: str
87
  mongodb_connected: bool
88
- rag_initialized: bool
89
  faiss_available: bool
90
  active_sessions: int
91
  memory_usage: Dict[str, Any]
92
  uptime_seconds: float
93
  last_error: Optional[str] = None
94
 
 
 
95
  def create_session_logger(session_id: str):
96
- """Create a logger with session context"""
97
  return logging.LoggerAdapter(logger, {'session_id': session_id[:8]})
98
 
99
- def safe_import_rag():
100
- """Safely import RAG module with error handling"""
101
- global RAG_MODULE
102
- try:
103
- import rag
104
- RAG_MODULE = rag
105
- logger.info("RAG module imported successfully")
106
- return True
107
- except ImportError as e:
108
- logger.error(f"Failed to import RAG module: {e}")
109
- logger.error("Make sure rag.py is in the same directory and all dependencies are installed")
110
- return False
111
- except Exception as e:
112
- logger.error(f"Unexpected error importing RAG module: {e}")
113
- logger.error(traceback.format_exc())
114
- return False
115
-
116
  def connect_mongodb():
117
- """Initialize MongoDB connection with comprehensive error handling"""
118
  global MONGO_CLIENT, DB
119
-
120
  try:
121
  mongodb_url = os.getenv("MONGODB_URL", "mongodb://localhost:27017/")
122
- if not mongodb_url or mongodb_url == "mongodb://localhost:27017/":
123
- logger.warning("Using default MongoDB URL - set MONGODB_URL environment variable for production")
124
-
125
- logger.info(f"Connecting to MongoDB: {mongodb_url[:20]}...")
126
  MONGO_CLIENT = pymongo.MongoClient(
127
- mongodb_url,
128
- serverSelectionTimeoutMS=10000, # 10 second timeout
129
- connectTimeoutMS=10000,
130
- socketTimeoutMS=10000
131
  )
132
-
133
- # Test connection
134
  MONGO_CLIENT.admin.command('ping')
135
  DB = MONGO_CLIENT["legal_rag_system"]
136
-
137
- logger.info("Creating MongoDB indexes...")
138
- # Create indexes with error handling
139
- try:
140
- DB.chats.create_index("session_id", background=True)
141
- DB.chats.create_index("created_at", expireAfterSeconds=24*60*60, background=True)
142
- DB.chats.create_index([("session_id", 1), ("created_at", 1)], background=True)
143
- logger.info("MongoDB indexes created successfully")
144
- except Exception as idx_error:
145
- logger.warning(f"Index creation failed (non-critical): {idx_error}")
146
-
147
  APP_STATE["mongodb_connected"] = True
148
- logger.info("MongoDB connected and configured successfully")
149
  return True
150
-
151
- except pymongo.errors.ServerSelectionTimeoutError:
152
- logger.error("MongoDB connection timeout - check if MongoDB is running and accessible")
153
- return False
154
- except pymongo.errors.ConfigurationError as e:
155
- logger.error(f"MongoDB configuration error: {e}")
156
- return False
157
  except Exception as e:
158
  logger.error(f"MongoDB connection failed: {e}")
159
- logger.error(traceback.format_exc())
160
  return False
161
 
162
- def initialize_rag():
163
- """Initialize RAG system with comprehensive error handling"""
164
- global RAG_INITIALIZED
165
-
166
- if not RAG_MODULE:
167
- logger.error("RAG module not available - cannot initialize")
168
- return False
169
-
170
- if not FAISS_AVAILABLE:
171
- logger.error("FAISS library not available - RAG system requires FAISS")
172
  return False
173
-
174
  try:
175
  model_id = os.getenv("EMBEDDING_MODEL_ID", "sentence-transformers/all-MiniLM-L6-v2")
176
  groq_api_key = os.getenv("GROQ_API_KEY")
177
-
178
- logger.info(f"Initializing RAG system with embedding model: {model_id}")
179
-
180
- if groq_api_key:
181
- logger.info("Groq API key found - full RAG capabilities available")
182
- else:
183
- logger.warning("No Groq API key - some RAG features may be limited")
184
-
185
- # Initialize with timeout protection
186
- RAG_MODULE.initialize_models(model_id, groq_api_key)
187
-
188
- RAG_INITIALIZED = True
189
- APP_STATE["rag_ready"] = True
190
- logger.info("RAG system initialized successfully")
191
  return True
192
-
193
- except ImportError as e:
194
- logger.error(f"Missing dependencies for RAG initialization: {e}")
195
- return False
196
  except Exception as e:
197
- logger.error(f"RAG initialization failed: {e}")
198
- logger.error(traceback.format_exc())
199
  APP_STATE["errors"].append(f"RAG init failed: {str(e)}")
200
  return False
201
 
202
- def decode_embedding_safely(embedding_list: List[float]) -> np.ndarray:
203
- """Safely convert embedding from storage with validation"""
204
- try:
205
- if not embedding_list or not isinstance(embedding_list, list):
206
- raise ValueError("Invalid embedding data")
207
-
208
- embedding = np.array(embedding_list, dtype=np.float32)
209
-
210
- if embedding.size == 0:
211
- raise ValueError("Empty embedding")
212
-
213
- if np.isnan(embedding).any() or np.isinf(embedding).any():
214
- raise ValueError("Embedding contains invalid values")
215
-
216
- return embedding
217
-
218
- except Exception as e:
219
- logger.error(f"Failed to decode embedding: {e}")
220
- return np.array([])
221
-
222
  def load_session_from_mongodb(session_id: str) -> Dict[str, Any]:
223
- """Load session with comprehensive error handling and validation"""
224
  session_logger = create_session_logger(session_id)
225
-
226
  if not DB:
227
  raise ValueError("Database not connected")
228
-
229
- try:
230
- # Get and validate session metadata
231
- session_doc = DB.sessions.find_one({"session_id": session_id})
232
- if not session_doc:
233
- raise ValueError(f"Session {session_id} not found in database")
234
-
235
- session_status = session_doc.get("status")
236
- if session_status != "completed":
237
- raise ValueError(f"Session not ready - status: {session_status}")
238
-
239
- session_logger.info(f"Loading session: {session_doc.get('filename', 'unknown')}")
240
-
241
- # Load chunks with validation
242
- chunks_cursor = DB.chunks.find({"session_id": session_id}).sort("created_at", 1)
243
- chunks_list = list(chunks_cursor)
244
-
245
- if not chunks_list:
246
- raise ValueError(f"No chunks found for session {session_id}")
247
-
248
- session_logger.info(f"Found {len(chunks_list)} chunks")
249
-
250
- # Process chunks with validation
251
- processed_chunks = []
252
- embeddings_matrix = []
253
- failed_chunks = 0
254
-
255
- for i, chunk_doc in enumerate(chunks_list):
256
- try:
257
- # Validate required fields
258
- if 'text' not in chunk_doc or not chunk_doc['text'].strip():
259
- session_logger.warning(f"Chunk {i} missing or empty text")
260
- failed_chunks += 1
261
- continue
262
-
263
- # Decode embedding
264
- embedding_list = chunk_doc.get('embedding', [])
265
- embedding = decode_embedding_safely(embedding_list)
266
-
267
- if embedding.size == 0:
268
- session_logger.warning(f"Chunk {i} has invalid embedding")
269
- failed_chunks += 1
270
- continue
271
-
272
- # Create processed chunk
273
- processed_chunk = {
274
- 'id': chunk_doc.get('chunk_id', f'chunk_{i}'),
275
- 'text': chunk_doc['text'],
276
- 'title': chunk_doc.get('title', session_doc.get('filename', 'Document')),
277
- 'section_type': chunk_doc.get('section_type', 'content'),
278
- 'importance_score': float(chunk_doc.get('importance_score', 1.0)),
279
- 'entities': chunk_doc.get('entities', []),
280
- 'embedding': embedding
281
- }
282
-
283
- processed_chunks.append(processed_chunk)
284
- embeddings_matrix.append(embedding)
285
-
286
- except Exception as chunk_error:
287
- session_logger.error(f"Failed to process chunk {i}: {chunk_error}")
288
- failed_chunks += 1
289
- continue
290
-
291
- if not processed_chunks:
292
- raise ValueError(f"No valid chunks could be loaded (failed: {failed_chunks})")
293
-
294
- if failed_chunks > 0:
295
- session_logger.warning(f"Failed to load {failed_chunks} chunks, continuing with {len(processed_chunks)}")
296
-
297
- # Create embeddings matrix
298
- embeddings_matrix = np.vstack(embeddings_matrix).astype('float32')
299
-
300
- # Prepare session store
301
- session_store = {
302
- "chunks": processed_chunks,
303
- "embeddings_matrix": embeddings_matrix,
304
- "faiss_index": None,
305
- "indexed": False,
306
- "metadata": {
307
- "session_id": session_id,
308
- "title": session_doc.get("filename", "Document"),
309
- "chunk_count": len(processed_chunks),
310
- "failed_chunks": failed_chunks,
311
- "loaded_at": datetime.utcnow(),
312
- "document_info": {
313
- "filename": session_doc.get("filename", "Unknown"),
314
- "text_length": session_doc.get("text_length", 0),
315
- "word_count": session_doc.get("word_count", 0),
316
- "file_size": session_doc.get("file_size", 0),
317
- "processing_completed_at": session_doc.get("processing_completed_at")
318
- }
319
- }
320
  }
321
-
322
- session_logger.info(f"Session loaded successfully: {len(processed_chunks)} chunks")
323
- return session_store
324
-
325
- except Exception as e:
326
- session_logger.error(f"Failed to load session: {e}")
327
- session_logger.error(traceback.format_exc())
328
- raise
329
 
330
- def build_faiss_index_safely(session_id: str) -> Dict[str, Any]:
331
- """Build FAISS index with error handling"""
332
  session_logger = create_session_logger(session_id)
333
-
334
- if not FAISS_AVAILABLE:
335
- raise ValueError("FAISS library not available")
336
-
337
  with STORE_LOCK:
338
- if session_id not in SESSION_STORES:
339
- raise ValueError(f"Session {session_id} not loaded")
340
-
341
  store = SESSION_STORES[session_id]
342
- if store["indexed"]:
343
- session_logger.info("Session already indexed")
344
- return store["metadata"]
345
-
346
- chunks = store["chunks"]
347
- embeddings_matrix = store["embeddings_matrix"]
348
-
349
- try:
350
- session_logger.info(f"Building FAISS index for {len(chunks)} chunks...")
351
-
352
- # Validate embeddings matrix
353
- if embeddings_matrix.shape[0] != len(chunks):
354
- raise ValueError("Embeddings matrix size mismatch with chunks")
355
-
356
- # Create FAISS index
357
- dimension = embeddings_matrix.shape[1]
358
- faiss_index = faiss.IndexFlatIP(dimension)
359
- faiss_index.add(embeddings_matrix)
360
-
361
- # Initialize RAG system components
362
- if RAG_MODULE:
363
- RAG_MODULE.CHUNKS_DATA = chunks
364
- RAG_MODULE.DENSE_INDEX = faiss_index
365
-
366
- # Build additional indices
367
- session_logger.info("Building additional retrieval indices...")
368
-
369
- try:
370
- # BM25 index
371
- tokenized_corpus = [chunk['text'].lower().split() for chunk in chunks]
372
- RAG_MODULE.BM25_INDEX = RAG_MODULE.BM25Okapi(tokenized_corpus)
373
-
374
- # Token index
375
- RAG_MODULE.TOKEN_TO_CHUNKS = defaultdict(set)
376
- for i, chunk in enumerate(chunks):
377
- tokens = chunk['text'].lower().split()
378
- for token in tokens:
379
- RAG_MODULE.TOKEN_TO_CHUNKS[token].add(i)
380
-
381
- # Concept graph
382
- import networkx as nx
383
- RAG_MODULE.CONCEPT_GRAPH = nx.Graph()
384
- for i, chunk in enumerate(chunks):
385
- RAG_MODULE.CONCEPT_GRAPH.add_node(
386
- i,
387
- text=chunk['text'][:200],
388
- importance=chunk['importance_score']
389
- )
390
-
391
- # Add edges for shared entities
392
- for j, other_chunk in enumerate(chunks[i+1:], i+1):
393
- shared_entities = set(e.get('text', '') for e in chunk['entities']) & \
394
- set(e.get('text', '') for e in other_chunk['entities'])
395
- if shared_entities:
396
- RAG_MODULE.CONCEPT_GRAPH.add_edge(i, j, weight=len(shared_entities))
397
-
398
- except Exception as index_error:
399
- session_logger.warning(f"Failed to build some retrieval indices: {index_error}")
400
-
401
- # Mark as indexed
402
- with STORE_LOCK:
403
- SESSION_STORES[session_id]["faiss_index"] = faiss_index
404
- SESSION_STORES[session_id]["indexed"] = True
405
-
406
- session_logger.info("FAISS index built successfully")
407
- return SESSION_STORES[session_id]["metadata"]
408
-
409
- except Exception as e:
410
- session_logger.error(f"Failed to build FAISS index: {e}")
411
- session_logger.error(traceback.format_exc())
412
- raise
413
 
414
- def save_chat_message_safely(session_id: str, role: str, message: str):
415
- """Save chat message with error handling"""
416
- if not DB:
417
- logger.warning("Database not available - chat message not saved")
418
- return
419
 
420
- try:
421
- chat_doc = {
422
- "session_id": session_id,
423
- "role": role,
424
- "message": message,
425
- "created_at": datetime.utcnow()
426
- }
427
- DB.chats.insert_one(chat_doc)
428
- except Exception as e:
429
- logger.error(f"Failed to save chat message for session {session_id}: {e}")
430
 
431
  def get_chat_history_safely(session_id: str, limit: int = 50) -> List[Dict[str, Any]]:
432
- """Get chat history with error handling"""
433
- if not DB:
434
- return []
435
-
436
  try:
437
- chats_cursor = DB.chats.find(
438
- {"session_id": session_id}
439
- ).sort("created_at", 1).limit(limit)
440
-
441
- chat_history = []
442
- for chat_doc in chats_cursor:
443
- chat_history.append({
444
- "role": chat_doc["role"],
445
- "message": chat_doc["message"],
446
- "timestamp": chat_doc["created_at"].isoformat()
447
- })
448
-
449
- return chat_history
450
-
451
  except Exception as e:
452
  logger.error(f"Failed to get chat history for session {session_id}: {e}")
453
  return []
454
 
 
 
 
 
 
 
 
 
 
 
 
455
  def cleanup_expired_sessions():
456
- """Clean up only expired chat sessions from memory, keep server running"""
457
- try:
458
- current_time = datetime.utcnow()
459
- expired_sessions = []
460
-
461
- with STORE_LOCK:
462
- for session_id, store in SESSION_STORES.items():
463
- loaded_at = store["metadata"]["loaded_at"]
464
- age_seconds = (current_time - loaded_at).total_seconds()
465
-
466
- # Only expire sessions older than TTL (30 minutes)
467
- if age_seconds > STORE_TTL:
468
- expired_sessions.append(session_id)
469
-
470
- # Clean up expired sessions one by one
471
- for session_id in expired_sessions:
472
- try:
473
- store = SESSION_STORES[session_id]
474
-
475
- # Clean up session-specific RAG instance
476
- if "rag_instance" in store:
477
- store["rag_instance"].cleanup()
478
-
479
- # Clean up FAISS index
480
- if store.get("faiss_index"):
481
- del store["faiss_index"]
482
-
483
- # Remove session from memory
484
- del SESSION_STORES[session_id]
485
-
486
- age_minutes = (current_time - store["metadata"]["loaded_at"]).total_seconds() / 60
487
- logger.info(f"Expired session {session_id[:8]} removed from memory (age: {age_minutes:.1f} minutes)")
488
-
489
- except Exception as cleanup_error:
490
- logger.error(f"Error cleaning up session {session_id[:8]}: {cleanup_error}")
491
-
492
- # Update active session count
493
- APP_STATE["active_sessions"] = len(SESSION_STORES)
494
-
495
- if expired_sessions:
496
- logger.info(f"Memory cleanup completed: {len(expired_sessions)} expired sessions removed, {len(SESSION_STORES)} sessions still active")
497
- else:
498
- logger.debug(f"No expired sessions found. {len(SESSION_STORES)} sessions still active in memory")
499
-
500
- except Exception as e:
501
- logger.error(f"Session cleanup error: {e}")
502
- logger.error(traceback.format_exc())
503
 
504
  async def periodic_cleanup():
505
- """Periodic cleanup of expired sessions - keeps server running"""
506
- cleanup_count = 0
507
- try:
508
- while True:
509
- cleanup_count += 1
510
- logger.debug(f"Running session cleanup cycle #{cleanup_count}")
511
-
512
- cleanup_expired_sessions()
513
-
514
- # Sleep for cleanup interval (30 minutes)
515
- await asyncio.sleep(CLEANUP_INTERVAL)
516
-
517
- except asyncio.CancelledError:
518
- logger.info(f"Session cleanup task cancelled after {cleanup_count} cycles")
519
- raise
520
- except Exception as e:
521
- logger.error(f"Periodic cleanup error in cycle #{cleanup_count}: {e}")
522
- logger.error(traceback.format_exc())
523
-
524
- # Don't break the loop - keep trying to clean up
525
- await asyncio.sleep(60) # Wait 1 minute before retrying
526
 
527
- # Global cleanup task
528
- cleanup_task = None
529
 
 
530
  @asynccontextmanager
531
  async def lifespan(app: FastAPI):
532
- """Application lifespan with comprehensive error handling"""
533
- global cleanup_task
534
-
535
- # Startup
536
- logger.info("Starting Advanced RAG Chat Service...")
537
  APP_STATE["startup_time"] = datetime.utcnow()
538
-
539
- startup_success = True
540
-
541
- # Check FAISS availability
542
- if not FAISS_AVAILABLE:
543
- logger.error("FAISS library not available - this is required for RAG functionality")
544
- startup_success = False
545
-
546
- # Import RAG module
547
- if not safe_import_rag():
548
- logger.error("RAG module import failed")
549
- startup_success = False
550
-
551
- # Connect to MongoDB (non-critical failure)
552
- if not connect_mongodb():
553
- logger.error("MongoDB connection failed - continuing with limited functionality")
554
-
555
- # Initialize RAG system (non-critical failure for basic health checks)
556
- if RAG_MODULE and FAISS_AVAILABLE:
557
- if not initialize_rag():
558
- logger.error("RAG initialization failed - RAG features disabled")
559
-
560
- # Start cleanup task if MongoDB is available
561
- if APP_STATE["mongodb_connected"]:
562
- try:
563
- cleanup_task = asyncio.create_task(periodic_cleanup())
564
- logger.info("Background cleanup task started")
565
- except Exception as e:
566
- logger.error(f"Failed to start cleanup task: {e}")
567
-
568
- if startup_success:
569
- logger.info("Startup completed successfully")
570
- else:
571
- logger.warning("Startup completed with errors - some features may be disabled")
572
 
573
  yield
574
 
575
- # Shutdown
576
  logger.info("Shutting down...")
577
-
578
  if cleanup_task:
579
  cleanup_task.cancel()
580
- try:
581
- await cleanup_task
582
- except asyncio.CancelledError:
583
- pass
584
-
585
  if MONGO_CLIENT:
586
  MONGO_CLIENT.close()
587
-
588
- logger.info("Shutdown completed")
589
 
590
- # Initialize FastAPI app
 
591
  app = FastAPI(
592
  title="Advanced RAG Chat Service",
593
- description="Robust RAG-based chat service with comprehensive error handling",
594
- version="2.0.0",
595
  lifespan=lifespan
596
  )
597
 
598
- # CORS configuration
599
  app.add_middleware(
600
  CORSMiddleware,
601
- allow_origins=["*"],
602
- allow_credentials=True,
603
- allow_methods=["*"],
604
- allow_headers=["*"],
605
  )
606
 
607
- # Root endpoint
608
  @app.get("/")
609
  async def root():
610
- """Service information endpoint"""
611
- uptime = (datetime.utcnow() - APP_STATE["startup_time"]).total_seconds() if APP_STATE["startup_time"] else 0
612
-
613
- return {
614
- "service": "Advanced RAG Chat Service",
615
- "version": "2.0.0",
616
- "status": "running",
617
- "uptime_seconds": uptime,
618
- "components": {
619
- "mongodb": APP_STATE["mongodb_connected"],
620
- "rag_system": APP_STATE["rag_ready"],
621
- "faiss": FAISS_AVAILABLE
622
- },
623
- "active_sessions": len(SESSION_STORES),
624
- "total_queries": APP_STATE["total_queries"],
625
- "endpoints": {
626
- "health": "GET /health",
627
- "init": "POST /init/{session_id}",
628
- "chat": "POST /chat/{session_id}",
629
- "history": "GET /history/{session_id}",
630
- "cleanup": "DELETE /session/{session_id}",
631
- "status": "GET /sessions/active"
632
- }
633
- }
634
 
635
  @app.get("/health", response_model=HealthResponse)
636
  async def health_check():
637
- """Comprehensive health check"""
638
- try:
639
- # Test MongoDB connection
640
- mongodb_connected = False
641
- if DB:
642
- try:
643
- DB.command("ping")
644
- mongodb_connected = True
645
- except:
646
- pass
647
-
648
- # Calculate uptime
649
- uptime = 0
650
- if APP_STATE["startup_time"]:
651
- uptime = (datetime.utcnow() - APP_STATE["startup_time"]).total_seconds()
652
-
653
- # Memory usage
654
- with STORE_LOCK:
655
- memory_sessions = len(SESSION_STORES)
656
- indexed_sessions = sum(1 for store in SESSION_STORES.values() if store["indexed"])
657
-
658
- # Overall status
659
- status = "healthy"
660
- if not FAISS_AVAILABLE:
661
- status = "degraded"
662
- elif not mongodb_connected and not RAG_INITIALIZED:
663
- status = "unhealthy"
664
-
665
- last_error = APP_STATE["errors"][-1] if APP_STATE["errors"] else None
666
-
667
- return HealthResponse(
668
- status=status,
669
- mongodb_connected=mongodb_connected,
670
- rag_initialized=RAG_INITIALIZED,
671
- faiss_available=FAISS_AVAILABLE,
672
- active_sessions=memory_sessions,
673
- memory_usage={
674
- "loaded_sessions": memory_sessions,
675
- "indexed_sessions": indexed_sessions,
676
- "store_ttl_minutes": STORE_TTL // 60,
677
- "cleanup_interval_minutes": CLEANUP_INTERVAL // 60
678
- },
679
- uptime_seconds=uptime,
680
- last_error=last_error
681
- )
682
-
683
- except Exception as e:
684
- logger.error(f"Health check failed: {e}")
685
- return HealthResponse(
686
- status="unhealthy",
687
- mongodb_connected=False,
688
- rag_initialized=False,
689
- faiss_available=False,
690
- active_sessions=0,
691
- memory_usage={},
692
- uptime_seconds=0,
693
- last_error=str(e)
694
- )
695
 
696
  @app.post("/init/{session_id}", response_model=InitResponse)
697
  async def initialize_session(session_id: str, request: InitRequest):
698
- """Initialize session with comprehensive validation"""
699
  session_logger = create_session_logger(session_id)
700
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
701
  try:
702
- # Validate prerequisites
703
- if not DB:
704
- raise HTTPException(status_code=503, detail="Database not connected")
705
-
706
- if not RAG_INITIALIZED:
707
- raise HTTPException(status_code=503, detail="RAG system not initialized")
708
-
709
- if not FAISS_AVAILABLE:
710
- raise HTTPException(status_code=503, detail="FAISS library not available")
711
-
712
- # Check if already initialized
713
- with STORE_LOCK:
714
- if session_id in SESSION_STORES and SESSION_STORES[session_id]["indexed"] and not request.force_reload:
715
- store = SESSION_STORES[session_id]
716
- metadata = store["metadata"]
717
- session_logger.info("Session already initialized")
718
- return InitResponse(
719
- success=True,
720
- session_id=session_id,
721
- message="Session already initialized",
722
- chunk_count=metadata["chunk_count"],
723
- title=metadata["title"],
724
- document_info=metadata["document_info"]
725
- )
726
-
727
- session_logger.info("Initializing session...")
728
 
729
- # Load session from MongoDB
730
- session_store = load_session_from_mongodb(session_id)
731
-
732
- # Store in memory
733
  with STORE_LOCK:
734
  SESSION_STORES[session_id] = session_store
735
- APP_STATE["active_sessions"] = len(SESSION_STORES)
736
-
737
- # Build FAISS index
738
- metadata = build_faiss_index_safely(session_id)
739
-
740
- session_logger.info(f"Session initialized: {metadata['chunk_count']} chunks ready")
741
 
 
742
  return InitResponse(
743
  success=True,
744
  session_id=session_id,
745
- message=f"Session initialized successfully with {metadata['chunk_count']} chunks",
746
- chunk_count=metadata["chunk_count"],
747
- title=metadata["title"],
748
- document_info=metadata["document_info"]
749
- )
750
-
751
- except HTTPException:
752
- raise
753
- except ValueError as e:
754
- session_logger.error(f"Session initialization validation error: {e}")
755
- return InitResponse(
756
- success=False,
757
- session_id=session_id,
758
- message="Session initialization failed",
759
- chunk_count=0,
760
- title="Error",
761
- error_details=str(e)
762
  )
763
  except Exception as e:
764
- session_logger.error(f"Session initialization error: {e}")
765
- session_logger.error(traceback.format_exc())
766
- APP_STATE["errors"].append(f"Init failed for {session_id[:8]}: {str(e)}")
767
- return InitResponse(
768
- success=False,
769
- session_id=session_id,
770
- message="Internal server error during initialization",
771
- chunk_count=0,
772
- title="Error",
773
- error_details="Internal server error"
774
- )
775
 
776
  @app.post("/chat/{session_id}", response_model=ChatResponse)
777
  async def chat_with_document(session_id: str, request: ChatRequest):
778
- """Chat endpoint with comprehensive error handling"""
779
  session_logger = create_session_logger(session_id)
780
  start_time = time.time()
781
 
782
  try:
783
- # Validate prerequisites
784
- if not DB:
785
- raise HTTPException(status_code=503, detail="Database not connected")
786
-
787
- if not RAG_INITIALIZED or not RAG_MODULE:
788
- raise HTTPException(status_code=503, detail="RAG system not initialized")
789
-
790
- # Validate session
791
  with STORE_LOCK:
792
- if session_id not in SESSION_STORES:
793
- raise HTTPException(
794
- status_code=400,
795
- detail=f"Session not initialized. Call /init/{session_id} first."
796
- )
797
-
798
- if not SESSION_STORES[session_id]["indexed"]:
799
- raise HTTPException(
800
- status_code=400,
801
- detail="Session not indexed properly. Try reinitializing."
802
- )
803
-
804
  session_logger.info(f"Processing query: {request.message[:100]}...")
 
 
 
805
 
806
- # Query RAG system
807
- try:
808
- result = RAG_MODULE.query_documents(request.message, top_k=5)
809
- APP_STATE["total_queries"] += 1
810
- except Exception as rag_error:
811
- session_logger.error(f"RAG query failed: {rag_error}")
812
- result = {
813
- 'error': f'RAG processing failed: {str(rag_error)}',
814
- 'answer': 'I apologize, but I encountered an error while processing your question. Please try again or rephrase your query.',
815
- 'sources': [],
816
- 'query_analysis': {},
817
- 'confidence': 0.0
818
- }
819
-
820
- if 'error' in result and not result.get('answer'):
821
- raise HTTPException(status_code=500, detail=result['error'])
822
-
823
- answer = result.get('answer', 'Unable to generate answer.')
824
- sources = result.get('sources', [])
825
- query_analysis = result.get('query_analysis', {})
826
- confidence = result.get('confidence', 0.0)
827
-
828
- # Save chat messages
829
- save_chat_message_safely(session_id, "user", request.message)
830
- save_chat_message_safely(session_id, "assistant", answer)
831
 
832
- # Get chat history
833
- chat_history = get_chat_history_safely(session_id)
 
834
 
835
  processing_time = time.time() - start_time
836
- session_logger.info(f"Query processed in {processing_time:.2f}s, confidence: {confidence:.1f}%")
837
-
838
- # Format sources
839
- formatted_sources = []
840
- for source in sources:
841
- try:
842
- formatted_source = {
843
- "chunk_id": source.get("chunk_id", ""),
844
- "title": source.get("title", ""),
845
- "section": source.get("section", ""),
846
- "relevance_score": float(source.get("relevance_score", 0.0)),
847
- "text_preview": source.get("excerpt", "")[:300] + ("..." if len(source.get("excerpt", "")) > 300 else ""),
848
- "entities": source.get("entities", [])
849
- }
850
- formatted_sources.append(formatted_source)
851
- except Exception as source_error:
852
- session_logger.warning(f"Failed to format source: {source_error}")
853
 
854
  return ChatResponse(
855
  success=True,
856
  answer=answer,
857
- sources=formatted_sources,
858
- chat_history=chat_history,
859
  processing_time=processing_time,
860
  session_id=session_id,
861
- query_analysis=query_analysis,
862
- confidence=confidence
863
  )
864
-
865
- except HTTPException:
866
- raise
867
  except Exception as e:
868
- session_logger.error(f"Chat processing failed: {e}")
869
- session_logger.error(traceback.format_exc())
870
- APP_STATE["errors"].append(f"Chat failed for {session_id[:8]}: {str(e)}")
871
-
872
- return ChatResponse(
873
- success=False,
874
- answer="I apologize, but I encountered an error while processing your question. Please try again.",
875
- sources=[],
876
- chat_history=get_chat_history_safely(session_id),
877
- processing_time=time.time() - start_time,
878
- session_id=session_id,
879
- error_details="Internal server error"
880
- )
881
 
882
  @app.get("/history/{session_id}")
883
  async def get_session_history(session_id: str):
884
- """Get chat history for a session"""
885
- session_logger = create_session_logger(session_id)
886
-
887
  if not DB:
888
  raise HTTPException(status_code=503, detail="Database not connected")
889
-
890
- try:
891
- chat_history = get_chat_history_safely(session_id, limit=100)
892
-
893
- session_logger.info(f"Retrieved {len(chat_history)} chat messages")
894
-
895
- return {
896
- "success": True,
897
- "session_id": session_id,
898
- "chat_history": chat_history,
899
- "total_messages": len(chat_history)
900
- }
901
-
902
- except Exception as e:
903
- session_logger.error(f"Failed to get chat history: {e}")
904
- raise HTTPException(status_code=500, detail=f"Failed to retrieve chat history: {str(e)}")
905
 
906
  @app.delete("/session/{session_id}")
907
  async def cleanup_session(session_id: str):
908
- """Clean up session from memory"""
909
- session_logger = create_session_logger(session_id)
910
-
911
- try:
912
- cleaned_up = False
913
-
914
- with STORE_LOCK:
915
- if session_id in SESSION_STORES:
916
- # Clean up session-specific RAG instance
917
- store = SESSION_STORES[session_id]
918
- if "rag_instance" in store:
919
- try:
920
- # Clean up any resources in the RAG instance
921
- store["rag_instance"].cleanup()
922
- except:
923
- pass
924
-
925
- # Clean up FAISS index
926
- if store.get("faiss_index"):
927
- del store["faiss_index"]
928
-
929
- del SESSION_STORES[session_id]
930
- APP_STATE["active_sessions"] = len(SESSION_STORES)
931
- cleaned_up = True
932
- session_logger.info("Session removed from memory")
933
-
934
- if not cleaned_up:
935
- session_logger.info("Session not found in memory")
936
-
937
- return {
938
- "success": True,
939
- "message": f"Session {session_id} cleaned up successfully"
940
- }
941
-
942
- except Exception as e:
943
- session_logger.error(f"Session cleanup failed: {e}")
944
- raise HTTPException(status_code=500, detail=f"Failed to cleanup session: {str(e)}")
945
 
946
- @app.get("/sessions/active")
947
- async def get_active_sessions():
948
- """Get information about active sessions in memory with TTL info"""
949
  try:
950
- current_time = datetime.utcnow()
951
-
952
- with STORE_LOCK:
953
- active_sessions = []
954
- for session_id, store in SESSION_STORES.items():
955
- metadata = store["metadata"]
956
- loaded_at = metadata["loaded_at"]
957
- age_seconds = (current_time - loaded_at).total_seconds()
958
- remaining_seconds = STORE_TTL - age_seconds
959
-
960
- active_sessions.append({
961
- "session_id": session_id,
962
- "title": metadata["title"],
963
- "chunk_count": metadata["chunk_count"],
964
- "indexed": store["indexed"],
965
- "has_rag_instance": "rag_instance" in store,
966
- "loaded_at": loaded_at.isoformat(),
967
- "age_minutes": age_seconds / 60,
968
- "remaining_minutes": max(0, remaining_seconds / 60),
969
- "expires_at": (loaded_at + timedelta(seconds=STORE_TTL)).isoformat(),
970
- "will_expire_soon": remaining_seconds < 300, # Less than 5 minutes
971
- "failed_chunks": metadata.get("failed_chunks", 0)
972
- })
973
-
974
- # Sort by remaining time (expiring soon first)
975
- active_sessions.sort(key=lambda x: x["remaining_minutes"])
976
-
977
- return {
978
- "success": True,
979
- "active_sessions": active_sessions,
980
- "total_sessions": len(active_sessions),
981
- "session_ttl_minutes": STORE_TTL / 60,
982
- "cleanup_interval_minutes": CLEANUP_INTERVAL / 60,
983
- "next_cleanup_in_minutes": CLEANUP_INTERVAL / 60 # Approximate
984
- }
985
-
986
- except Exception as e:
987
- logger.error(f"Failed to get active sessions: {e}")
988
- raise HTTPException(status_code=500, detail=f"Failed to get active sessions: {str(e)}")
989
-
990
- @app.post("/sessions/{session_id}/extend")
991
- async def extend_session_ttl(session_id: str):
992
- """Extend a session's TTL by resetting its load time (keep it alive longer)"""
993
- session_logger = create_session_logger(session_id)
994
-
995
- try:
996
- with STORE_LOCK:
997
- if session_id not in SESSION_STORES:
998
- raise HTTPException(status_code=404, detail="Session not found in memory")
999
-
1000
- # Reset the loaded_at timestamp to extend TTL
1001
- old_loaded_at = SESSION_STORES[session_id]["metadata"]["loaded_at"]
1002
- SESSION_STORES[session_id]["metadata"]["loaded_at"] = datetime.utcnow()
1003
-
1004
- session_logger.info(f"Session TTL extended (was loaded at: {old_loaded_at.isoformat()})")
1005
-
1006
- return {
1007
- "success": True,
1008
- "message": f"Session {session_id} TTL extended for another {STORE_TTL//60} minutes",
1009
- "new_expiry": (datetime.utcnow() + timedelta(seconds=STORE_TTL)).isoformat()
1010
- }
1011
-
1012
- except HTTPException:
1013
- raise
1014
- except Exception as e:
1015
- session_logger.error(f"Failed to extend session TTL: {e}")
1016
- raise HTTPException(status_code=500, detail=f"Failed to extend session TTL: {str(e)}")
1017
-
1018
- @app.post("/cleanup/run")
1019
- async def manual_cleanup():
1020
- """Manually trigger cleanup of expired sessions"""
1021
- try:
1022
- before_count = len(SESSION_STORES)
1023
- cleanup_expired_sessions()
1024
- after_count = len(SESSION_STORES)
1025
- cleaned_count = before_count - after_count
1026
-
1027
- return {
1028
- "success": True,
1029
- "message": f"Manual cleanup completed",
1030
- "sessions_before": before_count,
1031
- "sessions_after": after_count,
1032
- "sessions_cleaned": cleaned_count
1033
- }
1034
-
1035
- except Exception as e:
1036
- logger.error(f"Manual cleanup failed: {e}")
1037
- raise HTTPException(status_code=500, detail=f"Manual cleanup failed: {str(e)}")
1038
-
1039
- @app.get("/rag/status")
1040
- async def get_rag_status():
1041
- """Get RAG system status"""
1042
- try:
1043
- return {
1044
- "success": True,
1045
- "rag_initialized": RAG_INITIALIZED,
1046
- "faiss_available": FAISS_AVAILABLE,
1047
- "concurrency": {
1048
- "session_isolated_rag": True,
1049
- "async_processing": True,
1050
- "thread_pool_execution": True,
1051
- "no_global_state_conflicts": True
1052
- },
1053
- "optimization": {
1054
- "precomputed_embeddings": True,
1055
- "persistent_faiss_index": True,
1056
- "mongodb_persistence": True,
1057
- "memory_cleanup": True
1058
- },
1059
- "features": {
1060
- "multi_stage_retrieval": True,
1061
- "dense_retrieval": "FAISS + Session-Isolated Embeddings",
1062
- "sparse_retrieval": "BM25 per Session",
1063
- "entity_based_retrieval": "Legal NER + SpaCy",
1064
- "graph_based_retrieval": "Legal Concept Graph per Session",
1065
- "query_analysis": "Legal Intent Classification",
1066
- "answer_generation": "Groq LLM with IRAC Method"
1067
- },
1068
- "active_sessions": len(SESSION_STORES),
1069
- "total_queries_processed": APP_STATE["total_queries"]
1070
- }
1071
-
1072
  except Exception as e:
1073
- logger.error(f"Failed to get RAG status: {e}")
1074
- raise HTTPException(status_code=500, detail=f"Failed to get RAG status: {str(e)}")
1075
 
1076
  if __name__ == "__main__":
1077
  import uvicorn
1078
  port = int(os.getenv("PORT", 7861))
1079
- logger.info(f"Starting server on port {port}")
1080
- uvicorn.run(app, host="0.0.0.0", port=port)
 
 
 
 
 
 
 
 
 
 
 
1
  import asyncio
2
+ import concurrent.futures
3
+ import logging
4
+ import os
5
+ import sys
6
  import threading
7
  import time
8
+ import traceback
9
  from collections import defaultdict
10
  from contextlib import asynccontextmanager
11
+ from datetime import datetime, timedelta
12
+ from functools import partial
13
+ from typing import Any, Dict, List, Optional
14
+
15
+ import numpy as np
16
+ import pymongo
17
+ from fastapi import FastAPI, HTTPException
18
+ from fastapi.middleware.cors import CORSMiddleware
19
+ from pydantic import BaseModel, Field
20
 
21
  try:
22
  import faiss
 
24
  except ImportError:
25
  FAISS_AVAILABLE = False
26
 
27
+ try:
28
+ # We import the SessionRAG class and the model initializer.
29
+ from rag import SessionRAG, initialize_models
30
+ RAG_AVAILABLE = True
31
+ except ImportError:
32
+ RAG_AVAILABLE = False
33
+
34
  # Configure comprehensive logging
35
  logging.basicConfig(
36
  level=logging.INFO,
 
42
  )
43
  logger = logging.getLogger(__name__)
44
 
45
+ # --- Global State ---
46
  MONGO_CLIENT = None
47
  DB = None
48
+ RAG_MODELS_INITIALIZED = False
 
49
  APP_STATE = {
50
  "startup_time": None,
51
  "mongodb_connected": False,
52
+ "rag_models_ready": False,
53
  "active_sessions": 0,
54
  "total_queries": 0,
55
  "errors": []
56
  }
57
 
58
+ # --- In-memory Session Management & Threading ---
59
+ SESSION_STORES = {}
60
+ STORE_LOCK = threading.RLock()
61
+ CLEANUP_INTERVAL = 1800 # 30 minutes
62
+ STORE_TTL = 1800 # 30 minutes
63
+ # Note: asyncio.to_thread() uses its own internal thread pool, so a dedicated
64
+ # global pool is not strictly necessary unless fine-grained control is needed.
65
 
66
+ # --- Pydantic Models ---
 
 
 
 
 
67
  class ChatRequest(BaseModel):
68
+ message: str = Field(..., min_length=1, max_length=5000)
69
 
70
  class ChatResponse(BaseModel):
71
  success: bool
 
79
  error_details: Optional[str] = None
80
 
81
  class InitRequest(BaseModel):
82
+ force_reload: bool = Field(default=False)
83
 
84
  class InitResponse(BaseModel):
85
  success: bool
 
93
  class HealthResponse(BaseModel):
94
  status: str
95
  mongodb_connected: bool
96
+ rag_models_initialized: bool
97
  faiss_available: bool
98
  active_sessions: int
99
  memory_usage: Dict[str, Any]
100
  uptime_seconds: float
101
  last_error: Optional[str] = None
102
 
103
+
104
+ # --- Helper Functions ---
105
  def create_session_logger(session_id: str):
 
106
  return logging.LoggerAdapter(logger, {'session_id': session_id[:8]})
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  def connect_mongodb():
 
109
  global MONGO_CLIENT, DB
 
110
  try:
111
  mongodb_url = os.getenv("MONGODB_URL", "mongodb://localhost:27017/")
112
+ logger.info(f"Connecting to MongoDB...")
 
 
 
113
  MONGO_CLIENT = pymongo.MongoClient(
114
+ mongodb_url, serverSelectionTimeoutMS=5000
 
 
 
115
  )
 
 
116
  MONGO_CLIENT.admin.command('ping')
117
  DB = MONGO_CLIENT["legal_rag_system"]
118
+ DB.chats.create_index("session_id", background=True)
119
+ DB.chats.create_index("created_at", expireAfterSeconds=24 * 60 * 60, background=True)
 
 
 
 
 
 
 
 
 
120
  APP_STATE["mongodb_connected"] = True
121
+ logger.info("MongoDB connected successfully")
122
  return True
 
 
 
 
 
 
 
123
  except Exception as e:
124
  logger.error(f"MongoDB connection failed: {e}")
 
125
  return False
126
 
127
+ def init_rag_models():
128
+ """Initializes the shared, stateless RAG models once at startup."""
129
+ global RAG_MODELS_INITIALIZED
130
+ if not RAG_AVAILABLE or not FAISS_AVAILABLE:
131
+ logger.error("RAG module or FAISS not available - cannot initialize models.")
 
 
 
 
 
132
  return False
 
133
  try:
134
  model_id = os.getenv("EMBEDDING_MODEL_ID", "sentence-transformers/all-MiniLM-L6-v2")
135
  groq_api_key = os.getenv("GROQ_API_KEY")
136
+ logger.info(f"Initializing shared RAG models with embedding model: {model_id}")
137
+ initialize_models(model_id, groq_api_key)
138
+ RAG_MODELS_INITIALIZED = True
139
+ APP_STATE["rag_models_ready"] = True
140
+ logger.info("Shared RAG models initialized successfully")
 
 
 
 
 
 
 
 
 
141
  return True
 
 
 
 
142
  except Exception as e:
143
+ logger.error(f"RAG model initialization failed: {e}", exc_info=True)
 
144
  APP_STATE["errors"].append(f"RAG init failed: {str(e)}")
145
  return False
146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  def load_session_from_mongodb(session_id: str) -> Dict[str, Any]:
148
+ """Loads session data, processes chunks, and creates a new SessionRAG instance."""
149
  session_logger = create_session_logger(session_id)
 
150
  if not DB:
151
  raise ValueError("Database not connected")
152
+
153
+ session_doc = DB.sessions.find_one({"session_id": session_id})
154
+ if not session_doc:
155
+ raise ValueError(f"Session {session_id} not found in database")
156
+ if session_doc.get("status") != "completed":
157
+ raise ValueError(f"Session not ready - status: {session_doc.get('status')}")
158
+
159
+ session_logger.info(f"Loading chunks for: {session_doc.get('filename', 'unknown')}")
160
+ chunks_list = list(DB.chunks.find({"session_id": session_id}).sort("created_at", 1))
161
+ if not chunks_list:
162
+ raise ValueError(f"No chunks found for session {session_id}")
163
+
164
+ groq_api_key = os.getenv("GROQ_API_KEY")
165
+ session_rag = SessionRAG(session_id, groq_api_key)
166
+
167
+ # The SessionRAG instance processes the raw DB chunks and stores them internally
168
+ processed_chunks = session_rag.process_db_chunks(chunks_list)
169
+
170
+ session_store = {
171
+ "session_rag": session_rag,
172
+ "indexed": False,
173
+ "metadata": {
174
+ "session_id": session_id,
175
+ "title": session_doc.get("filename", "Document"),
176
+ "chunk_count": len(processed_chunks),
177
+ "loaded_at": datetime.utcnow(),
178
+ "document_info": {"filename": session_doc.get("filename", "Unknown")}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  }
180
+ }
181
+ session_logger.info(f"Session loaded and SessionRAG instance created.")
182
+ return session_store
 
 
 
 
 
183
 
184
+ def build_indices_for_session(session_id: str) -> Dict[str, Any]:
185
+ """Builds all indices for a session using its SessionRAG instance."""
186
  session_logger = create_session_logger(session_id)
 
 
 
 
187
  with STORE_LOCK:
 
 
 
188
  store = SESSION_STORES[session_id]
189
+ session_rag = store["session_rag"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
+ session_logger.info(f"Building indices for {store['metadata']['chunk_count']} chunks...")
192
+ session_rag.build_all_indices(session_rag.chunks_data) # Correctly passing chunks
 
 
 
193
 
194
+ with STORE_LOCK:
195
+ SESSION_STORES[session_id]["indexed"] = True
196
+
197
+ session_logger.info("All session-specific indices built successfully")
198
+ return store["metadata"]
 
 
 
 
 
199
 
200
  def get_chat_history_safely(session_id: str, limit: int = 50) -> List[Dict[str, Any]]:
201
+ """Get chat history with error handling."""
202
+ if not DB: return []
 
 
203
  try:
204
+ chats_cursor = DB.chats.find({"session_id": session_id}).sort("created_at", -1).limit(limit)
205
+ # Reverse to get chronological order [oldest -> newest]
206
+ return list(chats_cursor)[::-1]
 
 
 
 
 
 
 
 
 
 
 
207
  except Exception as e:
208
  logger.error(f"Failed to get chat history for session {session_id}: {e}")
209
  return []
210
 
211
+ # --- Session Cleanup ---
212
+ def cleanup_session_resources(session_id: str):
213
+ """Safely cleans up a session's resources."""
214
+ with STORE_LOCK:
215
+ if session_id in SESSION_STORES:
216
+ store = SESSION_STORES.pop(session_id) # Atomically get and remove
217
+ session_rag = store.get("session_rag")
218
+ if hasattr(session_rag, 'cleanup'):
219
+ session_rag.cleanup()
220
+ logger.info(f"Cleaned up session from memory: {session_id[:8]}")
221
+
222
  def cleanup_expired_sessions():
223
+ """Finds and cleans up expired sessions."""
224
+ now = datetime.utcnow()
225
+ # Create a snapshot of items to avoid issues with modifying dict while iterating
226
+ expired_ids = [
227
+ sid for sid, store in list(SESSION_STORES.items())
228
+ if (now - store["metadata"]["loaded_at"]).total_seconds() > STORE_TTL
229
+ ]
230
+ if expired_ids:
231
+ logger.info(f"Found {len(expired_ids)} expired sessions to clean up.")
232
+ for sid in expired_ids:
233
+ cleanup_session_resources(sid)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
  async def periodic_cleanup():
236
+ while True:
237
+ await asyncio.sleep(CLEANUP_INTERVAL)
238
+ logger.info("Running periodic cleanup of expired sessions...")
239
+ cleanup_expired_sessions()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
 
 
 
241
 
242
+ # --- Application Lifespan ---
243
  @asynccontextmanager
244
  async def lifespan(app: FastAPI):
245
+ # --- FIX APPLIED: Variable initialized to None ---
246
+ cleanup_task = None
 
 
 
247
  APP_STATE["startup_time"] = datetime.utcnow()
248
+ logger.info("Starting Advanced RAG Chat Service...")
249
+
250
+ connect_mongodb()
251
+ init_rag_models()
252
+
253
+ cleanup_task = asyncio.create_task(periodic_cleanup())
254
+ logger.info("Background session cleanup task started.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
 
256
  yield
257
 
 
258
  logger.info("Shutting down...")
 
259
  if cleanup_task:
260
  cleanup_task.cancel()
 
 
 
 
 
261
  if MONGO_CLIENT:
262
  MONGO_CLIENT.close()
263
+ logger.info("Shutdown complete.")
 
264
 
265
+
266
+ # --- FastAPI App and Endpoints ---
267
  app = FastAPI(
268
  title="Advanced RAG Chat Service",
269
+ description="A robust, session-isolated RAG chat service.",
270
+ version="3.1.0",
271
  lifespan=lifespan
272
  )
273
 
 
274
  app.add_middleware(
275
  CORSMiddleware,
276
+ allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"],
 
 
 
277
  )
278
 
 
279
  @app.get("/")
280
  async def root():
281
+ return {"service": "Advanced RAG Chat Service", "version": "3.1.0"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
 
283
  @app.get("/health", response_model=HealthResponse)
284
  async def health_check():
285
+ uptime = (datetime.utcnow() - APP_STATE["startup_time"]).total_seconds()
286
+ with STORE_LOCK:
287
+ active_sessions = len(SESSION_STORES)
288
+ indexed_sessions = sum(1 for s in SESSION_STORES.values() if s["indexed"])
289
+
290
+ status = "healthy"
291
+ if not RAG_MODELS_INITIALIZED or not APP_STATE["mongodb_connected"]:
292
+ status = "degraded"
293
+
294
+ return HealthResponse(
295
+ status=status,
296
+ mongodb_connected=APP_STATE["mongodb_connected"],
297
+ rag_models_initialized=RAG_MODELS_INITIALIZED, # Correct field name
298
+ faiss_available=FAISS_AVAILABLE,
299
+ active_sessions=active_sessions,
300
+ memory_usage={"loaded_sessions": active_sessions, "indexed_sessions": indexed_sessions},
301
+ uptime_seconds=uptime,
302
+ last_error=APP_STATE["errors"][-1] if APP_STATE["errors"] else None
303
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
 
305
  @app.post("/init/{session_id}", response_model=InitResponse)
306
  async def initialize_session(session_id: str, request: InitRequest):
307
+ """Initializes a session by loading data and building indices."""
308
  session_logger = create_session_logger(session_id)
309
 
310
+ if not RAG_MODELS_INITIALIZED:
311
+ raise HTTPException(status_code=503, detail="RAG models are not ready.")
312
+
313
+ with STORE_LOCK:
314
+ if session_id in SESSION_STORES and not request.force_reload:
315
+ metadata = SESSION_STORES[session_id]["metadata"]
316
+ session_logger.info("Session already initialized.")
317
+ return InitResponse(
318
+ success=True,
319
+ session_id=session_id,
320
+ message="Session already initialized",
321
+ # --- FIX APPLIED: Safe unpacking with .get() ---
322
+ chunk_count=metadata.get('chunk_count', 0),
323
+ title=metadata.get('title', 'Unknown Document'),
324
+ document_info=metadata.get('document_info')
325
+ )
326
+
327
  try:
328
+ session_logger.info("Loading session from database...")
329
+ # --- FIX APPLIED: Consistent threading model ---
330
+ session_store = await asyncio.to_thread(load_session_from_mongodb, session_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
 
 
 
 
 
332
  with STORE_LOCK:
333
  SESSION_STORES[session_id] = session_store
334
+
335
+ session_logger.info("Building session indices...")
336
+ metadata = await asyncio.to_thread(build_indices_for_session, session_id)
 
 
 
337
 
338
+ session_logger.info("Session initialized successfully.")
339
  return InitResponse(
340
  success=True,
341
  session_id=session_id,
342
+ message=f"Session initialized with {metadata.get('chunk_count', 0)} chunks.",
343
+ chunk_count=metadata.get('chunk_count', 0),
344
+ title=metadata.get('title', 'Unknown Document'),
345
+ document_info=metadata.get('document_info')
 
 
 
 
 
 
 
 
 
 
 
 
 
346
  )
347
  except Exception as e:
348
+ session_logger.error(f"Session initialization failed: {e}", exc_info=True)
349
+ cleanup_session_resources(session_id)
350
+ raise HTTPException(status_code=500, detail=f"Failed to initialize session: {e}")
 
 
 
 
 
 
 
 
351
 
352
  @app.post("/chat/{session_id}", response_model=ChatResponse)
353
  async def chat_with_document(session_id: str, request: ChatRequest):
354
+ """Handles chat queries for an initialized session."""
355
  session_logger = create_session_logger(session_id)
356
  start_time = time.time()
357
 
358
  try:
 
 
 
 
 
 
 
 
359
  with STORE_LOCK:
360
+ store = SESSION_STORES.get(session_id)
361
+ if not store or not store.get("indexed"):
362
+ raise HTTPException(status_code=400, detail="Session not initialized or indexed.")
363
+ session_rag = store["session_rag"]
364
+
 
 
 
 
 
 
 
365
  session_logger.info(f"Processing query: {request.message[:100]}...")
366
+ # --- FIX APPLIED: Consistent threading model ---
367
+ result = await asyncio.to_thread(session_rag.query_documents, request.message, top_k=5)
368
+ APP_STATE["total_queries"] += 1
369
 
370
+ answer = result.get('answer', 'Unable to generate an answer.')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
 
372
+ # Non-blocking save to DB
373
+ asyncio.create_task(save_chat_message_safely(session_id, "user", request.message))
374
+ asyncio.create_task(save_chat_message_safely(session_id, "assistant", answer))
375
 
376
  processing_time = time.time() - start_time
377
+ session_logger.info(f"Query processed in {processing_time:.2f}s.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
378
 
379
  return ChatResponse(
380
  success=True,
381
  answer=answer,
382
+ sources=result.get('sources', []),
383
+ chat_history=[],
384
  processing_time=processing_time,
385
  session_id=session_id,
386
+ query_analysis=result.get('query_analysis'),
387
+ confidence=result.get('confidence')
388
  )
 
 
 
389
  except Exception as e:
390
+ session_logger.error(f"Chat processing failed: {e}", exc_info=True)
391
+ raise HTTPException(status_code=500, detail=f"An error occurred during chat processing: {e}")
 
 
 
 
 
 
 
 
 
 
 
392
 
393
  @app.get("/history/{session_id}")
394
  async def get_session_history(session_id: str):
395
+ """Retrieves chat history for a session."""
 
 
396
  if not DB:
397
  raise HTTPException(status_code=503, detail="Database not connected")
398
+ # --- FIX APPLIED: Consistent threading model ---
399
+ history = await asyncio.to_thread(get_chat_history_safely, session_id)
400
+ return {"session_id": session_id, "chat_history": history}
 
 
 
 
 
 
 
 
 
 
 
 
 
401
 
402
  @app.delete("/session/{session_id}")
403
  async def cleanup_session(session_id: str):
404
+ """Manually cleans up a specific session from memory."""
405
+ cleanup_session_resources(session_id)
406
+ return {"success": True, "message": f"Session {session_id} cleaned up."}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
407
 
408
+ async def save_chat_message_safely(session_id: str, role: str, message: str):
409
+ """Saves chat messages in a non-blocking way."""
410
+ if not DB: return
411
  try:
412
+ await asyncio.to_thread(
413
+ DB.chats.insert_one,
414
+ {"session_id": session_id, "role": role, "message": message, "created_at": datetime.utcnow()}
415
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
416
  except Exception as e:
417
+ logger.error(f"Failed to save chat message for session {session_id}: {e}")
 
418
 
419
  if __name__ == "__main__":
420
  import uvicorn
421
  port = int(os.getenv("PORT", 7861))
422
+ logger.info(f"Starting server on http://0.0.0.0:{port}")
423
+ uvicorn.run("app:app", host="0.0.0.0", port=port, reload=True)