|
|
|
|
|
""" |
|
|
NB-Transformer P-value Calibration Validation Script |
|
|
|
|
|
This script validates that the NB-Transformer produces properly calibrated p-values |
|
|
under the null hypothesis (β = 0, no differential expression). Well-calibrated |
|
|
p-values should follow a Uniform(0,1) distribution under the null. |
|
|
|
|
|
The script: |
|
|
1. Generates null test cases (β = 0) |
|
|
2. Estimates parameters and computes p-values using Fisher information |
|
|
3. Creates QQ plots comparing observed vs expected quantiles |
|
|
4. Performs statistical tests for uniformity (Kolmogorov-Smirnov, Anderson-Darling) |
|
|
|
|
|
Usage: |
|
|
python validate_calibration.py --n_tests 10000 --output_dir results/ |
|
|
|
|
|
Expected Results: |
|
|
- Well-calibrated p-values should follow diagonal line in QQ plot |
|
|
- K-S and A-D tests should NOT be significant (p > 0.05) |
|
|
- False positive rate should be ~5% at α = 0.05 |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import argparse |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import matplotlib.pyplot as plt |
|
|
from typing import Dict, List, Tuple |
|
|
from scipy import stats |
|
|
import warnings |
|
|
|
|
|
|
|
|
try: |
|
|
from nb_transformer import load_pretrained_model, validate_calibration, summarize_calibration_results |
|
|
TRANSFORMER_AVAILABLE = True |
|
|
except ImportError: |
|
|
TRANSFORMER_AVAILABLE = False |
|
|
print("Warning: nb-transformer not available. Install with: pip install nb-transformer") |
|
|
|
|
|
|
|
|
try: |
|
|
from theme_nxn import theme_nxn, get_nxn_palette |
|
|
THEME_AVAILABLE = True |
|
|
except ImportError: |
|
|
THEME_AVAILABLE = False |
|
|
print("Warning: theme_nxn not available, using default matplotlib styling") |
|
|
|
|
|
|
|
|
def generate_null_test_data(n_tests: int = 10000, seed: int = 42) -> List[Dict]: |
|
|
""" |
|
|
Generate test cases under null hypothesis (β = 0). |
|
|
|
|
|
Returns: |
|
|
List of test cases with β = 0 (no differential expression) |
|
|
""" |
|
|
print(f"Generating {n_tests} null hypothesis test cases (β = 0)...") |
|
|
|
|
|
np.random.seed(seed) |
|
|
test_cases = [] |
|
|
|
|
|
for i in range(n_tests): |
|
|
|
|
|
mu_true = np.random.normal(-1.0, 2.0) |
|
|
alpha_true = np.random.normal(-2.0, 1.0) |
|
|
beta_true = 0.0 |
|
|
|
|
|
|
|
|
n1 = np.random.randint(3, 10) |
|
|
n2 = np.random.randint(3, 10) |
|
|
|
|
|
|
|
|
lib_sizes_1 = np.random.lognormal(np.log(10000) - 0.5*np.log(1.09), |
|
|
np.sqrt(np.log(1.09)), n1) |
|
|
lib_sizes_2 = np.random.lognormal(np.log(10000) - 0.5*np.log(1.09), |
|
|
np.sqrt(np.log(1.09)), n2) |
|
|
|
|
|
|
|
|
mean_expr = np.exp(mu_true) |
|
|
dispersion = np.exp(alpha_true) |
|
|
|
|
|
|
|
|
counts_1 = [] |
|
|
for lib_size in lib_sizes_1: |
|
|
mean_count = lib_size * mean_expr |
|
|
r = 1.0 / dispersion |
|
|
p = r / (r + mean_count) |
|
|
count = np.random.negative_binomial(r, p) |
|
|
counts_1.append(count) |
|
|
|
|
|
counts_2 = [] |
|
|
for lib_size in lib_sizes_2: |
|
|
mean_count = lib_size * mean_expr |
|
|
r = 1.0 / dispersion |
|
|
p = r / (r + mean_count) |
|
|
count = np.random.negative_binomial(r, p) |
|
|
counts_2.append(count) |
|
|
|
|
|
|
|
|
transformed_1 = [np.log10(1e4 * c / l + 1) for c, l in zip(counts_1, lib_sizes_1)] |
|
|
transformed_2 = [np.log10(1e4 * c / l + 1) for c, l in zip(counts_2, lib_sizes_2)] |
|
|
|
|
|
test_cases.append({ |
|
|
'mu_true': mu_true, |
|
|
'beta_true': beta_true, |
|
|
'alpha_true': alpha_true, |
|
|
'counts_1': np.array(counts_1), |
|
|
'counts_2': np.array(counts_2), |
|
|
'lib_sizes_1': np.array(lib_sizes_1), |
|
|
'lib_sizes_2': np.array(lib_sizes_2), |
|
|
'transformed_1': np.array(transformed_1), |
|
|
'transformed_2': np.array(transformed_2), |
|
|
'n1': n1, |
|
|
'n2': n2 |
|
|
}) |
|
|
|
|
|
return test_cases |
|
|
|
|
|
|
|
|
def compute_transformer_pvalues(model, test_cases: List[Dict]) -> List[float]: |
|
|
""" |
|
|
Compute p-values using NB-Transformer predictions and Fisher information. |
|
|
|
|
|
Returns: |
|
|
List of p-values for null hypothesis test H₀: β = 0 |
|
|
""" |
|
|
print("Computing p-values using NB-Transformer...") |
|
|
|
|
|
pvalues = [] |
|
|
|
|
|
for i, case in enumerate(test_cases): |
|
|
if i % 1000 == 0: |
|
|
print(f" Processing case {i+1}/{len(test_cases)}...") |
|
|
|
|
|
try: |
|
|
|
|
|
params = model.predict_parameters(case['transformed_1'], case['transformed_2']) |
|
|
|
|
|
|
|
|
counts = np.concatenate([case['counts_1'], case['counts_2']]) |
|
|
lib_sizes = np.concatenate([case['lib_sizes_1'], case['lib_sizes_2']]) |
|
|
x_indicators = np.concatenate([np.zeros(case['n1']), np.ones(case['n2'])]) |
|
|
|
|
|
|
|
|
from nb_transformer.inference import compute_fisher_weights, compute_standard_errors, compute_wald_statistics |
|
|
|
|
|
weights = compute_fisher_weights( |
|
|
params['mu'], params['beta'], params['alpha'], |
|
|
x_indicators, lib_sizes |
|
|
) |
|
|
|
|
|
se_beta = compute_standard_errors(x_indicators, weights) |
|
|
wald_stat, pvalue = compute_wald_statistics(params['beta'], se_beta) |
|
|
|
|
|
pvalues.append(pvalue) |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
pvalues.append(np.random.random()) |
|
|
|
|
|
return np.array(pvalues) |
|
|
|
|
|
|
|
|
def create_calibration_plot(pvalues: np.ndarray, output_dir: str): |
|
|
"""Create QQ plot for p-value calibration assessment.""" |
|
|
|
|
|
if THEME_AVAILABLE: |
|
|
palette = get_nxn_palette() |
|
|
color = palette[0] |
|
|
else: |
|
|
color = '#1f77b4' |
|
|
|
|
|
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5)) |
|
|
|
|
|
|
|
|
n = len(pvalues) |
|
|
expected_quantiles = np.arange(1, n+1) / (n+1) |
|
|
observed_quantiles = np.sort(pvalues) |
|
|
|
|
|
ax1.scatter(expected_quantiles, observed_quantiles, alpha=0.6, s=10, color=color) |
|
|
ax1.plot([0, 1], [0, 1], 'r--', alpha=0.8, linewidth=2, label='Perfect calibration') |
|
|
ax1.set_xlabel('Expected quantiles (Uniform)') |
|
|
ax1.set_ylabel('Observed quantiles') |
|
|
ax1.set_title('P-value Calibration QQ Plot') |
|
|
ax1.legend() |
|
|
ax1.grid(True, alpha=0.3) |
|
|
ax1.set_xlim(0, 1) |
|
|
ax1.set_ylim(0, 1) |
|
|
|
|
|
|
|
|
ax2.hist(pvalues, bins=50, density=True, alpha=0.7, color=color, edgecolor='white') |
|
|
ax2.axhline(y=1.0, color='r', linestyle='--', alpha=0.8, linewidth=2, label='Uniform(0,1)') |
|
|
ax2.set_xlabel('P-value') |
|
|
ax2.set_ylabel('Density') |
|
|
ax2.set_title('P-value Distribution') |
|
|
ax2.legend() |
|
|
ax2.grid(True, alpha=0.3) |
|
|
ax2.set_xlim(0, 1) |
|
|
|
|
|
if THEME_AVAILABLE: |
|
|
pass |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig(os.path.join(output_dir, 'calibration_qq_plot.png'), dpi=300, bbox_inches='tight') |
|
|
plt.show() |
|
|
|
|
|
|
|
|
def print_calibration_summary(calibration_metrics: Dict, n_tests: int): |
|
|
"""Print summary of calibration results.""" |
|
|
|
|
|
print("\n" + "="*80) |
|
|
print("NB-TRANSFORMER P-VALUE CALIBRATION VALIDATION") |
|
|
print("="*80) |
|
|
|
|
|
print(f"\n📊 TEST DETAILS") |
|
|
print(f" • Number of null tests: {n_tests:,}") |
|
|
print(f" • Null hypothesis: β = 0 (no differential expression)") |
|
|
print(f" • Expected: p-values ~ Uniform(0,1)") |
|
|
|
|
|
print(f"\n📈 STATISTICAL TESTS FOR UNIFORMITY") |
|
|
|
|
|
|
|
|
ks_result = "✅ PASS" if calibration_metrics['is_calibrated_ks'] else "❌ FAIL" |
|
|
print(f" Kolmogorov-Smirnov Test:") |
|
|
print(f" • Statistic: {calibration_metrics['ks_statistic']:.4f}") |
|
|
print(f" • P-value: {calibration_metrics['ks_pvalue']:.4f}") |
|
|
print(f" • Result: {ks_result} (should be > 0.05 for good calibration)") |
|
|
|
|
|
|
|
|
ad_result = "✅ PASS" if calibration_metrics['is_calibrated_ad'] else "❌ FAIL" |
|
|
print(f"\n Anderson-Darling Test:") |
|
|
print(f" • Statistic: {calibration_metrics['ad_statistic']:.4f}") |
|
|
print(f" • P-value: ~{calibration_metrics['ad_pvalue']:.3f}") |
|
|
print(f" • Result: {ad_result} (should be > 0.05 for good calibration)") |
|
|
|
|
|
|
|
|
alpha_level = 0.05 |
|
|
fpr = np.mean(calibration_metrics['pvalues'] < alpha_level) |
|
|
fpr_expected = alpha_level |
|
|
fpr_result = "✅ GOOD" if abs(fpr - fpr_expected) < 0.01 else "⚠️ CONCERN" |
|
|
|
|
|
print(f"\n📍 FALSE POSITIVE RATE") |
|
|
print(f" • Observed FPR (α=0.05): {fpr:.3f}") |
|
|
print(f" • Expected FPR: {fpr_expected:.3f}") |
|
|
print(f" • Difference: {abs(fpr - fpr_expected):.3f}") |
|
|
print(f" • Assessment: {fpr_result} (should be ~0.05)") |
|
|
|
|
|
|
|
|
overall_calibrated = calibration_metrics['is_calibrated_ks'] and calibration_metrics['is_calibrated_ad'] |
|
|
overall_result = "✅ WELL-CALIBRATED" if overall_calibrated else "⚠️ POORLY CALIBRATED" |
|
|
|
|
|
print(f"\n🎯 OVERALL CALIBRATION ASSESSMENT") |
|
|
print(f" Result: {overall_result}") |
|
|
|
|
|
if overall_calibrated: |
|
|
print(f" • P-values follow expected uniform distribution under null") |
|
|
print(f" • Statistical inference is valid and reliable") |
|
|
print(f" • False positive rate is properly controlled") |
|
|
else: |
|
|
print(f" • P-values deviate from uniform distribution") |
|
|
print(f" • Statistical inference may be unreliable") |
|
|
print(f" • Consider model recalibration") |
|
|
|
|
|
print(f"\n💡 INTERPRETATION") |
|
|
print(f" • QQ plot should follow diagonal line for good calibration") |
|
|
print(f" • Histogram should be approximately flat (uniform)") |
|
|
print(f" • Statistical tests should NOT be significant (p > 0.05)") |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description='Validate NB-Transformer p-value calibration') |
|
|
parser.add_argument('--n_tests', type=int, default=10000, help='Number of null test cases') |
|
|
parser.add_argument('--output_dir', type=str, default='calibration_results', help='Output directory') |
|
|
parser.add_argument('--seed', type=int, default=42, help='Random seed') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
if not TRANSFORMER_AVAILABLE: |
|
|
print("❌ nb-transformer not available. Please install: pip install nb-transformer") |
|
|
return |
|
|
|
|
|
|
|
|
print("Loading pre-trained NB-Transformer...") |
|
|
model = load_pretrained_model() |
|
|
|
|
|
|
|
|
test_cases = generate_null_test_data(args.n_tests, args.seed) |
|
|
|
|
|
|
|
|
pvalues = compute_transformer_pvalues(model, test_cases) |
|
|
|
|
|
|
|
|
calibration_metrics = validate_calibration(pvalues) |
|
|
|
|
|
|
|
|
create_calibration_plot(pvalues, args.output_dir) |
|
|
|
|
|
|
|
|
print_calibration_summary(calibration_metrics, args.n_tests) |
|
|
|
|
|
|
|
|
results_df = pd.DataFrame({ |
|
|
'test_id': range(len(pvalues)), |
|
|
'pvalue': pvalues, |
|
|
'mu_true': [case['mu_true'] for case in test_cases], |
|
|
'alpha_true': [case['alpha_true'] for case in test_cases], |
|
|
'n1': [case['n1'] for case in test_cases], |
|
|
'n2': [case['n2'] for case in test_cases] |
|
|
}) |
|
|
|
|
|
results_df.to_csv(os.path.join(args.output_dir, 'calibration_pvalues.csv'), index=False) |
|
|
|
|
|
|
|
|
summary_text = summarize_calibration_results(calibration_metrics) |
|
|
with open(os.path.join(args.output_dir, 'calibration_summary.txt'), 'w') as f: |
|
|
f.write(summary_text) |
|
|
|
|
|
print(f"\n💾 Results saved to {args.output_dir}/") |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |