#!/usr/bin/env python """ NB-Transformer P-value Calibration Validation Script This script validates that the NB-Transformer produces properly calibrated p-values under the null hypothesis (β = 0, no differential expression). Well-calibrated p-values should follow a Uniform(0,1) distribution under the null. The script: 1. Generates null test cases (β = 0) 2. Estimates parameters and computes p-values using Fisher information 3. Creates QQ plots comparing observed vs expected quantiles 4. Performs statistical tests for uniformity (Kolmogorov-Smirnov, Anderson-Darling) Usage: python validate_calibration.py --n_tests 10000 --output_dir results/ Expected Results: - Well-calibrated p-values should follow diagonal line in QQ plot - K-S and A-D tests should NOT be significant (p > 0.05) - False positive rate should be ~5% at α = 0.05 """ import os import sys import argparse import numpy as np import pandas as pd import matplotlib.pyplot as plt from typing import Dict, List, Tuple from scipy import stats import warnings # Import nb-transformer try: from nb_transformer import load_pretrained_model, validate_calibration, summarize_calibration_results TRANSFORMER_AVAILABLE = True except ImportError: TRANSFORMER_AVAILABLE = False print("Warning: nb-transformer not available. Install with: pip install nb-transformer") # Import plotting theme try: from theme_nxn import theme_nxn, get_nxn_palette THEME_AVAILABLE = True except ImportError: THEME_AVAILABLE = False print("Warning: theme_nxn not available, using default matplotlib styling") def generate_null_test_data(n_tests: int = 10000, seed: int = 42) -> List[Dict]: """ Generate test cases under null hypothesis (β = 0). Returns: List of test cases with β = 0 (no differential expression) """ print(f"Generating {n_tests} null hypothesis test cases (β = 0)...") np.random.seed(seed) test_cases = [] for i in range(n_tests): # Sample parameters under null mu_true = np.random.normal(-1.0, 2.0) # Base mean (log scale) alpha_true = np.random.normal(-2.0, 1.0) # Dispersion (log scale) beta_true = 0.0 # NULL HYPOTHESIS: no differential expression # Random experimental design (3-9 samples per condition) n1 = np.random.randint(3, 10) n2 = np.random.randint(3, 10) # Sample library sizes lib_sizes_1 = np.random.lognormal(np.log(10000) - 0.5*np.log(1.09), np.sqrt(np.log(1.09)), n1) lib_sizes_2 = np.random.lognormal(np.log(10000) - 0.5*np.log(1.09), np.sqrt(np.log(1.09)), n2) # Generate counts under null (same mean expression in both conditions) mean_expr = np.exp(mu_true) dispersion = np.exp(alpha_true) # Both conditions have same mean expression (β = 0) counts_1 = [] for lib_size in lib_sizes_1: mean_count = lib_size * mean_expr r = 1.0 / dispersion p = r / (r + mean_count) count = np.random.negative_binomial(r, p) counts_1.append(count) counts_2 = [] for lib_size in lib_sizes_2: mean_count = lib_size * mean_expr # Same as condition 1 (β = 0) r = 1.0 / dispersion p = r / (r + mean_count) count = np.random.negative_binomial(r, p) counts_2.append(count) # Transform data for transformer transformed_1 = [np.log10(1e4 * c / l + 1) for c, l in zip(counts_1, lib_sizes_1)] transformed_2 = [np.log10(1e4 * c / l + 1) for c, l in zip(counts_2, lib_sizes_2)] test_cases.append({ 'mu_true': mu_true, 'beta_true': beta_true, # Always 0 under null 'alpha_true': alpha_true, 'counts_1': np.array(counts_1), 'counts_2': np.array(counts_2), 'lib_sizes_1': np.array(lib_sizes_1), 'lib_sizes_2': np.array(lib_sizes_2), 'transformed_1': np.array(transformed_1), 'transformed_2': np.array(transformed_2), 'n1': n1, 'n2': n2 }) return test_cases def compute_transformer_pvalues(model, test_cases: List[Dict]) -> List[float]: """ Compute p-values using NB-Transformer predictions and Fisher information. Returns: List of p-values for null hypothesis test H₀: β = 0 """ print("Computing p-values using NB-Transformer...") pvalues = [] for i, case in enumerate(test_cases): if i % 1000 == 0: print(f" Processing case {i+1}/{len(test_cases)}...") try: # Get parameter estimates params = model.predict_parameters(case['transformed_1'], case['transformed_2']) # Prepare data for Fisher information calculation counts = np.concatenate([case['counts_1'], case['counts_2']]) lib_sizes = np.concatenate([case['lib_sizes_1'], case['lib_sizes_2']]) x_indicators = np.concatenate([np.zeros(case['n1']), np.ones(case['n2'])]) # Compute Fisher information and p-value from nb_transformer.inference import compute_fisher_weights, compute_standard_errors, compute_wald_statistics weights = compute_fisher_weights( params['mu'], params['beta'], params['alpha'], x_indicators, lib_sizes ) se_beta = compute_standard_errors(x_indicators, weights) wald_stat, pvalue = compute_wald_statistics(params['beta'], se_beta) pvalues.append(pvalue) except Exception as e: # If computation fails, assign a random p-value (this should be rare) pvalues.append(np.random.random()) return np.array(pvalues) def create_calibration_plot(pvalues: np.ndarray, output_dir: str): """Create QQ plot for p-value calibration assessment.""" if THEME_AVAILABLE: palette = get_nxn_palette() color = palette[0] else: color = '#1f77b4' fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5)) # QQ plot n = len(pvalues) expected_quantiles = np.arange(1, n+1) / (n+1) observed_quantiles = np.sort(pvalues) ax1.scatter(expected_quantiles, observed_quantiles, alpha=0.6, s=10, color=color) ax1.plot([0, 1], [0, 1], 'r--', alpha=0.8, linewidth=2, label='Perfect calibration') ax1.set_xlabel('Expected quantiles (Uniform)') ax1.set_ylabel('Observed quantiles') ax1.set_title('P-value Calibration QQ Plot') ax1.legend() ax1.grid(True, alpha=0.3) ax1.set_xlim(0, 1) ax1.set_ylim(0, 1) # Histogram ax2.hist(pvalues, bins=50, density=True, alpha=0.7, color=color, edgecolor='white') ax2.axhline(y=1.0, color='r', linestyle='--', alpha=0.8, linewidth=2, label='Uniform(0,1)') ax2.set_xlabel('P-value') ax2.set_ylabel('Density') ax2.set_title('P-value Distribution') ax2.legend() ax2.grid(True, alpha=0.3) ax2.set_xlim(0, 1) if THEME_AVAILABLE: pass # Custom theme would be applied here plt.tight_layout() plt.savefig(os.path.join(output_dir, 'calibration_qq_plot.png'), dpi=300, bbox_inches='tight') plt.show() def print_calibration_summary(calibration_metrics: Dict, n_tests: int): """Print summary of calibration results.""" print("\n" + "="*80) print("NB-TRANSFORMER P-VALUE CALIBRATION VALIDATION") print("="*80) print(f"\n📊 TEST DETAILS") print(f" • Number of null tests: {n_tests:,}") print(f" • Null hypothesis: β = 0 (no differential expression)") print(f" • Expected: p-values ~ Uniform(0,1)") print(f"\n📈 STATISTICAL TESTS FOR UNIFORMITY") # Kolmogorov-Smirnov test ks_result = "✅ PASS" if calibration_metrics['is_calibrated_ks'] else "❌ FAIL" print(f" Kolmogorov-Smirnov Test:") print(f" • Statistic: {calibration_metrics['ks_statistic']:.4f}") print(f" • P-value: {calibration_metrics['ks_pvalue']:.4f}") print(f" • Result: {ks_result} (should be > 0.05 for good calibration)") # Anderson-Darling test ad_result = "✅ PASS" if calibration_metrics['is_calibrated_ad'] else "❌ FAIL" print(f"\n Anderson-Darling Test:") print(f" • Statistic: {calibration_metrics['ad_statistic']:.4f}") print(f" • P-value: ~{calibration_metrics['ad_pvalue']:.3f}") print(f" • Result: {ad_result} (should be > 0.05 for good calibration)") # False positive rate alpha_level = 0.05 fpr = np.mean(calibration_metrics['pvalues'] < alpha_level) fpr_expected = alpha_level fpr_result = "✅ GOOD" if abs(fpr - fpr_expected) < 0.01 else "⚠️ CONCERN" print(f"\n📍 FALSE POSITIVE RATE") print(f" • Observed FPR (α=0.05): {fpr:.3f}") print(f" • Expected FPR: {fpr_expected:.3f}") print(f" • Difference: {abs(fpr - fpr_expected):.3f}") print(f" • Assessment: {fpr_result} (should be ~0.05)") # Overall calibration assessment overall_calibrated = calibration_metrics['is_calibrated_ks'] and calibration_metrics['is_calibrated_ad'] overall_result = "✅ WELL-CALIBRATED" if overall_calibrated else "⚠️ POORLY CALIBRATED" print(f"\n🎯 OVERALL CALIBRATION ASSESSMENT") print(f" Result: {overall_result}") if overall_calibrated: print(f" • P-values follow expected uniform distribution under null") print(f" • Statistical inference is valid and reliable") print(f" • False positive rate is properly controlled") else: print(f" • P-values deviate from uniform distribution") print(f" • Statistical inference may be unreliable") print(f" • Consider model recalibration") print(f"\n💡 INTERPRETATION") print(f" • QQ plot should follow diagonal line for good calibration") print(f" • Histogram should be approximately flat (uniform)") print(f" • Statistical tests should NOT be significant (p > 0.05)") def main(): parser = argparse.ArgumentParser(description='Validate NB-Transformer p-value calibration') parser.add_argument('--n_tests', type=int, default=10000, help='Number of null test cases') parser.add_argument('--output_dir', type=str, default='calibration_results', help='Output directory') parser.add_argument('--seed', type=int, default=42, help='Random seed') args = parser.parse_args() # Create output directory os.makedirs(args.output_dir, exist_ok=True) # Check dependencies if not TRANSFORMER_AVAILABLE: print("❌ nb-transformer not available. Please install: pip install nb-transformer") return # Load pre-trained model print("Loading pre-trained NB-Transformer...") model = load_pretrained_model() # Generate null test data test_cases = generate_null_test_data(args.n_tests, args.seed) # Compute p-values pvalues = compute_transformer_pvalues(model, test_cases) # Validate calibration calibration_metrics = validate_calibration(pvalues) # Create plots create_calibration_plot(pvalues, args.output_dir) # Print summary print_calibration_summary(calibration_metrics, args.n_tests) # Save results results_df = pd.DataFrame({ 'test_id': range(len(pvalues)), 'pvalue': pvalues, 'mu_true': [case['mu_true'] for case in test_cases], 'alpha_true': [case['alpha_true'] for case in test_cases], 'n1': [case['n1'] for case in test_cases], 'n2': [case['n2'] for case in test_cases] }) results_df.to_csv(os.path.join(args.output_dir, 'calibration_pvalues.csv'), index=False) # Save summary summary_text = summarize_calibration_results(calibration_metrics) with open(os.path.join(args.output_dir, 'calibration_summary.txt'), 'w') as f: f.write(summary_text) print(f"\n💾 Results saved to {args.output_dir}/") if __name__ == '__main__': main()