sniro23 commited on
Commit
9a21ee7
Β·
1 Parent(s): a8406a1

Perf: Instrument RAG pipeline for performance diagnostics

Browse files
Files changed (1) hide show
  1. src/enhanced_groq_medical_rag.py +23 -9
src/enhanced_groq_medical_rag.py CHANGED
@@ -98,14 +98,26 @@ class EnhancedGroqMedicalRAG:
98
  self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
99
  self.logger.info("βœ… Cross-Encoder Re-ranker loaded")
100
 
101
- self.logger.info("🎯 Enhanced Medical RAG System ready - Medical-grade safety protocols active")
102
- self._test_groq_connection()
103
 
104
  def setup_logging(self):
105
  """Setup logging for the enhanced medical RAG system"""
106
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
107
  self.logger = logging.getLogger(__name__)
108
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  @retry(
110
  stop=stop_after_attempt(3),
111
  wait=wait_fixed(2),
@@ -341,12 +353,15 @@ class EnhancedGroqMedicalRAG:
341
 
342
  def query(self, query: str, history: Optional[List[Dict[str, str]]] = None, use_llm: bool = True) -> EnhancedMedicalResponse:
343
  """ENHANCED multi-stage medical query processing with comprehensive retrieval and timing."""
344
- start_time = time.time()
 
345
  try:
346
  self.logger.info(f"πŸ” Processing enhanced medical query: {query[:50]}...")
347
 
348
  # Step 1: Analyze query for comprehensive understanding
 
349
  query_analysis = self.analyze_medical_query(query)
 
350
 
351
  # Step 2: Multi-stage comprehensive retrieval
352
  all_documents = []
@@ -376,7 +391,7 @@ class EnhancedGroqMedicalRAG:
376
  seen_content.add(doc.content)
377
 
378
  if not all_documents:
379
- return self._create_no_results_response(query, start_time)
380
 
381
  # Step 3: Advanced multi-criteria re-ranking
382
  reranked_docs = self._advanced_medical_reranking(query_analysis, all_documents)
@@ -426,7 +441,7 @@ class EnhancedGroqMedicalRAG:
426
  safety_status = "CONTEXT_ONLY"
427
 
428
  context_adherence_score = verification_result.verification_score if verification_result else 1.0
429
- query_time = time.time() - start_time
430
 
431
  enhanced_response = EnhancedMedicalResponse(
432
  answer=final_response,
@@ -442,11 +457,10 @@ class EnhancedGroqMedicalRAG:
442
 
443
  self.logger.info(f"🎯 Enhanced medical query completed in {query_time:.2f}s - Safety: {safety_status}")
444
  finally:
445
- end_time = time.time()
446
- processing_time = end_time - start_time
447
 
448
  if 'enhanced_response' in locals() and isinstance(enhanced_response, EnhancedMedicalResponse):
449
- enhanced_response.query_time = processing_time
450
  # Ensure other fields are not None
451
  if not hasattr(enhanced_response, 'answer') or enhanced_response.answer is None:
452
  enhanced_response.answer = "An error occurred during processing."
@@ -461,7 +475,7 @@ class EnhancedGroqMedicalRAG:
461
  answer="A critical error occurred. Unable to generate a full response.",
462
  confidence=0.0,
463
  sources=[],
464
- query_time=processing_time,
465
  verification_result=None,
466
  safety_status="ERROR",
467
  medical_entities_count=0,
 
98
  self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
99
  self.logger.info("βœ… Cross-Encoder Re-ranker loaded")
100
 
101
+ # Add timers for performance diagnostics
102
+ self.timers = {}
103
 
104
  def setup_logging(self):
105
  """Setup logging for the enhanced medical RAG system"""
106
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
107
  self.logger = logging.getLogger(__name__)
108
 
109
+ def _start_timer(self, name: str):
110
+ """Starts a timer for a specific operation."""
111
+ self.timers[name] = time.time()
112
+
113
+ def _stop_timer(self, name: str):
114
+ """Stops a timer and logs the duration."""
115
+ if name in self.timers:
116
+ duration = time.time() - self.timers[name]
117
+ self.logger.info(f"⏱️ Timing: {name} took {duration:.2f}s")
118
+ return duration
119
+ return 0.0
120
+
121
  @retry(
122
  stop=stop_after_attempt(3),
123
  wait=wait_fixed(2),
 
353
 
354
  def query(self, query: str, history: Optional[List[Dict[str, str]]] = None, use_llm: bool = True) -> EnhancedMedicalResponse:
355
  """ENHANCED multi-stage medical query processing with comprehensive retrieval and timing."""
356
+ self._start_timer("Total Query Time")
357
+ total_processing_time = 0
358
  try:
359
  self.logger.info(f"πŸ” Processing enhanced medical query: {query[:50]}...")
360
 
361
  # Step 1: Analyze query for comprehensive understanding
362
+ self._start_timer("Query Analysis")
363
  query_analysis = self.analyze_medical_query(query)
364
+ self._stop_timer("Query Analysis")
365
 
366
  # Step 2: Multi-stage comprehensive retrieval
367
  all_documents = []
 
391
  seen_content.add(doc.content)
392
 
393
  if not all_documents:
394
+ return self._create_no_results_response(query, self._stop_timer("Total Query Time"))
395
 
396
  # Step 3: Advanced multi-criteria re-ranking
397
  reranked_docs = self._advanced_medical_reranking(query_analysis, all_documents)
 
441
  safety_status = "CONTEXT_ONLY"
442
 
443
  context_adherence_score = verification_result.verification_score if verification_result else 1.0
444
+ query_time = self._stop_timer("Total Query Time") - total_processing_time
445
 
446
  enhanced_response = EnhancedMedicalResponse(
447
  answer=final_response,
 
457
 
458
  self.logger.info(f"🎯 Enhanced medical query completed in {query_time:.2f}s - Safety: {safety_status}")
459
  finally:
460
+ total_processing_time = self._stop_timer("Total Query Time")
 
461
 
462
  if 'enhanced_response' in locals() and isinstance(enhanced_response, EnhancedMedicalResponse):
463
+ enhanced_response.query_time = total_processing_time
464
  # Ensure other fields are not None
465
  if not hasattr(enhanced_response, 'answer') or enhanced_response.answer is None:
466
  enhanced_response.answer = "An error occurred during processing."
 
475
  answer="A critical error occurred. Unable to generate a full response.",
476
  confidence=0.0,
477
  sources=[],
478
+ query_time=total_processing_time,
479
  verification_result=None,
480
  safety_status="ERROR",
481
  medical_entities_count=0,