nb-transformer / examples /validate_calibration.py
valsv's picture
Upload folder using huggingface_hub
ccd282b verified
#!/usr/bin/env python
"""
NB-Transformer P-value Calibration Validation Script
This script validates that the NB-Transformer produces properly calibrated p-values
under the null hypothesis (β = 0, no differential expression). Well-calibrated
p-values should follow a Uniform(0,1) distribution under the null.
The script:
1. Generates null test cases (β = 0)
2. Estimates parameters and computes p-values using Fisher information
3. Creates QQ plots comparing observed vs expected quantiles
4. Performs statistical tests for uniformity (Kolmogorov-Smirnov, Anderson-Darling)
Usage:
python validate_calibration.py --n_tests 10000 --output_dir results/
Expected Results:
- Well-calibrated p-values should follow diagonal line in QQ plot
- K-S and A-D tests should NOT be significant (p > 0.05)
- False positive rate should be ~5% at α = 0.05
"""
import os
import sys
import argparse
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from typing import Dict, List, Tuple
from scipy import stats
import warnings
# Import nb-transformer
try:
from nb_transformer import load_pretrained_model, validate_calibration, summarize_calibration_results
TRANSFORMER_AVAILABLE = True
except ImportError:
TRANSFORMER_AVAILABLE = False
print("Warning: nb-transformer not available. Install with: pip install nb-transformer")
# Import plotting theme
try:
from theme_nxn import theme_nxn, get_nxn_palette
THEME_AVAILABLE = True
except ImportError:
THEME_AVAILABLE = False
print("Warning: theme_nxn not available, using default matplotlib styling")
def generate_null_test_data(n_tests: int = 10000, seed: int = 42) -> List[Dict]:
"""
Generate test cases under null hypothesis (β = 0).
Returns:
List of test cases with β = 0 (no differential expression)
"""
print(f"Generating {n_tests} null hypothesis test cases (β = 0)...")
np.random.seed(seed)
test_cases = []
for i in range(n_tests):
# Sample parameters under null
mu_true = np.random.normal(-1.0, 2.0) # Base mean (log scale)
alpha_true = np.random.normal(-2.0, 1.0) # Dispersion (log scale)
beta_true = 0.0 # NULL HYPOTHESIS: no differential expression
# Random experimental design (3-9 samples per condition)
n1 = np.random.randint(3, 10)
n2 = np.random.randint(3, 10)
# Sample library sizes
lib_sizes_1 = np.random.lognormal(np.log(10000) - 0.5*np.log(1.09),
np.sqrt(np.log(1.09)), n1)
lib_sizes_2 = np.random.lognormal(np.log(10000) - 0.5*np.log(1.09),
np.sqrt(np.log(1.09)), n2)
# Generate counts under null (same mean expression in both conditions)
mean_expr = np.exp(mu_true)
dispersion = np.exp(alpha_true)
# Both conditions have same mean expression (β = 0)
counts_1 = []
for lib_size in lib_sizes_1:
mean_count = lib_size * mean_expr
r = 1.0 / dispersion
p = r / (r + mean_count)
count = np.random.negative_binomial(r, p)
counts_1.append(count)
counts_2 = []
for lib_size in lib_sizes_2:
mean_count = lib_size * mean_expr # Same as condition 1 (β = 0)
r = 1.0 / dispersion
p = r / (r + mean_count)
count = np.random.negative_binomial(r, p)
counts_2.append(count)
# Transform data for transformer
transformed_1 = [np.log10(1e4 * c / l + 1) for c, l in zip(counts_1, lib_sizes_1)]
transformed_2 = [np.log10(1e4 * c / l + 1) for c, l in zip(counts_2, lib_sizes_2)]
test_cases.append({
'mu_true': mu_true,
'beta_true': beta_true, # Always 0 under null
'alpha_true': alpha_true,
'counts_1': np.array(counts_1),
'counts_2': np.array(counts_2),
'lib_sizes_1': np.array(lib_sizes_1),
'lib_sizes_2': np.array(lib_sizes_2),
'transformed_1': np.array(transformed_1),
'transformed_2': np.array(transformed_2),
'n1': n1,
'n2': n2
})
return test_cases
def compute_transformer_pvalues(model, test_cases: List[Dict]) -> List[float]:
"""
Compute p-values using NB-Transformer predictions and Fisher information.
Returns:
List of p-values for null hypothesis test H₀: β = 0
"""
print("Computing p-values using NB-Transformer...")
pvalues = []
for i, case in enumerate(test_cases):
if i % 1000 == 0:
print(f" Processing case {i+1}/{len(test_cases)}...")
try:
# Get parameter estimates
params = model.predict_parameters(case['transformed_1'], case['transformed_2'])
# Prepare data for Fisher information calculation
counts = np.concatenate([case['counts_1'], case['counts_2']])
lib_sizes = np.concatenate([case['lib_sizes_1'], case['lib_sizes_2']])
x_indicators = np.concatenate([np.zeros(case['n1']), np.ones(case['n2'])])
# Compute Fisher information and p-value
from nb_transformer.inference import compute_fisher_weights, compute_standard_errors, compute_wald_statistics
weights = compute_fisher_weights(
params['mu'], params['beta'], params['alpha'],
x_indicators, lib_sizes
)
se_beta = compute_standard_errors(x_indicators, weights)
wald_stat, pvalue = compute_wald_statistics(params['beta'], se_beta)
pvalues.append(pvalue)
except Exception as e:
# If computation fails, assign a random p-value (this should be rare)
pvalues.append(np.random.random())
return np.array(pvalues)
def create_calibration_plot(pvalues: np.ndarray, output_dir: str):
"""Create QQ plot for p-value calibration assessment."""
if THEME_AVAILABLE:
palette = get_nxn_palette()
color = palette[0]
else:
color = '#1f77b4'
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
# QQ plot
n = len(pvalues)
expected_quantiles = np.arange(1, n+1) / (n+1)
observed_quantiles = np.sort(pvalues)
ax1.scatter(expected_quantiles, observed_quantiles, alpha=0.6, s=10, color=color)
ax1.plot([0, 1], [0, 1], 'r--', alpha=0.8, linewidth=2, label='Perfect calibration')
ax1.set_xlabel('Expected quantiles (Uniform)')
ax1.set_ylabel('Observed quantiles')
ax1.set_title('P-value Calibration QQ Plot')
ax1.legend()
ax1.grid(True, alpha=0.3)
ax1.set_xlim(0, 1)
ax1.set_ylim(0, 1)
# Histogram
ax2.hist(pvalues, bins=50, density=True, alpha=0.7, color=color, edgecolor='white')
ax2.axhline(y=1.0, color='r', linestyle='--', alpha=0.8, linewidth=2, label='Uniform(0,1)')
ax2.set_xlabel('P-value')
ax2.set_ylabel('Density')
ax2.set_title('P-value Distribution')
ax2.legend()
ax2.grid(True, alpha=0.3)
ax2.set_xlim(0, 1)
if THEME_AVAILABLE:
pass # Custom theme would be applied here
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'calibration_qq_plot.png'), dpi=300, bbox_inches='tight')
plt.show()
def print_calibration_summary(calibration_metrics: Dict, n_tests: int):
"""Print summary of calibration results."""
print("\n" + "="*80)
print("NB-TRANSFORMER P-VALUE CALIBRATION VALIDATION")
print("="*80)
print(f"\n📊 TEST DETAILS")
print(f" • Number of null tests: {n_tests:,}")
print(f" • Null hypothesis: β = 0 (no differential expression)")
print(f" • Expected: p-values ~ Uniform(0,1)")
print(f"\n📈 STATISTICAL TESTS FOR UNIFORMITY")
# Kolmogorov-Smirnov test
ks_result = "✅ PASS" if calibration_metrics['is_calibrated_ks'] else "❌ FAIL"
print(f" Kolmogorov-Smirnov Test:")
print(f" • Statistic: {calibration_metrics['ks_statistic']:.4f}")
print(f" • P-value: {calibration_metrics['ks_pvalue']:.4f}")
print(f" • Result: {ks_result} (should be > 0.05 for good calibration)")
# Anderson-Darling test
ad_result = "✅ PASS" if calibration_metrics['is_calibrated_ad'] else "❌ FAIL"
print(f"\n Anderson-Darling Test:")
print(f" • Statistic: {calibration_metrics['ad_statistic']:.4f}")
print(f" • P-value: ~{calibration_metrics['ad_pvalue']:.3f}")
print(f" • Result: {ad_result} (should be > 0.05 for good calibration)")
# False positive rate
alpha_level = 0.05
fpr = np.mean(calibration_metrics['pvalues'] < alpha_level)
fpr_expected = alpha_level
fpr_result = "✅ GOOD" if abs(fpr - fpr_expected) < 0.01 else "⚠️ CONCERN"
print(f"\n📍 FALSE POSITIVE RATE")
print(f" • Observed FPR (α=0.05): {fpr:.3f}")
print(f" • Expected FPR: {fpr_expected:.3f}")
print(f" • Difference: {abs(fpr - fpr_expected):.3f}")
print(f" • Assessment: {fpr_result} (should be ~0.05)")
# Overall calibration assessment
overall_calibrated = calibration_metrics['is_calibrated_ks'] and calibration_metrics['is_calibrated_ad']
overall_result = "✅ WELL-CALIBRATED" if overall_calibrated else "⚠️ POORLY CALIBRATED"
print(f"\n🎯 OVERALL CALIBRATION ASSESSMENT")
print(f" Result: {overall_result}")
if overall_calibrated:
print(f" • P-values follow expected uniform distribution under null")
print(f" • Statistical inference is valid and reliable")
print(f" • False positive rate is properly controlled")
else:
print(f" • P-values deviate from uniform distribution")
print(f" • Statistical inference may be unreliable")
print(f" • Consider model recalibration")
print(f"\n💡 INTERPRETATION")
print(f" • QQ plot should follow diagonal line for good calibration")
print(f" • Histogram should be approximately flat (uniform)")
print(f" • Statistical tests should NOT be significant (p > 0.05)")
def main():
parser = argparse.ArgumentParser(description='Validate NB-Transformer p-value calibration')
parser.add_argument('--n_tests', type=int, default=10000, help='Number of null test cases')
parser.add_argument('--output_dir', type=str, default='calibration_results', help='Output directory')
parser.add_argument('--seed', type=int, default=42, help='Random seed')
args = parser.parse_args()
# Create output directory
os.makedirs(args.output_dir, exist_ok=True)
# Check dependencies
if not TRANSFORMER_AVAILABLE:
print("❌ nb-transformer not available. Please install: pip install nb-transformer")
return
# Load pre-trained model
print("Loading pre-trained NB-Transformer...")
model = load_pretrained_model()
# Generate null test data
test_cases = generate_null_test_data(args.n_tests, args.seed)
# Compute p-values
pvalues = compute_transformer_pvalues(model, test_cases)
# Validate calibration
calibration_metrics = validate_calibration(pvalues)
# Create plots
create_calibration_plot(pvalues, args.output_dir)
# Print summary
print_calibration_summary(calibration_metrics, args.n_tests)
# Save results
results_df = pd.DataFrame({
'test_id': range(len(pvalues)),
'pvalue': pvalues,
'mu_true': [case['mu_true'] for case in test_cases],
'alpha_true': [case['alpha_true'] for case in test_cases],
'n1': [case['n1'] for case in test_cases],
'n2': [case['n2'] for case in test_cases]
})
results_df.to_csv(os.path.join(args.output_dir, 'calibration_pvalues.csv'), index=False)
# Save summary
summary_text = summarize_calibration_results(calibration_metrics)
with open(os.path.join(args.output_dir, 'calibration_summary.txt'), 'w') as f:
f.write(summary_text)
print(f"\n💾 Results saved to {args.output_dir}/")
if __name__ == '__main__':
main()