""" Comparison Validator for AegisLM Framework Production-grade validation for experiment comparison results with statistical significance testing, accuracy verification, and comprehensive reporting. """ import numpy as np from scipy import stats from typing import Dict, Any, List, Optional, Tuple from datetime import datetime from dataclasses import dataclass from pathlib import Path import sys import logging # Add parent directory to path for imports current_dir = Path(__file__).parent backend_dir = current_dir.parent if str(backend_dir) not in sys.path: sys.path.insert(0, str(backend_dir)) logger = logging.getLogger(__name__) @dataclass class ValidationWarning: """Validation warning information.""" message: str severity: str # "low", "medium", "high" metric_name: Optional[str] = None recommendation: Optional[str] = None @dataclass class ValidationError: """Validation error information.""" message: str error_type: str # "data_error", "statistical_error", "logic_error" metric_name: Optional[str] = None fix_required: bool = True @dataclass class StatisticalTest: """Statistical test result.""" test_name: str statistic: float p_value: float is_significant: bool confidence_level: float effect_size: Optional[float] = None @dataclass class MetricComparison: """Metric comparison result.""" metric_name: str baseline_value: float comparison_value: float absolute_difference: float relative_difference: float statistical_test: Optional[StatisticalTest] = None is_significant: bool = False confidence_interval: Optional[Tuple[float, float]] = None @dataclass class ComparisonValidationReport: """Comprehensive comparison validation report.""" is_valid: bool overall_confidence: float validation_errors: List[ValidationError] validation_warnings: List[ValidationWarning] metric_comparisons: List[MetricComparison] statistical_tests: List[StatisticalTest] validation_timestamp: datetime recommendations: List[str] class ComparisonValidator: """ Validate experiment comparison results for accuracy and statistical significance. Provides comprehensive validation including statistical testing, accuracy verification, and actionable recommendations. """ def __init__(self, significance_threshold: float = 0.05, confidence_level: float = 0.95): """ Initialize comparison validator. Args: significance_threshold: P-value threshold for statistical significance confidence_level: Confidence level for statistical tests """ self.significance_threshold = significance_threshold self.confidence_level = confidence_level async def validate_comparison(self, baseline_results: Dict[str, Any], comparison_results: Dict[str, Any]) -> ComparisonValidationReport: """ Validate comparison results for statistical significance and accuracy. Args: baseline_results: Baseline experiment results comparison_results: Comparison experiment results Returns: ComparisonValidationReport: Comprehensive validation report """ validation_errors = [] validation_warnings = [] metric_comparisons = [] statistical_tests = [] try: # Validate result completeness completeness_errors = self._validate_result_completeness(baseline_results, comparison_results) validation_errors.extend(completeness_errors) # Perform metric comparisons baseline_metrics = baseline_results.get('metrics', {}) comparison_metrics = comparison_results.get('metrics', {}) for metric_name in baseline_metrics: if metric_name in comparison_metrics: comparison_result = await self._compare_metric( metric_name, baseline_metrics[metric_name], comparison_metrics[metric_name] ) metric_comparisons.append(comparison_result) if comparison_result.statistical_test: statistical_tests.append(comparison_result.statistical_test) else: validation_errors.append(ValidationError( message=f"Missing metric in comparison: {metric_name}", error_type="data_error", metric_name=metric_name )) # Validate overall comparison logic logic_errors = self._validate_comparison_logic(baseline_results, comparison_results) validation_errors.extend(logic_errors) # Generate warnings based on results warnings = self._generate_warnings(metric_comparisons) validation_warnings.extend(warnings) # Calculate overall confidence overall_confidence = self._calculate_overall_confidence(metric_comparisons, validation_errors) # Generate recommendations recommendations = self._generate_recommendations(metric_comparisons, validation_errors, validation_warnings) # Determine overall validity is_valid = len(validation_errors) == 0 and overall_confidence >= 0.7 return ComparisonValidationReport( is_valid=is_valid, overall_confidence=overall_confidence, validation_errors=validation_errors, validation_warnings=validation_warnings, metric_comparisons=metric_comparisons, statistical_tests=statistical_tests, validation_timestamp=datetime.utcnow(), recommendations=recommendations ) except Exception as e: logger.error(f"Comparison validation failed: {e}") return ComparisonValidationReport( is_valid=False, overall_confidence=0.0, validation_errors=[ValidationError( message=f"Validation failed: {str(e)}", error_type="statistical_error", fix_required=True )], validation_warnings=[], metric_comparisons=[], statistical_tests=[], validation_timestamp=datetime.utcnow(), recommendations=["Fix validation errors before proceeding"] ) async def validate_trend_analysis(self, trend_data: List[Dict[str, Any]]) -> ComparisonValidationReport: """ Validate trend analysis results. Args: trend_data: List of trend data points Returns: ComparisonValidationReport: Trend validation report """ validation_errors = [] validation_warnings = [] metric_comparisons = [] statistical_tests = [] try: # Validate trend data completeness if len(trend_data) < 3: validation_errors.append(ValidationError( message="Insufficient data points for trend analysis (minimum 3 required)", error_type="data_error", fix_required=True )) # Analyze each metric's trend if trend_data: metrics = trend_data[0].get('metrics', {}) for metric_name in metrics: trend_result = await self._analyze_metric_trend(metric_name, trend_data) metric_comparisons.append(trend_result) if trend_result.statistical_test: statistical_tests.append(trend_result.statistical_test) # Generate trend-specific warnings trend_warnings = self._generate_trend_warnings(metric_comparisons) validation_warnings.extend(trend_warnings) # Calculate overall confidence overall_confidence = self._calculate_overall_confidence(metric_comparisons, validation_errors) # Generate recommendations recommendations = self._generate_trend_recommendations(metric_comparisons, validation_errors) # Determine overall validity is_valid = len(validation_errors) == 0 and overall_confidence >= 0.6 return ComparisonValidationReport( is_valid=is_valid, overall_confidence=overall_confidence, validation_errors=validation_errors, validation_warnings=validation_warnings, metric_comparisons=metric_comparisons, statistical_tests=statistical_tests, validation_timestamp=datetime.utcnow(), recommendations=recommendations ) except Exception as e: logger.error(f"Trend validation failed: {e}") return ComparisonValidationReport( is_valid=False, overall_confidence=0.0, validation_errors=[ValidationError( message=f"Trend validation failed: {str(e)}", error_type="statistical_error", fix_required=True )], validation_warnings=[], metric_comparisons=[], statistical_tests=[], validation_timestamp=datetime.utcnow(), recommendations=["Fix validation errors before proceeding"] ) async def _compare_metric(self, metric_name: str, baseline_data: Dict[str, Any], comparison_data: Dict[str, Any]) -> MetricComparison: """ Compare a single metric between baseline and comparison. Args: metric_name: Name of the metric baseline_data: Baseline metric data comparison_data: Comparison metric data Returns: MetricComparison: Comparison result """ # Extract values baseline_vals = baseline_data.get('values', []) comparison_vals = comparison_data.get('values', []) # Handle single values if not baseline_vals and 'value' in baseline_data: baseline_vals = [baseline_data['value']] if not comparison_vals and 'value' in comparison_data: comparison_vals = [comparison_data['value']] # Calculate basic statistics baseline_mean = np.mean(baseline_vals) if baseline_vals else 0 comparison_mean = np.mean(comparison_vals) if comparison_vals else 0 absolute_diff = comparison_mean - baseline_mean relative_diff = (absolute_diff / baseline_mean * 100) if baseline_mean != 0 else 0 # Statistical test statistical_test = None is_significant = False confidence_interval = None if len(baseline_vals) > 1 and len(comparison_vals) > 1: try: # Perform t-test t_stat, p_value = stats.ttest_ind(baseline_vals, comparison_vals) is_significant = p_value < self.significance_threshold # Calculate confidence interval pooled_std = np.sqrt( ((len(baseline_vals) - 1) * np.var(baseline_vals, ddof=1) + (len(comparison_vals) - 1) * np.var(comparison_vals, ddof=1)) / (len(baseline_vals) + len(comparison_vals) - 2) ) se = pooled_std * np.sqrt(1/len(baseline_vals) + 1/len(comparison_vals)) t_critical = stats.t.ppf(1 - (1 - self.confidence_level) / 2, len(baseline_vals) + len(comparison_vals) - 2) margin_error = t_critical * se confidence_interval = (absolute_diff - margin_error, absolute_diff + margin_error) # Calculate effect size (Cohen's d) pooled_std_effect = np.sqrt(((len(baseline_vals) - 1) * np.var(baseline_vals, ddof=1) + (len(comparison_vals) - 1) * np.var(comparison_vals, ddof=1)) / (len(baseline_vals) + len(comparison_vals) - 2)) effect_size = absolute_diff / pooled_std_effect if pooled_std_effect != 0 else 0 statistical_test = StatisticalTest( test_name="independent_t_test", statistic=t_stat, p_value=p_value, is_significant=is_significant, confidence_level=self.confidence_level, effect_size=effect_size ) except Exception as e: logger.warning(f"Statistical test failed for {metric_name}: {e}") return MetricComparison( metric_name=metric_name, baseline_value=baseline_mean, comparison_value=comparison_mean, absolute_difference=absolute_diff, relative_difference=relative_diff, statistical_test=statistical_test, is_significant=is_significant, confidence_interval=confidence_interval ) async def _analyze_metric_trend(self, metric_name: str, trend_data: List[Dict[str, Any]]) -> MetricComparison: """ Analyze trend for a specific metric. Args: metric_name: Name of the metric trend_data: Trend data points Returns: MetricComparison: Trend analysis result """ # Extract values over time values = [] for point in trend_data: metric_data = point.get('metrics', {}).get(metric_name, {}) if 'value' in metric_data: values.append(metric_data['value']) elif 'values' in metric_data: values.append(np.mean(metric_data['values'])) if len(values) < 2: return MetricComparison( metric_name=metric_name, baseline_value=values[0] if values else 0, comparison_value=values[-1] if values else 0, absolute_difference=0, relative_difference=0, is_significant=False ) # Calculate trend baseline_value = values[0] comparison_value = values[-1] absolute_diff = comparison_value - baseline_value relative_diff = (absolute_diff / baseline_value * 100) if baseline_value != 0 else 0 # Linear regression for trend significance x = np.arange(len(values)) y = np.array(values) statistical_test = None is_significant = False try: slope, intercept, r_value, p_value, std_err = stats.linregress(x, y) # Test if slope is significantly different from zero t_stat = slope / std_err if std_err != 0 else 0 p_val = 2 * (1 - stats.t.cdf(abs(t_stat), len(values) - 2)) is_significant = p_val < self.significance_threshold statistical_test = StatisticalTest( test_name="linear_regression", statistic=t_stat, p_value=p_val, is_significant=is_significant, confidence_level=self.confidence_level, effect_size=r_value ** 2 ) except Exception as e: logger.warning(f"Trend analysis failed for {metric_name}: {e}") return MetricComparison( metric_name=metric_name, baseline_value=baseline_value, comparison_value=comparison_value, absolute_difference=absolute_diff, relative_difference=relative_diff, statistical_test=statistical_test, is_significant=is_significant ) def _validate_result_completeness(self, baseline_results: Dict[str, Any], comparison_results: Dict[str, Any]) -> List[ValidationError]: """Validate completeness of comparison results.""" errors = [] # Check required fields required_fields = ['metrics', 'timestamp', 'experiment_id'] for field in required_fields: if field not in baseline_results: errors.append(ValidationError( message=f"Missing required field in baseline: {field}", error_type="data_error", fix_required=True )) if field not in comparison_results: errors.append(ValidationError( message=f"Missing required field in comparison: {field}", error_type="data_error", fix_required=True )) return errors def _validate_comparison_logic(self, baseline_results: Dict[str, Any], comparison_results: Dict[str, Any]) -> List[ValidationError]: """Validate logical consistency of comparison.""" errors = [] # Check if experiments are comparable baseline_config = baseline_results.get('config', {}) comparison_config = comparison_results.get('config', {}) # Check model compatibility if baseline_config.get('model_name') != comparison_config.get('model_name'): errors.append(ValidationError( message="Models are different - comparison may not be valid", error_type="logic_error", fix_required=False )) # Check dataset compatibility if (baseline_config.get('dataset_name') != comparison_config.get('dataset_name') or baseline_config.get('dataset_version') != comparison_config.get('dataset_version')): errors.append(ValidationError( message="Datasets are different - comparison may not be valid", error_type="logic_error", fix_required=False )) return errors def _generate_warnings(self, metric_comparisons: List[MetricComparison]) -> List[ValidationWarning]: """Generate warnings based on comparison results.""" warnings = [] for comparison in metric_comparisons: # Check for small differences if abs(comparison.relative_difference) < 1.0: warnings.append(ValidationWarning( message=f"Very small difference detected for {comparison.metric_name}", severity="low", metric_name=comparison.metric_name, recommendation="Consider if this difference is practically significant" )) # Check for non-significant results if comparison.statistical_test and not comparison.is_significant: warnings.append(ValidationWarning( message=f"No significant difference for {comparison.metric_name} (p={comparison.statistical_test.p_value:.3f})", severity="medium", metric_name=comparison.metric_name, recommendation="Consider increasing sample size or effect size" )) # Check for large effect sizes if (comparison.statistical_test and comparison.statistical_test.effect_size and abs(comparison.statistical_test.effect_size) > 0.8): warnings.append(ValidationWarning( message=f"Large effect size for {comparison.metric_name}", severity="medium", metric_name=comparison.metric_name, recommendation="Verify results are not due to outliers or data issues" )) return warnings def _generate_trend_warnings(self, metric_comparisons: List[MetricComparison]) -> List[ValidationWarning]: """Generate trend-specific warnings.""" warnings = [] for comparison in metric_comparisons: # Check for weak trend correlation if (comparison.statistical_test and comparison.statistical_test.effect_size and comparison.statistical_test.effect_size < 0.3): warnings.append(ValidationWarning( message=f"Weak trend correlation for {comparison.metric_name} (R² = {comparison.statistical_test.effect_size:.3f})", severity="medium", metric_name=comparison.metric_name, recommendation="Consider if trend is meaningful or due to noise" )) return warnings def _calculate_overall_confidence(self, metric_comparisons: List[MetricComparison], validation_errors: List[ValidationError]) -> float: """Calculate overall confidence in comparison results.""" if validation_errors: return 0.0 if not metric_comparisons: return 0.5 # Neutral confidence # Calculate confidence based on statistical significance significant_count = sum(1 for comp in metric_comparisons if comp.is_significant) total_count = len(metric_comparisons) if total_count == 0: return 0.5 # Base confidence from significance ratio significance_ratio = significant_count / total_count # Adjust for effect sizes avg_effect_size = 0 effect_sizes = [comp.statistical_test.effect_size for comp in metric_comparisons if comp.statistical_test and comp.statistical_test.effect_size is not None] if effect_sizes: avg_effect_size = np.mean(effect_sizes) # Calculate overall confidence overall_confidence = (significance_ratio * 0.7) + (min(avg_effect_size, 1.0) * 0.3) return min(overall_confidence, 1.0) def _generate_recommendations(self, metric_comparisons: List[MetricComparison], validation_errors: List[ValidationError], validation_warnings: List[ValidationWarning]) -> List[str]: """Generate actionable recommendations.""" recommendations = [] if validation_errors: recommendations.append("Fix validation errors before proceeding with analysis") # Statistical recommendations non_significant_metrics = [comp for comp in metric_comparisons if comp.statistical_test and not comp.is_significant] if non_significant_metrics: recommendations.append("Consider increasing sample size for metrics with non-significant results") # Effect size recommendations large_effect_metrics = [comp for comp in metric_comparisons if comp.statistical_test and comp.statistical_test.effect_size and abs(comp.statistical_test.effect_size) > 0.8] if large_effect_metrics: recommendations.append("Verify large effect sizes are not due to data anomalies") # General recommendations if len(metric_comparisons) > 0: avg_confidence = self._calculate_overall_confidence(metric_comparisons, validation_errors) if avg_confidence < 0.7: recommendations.append("Consider collecting more data to improve statistical power") return recommendations def _generate_trend_recommendations(self, metric_comparisons: List[MetricComparison], validation_errors: List[ValidationError]) -> List[str]: """Generate trend-specific recommendations.""" recommendations = [] if validation_errors: recommendations.append("Fix validation errors before proceeding with trend analysis") # Trend strength recommendations weak_trends = [comp for comp in metric_comparisons if comp.statistical_test and comp.statistical_test.effect_size and comp.statistical_test.effect_size < 0.3] if weak_trends: recommendations.append("Consider collecting more data points to strengthen trend analysis") return recommendations # Factory function def create_comparison_validator(significance_threshold: float = 0.05, confidence_level: float = 0.95) -> ComparisonValidator: """ Create a comparison validator instance. Args: significance_threshold: P-value threshold for significance confidence_level: Confidence level for tests Returns: ComparisonValidator: Configured validator """ return ComparisonValidator(significance_threshold, confidence_level)