ALM-2 / backend /experiments /comparison_validator.py
ACA050's picture
Upload 520 files
2ed8996 verified
"""
Comparison Validator for AegisLM Framework
Production-grade validation for experiment comparison results with
statistical significance testing, accuracy verification, and comprehensive reporting.
"""
import numpy as np
from scipy import stats
from typing import Dict, Any, List, Optional, Tuple
from datetime import datetime
from dataclasses import dataclass
from pathlib import Path
import sys
import logging
# Add parent directory to path for imports
current_dir = Path(__file__).parent
backend_dir = current_dir.parent
if str(backend_dir) not in sys.path:
sys.path.insert(0, str(backend_dir))
logger = logging.getLogger(__name__)
@dataclass
class ValidationWarning:
"""Validation warning information."""
message: str
severity: str # "low", "medium", "high"
metric_name: Optional[str] = None
recommendation: Optional[str] = None
@dataclass
class ValidationError:
"""Validation error information."""
message: str
error_type: str # "data_error", "statistical_error", "logic_error"
metric_name: Optional[str] = None
fix_required: bool = True
@dataclass
class StatisticalTest:
"""Statistical test result."""
test_name: str
statistic: float
p_value: float
is_significant: bool
confidence_level: float
effect_size: Optional[float] = None
@dataclass
class MetricComparison:
"""Metric comparison result."""
metric_name: str
baseline_value: float
comparison_value: float
absolute_difference: float
relative_difference: float
statistical_test: Optional[StatisticalTest] = None
is_significant: bool = False
confidence_interval: Optional[Tuple[float, float]] = None
@dataclass
class ComparisonValidationReport:
"""Comprehensive comparison validation report."""
is_valid: bool
overall_confidence: float
validation_errors: List[ValidationError]
validation_warnings: List[ValidationWarning]
metric_comparisons: List[MetricComparison]
statistical_tests: List[StatisticalTest]
validation_timestamp: datetime
recommendations: List[str]
class ComparisonValidator:
"""
Validate experiment comparison results for accuracy and statistical significance.
Provides comprehensive validation including statistical testing,
accuracy verification, and actionable recommendations.
"""
def __init__(self, significance_threshold: float = 0.05, confidence_level: float = 0.95):
"""
Initialize comparison validator.
Args:
significance_threshold: P-value threshold for statistical significance
confidence_level: Confidence level for statistical tests
"""
self.significance_threshold = significance_threshold
self.confidence_level = confidence_level
async def validate_comparison(self, baseline_results: Dict[str, Any],
comparison_results: Dict[str, Any]) -> ComparisonValidationReport:
"""
Validate comparison results for statistical significance and accuracy.
Args:
baseline_results: Baseline experiment results
comparison_results: Comparison experiment results
Returns:
ComparisonValidationReport: Comprehensive validation report
"""
validation_errors = []
validation_warnings = []
metric_comparisons = []
statistical_tests = []
try:
# Validate result completeness
completeness_errors = self._validate_result_completeness(baseline_results, comparison_results)
validation_errors.extend(completeness_errors)
# Perform metric comparisons
baseline_metrics = baseline_results.get('metrics', {})
comparison_metrics = comparison_results.get('metrics', {})
for metric_name in baseline_metrics:
if metric_name in comparison_metrics:
comparison_result = await self._compare_metric(
metric_name,
baseline_metrics[metric_name],
comparison_metrics[metric_name]
)
metric_comparisons.append(comparison_result)
if comparison_result.statistical_test:
statistical_tests.append(comparison_result.statistical_test)
else:
validation_errors.append(ValidationError(
message=f"Missing metric in comparison: {metric_name}",
error_type="data_error",
metric_name=metric_name
))
# Validate overall comparison logic
logic_errors = self._validate_comparison_logic(baseline_results, comparison_results)
validation_errors.extend(logic_errors)
# Generate warnings based on results
warnings = self._generate_warnings(metric_comparisons)
validation_warnings.extend(warnings)
# Calculate overall confidence
overall_confidence = self._calculate_overall_confidence(metric_comparisons, validation_errors)
# Generate recommendations
recommendations = self._generate_recommendations(metric_comparisons, validation_errors, validation_warnings)
# Determine overall validity
is_valid = len(validation_errors) == 0 and overall_confidence >= 0.7
return ComparisonValidationReport(
is_valid=is_valid,
overall_confidence=overall_confidence,
validation_errors=validation_errors,
validation_warnings=validation_warnings,
metric_comparisons=metric_comparisons,
statistical_tests=statistical_tests,
validation_timestamp=datetime.utcnow(),
recommendations=recommendations
)
except Exception as e:
logger.error(f"Comparison validation failed: {e}")
return ComparisonValidationReport(
is_valid=False,
overall_confidence=0.0,
validation_errors=[ValidationError(
message=f"Validation failed: {str(e)}",
error_type="statistical_error",
fix_required=True
)],
validation_warnings=[],
metric_comparisons=[],
statistical_tests=[],
validation_timestamp=datetime.utcnow(),
recommendations=["Fix validation errors before proceeding"]
)
async def validate_trend_analysis(self, trend_data: List[Dict[str, Any]]) -> ComparisonValidationReport:
"""
Validate trend analysis results.
Args:
trend_data: List of trend data points
Returns:
ComparisonValidationReport: Trend validation report
"""
validation_errors = []
validation_warnings = []
metric_comparisons = []
statistical_tests = []
try:
# Validate trend data completeness
if len(trend_data) < 3:
validation_errors.append(ValidationError(
message="Insufficient data points for trend analysis (minimum 3 required)",
error_type="data_error",
fix_required=True
))
# Analyze each metric's trend
if trend_data:
metrics = trend_data[0].get('metrics', {})
for metric_name in metrics:
trend_result = await self._analyze_metric_trend(metric_name, trend_data)
metric_comparisons.append(trend_result)
if trend_result.statistical_test:
statistical_tests.append(trend_result.statistical_test)
# Generate trend-specific warnings
trend_warnings = self._generate_trend_warnings(metric_comparisons)
validation_warnings.extend(trend_warnings)
# Calculate overall confidence
overall_confidence = self._calculate_overall_confidence(metric_comparisons, validation_errors)
# Generate recommendations
recommendations = self._generate_trend_recommendations(metric_comparisons, validation_errors)
# Determine overall validity
is_valid = len(validation_errors) == 0 and overall_confidence >= 0.6
return ComparisonValidationReport(
is_valid=is_valid,
overall_confidence=overall_confidence,
validation_errors=validation_errors,
validation_warnings=validation_warnings,
metric_comparisons=metric_comparisons,
statistical_tests=statistical_tests,
validation_timestamp=datetime.utcnow(),
recommendations=recommendations
)
except Exception as e:
logger.error(f"Trend validation failed: {e}")
return ComparisonValidationReport(
is_valid=False,
overall_confidence=0.0,
validation_errors=[ValidationError(
message=f"Trend validation failed: {str(e)}",
error_type="statistical_error",
fix_required=True
)],
validation_warnings=[],
metric_comparisons=[],
statistical_tests=[],
validation_timestamp=datetime.utcnow(),
recommendations=["Fix validation errors before proceeding"]
)
async def _compare_metric(self, metric_name: str, baseline_data: Dict[str, Any],
comparison_data: Dict[str, Any]) -> MetricComparison:
"""
Compare a single metric between baseline and comparison.
Args:
metric_name: Name of the metric
baseline_data: Baseline metric data
comparison_data: Comparison metric data
Returns:
MetricComparison: Comparison result
"""
# Extract values
baseline_vals = baseline_data.get('values', [])
comparison_vals = comparison_data.get('values', [])
# Handle single values
if not baseline_vals and 'value' in baseline_data:
baseline_vals = [baseline_data['value']]
if not comparison_vals and 'value' in comparison_data:
comparison_vals = [comparison_data['value']]
# Calculate basic statistics
baseline_mean = np.mean(baseline_vals) if baseline_vals else 0
comparison_mean = np.mean(comparison_vals) if comparison_vals else 0
absolute_diff = comparison_mean - baseline_mean
relative_diff = (absolute_diff / baseline_mean * 100) if baseline_mean != 0 else 0
# Statistical test
statistical_test = None
is_significant = False
confidence_interval = None
if len(baseline_vals) > 1 and len(comparison_vals) > 1:
try:
# Perform t-test
t_stat, p_value = stats.ttest_ind(baseline_vals, comparison_vals)
is_significant = p_value < self.significance_threshold
# Calculate confidence interval
pooled_std = np.sqrt(
((len(baseline_vals) - 1) * np.var(baseline_vals, ddof=1) +
(len(comparison_vals) - 1) * np.var(comparison_vals, ddof=1)) /
(len(baseline_vals) + len(comparison_vals) - 2)
)
se = pooled_std * np.sqrt(1/len(baseline_vals) + 1/len(comparison_vals))
t_critical = stats.t.ppf(1 - (1 - self.confidence_level) / 2, len(baseline_vals) + len(comparison_vals) - 2)
margin_error = t_critical * se
confidence_interval = (absolute_diff - margin_error, absolute_diff + margin_error)
# Calculate effect size (Cohen's d)
pooled_std_effect = np.sqrt(((len(baseline_vals) - 1) * np.var(baseline_vals, ddof=1) +
(len(comparison_vals) - 1) * np.var(comparison_vals, ddof=1)) /
(len(baseline_vals) + len(comparison_vals) - 2))
effect_size = absolute_diff / pooled_std_effect if pooled_std_effect != 0 else 0
statistical_test = StatisticalTest(
test_name="independent_t_test",
statistic=t_stat,
p_value=p_value,
is_significant=is_significant,
confidence_level=self.confidence_level,
effect_size=effect_size
)
except Exception as e:
logger.warning(f"Statistical test failed for {metric_name}: {e}")
return MetricComparison(
metric_name=metric_name,
baseline_value=baseline_mean,
comparison_value=comparison_mean,
absolute_difference=absolute_diff,
relative_difference=relative_diff,
statistical_test=statistical_test,
is_significant=is_significant,
confidence_interval=confidence_interval
)
async def _analyze_metric_trend(self, metric_name: str, trend_data: List[Dict[str, Any]]) -> MetricComparison:
"""
Analyze trend for a specific metric.
Args:
metric_name: Name of the metric
trend_data: Trend data points
Returns:
MetricComparison: Trend analysis result
"""
# Extract values over time
values = []
for point in trend_data:
metric_data = point.get('metrics', {}).get(metric_name, {})
if 'value' in metric_data:
values.append(metric_data['value'])
elif 'values' in metric_data:
values.append(np.mean(metric_data['values']))
if len(values) < 2:
return MetricComparison(
metric_name=metric_name,
baseline_value=values[0] if values else 0,
comparison_value=values[-1] if values else 0,
absolute_difference=0,
relative_difference=0,
is_significant=False
)
# Calculate trend
baseline_value = values[0]
comparison_value = values[-1]
absolute_diff = comparison_value - baseline_value
relative_diff = (absolute_diff / baseline_value * 100) if baseline_value != 0 else 0
# Linear regression for trend significance
x = np.arange(len(values))
y = np.array(values)
statistical_test = None
is_significant = False
try:
slope, intercept, r_value, p_value, std_err = stats.linregress(x, y)
# Test if slope is significantly different from zero
t_stat = slope / std_err if std_err != 0 else 0
p_val = 2 * (1 - stats.t.cdf(abs(t_stat), len(values) - 2))
is_significant = p_val < self.significance_threshold
statistical_test = StatisticalTest(
test_name="linear_regression",
statistic=t_stat,
p_value=p_val,
is_significant=is_significant,
confidence_level=self.confidence_level,
effect_size=r_value ** 2
)
except Exception as e:
logger.warning(f"Trend analysis failed for {metric_name}: {e}")
return MetricComparison(
metric_name=metric_name,
baseline_value=baseline_value,
comparison_value=comparison_value,
absolute_difference=absolute_diff,
relative_difference=relative_diff,
statistical_test=statistical_test,
is_significant=is_significant
)
def _validate_result_completeness(self, baseline_results: Dict[str, Any],
comparison_results: Dict[str, Any]) -> List[ValidationError]:
"""Validate completeness of comparison results."""
errors = []
# Check required fields
required_fields = ['metrics', 'timestamp', 'experiment_id']
for field in required_fields:
if field not in baseline_results:
errors.append(ValidationError(
message=f"Missing required field in baseline: {field}",
error_type="data_error",
fix_required=True
))
if field not in comparison_results:
errors.append(ValidationError(
message=f"Missing required field in comparison: {field}",
error_type="data_error",
fix_required=True
))
return errors
def _validate_comparison_logic(self, baseline_results: Dict[str, Any],
comparison_results: Dict[str, Any]) -> List[ValidationError]:
"""Validate logical consistency of comparison."""
errors = []
# Check if experiments are comparable
baseline_config = baseline_results.get('config', {})
comparison_config = comparison_results.get('config', {})
# Check model compatibility
if baseline_config.get('model_name') != comparison_config.get('model_name'):
errors.append(ValidationError(
message="Models are different - comparison may not be valid",
error_type="logic_error",
fix_required=False
))
# Check dataset compatibility
if (baseline_config.get('dataset_name') != comparison_config.get('dataset_name') or
baseline_config.get('dataset_version') != comparison_config.get('dataset_version')):
errors.append(ValidationError(
message="Datasets are different - comparison may not be valid",
error_type="logic_error",
fix_required=False
))
return errors
def _generate_warnings(self, metric_comparisons: List[MetricComparison]) -> List[ValidationWarning]:
"""Generate warnings based on comparison results."""
warnings = []
for comparison in metric_comparisons:
# Check for small differences
if abs(comparison.relative_difference) < 1.0:
warnings.append(ValidationWarning(
message=f"Very small difference detected for {comparison.metric_name}",
severity="low",
metric_name=comparison.metric_name,
recommendation="Consider if this difference is practically significant"
))
# Check for non-significant results
if comparison.statistical_test and not comparison.is_significant:
warnings.append(ValidationWarning(
message=f"No significant difference for {comparison.metric_name} (p={comparison.statistical_test.p_value:.3f})",
severity="medium",
metric_name=comparison.metric_name,
recommendation="Consider increasing sample size or effect size"
))
# Check for large effect sizes
if (comparison.statistical_test and
comparison.statistical_test.effect_size and
abs(comparison.statistical_test.effect_size) > 0.8):
warnings.append(ValidationWarning(
message=f"Large effect size for {comparison.metric_name}",
severity="medium",
metric_name=comparison.metric_name,
recommendation="Verify results are not due to outliers or data issues"
))
return warnings
def _generate_trend_warnings(self, metric_comparisons: List[MetricComparison]) -> List[ValidationWarning]:
"""Generate trend-specific warnings."""
warnings = []
for comparison in metric_comparisons:
# Check for weak trend correlation
if (comparison.statistical_test and
comparison.statistical_test.effect_size and
comparison.statistical_test.effect_size < 0.3):
warnings.append(ValidationWarning(
message=f"Weak trend correlation for {comparison.metric_name} (R² = {comparison.statistical_test.effect_size:.3f})",
severity="medium",
metric_name=comparison.metric_name,
recommendation="Consider if trend is meaningful or due to noise"
))
return warnings
def _calculate_overall_confidence(self, metric_comparisons: List[MetricComparison],
validation_errors: List[ValidationError]) -> float:
"""Calculate overall confidence in comparison results."""
if validation_errors:
return 0.0
if not metric_comparisons:
return 0.5 # Neutral confidence
# Calculate confidence based on statistical significance
significant_count = sum(1 for comp in metric_comparisons if comp.is_significant)
total_count = len(metric_comparisons)
if total_count == 0:
return 0.5
# Base confidence from significance ratio
significance_ratio = significant_count / total_count
# Adjust for effect sizes
avg_effect_size = 0
effect_sizes = [comp.statistical_test.effect_size for comp in metric_comparisons
if comp.statistical_test and comp.statistical_test.effect_size is not None]
if effect_sizes:
avg_effect_size = np.mean(effect_sizes)
# Calculate overall confidence
overall_confidence = (significance_ratio * 0.7) + (min(avg_effect_size, 1.0) * 0.3)
return min(overall_confidence, 1.0)
def _generate_recommendations(self, metric_comparisons: List[MetricComparison],
validation_errors: List[ValidationError],
validation_warnings: List[ValidationWarning]) -> List[str]:
"""Generate actionable recommendations."""
recommendations = []
if validation_errors:
recommendations.append("Fix validation errors before proceeding with analysis")
# Statistical recommendations
non_significant_metrics = [comp for comp in metric_comparisons
if comp.statistical_test and not comp.is_significant]
if non_significant_metrics:
recommendations.append("Consider increasing sample size for metrics with non-significant results")
# Effect size recommendations
large_effect_metrics = [comp for comp in metric_comparisons
if comp.statistical_test and comp.statistical_test.effect_size and
abs(comp.statistical_test.effect_size) > 0.8]
if large_effect_metrics:
recommendations.append("Verify large effect sizes are not due to data anomalies")
# General recommendations
if len(metric_comparisons) > 0:
avg_confidence = self._calculate_overall_confidence(metric_comparisons, validation_errors)
if avg_confidence < 0.7:
recommendations.append("Consider collecting more data to improve statistical power")
return recommendations
def _generate_trend_recommendations(self, metric_comparisons: List[MetricComparison],
validation_errors: List[ValidationError]) -> List[str]:
"""Generate trend-specific recommendations."""
recommendations = []
if validation_errors:
recommendations.append("Fix validation errors before proceeding with trend analysis")
# Trend strength recommendations
weak_trends = [comp for comp in metric_comparisons
if comp.statistical_test and comp.statistical_test.effect_size and
comp.statistical_test.effect_size < 0.3]
if weak_trends:
recommendations.append("Consider collecting more data points to strengthen trend analysis")
return recommendations
# Factory function
def create_comparison_validator(significance_threshold: float = 0.05, confidence_level: float = 0.95) -> ComparisonValidator:
"""
Create a comparison validator instance.
Args:
significance_threshold: P-value threshold for significance
confidence_level: Confidence level for tests
Returns:
ComparisonValidator: Configured validator
"""
return ComparisonValidator(significance_threshold, confidence_level)