""" A/B Testing Framework for Cognexa ML Service Provides a complete A/B testing system for: - Feature experiments (UI variants, algorithm variants) - Model comparison (comparing prediction model versions) - Notification strategies (channel, timing, content) - Recommendation variants (algorithm A vs B) Implements: - User bucketing (consistent hash-based assignment) - Statistical significance testing - Early stopping (O'Brien-Fleming bounds) - Multiple comparison correction (Bonferroni) - Experiment lifecycle management """ from __future__ import annotations import hashlib import json import logging import uuid from dataclasses import dataclass, asdict, field from datetime import datetime, timedelta from enum import Enum from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import numpy as np from scipy import stats as scipy_stats from statistical_analysis import compare_groups, _power_analyzer logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Enums # --------------------------------------------------------------------------- class ExperimentStatus(str, Enum): DRAFT = "draft" RUNNING = "running" PAUSED = "paused" COMPLETED = "completed" STOPPED_EARLY = "stopped_early" class ExperimentType(str, Enum): FEATURE_FLAG = "feature_flag" # on/off split MODEL_COMPARISON = "model_comparison" NOTIFICATION = "notification" UI_VARIANT = "ui_variant" RECOMMENDATION = "recommendation" class AllocationStrategy(str, Enum): HASH = "hash" # deterministic, consistent RANDOM = "random" # --------------------------------------------------------------------------- # Data Structures # --------------------------------------------------------------------------- @dataclass class Variant: """A single variant (arm) in an experiment.""" variant_id: str name: str # e.g. "control", "treatment_a" description: str allocation_percent: float # 0-100 config: Dict[str, Any] = field(default_factory=dict) is_control: bool = False @dataclass class Experiment: """A/B experiment definition.""" experiment_id: str name: str description: str experiment_type: str # ExperimentType status: str # ExperimentStatus variants: List[Variant] primary_metric: str # e.g. "task_completion_rate" secondary_metrics: List[str] allocation_strategy: str # AllocationStrategy traffic_percent: float # 0-100 fraction of users to include min_sample_size: int # required per variant before analysis alpha: float # significance level target_power: float created_at: str started_at: Optional[str] ended_at: Optional[str] owner: str tags: List[str] = field(default_factory=list) metadata: Dict[str, Any] = field(default_factory=dict) @dataclass class ExperimentAssignment: """Records which variant a user is assigned to.""" assignment_id: str experiment_id: str user_id: str variant_id: str variant_name: str assigned_at: str is_exposed: bool = True # first exposure recorded @dataclass class MetricObservation: """A single metric measurement from an experiment participant.""" observation_id: str experiment_id: str user_id: str variant_id: str metric_name: str value: float recorded_at: str @dataclass class VariantResult: """Statistical result for a single variant.""" variant_id: str variant_name: str n_observations: int mean: float std: float median: float is_control: bool @dataclass class ExperimentResult: """Full statistical analysis result for an experiment.""" experiment_id: str experiment_name: str primary_metric: str status: str control_result: VariantResult treatment_results: List[VariantResult] # Significance p_value: float is_significant: bool alpha: float # Effect size effect_size: float relative_uplift: float # % improvement over control # Power current_power: float recommended_sample_size: int is_adequately_powered: bool # Decision winner: Optional[str] # variant_name of winner, None if no winner yet recommendation: str analyzed_at: str @dataclass class EarlyStoppingDecision: """Result of early stopping evaluation.""" should_stop: bool reason: str current_p_value: float obrien_fleming_threshold: float current_n: int planned_n: int interim_fraction: float # --------------------------------------------------------------------------- # User Bucketing # --------------------------------------------------------------------------- class UserBucketizer: """Consistent hash-based user assignment to experiment variants.""" def assign_variant( self, user_id: str, experiment_id: str, variants: List[Variant], traffic_percent: float = 100.0, strategy: str = AllocationStrategy.HASH, ) -> Optional[Variant]: """ Assign a user to a variant. Returns None if user not in experiment traffic. """ if strategy == AllocationStrategy.RANDOM: bucket = np.random.uniform(0, 100) else: # Deterministic: hash(user_id + experiment_id) key = f"{user_id}:{experiment_id}" h = int(hashlib.md5(key.encode()).hexdigest(), 16) bucket = (h % 10000) / 100.0 # 0.00 - 99.99 # Check traffic inclusion if bucket >= traffic_percent: return None # Assign to variant based on allocation percentages cumulative = 0.0 # Normalize allocation within traffic slice total_alloc = sum(v.allocation_percent for v in variants) for variant in variants: cumulative += (variant.allocation_percent / total_alloc) * traffic_percent if bucket < cumulative: return variant return variants[-1] # fallback # --------------------------------------------------------------------------- # Experiment Storage # --------------------------------------------------------------------------- class ExperimentStore: """Filesystem-backed store for experiments and observations.""" def __init__(self, data_dir: str = "data/ab_testing"): self.data_dir = Path(data_dir) self.data_dir.mkdir(parents=True, exist_ok=True) self.experiments_file = self.data_dir / "experiments.json" self.assignments_file = self.data_dir / "assignments.json" self.observations_file = self.data_dir / "observations.json" self._experiments: Dict[str, Experiment] = self._load_experiments() self._assignments: List[ExperimentAssignment] = self._load_assignments() self._observations: List[MetricObservation] = self._load_observations() # -- Experiments ---------------------------------------------------------- def _load_experiments(self) -> Dict[str, Experiment]: if not self.experiments_file.exists(): return {} try: with open(self.experiments_file) as f: data = json.load(f) return { eid: Experiment( **{k: v for k, v in exp.items() if k != "variants"}, variants=[Variant(**v) for v in exp.get("variants", [])], ) for eid, exp in data.items() } except Exception as e: logger.warning("Could not load experiments: %s", e) return {} def save_experiment(self, experiment: Experiment): self._experiments[experiment.experiment_id] = experiment with open(self.experiments_file, "w") as f: json.dump( {eid: asdict(exp) for eid, exp in self._experiments.items()}, f, indent=2 ) def get_experiment(self, experiment_id: str) -> Optional[Experiment]: return self._experiments.get(experiment_id) def list_experiments(self, status: Optional[str] = None) -> List[Experiment]: exps = list(self._experiments.values()) if status: exps = [e for e in exps if e.status == status] return sorted(exps, key=lambda e: e.created_at, reverse=True) # -- Assignments ---------------------------------------------------------- def _load_assignments(self) -> List[ExperimentAssignment]: if not self.assignments_file.exists(): return [] try: with open(self.assignments_file) as f: return [ExperimentAssignment(**a) for a in json.load(f)] except Exception as e: logger.warning("Could not load assignments: %s", e) return [] def save_assignment(self, assignment: ExperimentAssignment): self._assignments.append(assignment) if len(self._assignments) > 100000: self._assignments = self._assignments[-100000:] with open(self.assignments_file, "w") as f: json.dump([asdict(a) for a in self._assignments], f) def get_user_assignment( self, user_id: str, experiment_id: str ) -> Optional[ExperimentAssignment]: for a in reversed(self._assignments): if a.user_id == user_id and a.experiment_id == experiment_id: return a return None # -- Observations ---------------------------------------------------------- def _load_observations(self) -> List[MetricObservation]: if not self.observations_file.exists(): return [] try: with open(self.observations_file) as f: return [MetricObservation(**o) for o in json.load(f)] except Exception as e: logger.warning("Could not load observations: %s", e) return [] def save_observation(self, obs: MetricObservation): self._observations.append(obs) if len(self._observations) > 500000: self._observations = self._observations[-500000:] with open(self.observations_file, "w") as f: json.dump([asdict(o) for o in self._observations], f) def get_observations( self, experiment_id: str, metric_name: str, ) -> Dict[str, List[float]]: """Return dict: variant_id -> list of metric values.""" result: Dict[str, List[float]] = {} for obs in self._observations: if obs.experiment_id == experiment_id and obs.metric_name == metric_name: result.setdefault(obs.variant_id, []).append(obs.value) return result # --------------------------------------------------------------------------- # Statistical Analyzer # --------------------------------------------------------------------------- class ExperimentAnalyzer: """Analyzes experiment results and computes statistical significance.""" def analyze( self, experiment: Experiment, observations: Dict[str, Dict[str, List[float]]], # variant_id -> metric -> values ) -> ExperimentResult: primary = experiment.primary_metric obs_by_variant = observations.get(primary, {}) control_variant = next( (v for v in experiment.variants if v.is_control), experiment.variants[0] ) control_values = obs_by_variant.get(control_variant.variant_id, []) treatment_variants = [v for v in experiment.variants if not v.is_control] def _summarize(v: Variant) -> VariantResult: vals = obs_by_variant.get(v.variant_id, []) return VariantResult( variant_id=v.variant_id, variant_name=v.name, n_observations=len(vals), mean=round(float(np.mean(vals)), 4) if vals else 0.0, std=round(float(np.std(vals, ddof=1)), 4) if len(vals) > 1 else 0.0, median=round(float(np.median(vals)), 4) if vals else 0.0, is_control=v.is_control, ) ctrl_result = _summarize(control_variant) treat_results = [_summarize(v) for v in treatment_variants] # Primary comparison: control vs best treatment best_treat = max(treat_results, key=lambda r: r.mean) if treat_results else None best_values = obs_by_variant.get(best_treat.variant_id, []) if best_treat else [] p_value = 1.0 effect_size = 0.0 if len(control_values) >= 2 and len(best_values) >= 2: test_result = compare_groups(control_values, best_values, test="auto", alpha=experiment.alpha) p_value = test_result["p_value"] effect_size = test_result["effect_size"] is_significant = p_value < experiment.alpha relative_uplift = ( (best_treat.mean - ctrl_result.mean) / ctrl_result.mean * 100 if ctrl_result.mean != 0 and best_treat else 0.0 ) # Power total_n = sum(len(v) for v in obs_by_variant.values()) per_variant_n = max(1, total_n // max(1, len(experiment.variants))) power_result = _power_analyzer.compute_sample_size( effect_size=max(0.01, abs(effect_size)), alpha=experiment.alpha, power=experiment.target_power, ) is_powered = per_variant_n >= power_result.required_sample_size # Winner winner = None if is_significant and is_powered and best_treat and relative_uplift > 0: winner = best_treat.variant_name elif is_significant and is_powered and best_treat and relative_uplift < 0: winner = control_variant.name recommendation = self._recommendation( is_significant, is_powered, winner, per_variant_n, power_result.required_sample_size, relative_uplift ) return ExperimentResult( experiment_id=experiment.experiment_id, experiment_name=experiment.name, primary_metric=primary, status=experiment.status, control_result=ctrl_result, treatment_results=treat_results, p_value=round(p_value, 6), is_significant=is_significant, alpha=experiment.alpha, effect_size=round(effect_size, 4), relative_uplift=round(relative_uplift, 2), current_power=round(power_result.current_power, 4), recommended_sample_size=power_result.required_sample_size, is_adequately_powered=is_powered, winner=winner, recommendation=recommendation, analyzed_at=datetime.utcnow().isoformat(), ) def check_early_stopping( self, current_p: float, current_n: int, planned_n: int, alpha: float = 0.05, ) -> EarlyStoppingDecision: """O'Brien-Fleming alpha spending for early stopping.""" if planned_n <= 0 or current_n <= 0: return EarlyStoppingDecision(False, "Insufficient data", 1.0, alpha, 0, planned_n, 0.0) fraction = min(1.0, current_n / planned_n) # O'Brien-Fleming boundary: alpha_spent = alpha * (2 - 2*Phi(z_alpha / sqrt(fraction))) # Simplified: threshold scales with 1/sqrt(fraction) z_alpha = scipy_stats.norm.ppf(1 - alpha / 2) if fraction < 0.01: threshold = 1e-6 # essentially never stop very early else: obf_z = z_alpha / np.sqrt(fraction) threshold = 2 * scipy_stats.norm.sf(obf_z) # two-tailed p-value threshold should_stop = current_p < threshold reason = ( f"Early stopping: p={current_p:.4f} < OBF threshold={threshold:.4f}" if should_stop else f"Continue: p={current_p:.4f} >= OBF threshold={threshold:.4f} at {fraction:.0%} interim" ) return EarlyStoppingDecision( should_stop=should_stop, reason=reason, current_p_value=round(current_p, 6), obrien_fleming_threshold=round(threshold, 6), current_n=current_n, planned_n=planned_n, interim_fraction=round(fraction, 4), ) def _recommendation( self, significant: bool, powered: bool, winner: Optional[str], current_n: int, required_n: int, uplift: float, ) -> str: if not powered: return ( f"Continue collecting data. Need {required_n} per variant, " f"currently at {current_n}." ) if not significant: return "No significant difference. Consider continuing or stopping (null hypothesis)." if winner: return ( f"Ship variant '{winner}' - statistically significant " f"({'↑' if uplift > 0 else '↓'}{abs(uplift):.1f}% vs control)." ) return "Re-examine experiment design." # --------------------------------------------------------------------------- # A/B Testing Manager (main API) # --------------------------------------------------------------------------- class ABTestingManager: """High-level manager for A/B experiments.""" def __init__(self): self.store = ExperimentStore() self.bucketizer = UserBucketizer() self.analyzer = ExperimentAnalyzer() def create_experiment( self, name: str, description: str, primary_metric: str, variants: List[Dict[str, Any]], experiment_type: str = ExperimentType.FEATURE_FLAG, secondary_metrics: Optional[List[str]] = None, traffic_percent: float = 100.0, min_sample_size: int = 100, alpha: float = 0.05, target_power: float = 0.80, owner: str = "system", tags: Optional[List[str]] = None, ) -> Experiment: experiment_id = str(uuid.uuid4()) parsed_variants = [ Variant( variant_id=str(uuid.uuid4()), name=v["name"], description=v.get("description", ""), allocation_percent=v.get("allocation_percent", 100.0 / len(variants)), config=v.get("config", {}), is_control=v.get("is_control", False), ) for v in variants ] # If no control marked, mark first as control if not any(v.is_control for v in parsed_variants): parsed_variants[0].is_control = True exp = Experiment( experiment_id=experiment_id, name=name, description=description, experiment_type=experiment_type, status=ExperimentStatus.DRAFT, variants=parsed_variants, primary_metric=primary_metric, secondary_metrics=secondary_metrics or [], allocation_strategy=AllocationStrategy.HASH, traffic_percent=traffic_percent, min_sample_size=min_sample_size, alpha=alpha, target_power=target_power, created_at=datetime.utcnow().isoformat(), started_at=None, ended_at=None, owner=owner, tags=tags or [], ) self.store.save_experiment(exp) logger.info("Created experiment %s (%s)", name, experiment_id) return exp def start_experiment(self, experiment_id: str) -> bool: exp = self.store.get_experiment(experiment_id) if not exp: return False exp.status = ExperimentStatus.RUNNING exp.started_at = datetime.utcnow().isoformat() self.store.save_experiment(exp) return True def stop_experiment(self, experiment_id: str, early: bool = False) -> bool: exp = self.store.get_experiment(experiment_id) if not exp: return False exp.status = ExperimentStatus.STOPPED_EARLY if early else ExperimentStatus.COMPLETED exp.ended_at = datetime.utcnow().isoformat() self.store.save_experiment(exp) return True def get_variant_for_user( self, user_id: str, experiment_id: str ) -> Optional[Dict[str, Any]]: """ Assign (or retrieve cached assignment) for a user in an experiment. Returns dict with variant info or None if not in experiment. """ exp = self.store.get_experiment(experiment_id) if not exp or exp.status != ExperimentStatus.RUNNING: return None # Check existing assignment existing = self.store.get_user_assignment(user_id, experiment_id) if existing: variant = next((v for v in exp.variants if v.variant_id == existing.variant_id), None) if variant: return {"variant_id": variant.variant_id, "variant_name": variant.name, "config": variant.config, "is_control": variant.is_control} # New assignment variant = self.bucketizer.assign_variant( user_id, experiment_id, exp.variants, exp.traffic_percent, exp.allocation_strategy ) if variant is None: return None assignment = ExperimentAssignment( assignment_id=str(uuid.uuid4()), experiment_id=experiment_id, user_id=user_id, variant_id=variant.variant_id, variant_name=variant.name, assigned_at=datetime.utcnow().isoformat(), ) self.store.save_assignment(assignment) return {"variant_id": variant.variant_id, "variant_name": variant.name, "config": variant.config, "is_control": variant.is_control} def record_metric( self, experiment_id: str, user_id: str, metric_name: str, value: float, ) -> bool: """Record a metric observation for a user in an experiment.""" assignment = self.store.get_user_assignment(user_id, experiment_id) if not assignment: return False obs = MetricObservation( observation_id=str(uuid.uuid4()), experiment_id=experiment_id, user_id=user_id, variant_id=assignment.variant_id, metric_name=metric_name, value=value, recorded_at=datetime.utcnow().isoformat(), ) self.store.save_observation(obs) return True def analyze_experiment(self, experiment_id: str) -> Optional[Dict[str, Any]]: """Run statistical analysis on an experiment.""" exp = self.store.get_experiment(experiment_id) if not exp: return None all_metrics = {exp.primary_metric} | set(exp.secondary_metrics) observations: Dict[str, Dict[str, List[float]]] = {} for metric in all_metrics: observations[metric] = self.store.get_observations(experiment_id, metric) result = self.analyzer.analyze(exp, observations) # Check early stopping total_n = sum( len(v) for v in observations.get(exp.primary_metric, {}).values() ) n_per_variant = max(1, total_n // max(1, len(exp.variants))) stopping = self.analyzer.check_early_stopping( result.p_value, n_per_variant, exp.min_sample_size, exp.alpha ) return { **asdict(result), "early_stopping": asdict(stopping), } def list_experiments(self, status: Optional[str] = None) -> List[Dict[str, Any]]: exps = self.store.list_experiments(status) return [asdict(e) for e in exps] # --------------------------------------------------------------------------- # Singleton # --------------------------------------------------------------------------- _manager_instance: Optional[ABTestingManager] = None def get_ab_manager() -> ABTestingManager: global _manager_instance if _manager_instance is None: _manager_instance = ABTestingManager() return _manager_instance