Coverage for tinytroupe / experimentation / statistical_tests.py: 0%
236 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-28 17:48 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-28 17:48 +0000
1import numpy as np
2import scipy.stats as stats
3from typing import Dict, List, Union, Callable, Any, Optional
5from tinytroupe.experimentation import logger
8class StatisticalTester:
9 """
10 A class to perform statistical tests on experiment results. To do so, a control is defined, and then one or
11 more treatments are compared to the control. The class supports various statistical tests, including t-tests,
12 Mann-Whitney U tests, and ANOVA. The user can specify the type of test to run, the significance level, and
13 the specific metrics to analyze. The results of the tests are returned in a structured format.
14 """
16 def __init__(self, control_experiment_data: Dict[str, list],
17 treatments_experiment_data: Dict[str, Dict[str, list]],
18 results_key:str = None):
19 """
20 Initialize with experiment results.
22 Args:
23 control_experiment_data (dict): Dictionary containing control experiment results with keys
24 as metric names and values as lists of values.
25 e.g.,{"control_exp": {"metric1": [0.1, 0.2], "metric2": [0.3, 0.4], ...}}
26 treatments_experiment_data (dict): Dictionary containing experiment results with keys
27 as experiment IDs and values as dicts of metric names to lists of values.
28 e.g., {"exp1": {"metric1": [0.1, 0.2], "metric2": [0.3, 0.4]},
29 "exp2": {"metric1": [0.5, 0.6], "metric2": [0.7, 0.8]}, ...}
30 """
32 # if results_key is provided, use it to extract the relevant data from the control and treatment data
33 # e.g., {"exp1": {"results": {"metric1": [0.1, 0.2], "metric2": [0.3, 0.4]}}
34 if results_key:
35 control_experiment_data = {k: v[results_key] for k, v in control_experiment_data.items()}
36 treatments_experiment_data = {k: v[results_key] for k, v in treatments_experiment_data.items()}
38 self.control_experiment_data = control_experiment_data
39 self.treatments_experiment_data = treatments_experiment_data
41 # Validate input data
42 self._validate_input_data()
44 def _validate_input_data(self):
45 """Validate the input data formats and structure."""
46 # Check that control and treatments are dictionaries
47 if not isinstance(self.control_experiment_data, dict):
48 raise TypeError("Control experiment data must be a dictionary")
49 if not isinstance(self.treatments_experiment_data, dict):
50 raise TypeError("Treatments experiment data must be a dictionary")
52 # Check that control has at least one experiment
53 if not self.control_experiment_data:
54 raise ValueError("Control experiment data cannot be empty")
56 # Check only one control
57 if len(self.control_experiment_data) > 1:
58 raise ValueError("Only one control experiment is allowed")
60 # Validate control experiment structure
61 for control_id, control_metrics in self.control_experiment_data.items():
62 if not isinstance(control_metrics, dict):
63 raise TypeError(f"Metrics for control experiment '{control_id}' must be a dictionary")
65 # Check that the metrics dictionary is not empty
66 if not control_metrics:
67 raise ValueError(f"Control experiment '{control_id}' has no metrics")
69 # Validate that metric values are lists
70 for metric, values in control_metrics.items():
71 if not isinstance(values, list):
72 raise TypeError(f"Values for metric '{metric}' in control experiment '{control_id}' must be a list")
74 # Check treatments have at least one experiment
75 if not self.treatments_experiment_data:
76 raise ValueError("Treatments experiment data cannot be empty")
78 # Validate treatment experiment structure
79 for treatment_id, treatment_data in self.treatments_experiment_data.items():
80 if not isinstance(treatment_data, dict):
81 raise TypeError(f"Data for treatment '{treatment_id}' must be a dictionary")
83 # Check that the metrics dictionary is not empty
84 if not treatment_data:
85 raise ValueError(f"Treatment '{treatment_id}' has no metrics")
87 # Get all control metrics for overlap checking
88 all_control_metrics = set()
89 for control_metrics in self.control_experiment_data.values():
90 all_control_metrics.update(control_metrics.keys())
92 # Check if there's any overlap between control and treatment metrics
93 common_metrics = all_control_metrics.intersection(set(treatment_data.keys()))
94 if not common_metrics:
95 logger.warning(f"Treatment '{treatment_id}' has no metrics in common with any control experiment")
97 # Check that treatment metrics are lists
98 for metric, values in treatment_data.items():
99 if not isinstance(values, list):
100 raise TypeError(f"Values for metric '{metric}' in treatment '{treatment_id}' must be a list")
102 def run_test(self,
103 test_type: str="welch_t_test",
104 alpha: float = 0.05,
105 **kwargs) -> Dict[str, Dict[str, Any]]:
106 """
107 Run the specified statistical test on the control and treatments data.
109 Args:
110 test_type (str): Type of statistical test to run.
111 Options: 't_test', 'welch_t_test', 'mann_whitney', 'anova', 'chi_square', 'ks_test'
112 alpha (float): Significance level, defaults to 0.05
113 **kwargs: Additional arguments for specific test types.
115 Returns:
116 dict: Dictionary containing the results of the statistical tests for each treatment (vs the one control).
117 Each key is the treatment ID and each value is a dictionary with test results.
118 """
119 supported_tests = {
120 't_test': self._run_t_test,
121 'welch_t_test': self._run_welch_t_test,
122 'mann_whitney': self._run_mann_whitney,
123 'anova': self._run_anova,
124 'chi_square': self._run_chi_square,
125 'ks_test': self._run_ks_test
126 }
128 if test_type not in supported_tests:
129 raise ValueError(f"Unsupported test type: {test_type}. Supported types: {list(supported_tests.keys())}")
131 results = {}
132 for control_id, control_data in self.control_experiment_data.items():
133 # get all metrics from control data
134 metrics = set()
135 metrics.update(control_data.keys())
136 for treatment_id, treatment_data in self.treatments_experiment_data.items():
137 results[treatment_id] = {}
139 for metric in metrics:
140 # Skip metrics not in treatment data
141 if metric not in treatment_data:
142 logger.warning(f"Metric '{metric}' not found in treatment '{treatment_id}'")
143 continue
145 control_values = control_data[metric]
146 treatment_values = treatment_data[metric]
148 # Skip if either control or treatment has no values
149 if len(control_values) == 0 or len(treatment_values) == 0:
150 logger.warning(f"Skipping metric '{metric}' for treatment '{treatment_id}' due to empty values")
151 continue
153 # Run the selected test and convert to JSON serializable types
154 test_result = supported_tests[test_type](control_values, treatment_values, alpha, **kwargs)
155 results[treatment_id][metric] = convert_to_serializable(test_result)
157 return results
159 def _run_t_test(self, control_values: list, treatment_values: list, alpha: float, **kwargs) -> Dict[str, Any]:
160 """Run Student's t-test (equal variance assumed)."""
161 # Convert to numpy arrays for calculations
162 control = np.array(control_values, dtype=float)
163 treatment = np.array(treatment_values, dtype=float)
165 # Calculate basic statistics
166 control_mean = np.mean(control)
167 treatment_mean = np.mean(treatment)
168 mean_diff = treatment_mean - control_mean
170 # Run the t-test
171 t_stat, p_value = stats.ttest_ind(control, treatment, equal_var=True)
173 # Calculate confidence interval
174 control_std = np.std(control, ddof=1)
175 treatment_std = np.std(treatment, ddof=1)
176 pooled_std = np.sqrt(((len(control) - 1) * control_std**2 +
177 (len(treatment) - 1) * treatment_std**2) /
178 (len(control) + len(treatment) - 2))
180 se = pooled_std * np.sqrt(1/len(control) + 1/len(treatment))
181 critical_value = stats.t.ppf(1 - alpha/2, len(control) + len(treatment) - 2)
182 margin_error = critical_value * se
183 ci_lower = mean_diff - margin_error
184 ci_upper = mean_diff + margin_error
186 # Determine if the result is significant
187 significant = p_value < alpha
189 return {
190 'test_type': 'Student t-test (equal variance)',
191 'control_mean': control_mean,
192 'treatment_mean': treatment_mean,
193 'mean_difference': mean_diff,
194 'percent_change': (mean_diff / control_mean * 100) if control_mean != 0 else float('inf'),
195 't_statistic': t_stat,
196 'p_value': p_value,
197 'confidence_interval': (ci_lower, ci_upper),
198 'confidence_level': 1 - alpha,
199 'significant': significant,
200 'control_sample_size': len(control),
201 'treatment_sample_size': len(treatment),
202 'control_std': control_std,
203 'treatment_std': treatment_std,
204 'effect_size': cohen_d(control, treatment)
205 }
207 def _run_welch_t_test(self, control_values: list, treatment_values: list, alpha: float, **kwargs) -> Dict[str, Any]:
208 """Run Welch's t-test (unequal variance)."""
209 # Convert to numpy arrays for calculations
210 control = np.array(control_values, dtype=float)
211 treatment = np.array(treatment_values, dtype=float)
213 # Calculate basic statistics
214 control_mean = np.mean(control)
215 treatment_mean = np.mean(treatment)
216 mean_diff = treatment_mean - control_mean
218 # Run Welch's t-test
219 t_stat, p_value = stats.ttest_ind(control, treatment, equal_var=False)
221 # Calculate confidence interval (for Welch's t-test)
222 control_var = np.var(control, ddof=1)
223 treatment_var = np.var(treatment, ddof=1)
225 # Calculate effective degrees of freedom (Welch-Satterthwaite equation)
226 v_num = (control_var/len(control) + treatment_var/len(treatment))**2
227 v_denom = (control_var/len(control))**2/(len(control)-1) + (treatment_var/len(treatment))**2/(len(treatment)-1)
228 df = v_num / v_denom if v_denom > 0 else float('inf')
230 se = np.sqrt(control_var/len(control) + treatment_var/len(treatment))
231 critical_value = stats.t.ppf(1 - alpha/2, df)
232 margin_error = critical_value * se
233 ci_lower = mean_diff - margin_error
234 ci_upper = mean_diff + margin_error
236 control_std = np.std(control, ddof=1)
237 treatment_std = np.std(treatment, ddof=1)
239 # Determine if the result is significant
240 significant = p_value < alpha
242 return {
243 'test_type': 'Welch t-test (unequal variance)',
244 'control_mean': control_mean,
245 'treatment_mean': treatment_mean,
246 'mean_difference': mean_diff,
247 'percent_change': (mean_diff / control_mean * 100) if control_mean != 0 else float('inf'),
248 't_statistic': t_stat,
249 'p_value': p_value,
250 'confidence_interval': (ci_lower, ci_upper),
251 'confidence_level': 1 - alpha,
252 'significant': significant,
253 'degrees_of_freedom': df,
254 'control_sample_size': len(control),
255 'treatment_sample_size': len(treatment),
256 'control_std': control_std,
257 'treatment_std': treatment_std,
258 'effect_size': cohen_d(control, treatment)
259 }
261 def _run_mann_whitney(self, control_values: list, treatment_values: list, alpha: float, **kwargs) -> Dict[str, Any]:
262 """Run Mann-Whitney U test (non-parametric test)."""
263 # Convert to numpy arrays
264 control = np.array(control_values, dtype=float)
265 treatment = np.array(treatment_values, dtype=float)
267 # Calculate basic statistics
268 control_median = np.median(control)
269 treatment_median = np.median(treatment)
270 median_diff = treatment_median - control_median
272 # Run the Mann-Whitney U test
273 u_stat, p_value = stats.mannwhitneyu(control, treatment, alternative='two-sided')
275 # Calculate common language effect size
276 # (probability that a randomly selected value from treatment is greater than control)
277 count = 0
278 for tc in treatment:
279 for cc in control:
280 if tc > cc:
281 count += 1
282 cles = count / (len(treatment) * len(control))
284 # Calculate approximate confidence interval using bootstrap
285 try:
286 from scipy.stats import bootstrap
288 def median_diff_func(x, y):
289 return np.median(x) - np.median(y)
291 res = bootstrap((control, treatment), median_diff_func,
292 confidence_level=1-alpha,
293 n_resamples=1000,
294 random_state=42)
295 ci_lower, ci_upper = res.confidence_interval
296 except ImportError:
297 # If bootstrap is not available, return None for confidence interval
298 ci_lower, ci_upper = None, None
299 logger.warning("SciPy bootstrap not available, skipping confidence interval calculation")
301 # Determine if the result is significant
302 significant = p_value < alpha
304 return {
305 'test_type': 'Mann-Whitney U test',
306 'control_median': control_median,
307 'treatment_median': treatment_median,
308 'median_difference': median_diff,
309 'percent_change': (median_diff / control_median * 100) if control_median != 0 else float('inf'),
310 'u_statistic': u_stat,
311 'p_value': p_value,
312 'confidence_interval': (ci_lower, ci_upper) if ci_lower is not None else None,
313 'confidence_level': 1 - alpha,
314 'significant': significant,
315 'control_sample_size': len(control),
316 'treatment_sample_size': len(treatment),
317 'effect_size': cles
318 }
320 def _run_anova(self, control_values: list, treatment_values: list, alpha: float, **kwargs) -> Dict[str, Any]:
321 """Run one-way ANOVA test."""
322 # For ANOVA, we typically need multiple groups, but we can still run it with just two
323 # Convert to numpy arrays
324 control = np.array(control_values, dtype=float)
325 treatment = np.array(treatment_values, dtype=float)
327 # Run one-way ANOVA
328 f_stat, p_value = stats.f_oneway(control, treatment)
330 # Calculate effect size (eta-squared)
331 total_values = np.concatenate([control, treatment])
332 grand_mean = np.mean(total_values)
334 ss_total = np.sum((total_values - grand_mean) ** 2)
335 ss_between = (len(control) * (np.mean(control) - grand_mean) ** 2 +
336 len(treatment) * (np.mean(treatment) - grand_mean) ** 2)
338 eta_squared = ss_between / ss_total if ss_total > 0 else 0
340 # Determine if the result is significant
341 significant = p_value < alpha
343 return {
344 'test_type': 'One-way ANOVA',
345 'f_statistic': f_stat,
346 'p_value': p_value,
347 'significant': significant,
348 'control_sample_size': len(control),
349 'treatment_sample_size': len(treatment),
350 'effect_size': eta_squared,
351 'effect_size_type': 'eta_squared'
352 }
354 def _run_chi_square(self, control_values: list, treatment_values: list, alpha: float, **kwargs) -> Dict[str, Any]:
355 """Run Chi-square test for categorical data."""
356 # For chi-square, we assume the values represent counts in different categories
357 # Convert to numpy arrays
358 control = np.array(control_values, dtype=float)
359 treatment = np.array(treatment_values, dtype=float)
361 # Check if the arrays are the same length (same number of categories)
362 if len(control) != len(treatment):
363 raise ValueError("Control and treatment must have the same number of categories for chi-square test")
365 # Run chi-square test
366 contingency_table = np.vstack([control, treatment])
367 chi2_stat, p_value, dof, expected = stats.chi2_contingency(contingency_table)
369 # Calculate Cramer's V as effect size
370 n = np.sum(contingency_table)
371 min_dim = min(contingency_table.shape) - 1
372 cramers_v = np.sqrt(chi2_stat / (n * min_dim)) if n * min_dim > 0 else 0
374 # Determine if the result is significant
375 significant = p_value < alpha
377 return {
378 'test_type': 'Chi-square test',
379 'chi2_statistic': chi2_stat,
380 'p_value': p_value,
381 'degrees_of_freedom': dof,
382 'significant': significant,
383 'effect_size': cramers_v,
384 'effect_size_type': 'cramers_v'
385 }
387 def check_assumptions(self, metric: str) -> Dict[str, Dict[str, Any]]:
388 """
389 Check statistical assumptions for the given metric across all treatments.
391 Args:
392 metric (str): The metric to check assumptions for.
394 Returns:
395 dict: Dictionary with results of assumption checks for each treatment.
396 """
397 if metric not in self.control_experiment_data:
398 raise ValueError(f"Metric '{metric}' not found in control data")
400 results = {}
401 control_values = np.array(self.control_experiment_data[metric], dtype=float)
403 # Check normality of control
404 control_shapiro = stats.shapiro(control_values)
405 control_normality = {
406 'test': 'Shapiro-Wilk',
407 'statistic': control_shapiro[0],
408 'p_value': control_shapiro[1],
409 'normal': control_shapiro[1] >= 0.05
410 }
412 for treatment_id, treatment_data in self.treatments_experiment_data.items():
413 if metric not in treatment_data:
414 logger.warning(f"Metric '{metric}' not found in treatment '{treatment_id}'")
415 continue
417 treatment_values = np.array(treatment_data[metric], dtype=float)
419 # Check normality of treatment
420 treatment_shapiro = stats.shapiro(treatment_values)
421 treatment_normality = {
422 'test': 'Shapiro-Wilk',
423 'statistic': treatment_shapiro[0],
424 'p_value': treatment_shapiro[1],
425 'normal': treatment_shapiro[1] >= 0.05
426 }
428 # Check homogeneity of variance
429 levene_test = stats.levene(control_values, treatment_values)
430 variance_homogeneity = {
431 'test': 'Levene',
432 'statistic': levene_test[0],
433 'p_value': levene_test[1],
434 'equal_variance': levene_test[1] >= 0.05
435 }
437 # Store results and convert to JSON serializable types
438 results[treatment_id] = convert_to_serializable({
439 'control_normality': control_normality,
440 'treatment_normality': treatment_normality,
441 'variance_homogeneity': variance_homogeneity,
442 'recommended_test': self._recommend_test(control_normality['normal'],
443 treatment_normality['normal'],
444 variance_homogeneity['equal_variance'])
445 })
447 return results
449 def _recommend_test(self, control_normal: bool, treatment_normal: bool, equal_variance: bool) -> str:
450 """Recommend a statistical test based on assumption checks."""
451 if control_normal and treatment_normal:
452 if equal_variance:
453 return 't_test'
454 else:
455 return 'welch_t_test'
456 else:
457 return 'mann_whitney'
459 def _run_ks_test(self, control_values: list, treatment_values: list, alpha: float, **kwargs) -> Dict[str, Any]:
460 """
461 Run Kolmogorov-Smirnov test to compare distributions.
463 This test compares the empirical cumulative distribution functions (ECDFs) of two samples
464 to determine if they come from the same distribution. It's particularly useful for:
465 - Categorical responses (e.g., "Yes"/"No"/"Maybe") when converted to ordinal values
466 - Continuous data where you want to compare entire distributions, not just means
467 - Detecting differences in distribution shape, spread, or location
468 """
469 # Convert to numpy arrays
470 control = np.array(control_values, dtype=float)
471 treatment = np.array(treatment_values, dtype=float)
473 # Calculate basic statistics
474 control_median = np.median(control)
475 treatment_median = np.median(treatment)
476 control_mean = np.mean(control)
477 treatment_mean = np.mean(treatment)
479 # Run the Kolmogorov-Smirnov test
480 ks_stat, p_value = stats.ks_2samp(control, treatment)
482 # Calculate distribution characteristics
483 control_std = np.std(control, ddof=1)
484 treatment_std = np.std(treatment, ddof=1)
486 # Calculate effect size using the KS statistic itself as a measure
487 # KS statistic ranges from 0 (identical distributions) to 1 (completely different)
488 effect_size = ks_stat
490 # Additional distribution comparison metrics
491 # Calculate overlap coefficient (area under the minimum of two PDFs)
492 try:
493 # Create histograms for overlap calculation
494 combined_range = np.linspace(
495 min(np.min(control), np.min(treatment)),
496 max(np.max(control), np.max(treatment)),
497 50
498 )
499 control_hist, _ = np.histogram(control, bins=combined_range, density=True)
500 treatment_hist, _ = np.histogram(treatment, bins=combined_range, density=True)
502 # Calculate overlap (intersection over union-like metric)
503 overlap = np.sum(np.minimum(control_hist, treatment_hist)) / np.sum(np.maximum(control_hist, treatment_hist))
504 overlap = overlap if not np.isnan(overlap) else 0.0
505 except:
506 overlap = None
508 # Calculate percentile differences for additional insights
509 percentiles = [25, 50, 75, 90, 95]
510 percentile_diffs = {}
511 for p in percentiles:
512 control_p = np.percentile(control, p)
513 treatment_p = np.percentile(treatment, p)
514 percentile_diffs[f"p{p}_diff"] = treatment_p - control_p
516 # Determine significance
517 significant = p_value < alpha
519 return {
520 'test_type': 'Kolmogorov-Smirnov test',
521 'control_mean': control_mean,
522 'treatment_mean': treatment_mean,
523 'control_median': control_median,
524 'treatment_median': treatment_median,
525 'control_std': control_std,
526 'treatment_std': treatment_std,
527 'ks_statistic': ks_stat,
528 'p_value': p_value,
529 'significant': significant,
530 'control_sample_size': len(control),
531 'treatment_sample_size': len(treatment),
532 'effect_size': effect_size,
533 'overlap_coefficient': overlap,
534 'percentile_differences': percentile_diffs,
535 'interpretation': self._interpret_ks_result(ks_stat, significant),
536 'confidence_level': 1 - alpha
537 }
539 def _interpret_ks_result(self, ks_stat: float, significant: bool) -> str:
540 """Provide interpretation of KS test results."""
541 if not significant:
542 return "No significant difference between distributions"
544 if ks_stat < 0.1:
545 return "Very small difference between distributions"
546 elif ks_stat < 0.25:
547 return "Small difference between distributions"
548 elif ks_stat < 0.5:
549 return "Moderate difference between distributions"
550 else:
551 return "Large difference between distributions"
554def cohen_d(x: Union[list, np.ndarray], y: Union[list, np.ndarray]) -> float:
555 """
556 Calculate Cohen's d effect size for two samples.
558 Args:
559 x: First sample
560 y: Second sample
562 Returns:
563 float: Cohen's d effect size
564 """
565 nx = len(x)
566 ny = len(y)
568 # Convert to numpy arrays
569 x = np.array(x, dtype=float)
570 y = np.array(y, dtype=float)
572 # Calculate means
573 mx = np.mean(x)
574 my = np.mean(y)
576 # Calculate standard deviations
577 sx = np.std(x, ddof=1)
578 sy = np.std(y, ddof=1)
580 # Pooled standard deviation
581 pooled_sd = np.sqrt(((nx - 1) * sx**2 + (ny - 1) * sy**2) / (nx + ny - 2))
583 # Cohen's d
584 return (my - mx) / pooled_sd if pooled_sd > 0 else 0
587def convert_to_serializable(obj):
588 """
589 Convert NumPy types to native Python types recursively to ensure JSON serialization works.
591 Args:
592 obj: Any object that might contain NumPy types
594 Returns:
595 Object with NumPy types converted to Python native types
596 """
597 if isinstance(obj, np.ndarray):
598 return obj.tolist()
599 elif isinstance(obj, (np.number, np.bool_)):
600 return obj.item()
601 elif isinstance(obj, dict):
602 return {k: convert_to_serializable(v) for k, v in obj.items()}
603 elif isinstance(obj, list):
604 return [convert_to_serializable(i) for i in obj]
605 elif isinstance(obj, tuple):
606 return tuple(convert_to_serializable(i) for i in obj)
607 else:
608 return obj