kn29 commited on
Commit
4be82b3
·
verified ·
1 Parent(s): 18411e9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +196 -452
app.py CHANGED
@@ -1,19 +1,13 @@
1
- # app_optimized.py - Performance-Optimized FastAPI App
2
  import asyncio
3
- import concurrent.futures
4
  import logging
5
  import os
6
  import sys
7
  import threading
8
  import time
9
- import traceback
10
- from collections import defaultdict
11
  from contextlib import asynccontextmanager
12
- from datetime import datetime, timedelta
13
- from functools import partial
14
  from typing import Any, Dict, List, Optional
15
 
16
- import numpy as np
17
  import pymongo
18
  from fastapi import FastAPI, HTTPException
19
  from fastapi.middleware.cors import CORSMiddleware
@@ -26,8 +20,7 @@ except ImportError:
26
  FAISS_AVAILABLE = False
27
 
28
  try:
29
- # Import the optimized SessionRAG class
30
- from rag import OptimizedSessionRAG, initialize_models
31
  RAG_AVAILABLE = True
32
  except ImportError:
33
  RAG_AVAILABLE = False
@@ -47,25 +40,17 @@ logger = logging.getLogger(__name__)
47
  MONGO_CLIENT = None
48
  DB = None
49
  RAG_MODELS_INITIALIZED = False
50
- SESSION_LAST_ACCESS = {}
 
 
51
  APP_STATE = {
52
  "startup_time": None,
53
  "mongodb_connected": False,
54
  "rag_models_ready": False,
55
- "active_sessions": 0,
56
  "total_queries": 0,
57
  "errors": []
58
  }
59
 
60
- # --- Optimized Session Management ---
61
- SESSION_STORES = {}
62
- STORE_LOCK = threading.RLock()
63
- CLEANUP_INTERVAL = 1800 # 30 minutes
64
- STORE_TTL = 1800 # 30 minutes
65
-
66
- # Thread pool for async operations
67
- EXECUTOR = concurrent.futures.ThreadPoolExecutor(max_workers=4)
68
-
69
  # --- Pydantic Models ---
70
  class ChatRequest(BaseModel):
71
  message: str = Field(..., min_length=1, max_length=5000)
@@ -90,25 +75,25 @@ class HealthResponse(BaseModel):
90
  uptime_seconds: float
91
  last_error: Optional[str] = None
92
 
 
93
  # --- Helper Functions ---
94
  def create_session_logger(session_id: str):
95
  return logging.LoggerAdapter(logger, {'session_id': session_id[:8]})
96
 
97
  def connect_mongodb():
 
98
  global MONGO_CLIENT, DB
99
  try:
100
  mongodb_url = os.getenv("MONGODB_URL", "mongodb://localhost:27017/")
101
  logger.info(f"Connecting to MongoDB...")
102
  MONGO_CLIENT = pymongo.MongoClient(
103
  mongodb_url,
104
- serverSelectionTimeoutMS=5000,
105
- maxPoolSize=50,
106
- waitQueueTimeoutMS=2500
107
  )
108
  MONGO_CLIENT.admin.command('ping')
109
  DB = MONGO_CLIENT["legal_rag_system"]
110
 
111
- # Ensure indices exist
112
  DB.chats.create_index("session_id", background=True)
113
  DB.chats.create_index("created_at", expireAfterSeconds=24 * 60 * 60, background=True)
114
  DB.sessions.create_index("session_id", unique=True, background=True)
@@ -119,13 +104,14 @@ def connect_mongodb():
119
  return True
120
  except Exception as e:
121
  logger.error(f"MongoDB connection failed: {e}")
 
122
  return False
123
 
124
  def init_rag_models():
125
- """Initialize shared RAG models once at startup"""
126
  global RAG_MODELS_INITIALIZED
127
  if not RAG_AVAILABLE or not FAISS_AVAILABLE:
128
- logger.error("RAG module or FAISS not available - cannot initialize models.")
129
  return False
130
  try:
131
  model_id = os.getenv("EMBEDDING_MODEL_ID", "sentence-transformers/all-MiniLM-L6-v2")
@@ -141,162 +127,163 @@ def init_rag_models():
141
  APP_STATE["errors"].append(f"RAG init failed: {str(e)}")
142
  return False
143
 
144
- async def load_session_from_mongodb_async(session_id: str) -> Dict[str, Any]:
145
- """OPTIMIZED: Load session from MongoDB in async way"""
146
- def _load_session_sync():
147
- session_logger = create_session_logger(session_id)
148
- session_logger.info(f"Loading session from MongoDB: {session_id}")
149
-
150
- if not DB:
151
- raise ValueError("Database not connected")
152
-
153
- # Load session metadata
154
- session_doc = DB.sessions.find_one({"session_id": session_id})
155
- if not session_doc:
156
- raise ValueError(f"Session {session_id} not found in database")
157
- if session_doc.get("status") != "completed":
158
- raise ValueError(f"Session not ready - status: {session_doc.get('status')}")
159
-
160
- # Load chunks with embeddings efficiently
161
- session_logger.info(f"Loading chunks with embeddings for: {session_doc.get('filename', 'unknown')}")
162
-
163
- # Use projection to only load needed fields
164
- chunks_cursor = DB.chunks.find(
165
- {"session_id": session_id},
166
- {
167
- "chunk_id": 1,
168
- "content": 1,
169
- "title": 1,
170
- "section_type": 1,
171
- "importance_score": 1,
172
- "entities": 1,
173
- "embedding": 1
174
- }
175
- ).sort("created_at", 1)
176
-
177
- chunks_list = list(chunks_cursor)
178
-
179
- if not chunks_list:
180
- raise ValueError(f"No chunks found for session {session_id}")
181
 
182
- # Create OptimizedSessionRAG instance
183
- groq_api_key = os.getenv("GROQ_API_KEY")
184
- session_rag = OptimizedSessionRAG(session_id, groq_api_key)
185
-
186
- session_logger.info(f"Loading existing session data with {len(chunks_list)} chunks")
187
- session_rag.load_existing_session_data(chunks_list)
 
 
188
 
189
- session_store = {
190
- "session_rag": session_rag,
191
- "indexed": True,
192
- "metadata": {
193
- "session_id": session_id,
194
- "title": session_doc.get("filename", "Document"),
195
- "chunk_count": len(chunks_list),
196
- "loaded_at": datetime.utcnow(),
197
- "load_time": getattr(session_rag, 'load_time', 0),
198
- "index_build_time": getattr(session_rag, 'index_build_time', 0),
199
- "document_info": {"filename": session_doc.get("filename", "Unknown")}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  }
201
  }
202
- session_logger.info(f"Session loaded from MongoDB in {session_rag.load_time:.2f}s")
203
- return session_store
204
 
205
- # Run in thread pool to avoid blocking
206
- return await asyncio.get_event_loop().run_in_executor(EXECUTOR, _load_session_sync)
207
-
208
- def cleanup_session_resources(session_id: str):
209
- """Safely clean up session resources"""
 
 
 
210
  with STORE_LOCK:
 
211
  if session_id in SESSION_STORES:
212
- store = SESSION_STORES.pop(session_id)
213
- session_rag = store.get("session_rag")
214
- if hasattr(session_rag, 'cleanup'):
215
- session_rag.cleanup()
216
- logger.info(f"Cleaned up session from memory: {session_id[:8]}")
217
-
218
- def cleanup_expired_sessions():
219
- """Clean up sessions that haven't been accessed recently"""
220
- now = datetime.utcnow()
221
- expired_ids = [
222
- sid for sid, last_access in list(SESSION_LAST_ACCESS.items())
223
- if (now - last_access).total_seconds() > STORE_TTL
224
- ]
225
- if expired_ids:
226
- logger.info(f"Cleaning up {len(expired_ids)} expired sessions")
227
- for sid in expired_ids:
228
- cleanup_session_resources(sid)
229
- SESSION_LAST_ACCESS.pop(sid, None)
230
-
231
- async def periodic_cleanup():
232
- while True:
233
- await asyncio.sleep(CLEANUP_INTERVAL)
234
- logger.info("Running periodic cleanup of expired sessions...")
235
- cleanup_expired_sessions()
236
 
237
  async def save_chat_message_safely(session_id: str, role: str, message: str):
238
- """Save chat messages asynchronously"""
239
- if not DB:
240
  return
241
  try:
242
- def _save_message():
243
- DB.chats.insert_one({
244
- "session_id": session_id,
245
- "role": role,
246
- "message": message,
 
247
  "created_at": datetime.utcnow()
248
- })
249
-
250
- await asyncio.get_event_loop().run_in_executor(EXECUTOR, _save_message)
251
  except Exception as e:
252
  logger.error(f"Failed to save chat message for session {session_id}: {e}")
253
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  # --- Application Lifespan ---
255
  @asynccontextmanager
256
  async def lifespan(app: FastAPI):
257
- cleanup_task = None
258
  APP_STATE["startup_time"] = datetime.utcnow()
259
- logger.info("Starting Optimized RAG Chat Service...")
260
 
 
261
  connect_mongodb()
262
  init_rag_models()
263
 
264
- cleanup_task = asyncio.create_task(periodic_cleanup())
265
- logger.info("Background session cleanup task started.")
266
 
267
  yield
268
 
 
269
  logger.info("Shutting down...")
270
- if cleanup_task:
271
- cleanup_task.cancel()
272
  if MONGO_CLIENT:
273
  MONGO_CLIENT.close()
274
- EXECUTOR.shutdown(wait=True)
275
- logger.info("Shutdown complete.")
276
 
277
  # --- FastAPI App ---
278
  app = FastAPI(
279
- title="Optimized RAG Chat Service",
280
- description="High-performance, session-isolated RAG chat service with pre-computed embeddings.",
281
  version="4.0.0",
282
  lifespan=lifespan
283
  )
284
 
285
  app.add_middleware(
286
  CORSMiddleware,
287
- allow_origins=["*"],
288
- allow_credentials=True,
289
- allow_methods=["*"],
290
  allow_headers=["*"],
291
  )
292
 
293
  @app.get("/")
294
  async def root():
295
- return {"service": "Optimized RAG Chat Service", "version": "4.0.0"}
 
 
 
 
296
 
297
  @app.get("/health", response_model=HealthResponse)
298
  async def health_check():
 
299
  uptime = (datetime.utcnow() - APP_STATE["startup_time"]).total_seconds()
 
300
  with STORE_LOCK:
301
  active_sessions = len(SESSION_STORES)
302
  indexed_sessions = sum(1 for s in SESSION_STORES.values() if s.get("indexed", False))
@@ -312,7 +299,7 @@ async def health_check():
312
  faiss_available=FAISS_AVAILABLE,
313
  active_sessions=active_sessions,
314
  memory_usage={
315
- "loaded_sessions": active_sessions,
316
  "indexed_sessions": indexed_sessions
317
  },
318
  uptime_seconds=uptime,
@@ -321,50 +308,51 @@ async def health_check():
321
 
322
  @app.post("/chat/{session_id}", response_model=ChatResponse)
323
  async def chat_with_document(session_id: str, request: ChatRequest):
 
 
 
 
 
 
 
324
  session_logger = create_session_logger(session_id)
325
  start_time = time.time()
326
 
327
  try:
328
- session_logger.info(f"Received chat request: {request.message[:100]}...")
329
-
330
- # Check if session is loaded in memory
331
- with STORE_LOCK:
332
- if session_id not in SESSION_STORES:
333
- # Lazy load: Load session from MongoDB asynchronously
334
- session_logger.info("Loading session from MongoDB for first chat request")
335
- try:
336
- session_store = await load_session_from_mongodb_async(session_id)
337
- SESSION_STORES[session_id] = session_store
338
- session_logger.info(f"Session loaded successfully from MongoDB in {session_store['metadata'].get('load_time', 0):.2f}s")
339
- except Exception as load_error:
340
- session_logger.error(f"Failed to load session: {load_error}")
341
- raise HTTPException(status_code=404, detail=f"Failed to load session: {str(load_error)}")
342
-
343
- # Update last access time
344
- SESSION_LAST_ACCESS[session_id] = datetime.utcnow()
345
- session_rag = SESSION_STORES[session_id]["session_rag"]
346
 
347
- session_logger.info(f"Processing query with OptimizedSessionRAG...")
348
-
349
- # Process the query using OptimizedSessionRAG - this is FAST now
350
- def _process_query():
351
- return session_rag.query_documents(request.message, top_k=5)
352
-
353
- result = await asyncio.get_event_loop().run_in_executor(EXECUTOR, _process_query)
 
 
 
 
 
 
 
 
 
 
 
354
 
355
  if 'error' in result:
356
- session_logger.error(f"Query processing error: {result['error']}")
357
  raise HTTPException(status_code=500, detail=result['error'])
358
 
359
  APP_STATE["total_queries"] += 1
360
  answer = result.get('answer', 'Unable to generate an answer.')
361
 
362
- # Save chat messages asynchronously (non-blocking)
363
  asyncio.create_task(save_chat_message_safely(session_id, "user", request.message))
364
  asyncio.create_task(save_chat_message_safely(session_id, "assistant", answer))
365
 
366
  processing_time = time.time() - start_time
367
- session_logger.info(f"Query processed successfully in {processing_time:.2f}s")
368
 
369
  return ChatResponse(
370
  success=True,
@@ -381,315 +369,71 @@ async def chat_with_document(session_id: str, request: ChatRequest):
381
  except Exception as e:
382
  session_logger.error(f"Chat processing failed: {e}", exc_info=True)
383
  APP_STATE["errors"].append(f"Chat error: {str(e)}")
384
- raise HTTPException(status_code=500, detail=f"Chat processing error: {str(e)}")
 
 
 
385
 
386
  @app.get("/history/{session_id}")
387
  async def get_session_history(session_id: str):
388
- """Retrieve chat history for a session"""
389
  if not DB:
390
  raise HTTPException(status_code=503, detail="Database not connected")
391
 
392
- def _get_history():
393
- try:
394
- chats_cursor = DB.chats.find({"session_id": session_id}).sort("created_at", -1).limit(50)
395
- return list(chats_cursor)[::-1] # Reverse to get chronological order
396
- except Exception as e:
397
- logger.error(f"Failed to get chat history for session {session_id}: {e}")
398
- return []
399
-
400
- history = await asyncio.get_event_loop().run_in_executor(EXECUTOR, _get_history)
401
- return {"session_id": session_id, "chat_history": history}
402
-
403
- @app.delete("/session/{session_id}")
404
- async def cleanup_session(session_id: str):
405
- """Manually clean up a specific session from memory"""
406
- cleanup_session_resources(session_id)
407
- SESSION_LAST_ACCESS.pop(session_id, None)
408
- return {"success": True, "message": f"Session {session_id} cleaned up."}
409
-
410
- @app.get("/session/{session_id}/info")
411
- async def get_session_info(session_id: str):
412
- """Get information about a loaded session"""
413
- with STORE_LOCK:
414
- if session_id not in SESSION_STORES:
415
- raise HTTPException(status_code=404, detail="Session not loaded in memory")
416
-
417
- metadata = SESSION_STORES[session_id]["metadata"]
418
- return {
419
- "session_id": session_id,
420
- "loaded": True,
421
- "metadata": metadata,
422
- "last_access": SESSION_LAST_ACCESS.get(session_id)
423
- }
424
-
425
- @app.get("/sessions")
426
- async def list_active_sessions():
427
- """List all currently active sessions in memory"""
428
- with STORE_LOCK:
429
- sessions = []
430
- for session_id, store in SESSION_STORES.items():
431
- metadata = store.get("metadata", {})
432
- sessions.append({
433
- "session_id": session_id,
434
- "title": metadata.get("title", "Unknown"),
435
- "chunk_count": metadata.get("chunk_count", 0),
436
- "loaded_at": metadata.get("loaded_at"),
437
- "load_time": metadata.get("load_time", 0),
438
- "last_access": SESSION_LAST_ACCESS.get(session_id)
439
- })
440
-
441
- return {
442
- "active_sessions": len(sessions),
443
- "sessions": sessions
444
- }
445
-
446
- # Add this enhanced logging to your chat endpoint in app.py
447
-
448
- @app.post("/chat/{session_id}", response_model=ChatResponse)
449
- async def chat_with_document(session_id: str, request: ChatRequest):
450
- session_logger = create_session_logger(session_id)
451
- start_time = time.time()
452
-
453
- try:
454
- session_logger.info(f"=== CHAT REQUEST START ===")
455
- session_logger.info(f"Session ID: {session_id}")
456
- session_logger.info(f"Message length: {len(request.message)}")
457
- session_logger.info(f"Message preview: {request.message[:100]}...")
458
- session_logger.info(f"RAG models initialized: {RAG_MODELS_INITIALIZED}")
459
- session_logger.info(f"MongoDB connected: {APP_STATE['mongodb_connected']}")
460
-
461
- # Check if session is already loaded in memory
462
- with STORE_LOCK:
463
- session_in_memory = session_id in SESSION_STORES
464
- session_logger.info(f"Session in memory: {session_in_memory}")
465
-
466
- if not session_in_memory:
467
- # Lazy load: Load session from MongoDB when first chat request comes
468
- session_logger.info("=== LAZY LOADING FROM MONGODB ===")
469
- session_logger.info("Loading session from MongoDB for first chat request")
470
-
471
- try:
472
- load_start = time.time()
473
- session_store = await asyncio.to_thread(load_session_from_mongodb, session_id)
474
- load_time = time.time() - load_start
475
- session_logger.info(f"Session loaded from MongoDB in {load_time:.2f}s")
476
-
477
- SESSION_STORES[session_id] = session_store
478
- session_logger.info(f"Session added to memory store")
479
- session_logger.info(f"Chunks loaded: {len(session_store.get('session_rag', {}).chunks_data) if hasattr(session_store.get('session_rag'), 'chunks_data') else 'unknown'}")
480
-
481
- except Exception as load_error:
482
- session_logger.error(f"Failed to load session: {load_error}", exc_info=True)
483
- raise HTTPException(status_code=404, detail=f"Failed to load session: {str(load_error)}")
484
-
485
- # Update last access time
486
- SESSION_LAST_ACCESS[session_id] = datetime.utcnow()
487
- session_rag = SESSION_STORES[session_id]["session_rag"]
488
- session_logger.info(f"Using session RAG instance: {type(session_rag)}")
489
-
490
- session_logger.info(f"=== PROCESSING QUERY ===")
491
-
492
- # Process the query using SessionRAG
493
- query_start = time.time()
494
- result = await asyncio.to_thread(session_rag.query_documents, request.message, top_k=5)
495
- query_time = time.time() - query_start
496
-
497
- session_logger.info(f"Query processed in {query_time:.2f}s")
498
- session_logger.info(f"Result keys: {list(result.keys()) if isinstance(result, dict) else 'not dict'}")
499
-
500
- if 'error' in result:
501
- session_logger.error(f"Query processing error: {result['error']}")
502
- raise HTTPException(status_code=500, detail=result['error'])
503
-
504
- APP_STATE["total_queries"] += 1
505
- answer = result.get('answer', 'Unable to generate an answer.')
506
- sources = result.get('sources', [])
507
-
508
- session_logger.info(f"Generated answer length: {len(answer)}")
509
- session_logger.info(f"Sources found: {len(sources)}")
510
-
511
- # Save chat messages asynchronously
512
- asyncio.create_task(save_chat_message_safely(session_id, "user", request.message))
513
- asyncio.create_task(save_chat_message_safely(session_id, "assistant", answer))
514
-
515
- processing_time = time.time() - start_time
516
- session_logger.info(f"=== CHAT REQUEST COMPLETE ===")
517
- session_logger.info(f"Total processing time: {processing_time:.2f}s")
518
-
519
- return ChatResponse(
520
- success=True,
521
- answer=answer,
522
- sources=sources,
523
- chat_history=[],
524
- processing_time=processing_time,
525
- session_id=session_id,
526
- query_analysis=result.get('query_analysis'),
527
- confidence=result.get('confidence')
528
- )
529
-
530
- except HTTPException:
531
- session_logger.error(f"HTTP Exception in chat processing")
532
- raise
533
- except Exception as e:
534
- session_logger.error(f"Chat processing failed: {e}", exc_info=True)
535
- APP_STATE["errors"].append(f"Chat error: {str(e)}")
536
- raise HTTPException(status_code=500, detail=f"Chat processing error: {str(e)}")
537
-
538
- # Also add this endpoint for better debugging
539
- @app.get("/debug/{session_id}")
540
- async def debug_session(session_id: str):
541
- """Debug endpoint to check session status"""
542
- with STORE_LOCK:
543
- session_in_memory = session_id in SESSION_STORES
544
- session_info = {}
545
-
546
- if session_in_memory:
547
- store = SESSION_STORES[session_id]
548
- session_rag = store.get("session_rag")
549
- session_info = {
550
- "in_memory": True,
551
- "indexed": store.get("indexed", False),
552
- "metadata": store.get("metadata", {}),
553
- "chunks_count": len(session_rag.chunks_data) if hasattr(session_rag, 'chunks_data') else 0,
554
- "has_dense_index": hasattr(session_rag, 'dense_index') and session_rag.dense_index is not None,
555
- "has_bm25_index": hasattr(session_rag, 'bm25_index') and session_rag.bm25_index is not None,
556
- }
557
- else:
558
- session_info = {"in_memory": False}
559
-
560
- # Check MongoDB
561
- mongodb_info = {"connected": False, "session_exists": False, "chunks_count": 0}
562
- if DB:
563
- mongodb_info["connected"] = True
564
- session_doc = DB.sessions.find_one({"session_id": session_id})
565
- if session_doc:
566
- mongodb_info["session_exists"] = True
567
- mongodb_info["session_status"] = session_doc.get("status")
568
- mongodb_info["filename"] = session_doc.get("filename")
569
- chunks_count = DB.chunks.count_documents({"session_id": session_id})
570
- mongodb_info["chunks_count"] = chunks_count
571
-
572
  return {
573
  "session_id": session_id,
574
- "memory": session_info,
575
- "mongodb": mongodb_info,
576
- "app_state": {
577
- "rag_models_ready": RAG_MODELS_INITIALIZED,
578
- "mongodb_connected": APP_STATE["mongodb_connected"],
579
- "active_sessions": len(SESSION_STORES)
580
- }
581
  }
582
 
583
- # Enhanced load_session_from_mongodb with better error handling
584
- def load_session_from_mongodb(session_id: str) -> Dict[str, Any]:
585
- """Load pre-existing session data and embeddings from MongoDB"""
586
- session_logger = create_session_logger(session_id)
587
- session_logger.info(f"=== LOADING SESSION FROM MONGODB ===")
588
- session_logger.info(f"Session ID: {session_id}")
589
-
590
  if not DB:
591
- session_logger.error("Database not connected")
592
- raise ValueError("Database not connected")
593
-
594
- # Load session metadata
595
- session_logger.info("Loading session metadata...")
596
- session_doc = DB.sessions.find_one({"session_id": session_id})
597
- if not session_doc:
598
- session_logger.error(f"Session {session_id} not found in database")
599
- raise ValueError(f"Session {session_id} not found in database")
600
-
601
- session_logger.info(f"Found session: {session_doc.get('filename', 'unknown')} - Status: {session_doc.get('status')}")
602
-
603
- if session_doc.get("status") != "completed":
604
- session_logger.error(f"Session not ready - status: {session_doc.get('status')}")
605
- raise ValueError(f"Session not ready - status: {session_doc.get('status')}")
606
-
607
- # Load chunks with embeddings from MongoDB
608
- session_logger.info(f"Loading chunks with embeddings for: {session_doc.get('filename', 'unknown')}")
609
- chunks_cursor = DB.chunks.find({"session_id": session_id}).sort("created_at", 1)
610
- chunks_list = list(chunks_cursor)
611
 
612
- session_logger.info(f"Found {len(chunks_list)} chunks in database")
 
 
613
 
614
- if not chunks_list:
615
- session_logger.error(f"No chunks found for session {session_id}")
616
- raise ValueError(f"No chunks found for session {session_id}")
617
-
618
- # Verify chunks have embeddings
619
- chunks_with_embeddings = 0
620
- for chunk in chunks_list:
621
- if chunk.get('embedding') is not None:
622
- chunks_with_embeddings += 1
623
 
624
- session_logger.info(f"Chunks with embeddings: {chunks_with_embeddings}/{len(chunks_list)}")
 
 
625
 
626
- if chunks_with_embeddings == 0:
627
- session_logger.error("No chunks have embeddings - document may not be fully processed")
628
- raise ValueError("No chunks have embeddings - document may not be fully processed")
 
 
629
 
630
- # Create SessionRAG instance and load pre-existing data
631
- try:
632
- session_logger.info("Creating SessionRAG instance...")
633
- groq_api_key = os.getenv("GROQ_API_KEY")
634
- session_rag = SessionRAG(session_id, groq_api_key)
635
-
636
- session_logger.info(f"Loading existing session data with {len(chunks_list)} chunks")
637
- load_start = time.time()
638
- session_rag.load_existing_session_data(chunks_list)
639
- load_duration = time.time() - load_start
640
- session_logger.info(f"Session data loaded in {load_duration:.2f}s")
641
-
642
- session_store = {
643
- "session_rag": session_rag,
644
- "indexed": True,
645
- "metadata": {
646
- "session_id": session_id,
647
- "title": session_doc.get("filename", "Document"),
648
- "chunk_count": len(chunks_list),
649
- "loaded_at": datetime.utcnow(),
650
- "load_time": load_duration,
651
- "document_info": {"filename": session_doc.get("filename", "Unknown")}
652
  }
653
- }
654
- session_logger.info("=== SESSION LOADING COMPLETE ===")
655
- return session_store
656
-
657
- except Exception as rag_error:
658
- session_logger.error(f"Failed to create/load SessionRAG: {rag_error}", exc_info=True)
659
- raise ValueError(f"Failed to initialize RAG system: {str(rag_error)}")
660
-
661
- # Add a health check specifically for the chat endpoint
662
- @app.get("/health/chat")
663
- async def chat_health_check():
664
- """Specific health check for chat functionality"""
665
- issues = []
666
-
667
- if not RAG_MODELS_INITIALIZED:
668
- issues.append("RAG models not initialized")
669
-
670
- if not APP_STATE["mongodb_connected"]:
671
- issues.append("MongoDB not connected")
672
-
673
- if not FAISS_AVAILABLE:
674
- issues.append("FAISS not available")
675
-
676
- groq_key = os.getenv("GROQ_API_KEY")
677
- if not groq_key:
678
- issues.append("GROQ API key not configured")
679
 
680
  return {
681
- "status": "healthy" if not issues else "unhealthy",
682
- "issues": issues,
683
- "components": {
684
- "rag_models": RAG_MODELS_INITIALIZED,
685
- "mongodb": APP_STATE["mongodb_connected"],
686
- "faiss": FAISS_AVAILABLE,
687
- "groq_key_configured": bool(groq_key)
688
- }
689
  }
690
 
 
691
  if __name__ == "__main__":
692
  import uvicorn
693
  port = int(os.getenv("PORT", 7861))
694
- logger.info(f"Starting optimized server on http://0.0.0.0:{port}")
695
- uvicorn.run("app_optimized:app", host="0.0.0.0", port=port, reload=False, workers=1)
 
 
1
  import asyncio
 
2
  import logging
3
  import os
4
  import sys
5
  import threading
6
  import time
 
 
7
  from contextlib import asynccontextmanager
8
+ from datetime import datetime
 
9
  from typing import Any, Dict, List, Optional
10
 
 
11
  import pymongo
12
  from fastapi import FastAPI, HTTPException
13
  from fastapi.middleware.cors import CORSMiddleware
 
20
  FAISS_AVAILABLE = False
21
 
22
  try:
23
+ from rag import SessionRAG, initialize_models
 
24
  RAG_AVAILABLE = True
25
  except ImportError:
26
  RAG_AVAILABLE = False
 
40
  MONGO_CLIENT = None
41
  DB = None
42
  RAG_MODELS_INITIALIZED = False
43
+ SESSION_STORES = {} # In-memory cache: {session_id: {session_rag, metadata, indexed}}
44
+ STORE_LOCK = threading.RLock()
45
+
46
  APP_STATE = {
47
  "startup_time": None,
48
  "mongodb_connected": False,
49
  "rag_models_ready": False,
 
50
  "total_queries": 0,
51
  "errors": []
52
  }
53
 
 
 
 
 
 
 
 
 
 
54
  # --- Pydantic Models ---
55
  class ChatRequest(BaseModel):
56
  message: str = Field(..., min_length=1, max_length=5000)
 
75
  uptime_seconds: float
76
  last_error: Optional[str] = None
77
 
78
+
79
  # --- Helper Functions ---
80
  def create_session_logger(session_id: str):
81
  return logging.LoggerAdapter(logger, {'session_id': session_id[:8]})
82
 
83
  def connect_mongodb():
84
+ """Connect to MongoDB Atlas"""
85
  global MONGO_CLIENT, DB
86
  try:
87
  mongodb_url = os.getenv("MONGODB_URL", "mongodb://localhost:27017/")
88
  logger.info(f"Connecting to MongoDB...")
89
  MONGO_CLIENT = pymongo.MongoClient(
90
  mongodb_url,
91
+ serverSelectionTimeoutMS=5000
 
 
92
  )
93
  MONGO_CLIENT.admin.command('ping')
94
  DB = MONGO_CLIENT["legal_rag_system"]
95
 
96
+ # Create indexes
97
  DB.chats.create_index("session_id", background=True)
98
  DB.chats.create_index("created_at", expireAfterSeconds=24 * 60 * 60, background=True)
99
  DB.sessions.create_index("session_id", unique=True, background=True)
 
104
  return True
105
  except Exception as e:
106
  logger.error(f"MongoDB connection failed: {e}")
107
+ APP_STATE["errors"].append(f"MongoDB error: {str(e)}")
108
  return False
109
 
110
  def init_rag_models():
111
+ """Initialize shared RAG models (embedding model, NLP model, etc.)"""
112
  global RAG_MODELS_INITIALIZED
113
  if not RAG_AVAILABLE or not FAISS_AVAILABLE:
114
+ logger.error("RAG module or FAISS not available")
115
  return False
116
  try:
117
  model_id = os.getenv("EMBEDDING_MODEL_ID", "sentence-transformers/all-MiniLM-L6-v2")
 
127
  APP_STATE["errors"].append(f"RAG init failed: {str(e)}")
128
  return False
129
 
130
+ def load_session_from_mongodb(session_id: str) -> Dict[str, Any]:
131
+ """
132
+ Load session data from MongoDB:
133
+ 1. Check if session exists in DB
134
+ 2. Load all chunks with pre-computed embeddings
135
+ 3. Create SessionRAG instance and rebuild indices from existing embeddings
136
+ """
137
+ session_logger = create_session_logger(session_id)
138
+ session_logger.info(f"Loading session from MongoDB: {session_id}")
139
+
140
+ if not DB:
141
+ raise ValueError("Database not connected")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
+ # 1. Load session metadata
144
+ session_doc = DB.sessions.find_one({"session_id": session_id})
145
+ if not session_doc:
146
+ raise ValueError(f"Session {session_id} not found in database")
147
+
148
+ # Check session status
149
+ if session_doc.get("status") != "completed":
150
+ raise ValueError(f"Session not ready - status: {session_doc.get('status')}")
151
 
152
+ # 2. Load chunks with embeddings from MongoDB
153
+ session_logger.info(f"Loading chunks for: {session_doc.get('filename', 'unknown')}")
154
+ chunks_cursor = DB.chunks.find({"session_id": session_id}).sort("created_at", 1)
155
+ chunks_list = list(chunks_cursor)
156
+
157
+ if not chunks_list:
158
+ raise ValueError(f"No chunks found for session {session_id}")
159
+
160
+ session_logger.info(f"Found {len(chunks_list)} chunks with pre-computed embeddings")
161
+
162
+ # 3. Create SessionRAG instance
163
+ groq_api_key = os.getenv("GROQ_API_KEY")
164
+ session_rag = SessionRAG(session_id, groq_api_key)
165
+
166
+ # 4. Load existing session data (rebuilds indices from stored embeddings)
167
+ session_logger.info(f"Rebuilding search indices from existing embeddings...")
168
+ session_rag.load_existing_session_data(chunks_list)
169
+
170
+ # 5. Create session store object
171
+ session_store = {
172
+ "session_rag": session_rag,
173
+ "indexed": True,
174
+ "metadata": {
175
+ "session_id": session_id,
176
+ "title": session_doc.get("filename", "Document"),
177
+ "chunk_count": len(chunks_list),
178
+ "loaded_at": datetime.utcnow(),
179
+ "document_info": {
180
+ "filename": session_doc.get("filename", "Unknown"),
181
+ "upload_date": session_doc.get("created_at")
182
  }
183
  }
184
+ }
 
185
 
186
+ session_logger.info("✓ Session loaded successfully with existing embeddings")
187
+ return session_store
188
+
189
+ def get_or_load_session(session_id: str) -> Dict[str, Any]:
190
+ """
191
+ Get session from memory cache, or load from MongoDB if not in memory.
192
+ Thread-safe with locking.
193
+ """
194
  with STORE_LOCK:
195
+ # Check if already loaded in memory
196
  if session_id in SESSION_STORES:
197
+ logger.info(f"Session {session_id[:8]} already in memory")
198
+ return SESSION_STORES[session_id]
199
+
200
+ # Not in memory - load from MongoDB
201
+ logger.info(f"Session {session_id[:8]} not in memory, loading from MongoDB...")
202
+ session_store = load_session_from_mongodb(session_id)
203
+ SESSION_STORES[session_id] = session_store
204
+ logger.info(f"Session {session_id[:8]} loaded and cached in memory")
205
+ return session_store
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
  async def save_chat_message_safely(session_id: str, role: str, message: str):
208
+ """Save chat messages to MongoDB asynchronously"""
209
+ if not DB:
210
  return
211
  try:
212
+ await asyncio.to_thread(
213
+ DB.chats.insert_one,
214
+ {
215
+ "session_id": session_id,
216
+ "role": role,
217
+ "message": message,
218
  "created_at": datetime.utcnow()
219
+ }
220
+ )
 
221
  except Exception as e:
222
  logger.error(f"Failed to save chat message for session {session_id}: {e}")
223
 
224
+ def get_chat_history_safely(session_id: str, limit: int = 50) -> List[Dict[str, Any]]:
225
+ """Get chat history from MongoDB with error handling"""
226
+ if not DB:
227
+ return []
228
+ try:
229
+ chats_cursor = DB.chats.find({"session_id": session_id}).sort("created_at", -1).limit(limit)
230
+ return list(chats_cursor)[::-1] # Reverse for chronological order
231
+ except Exception as e:
232
+ logger.error(f"Failed to get chat history for session {session_id}: {e}")
233
+ return []
234
+
235
+
236
  # --- Application Lifespan ---
237
  @asynccontextmanager
238
  async def lifespan(app: FastAPI):
239
+ """Application startup and shutdown"""
240
  APP_STATE["startup_time"] = datetime.utcnow()
241
+ logger.info("Starting RAG Chat Service...")
242
 
243
+ # Initialize MongoDB and models
244
  connect_mongodb()
245
  init_rag_models()
246
 
247
+ logger.info("✓ Service ready")
 
248
 
249
  yield
250
 
251
+ # Cleanup on shutdown
252
  logger.info("Shutting down...")
 
 
253
  if MONGO_CLIENT:
254
  MONGO_CLIENT.close()
255
+ logger.info("✓ Shutdown complete")
256
+
257
 
258
  # --- FastAPI App ---
259
  app = FastAPI(
260
+ title="Session-based RAG Chat Service",
261
+ description="RAG system with MongoDB session persistence",
262
  version="4.0.0",
263
  lifespan=lifespan
264
  )
265
 
266
  app.add_middleware(
267
  CORSMiddleware,
268
+ allow_origins=["*"],
269
+ allow_credentials=True,
270
+ allow_methods=["*"],
271
  allow_headers=["*"],
272
  )
273
 
274
  @app.get("/")
275
  async def root():
276
+ return {
277
+ "service": "Session-based RAG Chat Service",
278
+ "version": "4.0.0",
279
+ "description": "Embeddings stored in MongoDB, lazy-loaded on demand"
280
+ }
281
 
282
  @app.get("/health", response_model=HealthResponse)
283
  async def health_check():
284
+ """Health check endpoint"""
285
  uptime = (datetime.utcnow() - APP_STATE["startup_time"]).total_seconds()
286
+
287
  with STORE_LOCK:
288
  active_sessions = len(SESSION_STORES)
289
  indexed_sessions = sum(1 for s in SESSION_STORES.values() if s.get("indexed", False))
 
299
  faiss_available=FAISS_AVAILABLE,
300
  active_sessions=active_sessions,
301
  memory_usage={
302
+ "loaded_sessions": active_sessions,
303
  "indexed_sessions": indexed_sessions
304
  },
305
  uptime_seconds=uptime,
 
308
 
309
  @app.post("/chat/{session_id}", response_model=ChatResponse)
310
  async def chat_with_document(session_id: str, request: ChatRequest):
311
+ """
312
+ Main chat endpoint:
313
+ 1. Load session from MongoDB if not in memory (lazy loading)
314
+ 2. Process query using RAG pipeline
315
+ 3. Save chat messages to MongoDB
316
+ 4. Return answer with sources
317
+ """
318
  session_logger = create_session_logger(session_id)
319
  start_time = time.time()
320
 
321
  try:
322
+ session_logger.info(f"Chat request: {request.message[:100]}...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
 
324
+ # Get or load session (lazy loading from MongoDB)
325
+ try:
326
+ session_store = await asyncio.to_thread(get_or_load_session, session_id)
327
+ session_rag = session_store["session_rag"]
328
+ except Exception as load_error:
329
+ session_logger.error(f"Failed to load session: {load_error}")
330
+ raise HTTPException(
331
+ status_code=404,
332
+ detail=f"Session not found or failed to load: {str(load_error)}"
333
+ )
334
+
335
+ # Process query using RAG pipeline
336
+ session_logger.info(f"Processing query with RAG...")
337
+ result = await asyncio.to_thread(
338
+ session_rag.query_documents,
339
+ request.message,
340
+ top_k=5
341
+ )
342
 
343
  if 'error' in result:
344
+ session_logger.error(f"Query error: {result['error']}")
345
  raise HTTPException(status_code=500, detail=result['error'])
346
 
347
  APP_STATE["total_queries"] += 1
348
  answer = result.get('answer', 'Unable to generate an answer.')
349
 
350
+ # Save chat messages asynchronously to MongoDB
351
  asyncio.create_task(save_chat_message_safely(session_id, "user", request.message))
352
  asyncio.create_task(save_chat_message_safely(session_id, "assistant", answer))
353
 
354
  processing_time = time.time() - start_time
355
+ session_logger.info(f"Query processed in {processing_time:.2f}s")
356
 
357
  return ChatResponse(
358
  success=True,
 
369
  except Exception as e:
370
  session_logger.error(f"Chat processing failed: {e}", exc_info=True)
371
  APP_STATE["errors"].append(f"Chat error: {str(e)}")
372
+ raise HTTPException(
373
+ status_code=500,
374
+ detail=f"Chat processing error: {str(e)}"
375
+ )
376
 
377
  @app.get("/history/{session_id}")
378
  async def get_session_history(session_id: str):
379
+ """Get chat history for a session from MongoDB"""
380
  if not DB:
381
  raise HTTPException(status_code=503, detail="Database not connected")
382
 
383
+ history = await asyncio.to_thread(get_chat_history_safely, session_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
  return {
385
  "session_id": session_id,
386
+ "chat_history": history,
387
+ "count": len(history)
 
 
 
 
 
388
  }
389
 
390
+ @app.get("/session/{session_id}/info")
391
+ async def get_session_info(session_id: str):
392
+ """Get session metadata from MongoDB"""
 
 
 
 
393
  if not DB:
394
+ raise HTTPException(status_code=503, detail="Database not connected")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
 
396
+ session_doc = await asyncio.to_thread(DB.sessions.find_one, {"session_id": session_id})
397
+ if not session_doc:
398
+ raise HTTPException(status_code=404, detail="Session not found")
399
 
400
+ # Convert ObjectId to string for JSON serialization
401
+ session_doc['_id'] = str(session_doc['_id'])
 
 
 
 
 
 
 
402
 
403
+ # Check if loaded in memory
404
+ with STORE_LOCK:
405
+ in_memory = session_id in SESSION_STORES
406
 
407
+ return {
408
+ "session_id": session_id,
409
+ "metadata": session_doc,
410
+ "in_memory": in_memory
411
+ }
412
 
413
+ @app.delete("/session/{session_id}/cache")
414
+ async def clear_session_cache(session_id: str):
415
+ """Remove session from memory cache (data remains in MongoDB)"""
416
+ with STORE_LOCK:
417
+ if session_id in SESSION_STORES:
418
+ store = SESSION_STORES.pop(session_id)
419
+ session_rag = store.get("session_rag")
420
+ if hasattr(session_rag, 'cleanup'):
421
+ session_rag.cleanup()
422
+ logger.info(f"Session {session_id[:8]} removed from memory cache")
423
+ return {
424
+ "success": True,
425
+ "message": f"Session removed from memory cache",
426
+ "note": "Data remains in MongoDB"
 
 
 
 
 
 
 
 
427
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428
 
429
  return {
430
+ "success": False,
431
+ "message": "Session not found in memory cache"
 
 
 
 
 
 
432
  }
433
 
434
+
435
  if __name__ == "__main__":
436
  import uvicorn
437
  port = int(os.getenv("PORT", 7861))
438
+ logger.info(f"Starting server on http://0.0.0.0:{port}")
439
+ uvicorn.run("app:app", host="0.0.0.0", port=port, reload=True)