| """ |
| Comparison Validator for AegisLM Framework |
| |
| Production-grade validation for experiment comparison results with |
| statistical significance testing, accuracy verification, and comprehensive reporting. |
| """ |
|
|
| import numpy as np |
| from scipy import stats |
| from typing import Dict, Any, List, Optional, Tuple |
| from datetime import datetime |
| from dataclasses import dataclass |
| from pathlib import Path |
| import sys |
| import logging |
|
|
| |
| current_dir = Path(__file__).parent |
| backend_dir = current_dir.parent |
| if str(backend_dir) not in sys.path: |
| sys.path.insert(0, str(backend_dir)) |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| @dataclass |
| class ValidationWarning: |
| """Validation warning information.""" |
| message: str |
| severity: str |
| metric_name: Optional[str] = None |
| recommendation: Optional[str] = None |
|
|
|
|
| @dataclass |
| class ValidationError: |
| """Validation error information.""" |
| message: str |
| error_type: str |
| metric_name: Optional[str] = None |
| fix_required: bool = True |
|
|
|
|
| @dataclass |
| class StatisticalTest: |
| """Statistical test result.""" |
| test_name: str |
| statistic: float |
| p_value: float |
| is_significant: bool |
| confidence_level: float |
| effect_size: Optional[float] = None |
|
|
|
|
| @dataclass |
| class MetricComparison: |
| """Metric comparison result.""" |
| metric_name: str |
| baseline_value: float |
| comparison_value: float |
| absolute_difference: float |
| relative_difference: float |
| statistical_test: Optional[StatisticalTest] = None |
| is_significant: bool = False |
| confidence_interval: Optional[Tuple[float, float]] = None |
|
|
|
|
| @dataclass |
| class ComparisonValidationReport: |
| """Comprehensive comparison validation report.""" |
| is_valid: bool |
| overall_confidence: float |
| validation_errors: List[ValidationError] |
| validation_warnings: List[ValidationWarning] |
| metric_comparisons: List[MetricComparison] |
| statistical_tests: List[StatisticalTest] |
| validation_timestamp: datetime |
| recommendations: List[str] |
|
|
|
|
| class ComparisonValidator: |
| """ |
| Validate experiment comparison results for accuracy and statistical significance. |
| |
| Provides comprehensive validation including statistical testing, |
| accuracy verification, and actionable recommendations. |
| """ |
| |
| def __init__(self, significance_threshold: float = 0.05, confidence_level: float = 0.95): |
| """ |
| Initialize comparison validator. |
| |
| Args: |
| significance_threshold: P-value threshold for statistical significance |
| confidence_level: Confidence level for statistical tests |
| """ |
| self.significance_threshold = significance_threshold |
| self.confidence_level = confidence_level |
| |
| async def validate_comparison(self, baseline_results: Dict[str, Any], |
| comparison_results: Dict[str, Any]) -> ComparisonValidationReport: |
| """ |
| Validate comparison results for statistical significance and accuracy. |
| |
| Args: |
| baseline_results: Baseline experiment results |
| comparison_results: Comparison experiment results |
| |
| Returns: |
| ComparisonValidationReport: Comprehensive validation report |
| """ |
| validation_errors = [] |
| validation_warnings = [] |
| metric_comparisons = [] |
| statistical_tests = [] |
| |
| try: |
| |
| completeness_errors = self._validate_result_completeness(baseline_results, comparison_results) |
| validation_errors.extend(completeness_errors) |
| |
| |
| baseline_metrics = baseline_results.get('metrics', {}) |
| comparison_metrics = comparison_results.get('metrics', {}) |
| |
| for metric_name in baseline_metrics: |
| if metric_name in comparison_metrics: |
| comparison_result = await self._compare_metric( |
| metric_name, |
| baseline_metrics[metric_name], |
| comparison_metrics[metric_name] |
| ) |
| metric_comparisons.append(comparison_result) |
| |
| if comparison_result.statistical_test: |
| statistical_tests.append(comparison_result.statistical_test) |
| else: |
| validation_errors.append(ValidationError( |
| message=f"Missing metric in comparison: {metric_name}", |
| error_type="data_error", |
| metric_name=metric_name |
| )) |
| |
| |
| logic_errors = self._validate_comparison_logic(baseline_results, comparison_results) |
| validation_errors.extend(logic_errors) |
| |
| |
| warnings = self._generate_warnings(metric_comparisons) |
| validation_warnings.extend(warnings) |
| |
| |
| overall_confidence = self._calculate_overall_confidence(metric_comparisons, validation_errors) |
| |
| |
| recommendations = self._generate_recommendations(metric_comparisons, validation_errors, validation_warnings) |
| |
| |
| is_valid = len(validation_errors) == 0 and overall_confidence >= 0.7 |
| |
| return ComparisonValidationReport( |
| is_valid=is_valid, |
| overall_confidence=overall_confidence, |
| validation_errors=validation_errors, |
| validation_warnings=validation_warnings, |
| metric_comparisons=metric_comparisons, |
| statistical_tests=statistical_tests, |
| validation_timestamp=datetime.utcnow(), |
| recommendations=recommendations |
| ) |
| |
| except Exception as e: |
| logger.error(f"Comparison validation failed: {e}") |
| return ComparisonValidationReport( |
| is_valid=False, |
| overall_confidence=0.0, |
| validation_errors=[ValidationError( |
| message=f"Validation failed: {str(e)}", |
| error_type="statistical_error", |
| fix_required=True |
| )], |
| validation_warnings=[], |
| metric_comparisons=[], |
| statistical_tests=[], |
| validation_timestamp=datetime.utcnow(), |
| recommendations=["Fix validation errors before proceeding"] |
| ) |
| |
| async def validate_trend_analysis(self, trend_data: List[Dict[str, Any]]) -> ComparisonValidationReport: |
| """ |
| Validate trend analysis results. |
| |
| Args: |
| trend_data: List of trend data points |
| |
| Returns: |
| ComparisonValidationReport: Trend validation report |
| """ |
| validation_errors = [] |
| validation_warnings = [] |
| metric_comparisons = [] |
| statistical_tests = [] |
| |
| try: |
| |
| if len(trend_data) < 3: |
| validation_errors.append(ValidationError( |
| message="Insufficient data points for trend analysis (minimum 3 required)", |
| error_type="data_error", |
| fix_required=True |
| )) |
| |
| |
| if trend_data: |
| metrics = trend_data[0].get('metrics', {}) |
| |
| for metric_name in metrics: |
| trend_result = await self._analyze_metric_trend(metric_name, trend_data) |
| metric_comparisons.append(trend_result) |
| |
| if trend_result.statistical_test: |
| statistical_tests.append(trend_result.statistical_test) |
| |
| |
| trend_warnings = self._generate_trend_warnings(metric_comparisons) |
| validation_warnings.extend(trend_warnings) |
| |
| |
| overall_confidence = self._calculate_overall_confidence(metric_comparisons, validation_errors) |
| |
| |
| recommendations = self._generate_trend_recommendations(metric_comparisons, validation_errors) |
| |
| |
| is_valid = len(validation_errors) == 0 and overall_confidence >= 0.6 |
| |
| return ComparisonValidationReport( |
| is_valid=is_valid, |
| overall_confidence=overall_confidence, |
| validation_errors=validation_errors, |
| validation_warnings=validation_warnings, |
| metric_comparisons=metric_comparisons, |
| statistical_tests=statistical_tests, |
| validation_timestamp=datetime.utcnow(), |
| recommendations=recommendations |
| ) |
| |
| except Exception as e: |
| logger.error(f"Trend validation failed: {e}") |
| return ComparisonValidationReport( |
| is_valid=False, |
| overall_confidence=0.0, |
| validation_errors=[ValidationError( |
| message=f"Trend validation failed: {str(e)}", |
| error_type="statistical_error", |
| fix_required=True |
| )], |
| validation_warnings=[], |
| metric_comparisons=[], |
| statistical_tests=[], |
| validation_timestamp=datetime.utcnow(), |
| recommendations=["Fix validation errors before proceeding"] |
| ) |
| |
| async def _compare_metric(self, metric_name: str, baseline_data: Dict[str, Any], |
| comparison_data: Dict[str, Any]) -> MetricComparison: |
| """ |
| Compare a single metric between baseline and comparison. |
| |
| Args: |
| metric_name: Name of the metric |
| baseline_data: Baseline metric data |
| comparison_data: Comparison metric data |
| |
| Returns: |
| MetricComparison: Comparison result |
| """ |
| |
| baseline_vals = baseline_data.get('values', []) |
| comparison_vals = comparison_data.get('values', []) |
| |
| |
| if not baseline_vals and 'value' in baseline_data: |
| baseline_vals = [baseline_data['value']] |
| if not comparison_vals and 'value' in comparison_data: |
| comparison_vals = [comparison_data['value']] |
| |
| |
| baseline_mean = np.mean(baseline_vals) if baseline_vals else 0 |
| comparison_mean = np.mean(comparison_vals) if comparison_vals else 0 |
| |
| absolute_diff = comparison_mean - baseline_mean |
| relative_diff = (absolute_diff / baseline_mean * 100) if baseline_mean != 0 else 0 |
| |
| |
| statistical_test = None |
| is_significant = False |
| confidence_interval = None |
| |
| if len(baseline_vals) > 1 and len(comparison_vals) > 1: |
| try: |
| |
| t_stat, p_value = stats.ttest_ind(baseline_vals, comparison_vals) |
| |
| is_significant = p_value < self.significance_threshold |
| |
| |
| pooled_std = np.sqrt( |
| ((len(baseline_vals) - 1) * np.var(baseline_vals, ddof=1) + |
| (len(comparison_vals) - 1) * np.var(comparison_vals, ddof=1)) / |
| (len(baseline_vals) + len(comparison_vals) - 2) |
| ) |
| se = pooled_std * np.sqrt(1/len(baseline_vals) + 1/len(comparison_vals)) |
| t_critical = stats.t.ppf(1 - (1 - self.confidence_level) / 2, len(baseline_vals) + len(comparison_vals) - 2) |
| margin_error = t_critical * se |
| confidence_interval = (absolute_diff - margin_error, absolute_diff + margin_error) |
| |
| |
| pooled_std_effect = np.sqrt(((len(baseline_vals) - 1) * np.var(baseline_vals, ddof=1) + |
| (len(comparison_vals) - 1) * np.var(comparison_vals, ddof=1)) / |
| (len(baseline_vals) + len(comparison_vals) - 2)) |
| effect_size = absolute_diff / pooled_std_effect if pooled_std_effect != 0 else 0 |
| |
| statistical_test = StatisticalTest( |
| test_name="independent_t_test", |
| statistic=t_stat, |
| p_value=p_value, |
| is_significant=is_significant, |
| confidence_level=self.confidence_level, |
| effect_size=effect_size |
| ) |
| |
| except Exception as e: |
| logger.warning(f"Statistical test failed for {metric_name}: {e}") |
| |
| return MetricComparison( |
| metric_name=metric_name, |
| baseline_value=baseline_mean, |
| comparison_value=comparison_mean, |
| absolute_difference=absolute_diff, |
| relative_difference=relative_diff, |
| statistical_test=statistical_test, |
| is_significant=is_significant, |
| confidence_interval=confidence_interval |
| ) |
| |
| async def _analyze_metric_trend(self, metric_name: str, trend_data: List[Dict[str, Any]]) -> MetricComparison: |
| """ |
| Analyze trend for a specific metric. |
| |
| Args: |
| metric_name: Name of the metric |
| trend_data: Trend data points |
| |
| Returns: |
| MetricComparison: Trend analysis result |
| """ |
| |
| values = [] |
| for point in trend_data: |
| metric_data = point.get('metrics', {}).get(metric_name, {}) |
| if 'value' in metric_data: |
| values.append(metric_data['value']) |
| elif 'values' in metric_data: |
| values.append(np.mean(metric_data['values'])) |
| |
| if len(values) < 2: |
| return MetricComparison( |
| metric_name=metric_name, |
| baseline_value=values[0] if values else 0, |
| comparison_value=values[-1] if values else 0, |
| absolute_difference=0, |
| relative_difference=0, |
| is_significant=False |
| ) |
| |
| |
| baseline_value = values[0] |
| comparison_value = values[-1] |
| absolute_diff = comparison_value - baseline_value |
| relative_diff = (absolute_diff / baseline_value * 100) if baseline_value != 0 else 0 |
| |
| |
| x = np.arange(len(values)) |
| y = np.array(values) |
| |
| statistical_test = None |
| is_significant = False |
| |
| try: |
| slope, intercept, r_value, p_value, std_err = stats.linregress(x, y) |
| |
| |
| t_stat = slope / std_err if std_err != 0 else 0 |
| p_val = 2 * (1 - stats.t.cdf(abs(t_stat), len(values) - 2)) |
| is_significant = p_val < self.significance_threshold |
| |
| statistical_test = StatisticalTest( |
| test_name="linear_regression", |
| statistic=t_stat, |
| p_value=p_val, |
| is_significant=is_significant, |
| confidence_level=self.confidence_level, |
| effect_size=r_value ** 2 |
| ) |
| |
| except Exception as e: |
| logger.warning(f"Trend analysis failed for {metric_name}: {e}") |
| |
| return MetricComparison( |
| metric_name=metric_name, |
| baseline_value=baseline_value, |
| comparison_value=comparison_value, |
| absolute_difference=absolute_diff, |
| relative_difference=relative_diff, |
| statistical_test=statistical_test, |
| is_significant=is_significant |
| ) |
| |
| def _validate_result_completeness(self, baseline_results: Dict[str, Any], |
| comparison_results: Dict[str, Any]) -> List[ValidationError]: |
| """Validate completeness of comparison results.""" |
| errors = [] |
| |
| |
| required_fields = ['metrics', 'timestamp', 'experiment_id'] |
| for field in required_fields: |
| if field not in baseline_results: |
| errors.append(ValidationError( |
| message=f"Missing required field in baseline: {field}", |
| error_type="data_error", |
| fix_required=True |
| )) |
| if field not in comparison_results: |
| errors.append(ValidationError( |
| message=f"Missing required field in comparison: {field}", |
| error_type="data_error", |
| fix_required=True |
| )) |
| |
| return errors |
| |
| def _validate_comparison_logic(self, baseline_results: Dict[str, Any], |
| comparison_results: Dict[str, Any]) -> List[ValidationError]: |
| """Validate logical consistency of comparison.""" |
| errors = [] |
| |
| |
| baseline_config = baseline_results.get('config', {}) |
| comparison_config = comparison_results.get('config', {}) |
| |
| |
| if baseline_config.get('model_name') != comparison_config.get('model_name'): |
| errors.append(ValidationError( |
| message="Models are different - comparison may not be valid", |
| error_type="logic_error", |
| fix_required=False |
| )) |
| |
| |
| if (baseline_config.get('dataset_name') != comparison_config.get('dataset_name') or |
| baseline_config.get('dataset_version') != comparison_config.get('dataset_version')): |
| errors.append(ValidationError( |
| message="Datasets are different - comparison may not be valid", |
| error_type="logic_error", |
| fix_required=False |
| )) |
| |
| return errors |
| |
| def _generate_warnings(self, metric_comparisons: List[MetricComparison]) -> List[ValidationWarning]: |
| """Generate warnings based on comparison results.""" |
| warnings = [] |
| |
| for comparison in metric_comparisons: |
| |
| if abs(comparison.relative_difference) < 1.0: |
| warnings.append(ValidationWarning( |
| message=f"Very small difference detected for {comparison.metric_name}", |
| severity="low", |
| metric_name=comparison.metric_name, |
| recommendation="Consider if this difference is practically significant" |
| )) |
| |
| |
| if comparison.statistical_test and not comparison.is_significant: |
| warnings.append(ValidationWarning( |
| message=f"No significant difference for {comparison.metric_name} (p={comparison.statistical_test.p_value:.3f})", |
| severity="medium", |
| metric_name=comparison.metric_name, |
| recommendation="Consider increasing sample size or effect size" |
| )) |
| |
| |
| if (comparison.statistical_test and |
| comparison.statistical_test.effect_size and |
| abs(comparison.statistical_test.effect_size) > 0.8): |
| warnings.append(ValidationWarning( |
| message=f"Large effect size for {comparison.metric_name}", |
| severity="medium", |
| metric_name=comparison.metric_name, |
| recommendation="Verify results are not due to outliers or data issues" |
| )) |
| |
| return warnings |
| |
| def _generate_trend_warnings(self, metric_comparisons: List[MetricComparison]) -> List[ValidationWarning]: |
| """Generate trend-specific warnings.""" |
| warnings = [] |
| |
| for comparison in metric_comparisons: |
| |
| if (comparison.statistical_test and |
| comparison.statistical_test.effect_size and |
| comparison.statistical_test.effect_size < 0.3): |
| warnings.append(ValidationWarning( |
| message=f"Weak trend correlation for {comparison.metric_name} (R² = {comparison.statistical_test.effect_size:.3f})", |
| severity="medium", |
| metric_name=comparison.metric_name, |
| recommendation="Consider if trend is meaningful or due to noise" |
| )) |
| |
| return warnings |
| |
| def _calculate_overall_confidence(self, metric_comparisons: List[MetricComparison], |
| validation_errors: List[ValidationError]) -> float: |
| """Calculate overall confidence in comparison results.""" |
| if validation_errors: |
| return 0.0 |
| |
| if not metric_comparisons: |
| return 0.5 |
| |
| |
| significant_count = sum(1 for comp in metric_comparisons if comp.is_significant) |
| total_count = len(metric_comparisons) |
| |
| if total_count == 0: |
| return 0.5 |
| |
| |
| significance_ratio = significant_count / total_count |
| |
| |
| avg_effect_size = 0 |
| effect_sizes = [comp.statistical_test.effect_size for comp in metric_comparisons |
| if comp.statistical_test and comp.statistical_test.effect_size is not None] |
| |
| if effect_sizes: |
| avg_effect_size = np.mean(effect_sizes) |
| |
| |
| overall_confidence = (significance_ratio * 0.7) + (min(avg_effect_size, 1.0) * 0.3) |
| |
| return min(overall_confidence, 1.0) |
| |
| def _generate_recommendations(self, metric_comparisons: List[MetricComparison], |
| validation_errors: List[ValidationError], |
| validation_warnings: List[ValidationWarning]) -> List[str]: |
| """Generate actionable recommendations.""" |
| recommendations = [] |
| |
| if validation_errors: |
| recommendations.append("Fix validation errors before proceeding with analysis") |
| |
| |
| non_significant_metrics = [comp for comp in metric_comparisons |
| if comp.statistical_test and not comp.is_significant] |
| |
| if non_significant_metrics: |
| recommendations.append("Consider increasing sample size for metrics with non-significant results") |
| |
| |
| large_effect_metrics = [comp for comp in metric_comparisons |
| if comp.statistical_test and comp.statistical_test.effect_size and |
| abs(comp.statistical_test.effect_size) > 0.8] |
| |
| if large_effect_metrics: |
| recommendations.append("Verify large effect sizes are not due to data anomalies") |
| |
| |
| if len(metric_comparisons) > 0: |
| avg_confidence = self._calculate_overall_confidence(metric_comparisons, validation_errors) |
| if avg_confidence < 0.7: |
| recommendations.append("Consider collecting more data to improve statistical power") |
| |
| return recommendations |
| |
| def _generate_trend_recommendations(self, metric_comparisons: List[MetricComparison], |
| validation_errors: List[ValidationError]) -> List[str]: |
| """Generate trend-specific recommendations.""" |
| recommendations = [] |
| |
| if validation_errors: |
| recommendations.append("Fix validation errors before proceeding with trend analysis") |
| |
| |
| weak_trends = [comp for comp in metric_comparisons |
| if comp.statistical_test and comp.statistical_test.effect_size and |
| comp.statistical_test.effect_size < 0.3] |
| |
| if weak_trends: |
| recommendations.append("Consider collecting more data points to strengthen trend analysis") |
| |
| return recommendations |
|
|
|
|
| |
| def create_comparison_validator(significance_threshold: float = 0.05, confidence_level: float = 0.95) -> ComparisonValidator: |
| """ |
| Create a comparison validator instance. |
| |
| Args: |
| significance_threshold: P-value threshold for significance |
| confidence_level: Confidence level for tests |
| |
| Returns: |
| ComparisonValidator: Configured validator |
| """ |
| return ComparisonValidator(significance_threshold, confidence_level) |
|
|