Spaces:
Running
Running
| """ | |
| A/B Testing Framework for Model Performance Comparison | |
| Implements statistical testing, multi-armed bandit optimization, and comprehensive experiment tracking | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import pandas as pd | |
| import logging | |
| from typing import Dict, List, Tuple, Optional, Any, Union, Callable | |
| from dataclasses import dataclass, field | |
| from datetime import datetime, timedelta | |
| import json | |
| import sqlite3 | |
| from pathlib import Path | |
| import hashlib | |
| import uuid | |
| from enum import Enum | |
| import scipy.stats as stats | |
| from scipy.stats import chi2_contingency, mannwhitneyu, ttest_ind | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import plotly.graph_objects as go | |
| import plotly.express as px | |
| from plotly.subplots import make_subplots | |
| # Multi-armed bandit algorithms | |
| from typing import Protocol | |
| import math | |
| import random | |
| logger = logging.getLogger(__name__) | |
| class ExperimentStatus(Enum): | |
| """Experiment status enumeration""" | |
| DRAFT = "draft" | |
| RUNNING = "running" | |
| PAUSED = "paused" | |
| COMPLETED = "completed" | |
| CANCELLED = "cancelled" | |
| class BanditAlgorithm(Enum): | |
| """Multi-armed bandit algorithm types""" | |
| EPSILON_GREEDY = "epsilon_greedy" | |
| UCB1 = "ucb1" | |
| THOMPSON_SAMPLING = "thompson_sampling" | |
| LINUCB = "linucb" | |
| class ModelVariant: | |
| """A/B test model variant definition""" | |
| variant_id: str | |
| model: nn.Module | |
| model_path: Optional[str] = None | |
| description: str = "" | |
| hyperparameters: Dict[str, Any] = field(default_factory=dict) | |
| preprocessing_config: Dict[str, Any] = field(default_factory=dict) | |
| postprocessing_config: Dict[str, Any] = field(default_factory=dict) | |
| metadata: Dict[str, Any] = field(default_factory=dict) | |
| class ExperimentMetric: | |
| """Experiment metric definition""" | |
| name: str | |
| description: str | |
| metric_type: str # "accuracy", "latency", "throughput", "memory", "custom" | |
| target_value: Optional[float] = None | |
| higher_is_better: bool = True | |
| weight: float = 1.0 | |
| threshold: Optional[float] = None | |
| class ExperimentConfig: | |
| """A/B test experiment configuration""" | |
| experiment_id: str | |
| name: str | |
| description: str | |
| variants: List[ModelVariant] | |
| metrics: List[ExperimentMetric] | |
| traffic_split: Dict[str, float] # variant_id -> traffic percentage | |
| min_sample_size: int = 1000 | |
| max_duration_days: int = 30 | |
| confidence_level: float = 0.95 | |
| minimum_detectable_effect: float = 0.05 | |
| bandit_algorithm: BanditAlgorithm = BanditAlgorithm.EPSILON_GREEDY | |
| bandit_config: Dict[str, Any] = field(default_factory=dict) | |
| auto_stop_winning_threshold: float = 0.99 | |
| auto_stop_losing_threshold: float = 0.01 | |
| class TestResult: | |
| """Individual test result""" | |
| result_id: str | |
| experiment_id: str | |
| variant_id: str | |
| user_id: Optional[str] | |
| session_id: str | |
| timestamp: datetime | |
| metrics: Dict[str, float] | |
| metadata: Dict[str, Any] = field(default_factory=dict) | |
| processing_time_ms: float = 0.0 | |
| class StatisticalTestResult: | |
| """Statistical test result""" | |
| test_name: str | |
| statistic: float | |
| p_value: float | |
| confidence_interval: Tuple[float, float] | |
| effect_size: float | |
| is_significant: bool | |
| power: float | |
| interpretation: str | |
| class ExperimentSummary: | |
| """Experiment summary with statistical analysis""" | |
| experiment_id: str | |
| status: ExperimentStatus | |
| start_time: datetime | |
| end_time: Optional[datetime] | |
| total_samples: int | |
| variant_samples: Dict[str, int] | |
| variant_metrics: Dict[str, Dict[str, float]] | |
| statistical_tests: Dict[str, List[StatisticalTestResult]] | |
| confidence_intervals: Dict[str, Dict[str, Tuple[float, float]]] | |
| best_variant: Optional[str] | |
| recommendation: str | |
| bandit_performance: Dict[str, float] | |
| class BanditStrategy: | |
| """Base class for multi-armed bandit strategies""" | |
| def __init__(self, variants: List[str], config: Dict[str, Any]): | |
| self.variants = variants | |
| self.config = config | |
| self.reset() | |
| def reset(self): | |
| """Reset bandit state""" | |
| self.counts = {variant: 0 for variant in self.variants} | |
| self.rewards = {variant: 0.0 for variant in self.variants} | |
| self.total_count = 0 | |
| def select_variant(self) -> str: | |
| """Select next variant to test""" | |
| raise NotImplementedError | |
| def update_reward(self, variant: str, reward: float): | |
| """Update reward for selected variant""" | |
| self.counts[variant] += 1 | |
| self.rewards[variant] += reward | |
| self.total_count += 1 | |
| def get_variant_stats(self) -> Dict[str, Dict[str, float]]: | |
| """Get statistics for all variants""" | |
| stats = {} | |
| for variant in self.variants: | |
| if self.counts[variant] > 0: | |
| avg_reward = self.rewards[variant] / self.counts[variant] | |
| else: | |
| avg_reward = 0.0 | |
| stats[variant] = { | |
| 'count': self.counts[variant], | |
| 'total_reward': self.rewards[variant], | |
| 'average_reward': avg_reward, | |
| 'selection_probability': self.counts[variant] / max(self.total_count, 1) | |
| } | |
| return stats | |
| class EpsilonGreedyBandit(BanditStrategy): | |
| """Epsilon-greedy bandit algorithm""" | |
| def __init__(self, variants: List[str], config: Dict[str, Any]): | |
| super().__init__(variants, config) | |
| self.epsilon = config.get('epsilon', 0.1) | |
| self.decay_rate = config.get('decay_rate', 0.99) | |
| def select_variant(self) -> str: | |
| if random.random() < self.epsilon: | |
| # Explore: random selection | |
| return random.choice(self.variants) | |
| else: | |
| # Exploit: select best variant | |
| best_variant = self.variants[0] | |
| best_avg = 0.0 | |
| for variant in self.variants: | |
| if self.counts[variant] > 0: | |
| avg_reward = self.rewards[variant] / self.counts[variant] | |
| if avg_reward > best_avg: | |
| best_avg = avg_reward | |
| best_variant = variant | |
| return best_variant | |
| def update_reward(self, variant: str, reward: float): | |
| super().update_reward(variant, reward) | |
| # Decay epsilon over time | |
| self.epsilon *= self.decay_rate | |
| class UCB1Bandit(BanditStrategy): | |
| """Upper Confidence Bound (UCB1) bandit algorithm""" | |
| def __init__(self, variants: List[str], config: Dict[str, Any]): | |
| super().__init__(variants, config) | |
| self.exploration_factor = config.get('exploration_factor', 2.0) | |
| def select_variant(self) -> str: | |
| # Select unplayed variants first | |
| for variant in self.variants: | |
| if self.counts[variant] == 0: | |
| return variant | |
| # Calculate UCB values | |
| ucb_values = {} | |
| for variant in self.variants: | |
| avg_reward = self.rewards[variant] / self.counts[variant] | |
| confidence_interval = math.sqrt( | |
| (self.exploration_factor * math.log(self.total_count)) / self.counts[variant] | |
| ) | |
| ucb_values[variant] = avg_reward + confidence_interval | |
| # Select variant with highest UCB value | |
| return max(ucb_values, key=ucb_values.get) | |
| class ThompsonSamplingBandit(BanditStrategy): | |
| """Thompson Sampling bandit algorithm""" | |
| def __init__(self, variants: List[str], config: Dict[str, Any]): | |
| super().__init__(variants, config) | |
| self.alpha = {variant: 1.0 for variant in variants} # Success parameter | |
| self.beta = {variant: 1.0 for variant in variants} # Failure parameter | |
| def select_variant(self) -> str: | |
| # Sample from Beta distribution for each variant | |
| samples = {} | |
| for variant in self.variants: | |
| samples[variant] = np.random.beta(self.alpha[variant], self.beta[variant]) | |
| # Select variant with highest sample | |
| return max(samples, key=samples.get) | |
| def update_reward(self, variant: str, reward: float): | |
| super().update_reward(variant, reward) | |
| # Update Beta distribution parameters | |
| # Assuming reward is between 0 and 1 | |
| self.alpha[variant] += reward | |
| self.beta[variant] += (1 - reward) | |
| class ABTestingFramework: | |
| """ | |
| Comprehensive A/B testing framework for model comparison | |
| """ | |
| def __init__(self, db_path: str = "ab_testing.db"): | |
| self.db_path = db_path | |
| self._init_database() | |
| self.active_experiments = {} | |
| self.bandit_strategies = {} | |
| def _init_database(self): | |
| """Initialize SQLite database for experiment tracking""" | |
| conn = sqlite3.connect(self.db_path) | |
| cursor = conn.cursor() | |
| # Experiments table | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS experiments ( | |
| experiment_id TEXT PRIMARY KEY, | |
| name TEXT NOT NULL, | |
| description TEXT, | |
| config TEXT, | |
| status TEXT, | |
| start_time TIMESTAMP, | |
| end_time TIMESTAMP, | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
| ) | |
| """) | |
| # Test results table | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS test_results ( | |
| result_id TEXT PRIMARY KEY, | |
| experiment_id TEXT, | |
| variant_id TEXT, | |
| user_id TEXT, | |
| session_id TEXT, | |
| timestamp TIMESTAMP, | |
| metrics TEXT, | |
| metadata TEXT, | |
| processing_time_ms REAL, | |
| FOREIGN KEY (experiment_id) REFERENCES experiments (experiment_id) | |
| ) | |
| """) | |
| # Statistical results table | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS statistical_results ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| experiment_id TEXT, | |
| test_name TEXT, | |
| variant_a TEXT, | |
| variant_b TEXT, | |
| statistic REAL, | |
| p_value REAL, | |
| effect_size REAL, | |
| confidence_interval TEXT, | |
| is_significant BOOLEAN, | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
| FOREIGN KEY (experiment_id) REFERENCES experiments (experiment_id) | |
| ) | |
| """) | |
| conn.commit() | |
| conn.close() | |
| def create_experiment(self, config: ExperimentConfig) -> str: | |
| """Create a new A/B test experiment""" | |
| # Validate configuration | |
| self._validate_experiment_config(config) | |
| # Store experiment in database | |
| conn = sqlite3.connect(self.db_path) | |
| cursor = conn.cursor() | |
| cursor.execute(""" | |
| INSERT INTO experiments (experiment_id, name, description, config, status, start_time) | |
| VALUES (?, ?, ?, ?, ?, ?) | |
| """, ( | |
| config.experiment_id, | |
| config.name, | |
| config.description, | |
| json.dumps(config.__dict__, default=str), | |
| ExperimentStatus.DRAFT.value, | |
| datetime.now() | |
| )) | |
| conn.commit() | |
| conn.close() | |
| logger.info(f"Created experiment {config.experiment_id}: {config.name}") | |
| return config.experiment_id | |
| def start_experiment(self, experiment_id: str) -> bool: | |
| """Start an A/B test experiment""" | |
| # Load experiment configuration | |
| config = self._load_experiment_config(experiment_id) | |
| if not config: | |
| logger.error(f"Experiment {experiment_id} not found") | |
| return False | |
| # Initialize bandit strategy | |
| variant_ids = [v.variant_id for v in config.variants] | |
| if config.bandit_algorithm == BanditAlgorithm.EPSILON_GREEDY: | |
| self.bandit_strategies[experiment_id] = EpsilonGreedyBandit(variant_ids, config.bandit_config) | |
| elif config.bandit_algorithm == BanditAlgorithm.UCB1: | |
| self.bandit_strategies[experiment_id] = UCB1Bandit(variant_ids, config.bandit_config) | |
| elif config.bandit_algorithm == BanditAlgorithm.THOMPSON_SAMPLING: | |
| self.bandit_strategies[experiment_id] = ThompsonSamplingBandit(variant_ids, config.bandit_config) | |
| # Update experiment status | |
| self._update_experiment_status(experiment_id, ExperimentStatus.RUNNING) | |
| # Store active experiment | |
| self.active_experiments[experiment_id] = config | |
| logger.info(f"Started experiment {experiment_id}") | |
| return True | |
| def assign_variant(self, experiment_id: str, user_id: Optional[str] = None) -> Optional[str]: | |
| """ | |
| Assign a variant to a user for the experiment | |
| Args: | |
| experiment_id: Experiment ID | |
| user_id: Optional user ID for consistent assignment | |
| Returns: | |
| Assigned variant ID or None if experiment not found | |
| """ | |
| if experiment_id not in self.active_experiments: | |
| logger.warning(f"Experiment {experiment_id} is not active") | |
| return None | |
| config = self.active_experiments[experiment_id] | |
| # Use bandit algorithm if enabled | |
| if experiment_id in self.bandit_strategies: | |
| return self.bandit_strategies[experiment_id].select_variant() | |
| # Use static traffic split | |
| variant_ids = list(config.traffic_split.keys()) | |
| weights = list(config.traffic_split.values()) | |
| # Consistent assignment for same user | |
| if user_id: | |
| # Use hash of user_id for deterministic assignment | |
| hash_value = int(hashlib.md5(f"{experiment_id}_{user_id}".encode()).hexdigest(), 16) | |
| random.seed(hash_value) | |
| selected_variant = np.random.choice(variant_ids, p=weights) | |
| return selected_variant | |
| def record_result( | |
| self, | |
| experiment_id: str, | |
| variant_id: str, | |
| metrics: Dict[str, float], | |
| user_id: Optional[str] = None, | |
| session_id: Optional[str] = None, | |
| metadata: Optional[Dict[str, Any]] = None | |
| ) -> str: | |
| """ | |
| Record a test result | |
| Args: | |
| experiment_id: Experiment ID | |
| variant_id: Variant ID | |
| metrics: Dictionary of metric values | |
| user_id: Optional user ID | |
| session_id: Optional session ID | |
| metadata: Optional metadata | |
| Returns: | |
| Result ID | |
| """ | |
| result_id = str(uuid.uuid4()) | |
| timestamp = datetime.now() | |
| if session_id is None: | |
| session_id = str(uuid.uuid4()) | |
| if metadata is None: | |
| metadata = {} | |
| # Store result in database | |
| conn = sqlite3.connect(self.db_path) | |
| cursor = conn.cursor() | |
| cursor.execute(""" | |
| INSERT INTO test_results | |
| (result_id, experiment_id, variant_id, user_id, session_id, timestamp, metrics, metadata, processing_time_ms) | |
| VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) | |
| """, ( | |
| result_id, | |
| experiment_id, | |
| variant_id, | |
| user_id, | |
| session_id, | |
| timestamp, | |
| json.dumps(metrics), | |
| json.dumps(metadata), | |
| 0.0 # Will be updated if provided | |
| )) | |
| conn.commit() | |
| conn.close() | |
| # Update bandit strategy if applicable | |
| if experiment_id in self.bandit_strategies: | |
| # Calculate reward based on primary metric | |
| config = self.active_experiments[experiment_id] | |
| primary_metric = config.metrics[0] # Assume first metric is primary | |
| if primary_metric.name in metrics: | |
| reward = metrics[primary_metric.name] | |
| # Normalize reward to [0, 1] if needed | |
| if primary_metric.target_value: | |
| reward = min(reward / primary_metric.target_value, 1.0) | |
| self.bandit_strategies[experiment_id].update_reward(variant_id, reward) | |
| logger.debug(f"Recorded result {result_id} for experiment {experiment_id}, variant {variant_id}") | |
| return result_id | |
| def run_statistical_analysis(self, experiment_id: str) -> ExperimentSummary: | |
| """ | |
| Run comprehensive statistical analysis on experiment results | |
| Args: | |
| experiment_id: Experiment ID | |
| Returns: | |
| Experiment summary with statistical analysis | |
| """ | |
| # Load experiment data | |
| config = self._load_experiment_config(experiment_id) | |
| results_df = self._load_experiment_results(experiment_id) | |
| if results_df.empty: | |
| logger.warning(f"No results found for experiment {experiment_id}") | |
| return self._create_empty_summary(experiment_id) | |
| # Calculate basic statistics | |
| variant_samples = results_df['variant_id'].value_counts().to_dict() | |
| total_samples = len(results_df) | |
| # Calculate variant metrics | |
| variant_metrics = {} | |
| for variant_id in variant_samples.keys(): | |
| variant_data = results_df[results_df['variant_id'] == variant_id] | |
| variant_metrics[variant_id] = self._calculate_variant_metrics(variant_data, config.metrics) | |
| # Run statistical tests | |
| statistical_tests = self._run_statistical_tests(results_df, config.metrics) | |
| # Calculate confidence intervals | |
| confidence_intervals = self._calculate_confidence_intervals(results_df, config.metrics, config.confidence_level) | |
| # Determine best variant | |
| best_variant = self._determine_best_variant(variant_metrics, config.metrics) | |
| # Generate recommendation | |
| recommendation = self._generate_recommendation(statistical_tests, variant_metrics, config) | |
| # Get bandit performance | |
| bandit_performance = {} | |
| if experiment_id in self.bandit_strategies: | |
| bandit_performance = self.bandit_strategies[experiment_id].get_variant_stats() | |
| # Create summary | |
| summary = ExperimentSummary( | |
| experiment_id=experiment_id, | |
| status=self._get_experiment_status(experiment_id), | |
| start_time=config.start_time if hasattr(config, 'start_time') else datetime.now(), | |
| end_time=None, | |
| total_samples=total_samples, | |
| variant_samples=variant_samples, | |
| variant_metrics=variant_metrics, | |
| statistical_tests=statistical_tests, | |
| confidence_intervals=confidence_intervals, | |
| best_variant=best_variant, | |
| recommendation=recommendation, | |
| bandit_performance=bandit_performance | |
| ) | |
| return summary | |
| def stop_experiment(self, experiment_id: str, reason: str = "Manual stop") -> bool: | |
| """Stop an A/B test experiment""" | |
| if experiment_id not in self.active_experiments: | |
| logger.warning(f"Experiment {experiment_id} is not active") | |
| return False | |
| # Update experiment status | |
| self._update_experiment_status(experiment_id, ExperimentStatus.COMPLETED) | |
| # Remove from active experiments | |
| del self.active_experiments[experiment_id] | |
| # Clean up bandit strategy | |
| if experiment_id in self.bandit_strategies: | |
| del self.bandit_strategies[experiment_id] | |
| logger.info(f"Stopped experiment {experiment_id}. Reason: {reason}") | |
| return True | |
| def auto_check_experiments(self): | |
| """Automatically check experiments for early stopping conditions""" | |
| for experiment_id in list(self.active_experiments.keys()): | |
| config = self.active_experiments[experiment_id] | |
| summary = self.run_statistical_analysis(experiment_id) | |
| # Check for auto-stop conditions | |
| should_stop, reason = self._check_auto_stop_conditions(summary, config) | |
| if should_stop: | |
| self.stop_experiment(experiment_id, reason) | |
| def generate_report(self, experiment_id: str, output_path: Optional[str] = None) -> Dict[str, Any]: | |
| """ | |
| Generate comprehensive experiment report | |
| Args: | |
| experiment_id: Experiment ID | |
| output_path: Optional path to save report | |
| Returns: | |
| Report dictionary | |
| """ | |
| summary = self.run_statistical_analysis(experiment_id) | |
| # Create visualizations | |
| visualizations = self._create_visualizations(experiment_id, summary) | |
| # Generate report | |
| report = { | |
| 'experiment_summary': summary, | |
| 'detailed_analysis': self._generate_detailed_analysis(summary), | |
| 'visualizations': visualizations, | |
| 'recommendations': self._generate_detailed_recommendations(summary), | |
| 'next_steps': self._suggest_next_steps(summary) | |
| } | |
| # Save report if path provided | |
| if output_path: | |
| self._save_report(report, output_path) | |
| return report | |
| def _validate_experiment_config(self, config: ExperimentConfig): | |
| """Validate experiment configuration""" | |
| # Check traffic split sums to 1.0 | |
| total_traffic = sum(config.traffic_split.values()) | |
| if abs(total_traffic - 1.0) > 0.001: | |
| raise ValueError(f"Traffic split must sum to 1.0, got {total_traffic}") | |
| # Check all variants have traffic allocation | |
| variant_ids = {v.variant_id for v in config.variants} | |
| traffic_variants = set(config.traffic_split.keys()) | |
| if variant_ids != traffic_variants: | |
| raise ValueError("Variant IDs in variants and traffic_split must match") | |
| # Validate metrics | |
| if not config.metrics: | |
| raise ValueError("At least one metric must be defined") | |
| def _load_experiment_config(self, experiment_id: str) -> Optional[ExperimentConfig]: | |
| """Load experiment configuration from database""" | |
| conn = sqlite3.connect(self.db_path) | |
| cursor = conn.cursor() | |
| cursor.execute("SELECT config FROM experiments WHERE experiment_id = ?", (experiment_id,)) | |
| result = cursor.fetchone() | |
| conn.close() | |
| if result: | |
| config_dict = json.loads(result[0]) | |
| # Would need to reconstruct ExperimentConfig from dict | |
| # This is simplified for the example | |
| return config_dict | |
| return None | |
| def _load_experiment_results(self, experiment_id: str) -> pd.DataFrame: | |
| """Load experiment results from database""" | |
| conn = sqlite3.connect(self.db_path) | |
| query = """ | |
| SELECT result_id, variant_id, user_id, session_id, timestamp, | |
| metrics, metadata, processing_time_ms | |
| FROM test_results | |
| WHERE experiment_id = ? | |
| ORDER BY timestamp | |
| """ | |
| df = pd.read_sql_query(query, conn, params=(experiment_id,)) | |
| conn.close() | |
| # Parse JSON columns | |
| if not df.empty: | |
| df['metrics'] = df['metrics'].apply(json.loads) | |
| df['metadata'] = df['metadata'].apply(json.loads) | |
| df['timestamp'] = pd.to_datetime(df['timestamp']) | |
| return df | |
| def _calculate_variant_metrics(self, variant_data: pd.DataFrame, metrics: List[ExperimentMetric]) -> Dict[str, float]: | |
| """Calculate metrics for a variant""" | |
| result = {} | |
| for metric in metrics: | |
| metric_values = [] | |
| for _, row in variant_data.iterrows(): | |
| if metric.name in row['metrics']: | |
| metric_values.append(row['metrics'][metric.name]) | |
| if metric_values: | |
| result[f"{metric.name}_mean"] = np.mean(metric_values) | |
| result[f"{metric.name}_std"] = np.std(metric_values) | |
| result[f"{metric.name}_count"] = len(metric_values) | |
| result[f"{metric.name}_median"] = np.median(metric_values) | |
| if metric.metric_type == "accuracy": | |
| result[f"{metric.name}_min"] = np.min(metric_values) | |
| result[f"{metric.name}_max"] = np.max(metric_values) | |
| return result | |
| def _run_statistical_tests(self, results_df: pd.DataFrame, metrics: List[ExperimentMetric]) -> Dict[str, List[StatisticalTestResult]]: | |
| """Run statistical tests comparing variants""" | |
| tests = {} | |
| variants = results_df['variant_id'].unique() | |
| for metric in metrics: | |
| tests[metric.name] = [] | |
| # Pairwise comparisons between variants | |
| for i, variant_a in enumerate(variants): | |
| for variant_b in variants[i+1:]: | |
| # Extract metric values for both variants | |
| values_a = [] | |
| values_b = [] | |
| for _, row in results_df.iterrows(): | |
| if row['variant_id'] == variant_a and metric.name in row['metrics']: | |
| values_a.append(row['metrics'][metric.name]) | |
| elif row['variant_id'] == variant_b and metric.name in row['metrics']: | |
| values_b.append(row['metrics'][metric.name]) | |
| if len(values_a) > 10 and len(values_b) > 10: # Minimum sample size | |
| # Choose appropriate test based on metric type | |
| if metric.metric_type == "accuracy": | |
| # Use t-test for continuous metrics | |
| statistic, p_value = ttest_ind(values_a, values_b) | |
| test_name = "t-test" | |
| else: | |
| # Use Mann-Whitney U test for non-parametric data | |
| statistic, p_value = mannwhitneyu(values_a, values_b, alternative='two-sided') | |
| test_name = "Mann-Whitney U" | |
| # Calculate effect size (Cohen's d) | |
| pooled_std = np.sqrt(((len(values_a) - 1) * np.var(values_a) + | |
| (len(values_b) - 1) * np.var(values_b)) / | |
| (len(values_a) + len(values_b) - 2)) | |
| if pooled_std > 0: | |
| cohens_d = (np.mean(values_a) - np.mean(values_b)) / pooled_std | |
| else: | |
| cohens_d = 0.0 | |
| # Calculate confidence interval | |
| conf_int = self._calculate_mean_difference_ci(values_a, values_b) | |
| # Determine significance | |
| alpha = 1 - 0.95 # Assuming 95% confidence level | |
| is_significant = p_value < alpha | |
| # Calculate statistical power (simplified) | |
| power = self._calculate_statistical_power(len(values_a), len(values_b), cohens_d, alpha) | |
| # Generate interpretation | |
| interpretation = self._interpret_test_result( | |
| variant_a, variant_b, statistic, p_value, cohens_d, is_significant | |
| ) | |
| test_result = StatisticalTestResult( | |
| test_name=f"{test_name}_{variant_a}_vs_{variant_b}", | |
| statistic=statistic, | |
| p_value=p_value, | |
| confidence_interval=conf_int, | |
| effect_size=cohens_d, | |
| is_significant=is_significant, | |
| power=power, | |
| interpretation=interpretation | |
| ) | |
| tests[metric.name].append(test_result) | |
| return tests | |
| def _calculate_confidence_intervals( | |
| self, | |
| results_df: pd.DataFrame, | |
| metrics: List[ExperimentMetric], | |
| confidence_level: float | |
| ) -> Dict[str, Dict[str, Tuple[float, float]]]: | |
| """Calculate confidence intervals for metrics by variant""" | |
| intervals = {} | |
| variants = results_df['variant_id'].unique() | |
| for metric in metrics: | |
| intervals[metric.name] = {} | |
| for variant in variants: | |
| values = [] | |
| for _, row in results_df.iterrows(): | |
| if row['variant_id'] == variant and metric.name in row['metrics']: | |
| values.append(row['metrics'][metric.name]) | |
| if len(values) > 1: | |
| mean = np.mean(values) | |
| sem = stats.sem(values) # Standard error of mean | |
| h = sem * stats.t.ppf((1 + confidence_level) / 2., len(values) - 1) | |
| intervals[metric.name][variant] = (mean - h, mean + h) | |
| else: | |
| intervals[metric.name][variant] = (0.0, 0.0) | |
| return intervals | |
| def _determine_best_variant(self, variant_metrics: Dict[str, Dict[str, float]], metrics: List[ExperimentMetric]) -> Optional[str]: | |
| """Determine the best performing variant""" | |
| if not variant_metrics: | |
| return None | |
| # Use weighted scoring based on primary metric | |
| primary_metric = metrics[0] # Assume first metric is primary | |
| metric_key = f"{primary_metric.name}_mean" | |
| best_variant = None | |
| best_score = float('-inf') if primary_metric.higher_is_better else float('inf') | |
| for variant_id, variant_data in variant_metrics.items(): | |
| if metric_key in variant_data: | |
| score = variant_data[metric_key] | |
| if primary_metric.higher_is_better and score > best_score: | |
| best_score = score | |
| best_variant = variant_id | |
| elif not primary_metric.higher_is_better and score < best_score: | |
| best_score = score | |
| best_variant = variant_id | |
| return best_variant | |
| def _generate_recommendation( | |
| self, | |
| statistical_tests: Dict[str, List[StatisticalTestResult]], | |
| variant_metrics: Dict[str, Dict[str, float]], | |
| config: ExperimentConfig | |
| ) -> str: | |
| """Generate experiment recommendation""" | |
| # Count significant results | |
| significant_tests = 0 | |
| total_tests = 0 | |
| for metric_tests in statistical_tests.values(): | |
| for test in metric_tests: | |
| total_tests += 1 | |
| if test.is_significant: | |
| significant_tests += 1 | |
| if total_tests == 0: | |
| return "Insufficient data for recommendation" | |
| significance_ratio = significant_tests / total_tests | |
| if significance_ratio > 0.5: | |
| return "Strong evidence of variant differences. Recommend deploying best variant." | |
| elif significance_ratio > 0.2: | |
| return "Some evidence of variant differences. Consider extending experiment." | |
| else: | |
| return "No strong evidence of variant differences. Current variant can be maintained." | |
| def _check_auto_stop_conditions(self, summary: ExperimentSummary, config: ExperimentConfig) -> Tuple[bool, str]: | |
| """Check if experiment should be automatically stopped""" | |
| # Check minimum sample size | |
| if summary.total_samples < config.min_sample_size: | |
| return False, "" | |
| # Check duration | |
| if summary.start_time: | |
| duration = datetime.now() - summary.start_time | |
| if duration.days >= config.max_duration_days: | |
| return True, f"Maximum duration reached ({config.max_duration_days} days)" | |
| # Check for winning variant | |
| if summary.statistical_tests: | |
| # Simplified check for clear winner | |
| significant_wins = 0 | |
| total_comparisons = 0 | |
| for metric_tests in summary.statistical_tests.values(): | |
| for test in metric_tests: | |
| total_comparisons += 1 | |
| if test.is_significant and test.p_value < config.auto_stop_winning_threshold: | |
| significant_wins += 1 | |
| if total_comparisons > 0 and significant_wins / total_comparisons > 0.8: | |
| return True, "Clear winning variant detected" | |
| return False, "" | |
| def _create_visualizations(self, experiment_id: str, summary: ExperimentSummary) -> Dict[str, str]: | |
| """Create visualizations for experiment report""" | |
| visualizations = {} | |
| # This would create actual plots and return their paths/base64 encodings | |
| # For now, returning placeholder paths | |
| visualizations['metrics_comparison'] = f"metrics_comparison_{experiment_id}.png" | |
| visualizations['confidence_intervals'] = f"confidence_intervals_{experiment_id}.png" | |
| visualizations['statistical_significance'] = f"statistical_significance_{experiment_id}.png" | |
| visualizations['bandit_performance'] = f"bandit_performance_{experiment_id}.png" | |
| return visualizations | |
| def _generate_detailed_analysis(self, summary: ExperimentSummary) -> Dict[str, Any]: | |
| """Generate detailed analysis""" | |
| return { | |
| 'sample_size_analysis': self._analyze_sample_sizes(summary), | |
| 'effect_size_analysis': self._analyze_effect_sizes(summary), | |
| 'statistical_power_analysis': self._analyze_statistical_power(summary), | |
| 'practical_significance': self._analyze_practical_significance(summary) | |
| } | |
| def _generate_detailed_recommendations(self, summary: ExperimentSummary) -> List[str]: | |
| """Generate detailed recommendations""" | |
| recommendations = [] | |
| # Sample size recommendations | |
| if summary.total_samples < 1000: | |
| recommendations.append("Consider collecting more data for increased statistical power") | |
| # Effect size recommendations | |
| # ... implementation based on effect sizes | |
| # Business impact recommendations | |
| recommendations.append("Evaluate business impact beyond statistical significance") | |
| return recommendations | |
| def _suggest_next_steps(self, summary: ExperimentSummary) -> List[str]: | |
| """Suggest next steps based on results""" | |
| next_steps = [] | |
| if summary.best_variant: | |
| next_steps.append(f"Consider deploying {summary.best_variant} as the new default") | |
| next_steps.append("Monitor performance in production environment") | |
| next_steps.append("Plan follow-up experiments to optimize further") | |
| return next_steps | |
| def _calculate_mean_difference_ci(self, values_a: List[float], values_b: List[float]) -> Tuple[float, float]: | |
| """Calculate confidence interval for mean difference""" | |
| mean_a, mean_b = np.mean(values_a), np.mean(values_b) | |
| var_a, var_b = np.var(values_a, ddof=1), np.var(values_b, ddof=1) | |
| n_a, n_b = len(values_a), len(values_b) | |
| # Pooled standard error | |
| se_diff = np.sqrt(var_a / n_a + var_b / n_b) | |
| # Degrees of freedom (Welch's formula) | |
| df = (var_a / n_a + var_b / n_b) ** 2 / ((var_a / n_a) ** 2 / (n_a - 1) + (var_b / n_b) ** 2 / (n_b - 1)) | |
| # Critical value for 95% confidence | |
| t_crit = stats.t.ppf(0.975, df) | |
| mean_diff = mean_a - mean_b | |
| margin_of_error = t_crit * se_diff | |
| return (mean_diff - margin_of_error, mean_diff + margin_of_error) | |
| def _calculate_statistical_power(self, n1: int, n2: int, effect_size: float, alpha: float) -> float: | |
| """Calculate statistical power (simplified)""" | |
| # This is a simplified power calculation | |
| # In practice, you'd use more sophisticated methods | |
| total_n = n1 + n2 | |
| if total_n < 20: | |
| return 0.2 | |
| elif total_n < 100: | |
| return 0.5 | |
| elif abs(effect_size) > 0.5: | |
| return 0.8 | |
| else: | |
| return 0.6 | |
| def _interpret_test_result( | |
| self, | |
| variant_a: str, | |
| variant_b: str, | |
| statistic: float, | |
| p_value: float, | |
| effect_size: float, | |
| is_significant: bool | |
| ) -> str: | |
| """Interpret statistical test result""" | |
| if is_significant: | |
| if abs(effect_size) > 0.8: | |
| magnitude = "large" | |
| elif abs(effect_size) > 0.5: | |
| magnitude = "medium" | |
| else: | |
| magnitude = "small" | |
| direction = "outperforms" if effect_size > 0 else "underperforms" | |
| return f"{variant_a} {direction} {variant_b} with a {magnitude} effect size (p={p_value:.3f})" | |
| else: | |
| return f"No significant difference between {variant_a} and {variant_b} (p={p_value:.3f})" | |
| def _analyze_sample_sizes(self, summary: ExperimentSummary) -> Dict[str, Any]: | |
| """Analyze sample sizes""" | |
| return {"analysis": "Sample size analysis would be implemented here"} | |
| def _analyze_effect_sizes(self, summary: ExperimentSummary) -> Dict[str, Any]: | |
| """Analyze effect sizes""" | |
| return {"analysis": "Effect size analysis would be implemented here"} | |
| def _analyze_statistical_power(self, summary: ExperimentSummary) -> Dict[str, Any]: | |
| """Analyze statistical power""" | |
| return {"analysis": "Statistical power analysis would be implemented here"} | |
| def _analyze_practical_significance(self, summary: ExperimentSummary) -> Dict[str, Any]: | |
| """Analyze practical significance""" | |
| return {"analysis": "Practical significance analysis would be implemented here"} | |
| def _update_experiment_status(self, experiment_id: str, status: ExperimentStatus): | |
| """Update experiment status in database""" | |
| conn = sqlite3.connect(self.db_path) | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| "UPDATE experiments SET status = ?, end_time = ? WHERE experiment_id = ?", | |
| (status.value, datetime.now() if status == ExperimentStatus.COMPLETED else None, experiment_id) | |
| ) | |
| conn.commit() | |
| conn.close() | |
| def _get_experiment_status(self, experiment_id: str) -> ExperimentStatus: | |
| """Get experiment status""" | |
| conn = sqlite3.connect(self.db_path) | |
| cursor = conn.cursor() | |
| cursor.execute("SELECT status FROM experiments WHERE experiment_id = ?", (experiment_id,)) | |
| result = cursor.fetchone() | |
| conn.close() | |
| if result: | |
| return ExperimentStatus(result[0]) | |
| return ExperimentStatus.DRAFT | |
| def _create_empty_summary(self, experiment_id: str) -> ExperimentSummary: | |
| """Create empty experiment summary""" | |
| return ExperimentSummary( | |
| experiment_id=experiment_id, | |
| status=ExperimentStatus.DRAFT, | |
| start_time=datetime.now(), | |
| end_time=None, | |
| total_samples=0, | |
| variant_samples={}, | |
| variant_metrics={}, | |
| statistical_tests={}, | |
| confidence_intervals={}, | |
| best_variant=None, | |
| recommendation="No data available", | |
| bandit_performance={} | |
| ) | |
| def _save_report(self, report: Dict[str, Any], output_path: str): | |
| """Save experiment report""" | |
| Path(output_path).parent.mkdir(parents=True, exist_ok=True) | |
| with open(output_path, 'w') as f: | |
| json.dump(report, f, indent=2, default=str) | |
| logger.info(f"Report saved to {output_path}") | |
| if __name__ == "__main__": | |
| logging.basicConfig(level=logging.INFO) | |
| # Example usage | |
| framework = ABTestingFramework() | |
| # Example models (placeholders) | |
| model_a = nn.Sequential(nn.Linear(100, 1)) | |
| model_b = nn.Sequential(nn.Linear(100, 50), nn.ReLU(), nn.Linear(50, 1)) | |
| # Create experiment configuration | |
| config = ExperimentConfig( | |
| experiment_id="morph_detection_comparison_001", | |
| name="MorphGuard Model Comparison", | |
| description="Comparing baseline vs enhanced model for morph detection", | |
| variants=[ | |
| ModelVariant(variant_id="baseline", model=model_a, description="Baseline model"), | |
| ModelVariant(variant_id="enhanced", model=model_b, description="Enhanced model with hidden layer") | |
| ], | |
| metrics=[ | |
| ExperimentMetric(name="accuracy", description="Detection accuracy", metric_type="accuracy", higher_is_better=True), | |
| ExperimentMetric(name="latency", description="Inference latency", metric_type="latency", higher_is_better=False) | |
| ], | |
| traffic_split={"baseline": 0.5, "enhanced": 0.5}, | |
| bandit_algorithm=BanditAlgorithm.THOMPSON_SAMPLING | |
| ) | |
| # Create and start experiment | |
| experiment_id = framework.create_experiment(config) | |
| framework.start_experiment(experiment_id) | |
| # Simulate some test results | |
| for i in range(1000): | |
| variant = framework.assign_variant(experiment_id) | |
| # Simulate metrics (enhanced model performs slightly better) | |
| if variant == "enhanced": | |
| accuracy = np.random.normal(0.92, 0.02) | |
| latency = np.random.normal(45, 5) | |
| else: | |
| accuracy = np.random.normal(0.89, 0.02) | |
| latency = np.random.normal(50, 5) | |
| framework.record_result( | |
| experiment_id=experiment_id, | |
| variant_id=variant, | |
| metrics={"accuracy": accuracy, "latency": latency}, | |
| user_id=f"user_{i}", | |
| session_id=f"session_{i}" | |
| ) | |
| # Analyze results | |
| summary = framework.run_statistical_analysis(experiment_id) | |
| print(f"Best variant: {summary.best_variant}") | |
| print(f"Recommendation: {summary.recommendation}") | |
| # Generate report | |
| report = framework.generate_report(experiment_id) | |
| print("Report generated successfully") |