| | from typing import Dict, List, Tuple, Any, Optional |
| | import numpy as np |
| | import random |
| | from logger_config import config_logger |
| | from cross_encoder_reranker import CrossEncoderReranker |
| |
|
| | logger = config_logger(__name__) |
| |
|
| |
|
| | class ChatbotValidator: |
| | """ |
| | Handles automated validation and performance analysis for the chatbot. |
| | This testing module executes domain-specific queries, obtains chatbot responses, and evaluates them with a quality checker. |
| | """ |
| | |
| | def __init__(self, chatbot, quality_checker, cross_encoder_model='cross-encoder/ms-marco-MiniLM-L-12-v2'): |
| | """ |
| | Initialize the validator. |
| | Args: |
| | chatbot: RetrievalChatbot for inference |
| | quality_checker: ResponseQualityChecker |
| | """ |
| | self.chatbot = chatbot |
| | self.quality_checker = quality_checker |
| | self.reranker = CrossEncoderReranker(model_name=cross_encoder_model) |
| | |
| | |
| | self.domain_queries = { |
| | 'restaurant': [ |
| | "Hi, I have a question about your restaurant. Do they take reservations?", |
| | "I'd like to make a reservation for dinner tonight after 6pm. Is that time available?", |
| | "Can you recommend an Italian restaurant with wood-fired pizza?", |
| | ], |
| | 'movie': [ |
| | "How much are movie tickets for two people?", |
| | "I'm looking for showings after 6pm?", |
| | "Is this at the new theater with reclining seats?", |
| | ], |
| | 'ride_share': [ |
| | "I need a ride from the airport to downtown.", |
| | "What is the cost for Lyft? How about Uber XL?", |
| | "Can you book a car for tomorrow morning?", |
| | ], |
| | 'coffee': [ |
| | "Can I customize my coffee?", |
| | "Can I order a mocha from you?", |
| | "Can I get my usual venti vanilla latte?", |
| | ], |
| | 'pizza': [ |
| | "Do you have any pizza specials or deals available?", |
| | "How long is the wait until the pizza is ready and delivered to me?", |
| | "Please repeat my pizza order for two medium pizzas with thick crust.", |
| | ], |
| | 'auto': [ |
| | "The car is making a funny noise when I turn, and I'm due for an oil change.", |
| | "Is my buddy John available to work on my car?", |
| | "My Jeep needs a repair. Can you help me with that?", |
| | ], |
| | } |
| | |
| | def run_validation( |
| | self, |
| | num_examples: int = 3, |
| | top_k: int = 10, |
| | domains: Optional[List[str]] = None, |
| | randomize: bool = False, |
| | seed: int = 42 |
| | ) -> Dict[str, Any]: |
| | """ |
| | Run validation across testable domains. |
| | Args: |
| | num_examples: Number of test queries per domain |
| | top_k: Number of responses to retrieve for each query |
| | domains: Optional list of domain keys to test. If None, test all. |
| | randomize: If True, randomly select queries from the domain lists |
| | seed: Random seed for consistent sampling if randomize=True |
| | Returns: |
| | Dict with validation metrics |
| | """ |
| | logger.info("\n=== Running Automatic Validation ===") |
| | |
| | |
| | test_domains = domains if domains else list(self.domain_queries.keys()) |
| | |
| | |
| | metrics_history = [] |
| | domain_metrics = {} |
| | |
| | |
| | rng = random.Random(seed) |
| | |
| | |
| | for domain in test_domains: |
| | |
| | if domain not in self.domain_queries: |
| | logger.warning(f"Domain '{domain}' not found in domain_queries. Skipping.") |
| | continue |
| | |
| | all_queries = self.domain_queries[domain] |
| | if randomize: |
| | queries = rng.sample(all_queries, min(num_examples, len(all_queries))) |
| | else: |
| | queries = all_queries[:num_examples] |
| | |
| | |
| | domain_metrics[domain] = [] |
| | |
| | logger.info(f"\n=== Testing {domain.title()} Domain ===\n") |
| | |
| | for i, query in enumerate(queries, 1): |
| | logger.info(f"TEST CASE {i}: QUERY: {query}") |
| | |
| | |
| | responses = self.chatbot.retrieve_responses(query, top_k=top_k, reranker=self.reranker) |
| | quality_metrics = self.quality_checker.check_response_quality(query, responses) |
| | |
| | |
| | quality_metrics['domain'] = domain |
| | metrics_history.append(quality_metrics) |
| | domain_metrics[domain].append(quality_metrics) |
| | self._log_validation_results(query, responses, quality_metrics) |
| | logger.info(f"Quality metrics: {quality_metrics}\n") |
| | |
| | |
| | aggregate_metrics = self._calculate_aggregate_metrics(metrics_history) |
| | domain_analysis = self._analyze_domain_performance(domain_metrics) |
| | confidence_analysis = self._analyze_confidence_distribution(metrics_history) |
| | |
| | aggregate_metrics.update({ |
| | 'domain_performance': domain_analysis, |
| | 'confidence_analysis': confidence_analysis |
| | }) |
| | |
| | self._log_validation_summary(aggregate_metrics) |
| | return aggregate_metrics |
| | |
| | def _calculate_aggregate_metrics(self, metrics_history: List[Dict]) -> Dict[str, float]: |
| | """ |
| | Calculate aggregate metrics over tested queries. |
| | """ |
| | if not metrics_history: |
| | logger.warning("No metrics to aggregate. Returning empty summary.") |
| | return {} |
| | |
| | top_scores = [m.get('top_score', 0.0) for m in metrics_history] |
| | |
| | metrics = { |
| | 'num_queries_tested': len(metrics_history), |
| | 'avg_top_response_score': np.mean(top_scores), |
| | 'avg_diversity': np.mean([m.get('response_diversity', 0.0) for m in metrics_history]), |
| | 'avg_relevance': np.mean([m.get('query_response_relevance', 0.0) for m in metrics_history]), |
| | 'avg_length_score': np.mean([m.get('response_length_score', 0.0) for m in metrics_history]), |
| | 'avg_score_gap': np.mean([m.get('top_3_score_gap', 0.0) for m in metrics_history]), |
| | 'confidence_rate': np.mean([1.0 if m.get('is_confident', False) else 0.0 for m in metrics_history]), |
| | 'median_top_score': np.median(top_scores), |
| | 'score_std': np.std(top_scores), |
| | 'min_score': np.min(top_scores), |
| | 'max_score': np.max(top_scores) |
| | } |
| | return metrics |
| | |
| | def _analyze_domain_performance(self, domain_metrics: Dict[str, List[Dict]]) -> Dict[str, Dict[str, float]]: |
| | """ |
| | Analyze performance by domain, returning a nested dict. |
| | """ |
| | analysis = {} |
| | |
| | for domain, metrics_list in domain_metrics.items(): |
| | if not metrics_list: |
| | analysis[domain] = {} |
| | continue |
| | |
| | top_scores = [m.get('top_score', 0.0) for m in metrics_list] |
| | |
| | analysis[domain] = { |
| | 'confidence_rate': np.mean([1.0 if m.get('is_confident', False) else 0.0 for m in metrics_list]), |
| | 'avg_relevance': np.mean([m.get('query_response_relevance', 0.0) for m in metrics_list]), |
| | 'avg_diversity': np.mean([m.get('response_diversity', 0.0) for m in metrics_list]), |
| | 'avg_top_score': np.mean(top_scores), |
| | 'num_samples': len(metrics_list) |
| | } |
| | |
| | return analysis |
| | |
| | def _analyze_confidence_distribution(self, metrics_history: List[Dict]) -> Dict[str, float]: |
| | """ |
| | Analyze the distribution of top scores to gauge system confidence levels. |
| | """ |
| | if not metrics_history: |
| | return {'percentile_25': 0.0, 'percentile_50': 0.0, |
| | 'percentile_75': 0.0, 'percentile_90': 0.0} |
| | |
| | scores = [m.get('top_score', 0.0) for m in metrics_history] |
| | return { |
| | 'percentile_25': float(np.percentile(scores, 25)), |
| | 'percentile_50': float(np.percentile(scores, 50)), |
| | 'percentile_75': float(np.percentile(scores, 75)), |
| | 'percentile_90': float(np.percentile(scores, 90)) |
| | } |
| | |
| | def _log_validation_results( |
| | self, |
| | query: str, |
| | responses: List[Tuple[str, float]], |
| | metrics: Dict[str, Any], |
| | ): |
| | """ |
| | Log detailed validation results for each test case. |
| | """ |
| | domain = metrics.get('domain', 'Unknown') |
| | is_confident = metrics.get('is_confident', False) |
| | |
| | logger.info(f"DOMAIN: {domain} | CONFIDENCE: {'Yes' if is_confident else 'No'}") |
| | |
| | if is_confident or responses[0][1] >= 0.5: |
| | logger.info(f"SELECTED RESPONSE: '{responses[0][0]}'") |
| | else: |
| | logger.info("SELECTED RESPONSE: NONE (Low Confidence)") |
| | |
| | logger.info(" Top 3 Responses:") |
| | for i, (resp_text, score) in enumerate(responses[:3], 1): |
| | logger.info(f" {i}) Score: {score:.4f} | {resp_text}") |
| | |
| | def _log_validation_summary(self, metrics: Dict[str, Any]): |
| | """ |
| | Log a summary of all validation metrics and domain performance. |
| | """ |
| | if not metrics: |
| | logger.info("No metrics to summarize.") |
| | return |
| | |
| | logger.info("\n=== Validation Summary ===") |
| | |
| | |
| | logger.info("\nOverall Metrics:") |
| | for metric, value in metrics.items(): |
| | |
| | if isinstance(value, (int, float)): |
| | logger.info(f"{metric}: {value:.4f}") |
| | |
| | |
| | domain_perf = metrics.get('domain_performance', {}) |
| | logger.info("\nDomain Performance:") |
| | for domain, domain_stats in domain_perf.items(): |
| | logger.info(f"\n{domain.title()}:") |
| | for metric, value in domain_stats.items(): |
| | logger.info(f" {metric}: {value:.4f}") |
| | |
| | |
| | conf_analysis = metrics.get('confidence_analysis', {}) |
| | logger.info("\nConfidence Distribution:") |
| | for pct, val in conf_analysis.items(): |
| | logger.info(f" {pct}: {val:.4f}") |
| |
|
| |
|