SPG_ML / ab_testing.py
meetmendapara's picture
Initial commit for ML space
df31aa1
"""
A/B Testing Framework for Cognexa ML Service
Provides a complete A/B testing system for:
- Feature experiments (UI variants, algorithm variants)
- Model comparison (comparing prediction model versions)
- Notification strategies (channel, timing, content)
- Recommendation variants (algorithm A vs B)
Implements:
- User bucketing (consistent hash-based assignment)
- Statistical significance testing
- Early stopping (O'Brien-Fleming bounds)
- Multiple comparison correction (Bonferroni)
- Experiment lifecycle management
"""
from __future__ import annotations
import hashlib
import json
import logging
import uuid
from dataclasses import dataclass, asdict, field
from datetime import datetime, timedelta
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
from scipy import stats as scipy_stats
from statistical_analysis import compare_groups, _power_analyzer
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Enums
# ---------------------------------------------------------------------------
class ExperimentStatus(str, Enum):
DRAFT = "draft"
RUNNING = "running"
PAUSED = "paused"
COMPLETED = "completed"
STOPPED_EARLY = "stopped_early"
class ExperimentType(str, Enum):
FEATURE_FLAG = "feature_flag" # on/off split
MODEL_COMPARISON = "model_comparison"
NOTIFICATION = "notification"
UI_VARIANT = "ui_variant"
RECOMMENDATION = "recommendation"
class AllocationStrategy(str, Enum):
HASH = "hash" # deterministic, consistent
RANDOM = "random"
# ---------------------------------------------------------------------------
# Data Structures
# ---------------------------------------------------------------------------
@dataclass
class Variant:
"""A single variant (arm) in an experiment."""
variant_id: str
name: str # e.g. "control", "treatment_a"
description: str
allocation_percent: float # 0-100
config: Dict[str, Any] = field(default_factory=dict)
is_control: bool = False
@dataclass
class Experiment:
"""A/B experiment definition."""
experiment_id: str
name: str
description: str
experiment_type: str # ExperimentType
status: str # ExperimentStatus
variants: List[Variant]
primary_metric: str # e.g. "task_completion_rate"
secondary_metrics: List[str]
allocation_strategy: str # AllocationStrategy
traffic_percent: float # 0-100 fraction of users to include
min_sample_size: int # required per variant before analysis
alpha: float # significance level
target_power: float
created_at: str
started_at: Optional[str]
ended_at: Optional[str]
owner: str
tags: List[str] = field(default_factory=list)
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass
class ExperimentAssignment:
"""Records which variant a user is assigned to."""
assignment_id: str
experiment_id: str
user_id: str
variant_id: str
variant_name: str
assigned_at: str
is_exposed: bool = True # first exposure recorded
@dataclass
class MetricObservation:
"""A single metric measurement from an experiment participant."""
observation_id: str
experiment_id: str
user_id: str
variant_id: str
metric_name: str
value: float
recorded_at: str
@dataclass
class VariantResult:
"""Statistical result for a single variant."""
variant_id: str
variant_name: str
n_observations: int
mean: float
std: float
median: float
is_control: bool
@dataclass
class ExperimentResult:
"""Full statistical analysis result for an experiment."""
experiment_id: str
experiment_name: str
primary_metric: str
status: str
control_result: VariantResult
treatment_results: List[VariantResult]
# Significance
p_value: float
is_significant: bool
alpha: float
# Effect size
effect_size: float
relative_uplift: float # % improvement over control
# Power
current_power: float
recommended_sample_size: int
is_adequately_powered: bool
# Decision
winner: Optional[str] # variant_name of winner, None if no winner yet
recommendation: str
analyzed_at: str
@dataclass
class EarlyStoppingDecision:
"""Result of early stopping evaluation."""
should_stop: bool
reason: str
current_p_value: float
obrien_fleming_threshold: float
current_n: int
planned_n: int
interim_fraction: float
# ---------------------------------------------------------------------------
# User Bucketing
# ---------------------------------------------------------------------------
class UserBucketizer:
"""Consistent hash-based user assignment to experiment variants."""
def assign_variant(
self,
user_id: str,
experiment_id: str,
variants: List[Variant],
traffic_percent: float = 100.0,
strategy: str = AllocationStrategy.HASH,
) -> Optional[Variant]:
"""
Assign a user to a variant. Returns None if user not in experiment traffic.
"""
if strategy == AllocationStrategy.RANDOM:
bucket = np.random.uniform(0, 100)
else:
# Deterministic: hash(user_id + experiment_id)
key = f"{user_id}:{experiment_id}"
h = int(hashlib.md5(key.encode()).hexdigest(), 16)
bucket = (h % 10000) / 100.0 # 0.00 - 99.99
# Check traffic inclusion
if bucket >= traffic_percent:
return None
# Assign to variant based on allocation percentages
cumulative = 0.0
# Normalize allocation within traffic slice
total_alloc = sum(v.allocation_percent for v in variants)
for variant in variants:
cumulative += (variant.allocation_percent / total_alloc) * traffic_percent
if bucket < cumulative:
return variant
return variants[-1] # fallback
# ---------------------------------------------------------------------------
# Experiment Storage
# ---------------------------------------------------------------------------
class ExperimentStore:
"""Filesystem-backed store for experiments and observations."""
def __init__(self, data_dir: str = "data/ab_testing"):
self.data_dir = Path(data_dir)
self.data_dir.mkdir(parents=True, exist_ok=True)
self.experiments_file = self.data_dir / "experiments.json"
self.assignments_file = self.data_dir / "assignments.json"
self.observations_file = self.data_dir / "observations.json"
self._experiments: Dict[str, Experiment] = self._load_experiments()
self._assignments: List[ExperimentAssignment] = self._load_assignments()
self._observations: List[MetricObservation] = self._load_observations()
# -- Experiments ----------------------------------------------------------
def _load_experiments(self) -> Dict[str, Experiment]:
if not self.experiments_file.exists():
return {}
try:
with open(self.experiments_file) as f:
data = json.load(f)
return {
eid: Experiment(
**{k: v for k, v in exp.items() if k != "variants"},
variants=[Variant(**v) for v in exp.get("variants", [])],
)
for eid, exp in data.items()
}
except Exception as e:
logger.warning("Could not load experiments: %s", e)
return {}
def save_experiment(self, experiment: Experiment):
self._experiments[experiment.experiment_id] = experiment
with open(self.experiments_file, "w") as f:
json.dump(
{eid: asdict(exp) for eid, exp in self._experiments.items()},
f, indent=2
)
def get_experiment(self, experiment_id: str) -> Optional[Experiment]:
return self._experiments.get(experiment_id)
def list_experiments(self, status: Optional[str] = None) -> List[Experiment]:
exps = list(self._experiments.values())
if status:
exps = [e for e in exps if e.status == status]
return sorted(exps, key=lambda e: e.created_at, reverse=True)
# -- Assignments ----------------------------------------------------------
def _load_assignments(self) -> List[ExperimentAssignment]:
if not self.assignments_file.exists():
return []
try:
with open(self.assignments_file) as f:
return [ExperimentAssignment(**a) for a in json.load(f)]
except Exception as e:
logger.warning("Could not load assignments: %s", e)
return []
def save_assignment(self, assignment: ExperimentAssignment):
self._assignments.append(assignment)
if len(self._assignments) > 100000:
self._assignments = self._assignments[-100000:]
with open(self.assignments_file, "w") as f:
json.dump([asdict(a) for a in self._assignments], f)
def get_user_assignment(
self, user_id: str, experiment_id: str
) -> Optional[ExperimentAssignment]:
for a in reversed(self._assignments):
if a.user_id == user_id and a.experiment_id == experiment_id:
return a
return None
# -- Observations ----------------------------------------------------------
def _load_observations(self) -> List[MetricObservation]:
if not self.observations_file.exists():
return []
try:
with open(self.observations_file) as f:
return [MetricObservation(**o) for o in json.load(f)]
except Exception as e:
logger.warning("Could not load observations: %s", e)
return []
def save_observation(self, obs: MetricObservation):
self._observations.append(obs)
if len(self._observations) > 500000:
self._observations = self._observations[-500000:]
with open(self.observations_file, "w") as f:
json.dump([asdict(o) for o in self._observations], f)
def get_observations(
self,
experiment_id: str,
metric_name: str,
) -> Dict[str, List[float]]:
"""Return dict: variant_id -> list of metric values."""
result: Dict[str, List[float]] = {}
for obs in self._observations:
if obs.experiment_id == experiment_id and obs.metric_name == metric_name:
result.setdefault(obs.variant_id, []).append(obs.value)
return result
# ---------------------------------------------------------------------------
# Statistical Analyzer
# ---------------------------------------------------------------------------
class ExperimentAnalyzer:
"""Analyzes experiment results and computes statistical significance."""
def analyze(
self,
experiment: Experiment,
observations: Dict[str, Dict[str, List[float]]], # variant_id -> metric -> values
) -> ExperimentResult:
primary = experiment.primary_metric
obs_by_variant = observations.get(primary, {})
control_variant = next(
(v for v in experiment.variants if v.is_control), experiment.variants[0]
)
control_values = obs_by_variant.get(control_variant.variant_id, [])
treatment_variants = [v for v in experiment.variants if not v.is_control]
def _summarize(v: Variant) -> VariantResult:
vals = obs_by_variant.get(v.variant_id, [])
return VariantResult(
variant_id=v.variant_id,
variant_name=v.name,
n_observations=len(vals),
mean=round(float(np.mean(vals)), 4) if vals else 0.0,
std=round(float(np.std(vals, ddof=1)), 4) if len(vals) > 1 else 0.0,
median=round(float(np.median(vals)), 4) if vals else 0.0,
is_control=v.is_control,
)
ctrl_result = _summarize(control_variant)
treat_results = [_summarize(v) for v in treatment_variants]
# Primary comparison: control vs best treatment
best_treat = max(treat_results, key=lambda r: r.mean) if treat_results else None
best_values = obs_by_variant.get(best_treat.variant_id, []) if best_treat else []
p_value = 1.0
effect_size = 0.0
if len(control_values) >= 2 and len(best_values) >= 2:
test_result = compare_groups(control_values, best_values, test="auto", alpha=experiment.alpha)
p_value = test_result["p_value"]
effect_size = test_result["effect_size"]
is_significant = p_value < experiment.alpha
relative_uplift = (
(best_treat.mean - ctrl_result.mean) / ctrl_result.mean * 100
if ctrl_result.mean != 0 and best_treat else 0.0
)
# Power
total_n = sum(len(v) for v in obs_by_variant.values())
per_variant_n = max(1, total_n // max(1, len(experiment.variants)))
power_result = _power_analyzer.compute_sample_size(
effect_size=max(0.01, abs(effect_size)),
alpha=experiment.alpha,
power=experiment.target_power,
)
is_powered = per_variant_n >= power_result.required_sample_size
# Winner
winner = None
if is_significant and is_powered and best_treat and relative_uplift > 0:
winner = best_treat.variant_name
elif is_significant and is_powered and best_treat and relative_uplift < 0:
winner = control_variant.name
recommendation = self._recommendation(
is_significant, is_powered, winner, per_variant_n,
power_result.required_sample_size, relative_uplift
)
return ExperimentResult(
experiment_id=experiment.experiment_id,
experiment_name=experiment.name,
primary_metric=primary,
status=experiment.status,
control_result=ctrl_result,
treatment_results=treat_results,
p_value=round(p_value, 6),
is_significant=is_significant,
alpha=experiment.alpha,
effect_size=round(effect_size, 4),
relative_uplift=round(relative_uplift, 2),
current_power=round(power_result.current_power, 4),
recommended_sample_size=power_result.required_sample_size,
is_adequately_powered=is_powered,
winner=winner,
recommendation=recommendation,
analyzed_at=datetime.utcnow().isoformat(),
)
def check_early_stopping(
self,
current_p: float,
current_n: int,
planned_n: int,
alpha: float = 0.05,
) -> EarlyStoppingDecision:
"""O'Brien-Fleming alpha spending for early stopping."""
if planned_n <= 0 or current_n <= 0:
return EarlyStoppingDecision(False, "Insufficient data", 1.0, alpha, 0, planned_n, 0.0)
fraction = min(1.0, current_n / planned_n)
# O'Brien-Fleming boundary: alpha_spent = alpha * (2 - 2*Phi(z_alpha / sqrt(fraction)))
# Simplified: threshold scales with 1/sqrt(fraction)
z_alpha = scipy_stats.norm.ppf(1 - alpha / 2)
if fraction < 0.01:
threshold = 1e-6 # essentially never stop very early
else:
obf_z = z_alpha / np.sqrt(fraction)
threshold = 2 * scipy_stats.norm.sf(obf_z) # two-tailed p-value threshold
should_stop = current_p < threshold
reason = (
f"Early stopping: p={current_p:.4f} < OBF threshold={threshold:.4f}"
if should_stop
else f"Continue: p={current_p:.4f} >= OBF threshold={threshold:.4f} at {fraction:.0%} interim"
)
return EarlyStoppingDecision(
should_stop=should_stop,
reason=reason,
current_p_value=round(current_p, 6),
obrien_fleming_threshold=round(threshold, 6),
current_n=current_n,
planned_n=planned_n,
interim_fraction=round(fraction, 4),
)
def _recommendation(
self,
significant: bool,
powered: bool,
winner: Optional[str],
current_n: int,
required_n: int,
uplift: float,
) -> str:
if not powered:
return (
f"Continue collecting data. Need {required_n} per variant, "
f"currently at {current_n}."
)
if not significant:
return "No significant difference. Consider continuing or stopping (null hypothesis)."
if winner:
return (
f"Ship variant '{winner}' - statistically significant "
f"({'↑' if uplift > 0 else '↓'}{abs(uplift):.1f}% vs control)."
)
return "Re-examine experiment design."
# ---------------------------------------------------------------------------
# A/B Testing Manager (main API)
# ---------------------------------------------------------------------------
class ABTestingManager:
"""High-level manager for A/B experiments."""
def __init__(self):
self.store = ExperimentStore()
self.bucketizer = UserBucketizer()
self.analyzer = ExperimentAnalyzer()
def create_experiment(
self,
name: str,
description: str,
primary_metric: str,
variants: List[Dict[str, Any]],
experiment_type: str = ExperimentType.FEATURE_FLAG,
secondary_metrics: Optional[List[str]] = None,
traffic_percent: float = 100.0,
min_sample_size: int = 100,
alpha: float = 0.05,
target_power: float = 0.80,
owner: str = "system",
tags: Optional[List[str]] = None,
) -> Experiment:
experiment_id = str(uuid.uuid4())
parsed_variants = [
Variant(
variant_id=str(uuid.uuid4()),
name=v["name"],
description=v.get("description", ""),
allocation_percent=v.get("allocation_percent", 100.0 / len(variants)),
config=v.get("config", {}),
is_control=v.get("is_control", False),
)
for v in variants
]
# If no control marked, mark first as control
if not any(v.is_control for v in parsed_variants):
parsed_variants[0].is_control = True
exp = Experiment(
experiment_id=experiment_id,
name=name,
description=description,
experiment_type=experiment_type,
status=ExperimentStatus.DRAFT,
variants=parsed_variants,
primary_metric=primary_metric,
secondary_metrics=secondary_metrics or [],
allocation_strategy=AllocationStrategy.HASH,
traffic_percent=traffic_percent,
min_sample_size=min_sample_size,
alpha=alpha,
target_power=target_power,
created_at=datetime.utcnow().isoformat(),
started_at=None,
ended_at=None,
owner=owner,
tags=tags or [],
)
self.store.save_experiment(exp)
logger.info("Created experiment %s (%s)", name, experiment_id)
return exp
def start_experiment(self, experiment_id: str) -> bool:
exp = self.store.get_experiment(experiment_id)
if not exp:
return False
exp.status = ExperimentStatus.RUNNING
exp.started_at = datetime.utcnow().isoformat()
self.store.save_experiment(exp)
return True
def stop_experiment(self, experiment_id: str, early: bool = False) -> bool:
exp = self.store.get_experiment(experiment_id)
if not exp:
return False
exp.status = ExperimentStatus.STOPPED_EARLY if early else ExperimentStatus.COMPLETED
exp.ended_at = datetime.utcnow().isoformat()
self.store.save_experiment(exp)
return True
def get_variant_for_user(
self, user_id: str, experiment_id: str
) -> Optional[Dict[str, Any]]:
"""
Assign (or retrieve cached assignment) for a user in an experiment.
Returns dict with variant info or None if not in experiment.
"""
exp = self.store.get_experiment(experiment_id)
if not exp or exp.status != ExperimentStatus.RUNNING:
return None
# Check existing assignment
existing = self.store.get_user_assignment(user_id, experiment_id)
if existing:
variant = next((v for v in exp.variants if v.variant_id == existing.variant_id), None)
if variant:
return {"variant_id": variant.variant_id, "variant_name": variant.name,
"config": variant.config, "is_control": variant.is_control}
# New assignment
variant = self.bucketizer.assign_variant(
user_id, experiment_id, exp.variants,
exp.traffic_percent, exp.allocation_strategy
)
if variant is None:
return None
assignment = ExperimentAssignment(
assignment_id=str(uuid.uuid4()),
experiment_id=experiment_id,
user_id=user_id,
variant_id=variant.variant_id,
variant_name=variant.name,
assigned_at=datetime.utcnow().isoformat(),
)
self.store.save_assignment(assignment)
return {"variant_id": variant.variant_id, "variant_name": variant.name,
"config": variant.config, "is_control": variant.is_control}
def record_metric(
self,
experiment_id: str,
user_id: str,
metric_name: str,
value: float,
) -> bool:
"""Record a metric observation for a user in an experiment."""
assignment = self.store.get_user_assignment(user_id, experiment_id)
if not assignment:
return False
obs = MetricObservation(
observation_id=str(uuid.uuid4()),
experiment_id=experiment_id,
user_id=user_id,
variant_id=assignment.variant_id,
metric_name=metric_name,
value=value,
recorded_at=datetime.utcnow().isoformat(),
)
self.store.save_observation(obs)
return True
def analyze_experiment(self, experiment_id: str) -> Optional[Dict[str, Any]]:
"""Run statistical analysis on an experiment."""
exp = self.store.get_experiment(experiment_id)
if not exp:
return None
all_metrics = {exp.primary_metric} | set(exp.secondary_metrics)
observations: Dict[str, Dict[str, List[float]]] = {}
for metric in all_metrics:
observations[metric] = self.store.get_observations(experiment_id, metric)
result = self.analyzer.analyze(exp, observations)
# Check early stopping
total_n = sum(
len(v) for v in observations.get(exp.primary_metric, {}).values()
)
n_per_variant = max(1, total_n // max(1, len(exp.variants)))
stopping = self.analyzer.check_early_stopping(
result.p_value, n_per_variant,
exp.min_sample_size, exp.alpha
)
return {
**asdict(result),
"early_stopping": asdict(stopping),
}
def list_experiments(self, status: Optional[str] = None) -> List[Dict[str, Any]]:
exps = self.store.list_experiments(status)
return [asdict(e) for e in exps]
# ---------------------------------------------------------------------------
# Singleton
# ---------------------------------------------------------------------------
_manager_instance: Optional[ABTestingManager] = None
def get_ab_manager() -> ABTestingManager:
global _manager_instance
if _manager_instance is None:
_manager_instance = ABTestingManager()
return _manager_instance