Coverage for tinytroupe / experimentation / statistical_tests.py: 0%

236 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-28 17:48 +0000

1import numpy as np 

2import scipy.stats as stats 

3from typing import Dict, List, Union, Callable, Any, Optional 

4 

5from tinytroupe.experimentation import logger 

6 

7 

8class StatisticalTester: 

9 """ 

10 A class to perform statistical tests on experiment results. To do so, a control is defined, and then one or  

11 more treatments are compared to the control. The class supports various statistical tests, including t-tests, 

12 Mann-Whitney U tests, and ANOVA. The user can specify the type of test to run, the significance level, and 

13 the specific metrics to analyze. The results of the tests are returned in a structured format. 

14 """ 

15 

16 def __init__(self, control_experiment_data: Dict[str, list], 

17 treatments_experiment_data: Dict[str, Dict[str, list]], 

18 results_key:str = None): 

19 """ 

20 Initialize with experiment results. 

21  

22 Args: 

23 control_experiment_data (dict): Dictionary containing control experiment results with keys  

24 as metric names and values as lists of values. 

25 e.g.,{"control_exp": {"metric1": [0.1, 0.2], "metric2": [0.3, 0.4], ...}} 

26 treatments_experiment_data (dict): Dictionary containing experiment results with keys  

27 as experiment IDs and values as dicts of metric names to lists of values. 

28 e.g., {"exp1": {"metric1": [0.1, 0.2], "metric2": [0.3, 0.4]},  

29 "exp2": {"metric1": [0.5, 0.6], "metric2": [0.7, 0.8]}, ...} 

30 """ 

31 

32 # if results_key is provided, use it to extract the relevant data from the control and treatment data 

33 # e.g., {"exp1": {"results": {"metric1": [0.1, 0.2], "metric2": [0.3, 0.4]}} 

34 if results_key: 

35 control_experiment_data = {k: v[results_key] for k, v in control_experiment_data.items()} 

36 treatments_experiment_data = {k: v[results_key] for k, v in treatments_experiment_data.items()} 

37 

38 self.control_experiment_data = control_experiment_data 

39 self.treatments_experiment_data = treatments_experiment_data 

40 

41 # Validate input data 

42 self._validate_input_data() 

43 

44 def _validate_input_data(self): 

45 """Validate the input data formats and structure.""" 

46 # Check that control and treatments are dictionaries 

47 if not isinstance(self.control_experiment_data, dict): 

48 raise TypeError("Control experiment data must be a dictionary") 

49 if not isinstance(self.treatments_experiment_data, dict): 

50 raise TypeError("Treatments experiment data must be a dictionary") 

51 

52 # Check that control has at least one experiment 

53 if not self.control_experiment_data: 

54 raise ValueError("Control experiment data cannot be empty") 

55 

56 # Check only one control 

57 if len(self.control_experiment_data) > 1: 

58 raise ValueError("Only one control experiment is allowed") 

59 

60 # Validate control experiment structure 

61 for control_id, control_metrics in self.control_experiment_data.items(): 

62 if not isinstance(control_metrics, dict): 

63 raise TypeError(f"Metrics for control experiment '{control_id}' must be a dictionary") 

64 

65 # Check that the metrics dictionary is not empty 

66 if not control_metrics: 

67 raise ValueError(f"Control experiment '{control_id}' has no metrics") 

68 

69 # Validate that metric values are lists 

70 for metric, values in control_metrics.items(): 

71 if not isinstance(values, list): 

72 raise TypeError(f"Values for metric '{metric}' in control experiment '{control_id}' must be a list") 

73 

74 # Check treatments have at least one experiment 

75 if not self.treatments_experiment_data: 

76 raise ValueError("Treatments experiment data cannot be empty") 

77 

78 # Validate treatment experiment structure 

79 for treatment_id, treatment_data in self.treatments_experiment_data.items(): 

80 if not isinstance(treatment_data, dict): 

81 raise TypeError(f"Data for treatment '{treatment_id}' must be a dictionary") 

82 

83 # Check that the metrics dictionary is not empty 

84 if not treatment_data: 

85 raise ValueError(f"Treatment '{treatment_id}' has no metrics") 

86 

87 # Get all control metrics for overlap checking 

88 all_control_metrics = set() 

89 for control_metrics in self.control_experiment_data.values(): 

90 all_control_metrics.update(control_metrics.keys()) 

91 

92 # Check if there's any overlap between control and treatment metrics 

93 common_metrics = all_control_metrics.intersection(set(treatment_data.keys())) 

94 if not common_metrics: 

95 logger.warning(f"Treatment '{treatment_id}' has no metrics in common with any control experiment") 

96 

97 # Check that treatment metrics are lists 

98 for metric, values in treatment_data.items(): 

99 if not isinstance(values, list): 

100 raise TypeError(f"Values for metric '{metric}' in treatment '{treatment_id}' must be a list") 

101 

102 def run_test(self, 

103 test_type: str="welch_t_test", 

104 alpha: float = 0.05, 

105 **kwargs) -> Dict[str, Dict[str, Any]]: 

106 """ 

107 Run the specified statistical test on the control and treatments data. 

108 

109 Args: 

110 test_type (str): Type of statistical test to run.  

111 Options: 't_test', 'welch_t_test', 'mann_whitney', 'anova', 'chi_square', 'ks_test' 

112 alpha (float): Significance level, defaults to 0.05 

113 **kwargs: Additional arguments for specific test types. 

114 

115 Returns: 

116 dict: Dictionary containing the results of the statistical tests for each treatment (vs the one control). 

117 Each key is the treatment ID and each value is a dictionary with test results. 

118 """ 

119 supported_tests = { 

120 't_test': self._run_t_test, 

121 'welch_t_test': self._run_welch_t_test, 

122 'mann_whitney': self._run_mann_whitney, 

123 'anova': self._run_anova, 

124 'chi_square': self._run_chi_square, 

125 'ks_test': self._run_ks_test 

126 } 

127 

128 if test_type not in supported_tests: 

129 raise ValueError(f"Unsupported test type: {test_type}. Supported types: {list(supported_tests.keys())}") 

130 

131 results = {} 

132 for control_id, control_data in self.control_experiment_data.items(): 

133 # get all metrics from control data 

134 metrics = set() 

135 metrics.update(control_data.keys()) 

136 for treatment_id, treatment_data in self.treatments_experiment_data.items(): 

137 results[treatment_id] = {} 

138 

139 for metric in metrics: 

140 # Skip metrics not in treatment data 

141 if metric not in treatment_data: 

142 logger.warning(f"Metric '{metric}' not found in treatment '{treatment_id}'") 

143 continue 

144 

145 control_values = control_data[metric] 

146 treatment_values = treatment_data[metric] 

147 

148 # Skip if either control or treatment has no values 

149 if len(control_values) == 0 or len(treatment_values) == 0: 

150 logger.warning(f"Skipping metric '{metric}' for treatment '{treatment_id}' due to empty values") 

151 continue 

152 

153 # Run the selected test and convert to JSON serializable types 

154 test_result = supported_tests[test_type](control_values, treatment_values, alpha, **kwargs) 

155 results[treatment_id][metric] = convert_to_serializable(test_result) 

156 

157 return results 

158 

159 def _run_t_test(self, control_values: list, treatment_values: list, alpha: float, **kwargs) -> Dict[str, Any]: 

160 """Run Student's t-test (equal variance assumed).""" 

161 # Convert to numpy arrays for calculations 

162 control = np.array(control_values, dtype=float) 

163 treatment = np.array(treatment_values, dtype=float) 

164 

165 # Calculate basic statistics 

166 control_mean = np.mean(control) 

167 treatment_mean = np.mean(treatment) 

168 mean_diff = treatment_mean - control_mean 

169 

170 # Run the t-test 

171 t_stat, p_value = stats.ttest_ind(control, treatment, equal_var=True) 

172 

173 # Calculate confidence interval 

174 control_std = np.std(control, ddof=1) 

175 treatment_std = np.std(treatment, ddof=1) 

176 pooled_std = np.sqrt(((len(control) - 1) * control_std**2 + 

177 (len(treatment) - 1) * treatment_std**2) / 

178 (len(control) + len(treatment) - 2)) 

179 

180 se = pooled_std * np.sqrt(1/len(control) + 1/len(treatment)) 

181 critical_value = stats.t.ppf(1 - alpha/2, len(control) + len(treatment) - 2) 

182 margin_error = critical_value * se 

183 ci_lower = mean_diff - margin_error 

184 ci_upper = mean_diff + margin_error 

185 

186 # Determine if the result is significant 

187 significant = p_value < alpha 

188 

189 return { 

190 'test_type': 'Student t-test (equal variance)', 

191 'control_mean': control_mean, 

192 'treatment_mean': treatment_mean, 

193 'mean_difference': mean_diff, 

194 'percent_change': (mean_diff / control_mean * 100) if control_mean != 0 else float('inf'), 

195 't_statistic': t_stat, 

196 'p_value': p_value, 

197 'confidence_interval': (ci_lower, ci_upper), 

198 'confidence_level': 1 - alpha, 

199 'significant': significant, 

200 'control_sample_size': len(control), 

201 'treatment_sample_size': len(treatment), 

202 'control_std': control_std, 

203 'treatment_std': treatment_std, 

204 'effect_size': cohen_d(control, treatment) 

205 } 

206 

207 def _run_welch_t_test(self, control_values: list, treatment_values: list, alpha: float, **kwargs) -> Dict[str, Any]: 

208 """Run Welch's t-test (unequal variance).""" 

209 # Convert to numpy arrays for calculations 

210 control = np.array(control_values, dtype=float) 

211 treatment = np.array(treatment_values, dtype=float) 

212 

213 # Calculate basic statistics 

214 control_mean = np.mean(control) 

215 treatment_mean = np.mean(treatment) 

216 mean_diff = treatment_mean - control_mean 

217 

218 # Run Welch's t-test 

219 t_stat, p_value = stats.ttest_ind(control, treatment, equal_var=False) 

220 

221 # Calculate confidence interval (for Welch's t-test) 

222 control_var = np.var(control, ddof=1) 

223 treatment_var = np.var(treatment, ddof=1) 

224 

225 # Calculate effective degrees of freedom (Welch-Satterthwaite equation) 

226 v_num = (control_var/len(control) + treatment_var/len(treatment))**2 

227 v_denom = (control_var/len(control))**2/(len(control)-1) + (treatment_var/len(treatment))**2/(len(treatment)-1) 

228 df = v_num / v_denom if v_denom > 0 else float('inf') 

229 

230 se = np.sqrt(control_var/len(control) + treatment_var/len(treatment)) 

231 critical_value = stats.t.ppf(1 - alpha/2, df) 

232 margin_error = critical_value * se 

233 ci_lower = mean_diff - margin_error 

234 ci_upper = mean_diff + margin_error 

235 

236 control_std = np.std(control, ddof=1) 

237 treatment_std = np.std(treatment, ddof=1) 

238 

239 # Determine if the result is significant 

240 significant = p_value < alpha 

241 

242 return { 

243 'test_type': 'Welch t-test (unequal variance)', 

244 'control_mean': control_mean, 

245 'treatment_mean': treatment_mean, 

246 'mean_difference': mean_diff, 

247 'percent_change': (mean_diff / control_mean * 100) if control_mean != 0 else float('inf'), 

248 't_statistic': t_stat, 

249 'p_value': p_value, 

250 'confidence_interval': (ci_lower, ci_upper), 

251 'confidence_level': 1 - alpha, 

252 'significant': significant, 

253 'degrees_of_freedom': df, 

254 'control_sample_size': len(control), 

255 'treatment_sample_size': len(treatment), 

256 'control_std': control_std, 

257 'treatment_std': treatment_std, 

258 'effect_size': cohen_d(control, treatment) 

259 } 

260 

261 def _run_mann_whitney(self, control_values: list, treatment_values: list, alpha: float, **kwargs) -> Dict[str, Any]: 

262 """Run Mann-Whitney U test (non-parametric test).""" 

263 # Convert to numpy arrays 

264 control = np.array(control_values, dtype=float) 

265 treatment = np.array(treatment_values, dtype=float) 

266 

267 # Calculate basic statistics 

268 control_median = np.median(control) 

269 treatment_median = np.median(treatment) 

270 median_diff = treatment_median - control_median 

271 

272 # Run the Mann-Whitney U test 

273 u_stat, p_value = stats.mannwhitneyu(control, treatment, alternative='two-sided') 

274 

275 # Calculate common language effect size 

276 # (probability that a randomly selected value from treatment is greater than control) 

277 count = 0 

278 for tc in treatment: 

279 for cc in control: 

280 if tc > cc: 

281 count += 1 

282 cles = count / (len(treatment) * len(control)) 

283 

284 # Calculate approximate confidence interval using bootstrap 

285 try: 

286 from scipy.stats import bootstrap 

287 

288 def median_diff_func(x, y): 

289 return np.median(x) - np.median(y) 

290 

291 res = bootstrap((control, treatment), median_diff_func, 

292 confidence_level=1-alpha, 

293 n_resamples=1000, 

294 random_state=42) 

295 ci_lower, ci_upper = res.confidence_interval 

296 except ImportError: 

297 # If bootstrap is not available, return None for confidence interval 

298 ci_lower, ci_upper = None, None 

299 logger.warning("SciPy bootstrap not available, skipping confidence interval calculation") 

300 

301 # Determine if the result is significant 

302 significant = p_value < alpha 

303 

304 return { 

305 'test_type': 'Mann-Whitney U test', 

306 'control_median': control_median, 

307 'treatment_median': treatment_median, 

308 'median_difference': median_diff, 

309 'percent_change': (median_diff / control_median * 100) if control_median != 0 else float('inf'), 

310 'u_statistic': u_stat, 

311 'p_value': p_value, 

312 'confidence_interval': (ci_lower, ci_upper) if ci_lower is not None else None, 

313 'confidence_level': 1 - alpha, 

314 'significant': significant, 

315 'control_sample_size': len(control), 

316 'treatment_sample_size': len(treatment), 

317 'effect_size': cles 

318 } 

319 

320 def _run_anova(self, control_values: list, treatment_values: list, alpha: float, **kwargs) -> Dict[str, Any]: 

321 """Run one-way ANOVA test.""" 

322 # For ANOVA, we typically need multiple groups, but we can still run it with just two 

323 # Convert to numpy arrays 

324 control = np.array(control_values, dtype=float) 

325 treatment = np.array(treatment_values, dtype=float) 

326 

327 # Run one-way ANOVA 

328 f_stat, p_value = stats.f_oneway(control, treatment) 

329 

330 # Calculate effect size (eta-squared) 

331 total_values = np.concatenate([control, treatment]) 

332 grand_mean = np.mean(total_values) 

333 

334 ss_total = np.sum((total_values - grand_mean) ** 2) 

335 ss_between = (len(control) * (np.mean(control) - grand_mean) ** 2 + 

336 len(treatment) * (np.mean(treatment) - grand_mean) ** 2) 

337 

338 eta_squared = ss_between / ss_total if ss_total > 0 else 0 

339 

340 # Determine if the result is significant 

341 significant = p_value < alpha 

342 

343 return { 

344 'test_type': 'One-way ANOVA', 

345 'f_statistic': f_stat, 

346 'p_value': p_value, 

347 'significant': significant, 

348 'control_sample_size': len(control), 

349 'treatment_sample_size': len(treatment), 

350 'effect_size': eta_squared, 

351 'effect_size_type': 'eta_squared' 

352 } 

353 

354 def _run_chi_square(self, control_values: list, treatment_values: list, alpha: float, **kwargs) -> Dict[str, Any]: 

355 """Run Chi-square test for categorical data.""" 

356 # For chi-square, we assume the values represent counts in different categories 

357 # Convert to numpy arrays 

358 control = np.array(control_values, dtype=float) 

359 treatment = np.array(treatment_values, dtype=float) 

360 

361 # Check if the arrays are the same length (same number of categories) 

362 if len(control) != len(treatment): 

363 raise ValueError("Control and treatment must have the same number of categories for chi-square test") 

364 

365 # Run chi-square test 

366 contingency_table = np.vstack([control, treatment]) 

367 chi2_stat, p_value, dof, expected = stats.chi2_contingency(contingency_table) 

368 

369 # Calculate Cramer's V as effect size 

370 n = np.sum(contingency_table) 

371 min_dim = min(contingency_table.shape) - 1 

372 cramers_v = np.sqrt(chi2_stat / (n * min_dim)) if n * min_dim > 0 else 0 

373 

374 # Determine if the result is significant 

375 significant = p_value < alpha 

376 

377 return { 

378 'test_type': 'Chi-square test', 

379 'chi2_statistic': chi2_stat, 

380 'p_value': p_value, 

381 'degrees_of_freedom': dof, 

382 'significant': significant, 

383 'effect_size': cramers_v, 

384 'effect_size_type': 'cramers_v' 

385 } 

386 

387 def check_assumptions(self, metric: str) -> Dict[str, Dict[str, Any]]: 

388 """ 

389 Check statistical assumptions for the given metric across all treatments. 

390  

391 Args: 

392 metric (str): The metric to check assumptions for. 

393  

394 Returns: 

395 dict: Dictionary with results of assumption checks for each treatment. 

396 """ 

397 if metric not in self.control_experiment_data: 

398 raise ValueError(f"Metric '{metric}' not found in control data") 

399 

400 results = {} 

401 control_values = np.array(self.control_experiment_data[metric], dtype=float) 

402 

403 # Check normality of control 

404 control_shapiro = stats.shapiro(control_values) 

405 control_normality = { 

406 'test': 'Shapiro-Wilk', 

407 'statistic': control_shapiro[0], 

408 'p_value': control_shapiro[1], 

409 'normal': control_shapiro[1] >= 0.05 

410 } 

411 

412 for treatment_id, treatment_data in self.treatments_experiment_data.items(): 

413 if metric not in treatment_data: 

414 logger.warning(f"Metric '{metric}' not found in treatment '{treatment_id}'") 

415 continue 

416 

417 treatment_values = np.array(treatment_data[metric], dtype=float) 

418 

419 # Check normality of treatment 

420 treatment_shapiro = stats.shapiro(treatment_values) 

421 treatment_normality = { 

422 'test': 'Shapiro-Wilk', 

423 'statistic': treatment_shapiro[0], 

424 'p_value': treatment_shapiro[1], 

425 'normal': treatment_shapiro[1] >= 0.05 

426 } 

427 

428 # Check homogeneity of variance 

429 levene_test = stats.levene(control_values, treatment_values) 

430 variance_homogeneity = { 

431 'test': 'Levene', 

432 'statistic': levene_test[0], 

433 'p_value': levene_test[1], 

434 'equal_variance': levene_test[1] >= 0.05 

435 } 

436 

437 # Store results and convert to JSON serializable types 

438 results[treatment_id] = convert_to_serializable({ 

439 'control_normality': control_normality, 

440 'treatment_normality': treatment_normality, 

441 'variance_homogeneity': variance_homogeneity, 

442 'recommended_test': self._recommend_test(control_normality['normal'], 

443 treatment_normality['normal'], 

444 variance_homogeneity['equal_variance']) 

445 }) 

446 

447 return results 

448 

449 def _recommend_test(self, control_normal: bool, treatment_normal: bool, equal_variance: bool) -> str: 

450 """Recommend a statistical test based on assumption checks.""" 

451 if control_normal and treatment_normal: 

452 if equal_variance: 

453 return 't_test' 

454 else: 

455 return 'welch_t_test' 

456 else: 

457 return 'mann_whitney' 

458 

459 def _run_ks_test(self, control_values: list, treatment_values: list, alpha: float, **kwargs) -> Dict[str, Any]: 

460 """ 

461 Run Kolmogorov-Smirnov test to compare distributions. 

462  

463 This test compares the empirical cumulative distribution functions (ECDFs) of two samples 

464 to determine if they come from the same distribution. It's particularly useful for: 

465 - Categorical responses (e.g., "Yes"/"No"/"Maybe") when converted to ordinal values 

466 - Continuous data where you want to compare entire distributions, not just means 

467 - Detecting differences in distribution shape, spread, or location 

468 """ 

469 # Convert to numpy arrays 

470 control = np.array(control_values, dtype=float) 

471 treatment = np.array(treatment_values, dtype=float) 

472 

473 # Calculate basic statistics 

474 control_median = np.median(control) 

475 treatment_median = np.median(treatment) 

476 control_mean = np.mean(control) 

477 treatment_mean = np.mean(treatment) 

478 

479 # Run the Kolmogorov-Smirnov test 

480 ks_stat, p_value = stats.ks_2samp(control, treatment) 

481 

482 # Calculate distribution characteristics 

483 control_std = np.std(control, ddof=1) 

484 treatment_std = np.std(treatment, ddof=1) 

485 

486 # Calculate effect size using the KS statistic itself as a measure 

487 # KS statistic ranges from 0 (identical distributions) to 1 (completely different) 

488 effect_size = ks_stat 

489 

490 # Additional distribution comparison metrics 

491 # Calculate overlap coefficient (area under the minimum of two PDFs) 

492 try: 

493 # Create histograms for overlap calculation 

494 combined_range = np.linspace( 

495 min(np.min(control), np.min(treatment)), 

496 max(np.max(control), np.max(treatment)), 

497 50 

498 ) 

499 control_hist, _ = np.histogram(control, bins=combined_range, density=True) 

500 treatment_hist, _ = np.histogram(treatment, bins=combined_range, density=True) 

501 

502 # Calculate overlap (intersection over union-like metric) 

503 overlap = np.sum(np.minimum(control_hist, treatment_hist)) / np.sum(np.maximum(control_hist, treatment_hist)) 

504 overlap = overlap if not np.isnan(overlap) else 0.0 

505 except: 

506 overlap = None 

507 

508 # Calculate percentile differences for additional insights 

509 percentiles = [25, 50, 75, 90, 95] 

510 percentile_diffs = {} 

511 for p in percentiles: 

512 control_p = np.percentile(control, p) 

513 treatment_p = np.percentile(treatment, p) 

514 percentile_diffs[f"p{p}_diff"] = treatment_p - control_p 

515 

516 # Determine significance 

517 significant = p_value < alpha 

518 

519 return { 

520 'test_type': 'Kolmogorov-Smirnov test', 

521 'control_mean': control_mean, 

522 'treatment_mean': treatment_mean, 

523 'control_median': control_median, 

524 'treatment_median': treatment_median, 

525 'control_std': control_std, 

526 'treatment_std': treatment_std, 

527 'ks_statistic': ks_stat, 

528 'p_value': p_value, 

529 'significant': significant, 

530 'control_sample_size': len(control), 

531 'treatment_sample_size': len(treatment), 

532 'effect_size': effect_size, 

533 'overlap_coefficient': overlap, 

534 'percentile_differences': percentile_diffs, 

535 'interpretation': self._interpret_ks_result(ks_stat, significant), 

536 'confidence_level': 1 - alpha 

537 } 

538 

539 def _interpret_ks_result(self, ks_stat: float, significant: bool) -> str: 

540 """Provide interpretation of KS test results.""" 

541 if not significant: 

542 return "No significant difference between distributions" 

543 

544 if ks_stat < 0.1: 

545 return "Very small difference between distributions" 

546 elif ks_stat < 0.25: 

547 return "Small difference between distributions" 

548 elif ks_stat < 0.5: 

549 return "Moderate difference between distributions" 

550 else: 

551 return "Large difference between distributions" 

552 

553 

554def cohen_d(x: Union[list, np.ndarray], y: Union[list, np.ndarray]) -> float: 

555 """ 

556 Calculate Cohen's d effect size for two samples. 

557  

558 Args: 

559 x: First sample 

560 y: Second sample 

561  

562 Returns: 

563 float: Cohen's d effect size 

564 """ 

565 nx = len(x) 

566 ny = len(y) 

567 

568 # Convert to numpy arrays 

569 x = np.array(x, dtype=float) 

570 y = np.array(y, dtype=float) 

571 

572 # Calculate means 

573 mx = np.mean(x) 

574 my = np.mean(y) 

575 

576 # Calculate standard deviations 

577 sx = np.std(x, ddof=1) 

578 sy = np.std(y, ddof=1) 

579 

580 # Pooled standard deviation 

581 pooled_sd = np.sqrt(((nx - 1) * sx**2 + (ny - 1) * sy**2) / (nx + ny - 2)) 

582 

583 # Cohen's d 

584 return (my - mx) / pooled_sd if pooled_sd > 0 else 0 

585 

586 

587def convert_to_serializable(obj): 

588 """ 

589 Convert NumPy types to native Python types recursively to ensure JSON serialization works. 

590  

591 Args: 

592 obj: Any object that might contain NumPy types 

593  

594 Returns: 

595 Object with NumPy types converted to Python native types 

596 """ 

597 if isinstance(obj, np.ndarray): 

598 return obj.tolist() 

599 elif isinstance(obj, (np.number, np.bool_)): 

600 return obj.item() 

601 elif isinstance(obj, dict): 

602 return {k: convert_to_serializable(v) for k, v in obj.items()} 

603 elif isinstance(obj, list): 

604 return [convert_to_serializable(i) for i in obj] 

605 elif isinstance(obj, tuple): 

606 return tuple(convert_to_serializable(i) for i in obj) 

607 else: 

608 return obj