#!/usr/bin/env python3 """ NQR-SNN Out-of-Distribution Generalization Test ================================================= Tests whether the model has learned NQR physics or just memorized the synthetic generator's parametric family. Run locally: python ood_generalization_test.py Tests 8 realistic domain shifts that REAL NQR signals would have: 1. Parameter extrapolation (signal params outside training bounds) 2. Temperature-induced frequency drift (±10-30 kHz) 3. MAPER contamination (magneto-acoustic ringing) 4. Receiver dead time (first 100-500μs zeroed) 5. Colored/correlated noise (1/f + narrowband RFI) 6. Lorentzian lineshape (not Voigt — the real NQR shape) 7. Multi-polymorph signals (two overlapping lines) 8. Powder-broadened signals (inhomogeneous lineshape) If accuracy drops significantly (<90%) on these tests, the model is overfitting to the synthetic generator and would fail on real hardware data. Expected runtime: ~5-10 min (uses pre-trained model from outputs/models/) """ import os import sys import time import json import numpy as np import torch sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from nqr_snn import config from nqr_snn.data.dataset import extract_features_batch from nqr_snn.data.generator import generate_noise_only_at_power from nqr_snn.snn.model import SpikingClassifier, SPIKINGJELLY_AVAILABLE from nqr_snn.snn.encoder import DeterministicEncoder from nqr_snn.snn.ensemble import SNNEnsemble from nqr_snn.evaluation.metrics import full_report DEVICE = "cuda" if torch.cuda.is_available() else "cpu" N_PER_CLASS = 200 # signals per class per test # ═══════════════════════════════════════════════════════════════════ # SIGNAL GENERATORS — Realistic OOD Scenarios # ═══════════════════════════════════════════════════════════════════ def _time_axis(): return np.arange(config.SIGNAL_LENGTH) * config.SAMPLING_INTERVAL def _add_noise_at_snr(clean, target_snr_db, rng, noise_type="white"): """Add noise to achieve target SNR.""" from nqr_snn.data.noise_sources import generate_noise_sample noise_raw = generate_noise_sample(noise_type, rng) signal_power = np.mean(np.abs(clean)**2) noise_power = np.mean(np.abs(noise_raw)**2) target_linear = 10 ** (target_snr_db / 10) desired_noise_power = signal_power / target_linear if noise_power > 0: noise_scale = np.sqrt(desired_noise_power / noise_power) return clean + noise_scale * noise_raw return clean + noise_raw # --- TEST 1: Parameter Extrapolation --- def gen_extrapolated_params(n, snr_db=-35, seed=42): """Signal params OUTSIDE training ranges. Training uses: A=[0.5,1.5], sigma=[0.001,0.01], T2=[0.001,0.01], nu=[-150,150] This uses: A=[0.05,0.3] or [2.0,5.0], sigma=[0.02,0.05], T2=[0.02,0.05], nu=[-500,-200] or [200,500] """ rng = np.random.RandomState(seed) t = _time_axis() signals = [] for _ in range(n): # Parameters outside training bounds A = rng.choice([rng.uniform(0.05, 0.3), rng.uniform(2.0, 5.0)]) sigma = rng.uniform(0.02, 0.05) # wider than training [0.001, 0.01] T2 = rng.uniform(0.02, 0.05) # longer than training nu = rng.choice([rng.uniform(-500, -200), rng.uniform(200, 500)]) # beyond ±150 phi = rng.uniform(-np.pi, np.pi) clean = A * np.exp(-t**2/(2*sigma**2) - t/T2 + 1j*(2*np.pi*nu*t + phi)) noisy = _add_noise_at_snr(clean, snr_db, rng) signals.append(noisy) return np.array(signals, dtype=np.complex128) # --- TEST 2: Temperature Frequency Drift --- def gen_temp_drift(n, snr_db=-35, seed=42): """Simulate temperature-induced frequency shift (±10-30 kHz). Real TNT shifts ~800 Hz/°C. Over -20 to +50°C range = ±28 kHz. Training uses nu=[-150, 150] Hz. This adds ±10000-30000 Hz shift. """ rng = np.random.RandomState(seed) t = _time_axis() signals = [] for _ in range(n): A = rng.uniform(0.5, 1.5) sigma = rng.uniform(0.001, 0.01) T2 = rng.uniform(0.001, 0.01) # Normal frequency PLUS large temperature drift nu_base = rng.uniform(-150, 150) temp_shift = rng.choice([-1, 1]) * rng.uniform(10000, 30000) # ±10-30 kHz nu = nu_base + temp_shift phi = rng.uniform(-np.pi, np.pi) clean = A * np.exp(-t**2/(2*sigma**2) - t/T2 + 1j*(2*np.pi*nu*t + phi)) noisy = _add_noise_at_snr(clean, snr_db, rng) signals.append(noisy) return np.array(signals, dtype=np.complex128) # --- TEST 3: MAPER Contamination --- def gen_maper(n, snr_db=-35, seed=42): """Add Magneto-Acoustic/PiezoElectric Ringing (MAPER). MAPER is the #1 real-world artifact: coherent exponentially-decaying oscillation caused by metal objects in the RF field. Can be 10-100x stronger than the NQR signal. """ rng = np.random.RandomState(seed) t = _time_axis() signals = [] for _ in range(n): # Normal NQR signal A = rng.uniform(0.5, 1.5) sigma = rng.uniform(0.001, 0.01) T2 = rng.uniform(0.001, 0.01) nu = rng.uniform(-150, 150) phi = rng.uniform(-np.pi, np.pi) clean = A * np.exp(-t**2/(2*sigma**2) - t/T2 + 1j*(2*np.pi*nu*t + phi)) # MAPER: exponentially decaying sinusoid at different frequency maper_amp = rng.uniform(5, 50) * A # 5-50x signal amplitude maper_freq = rng.uniform(500, 5000) # different from NQR freq maper_tau = rng.uniform(0.0005, 0.005) # 0.5-5 ms decay maper_phase = rng.uniform(-np.pi, np.pi) maper = maper_amp * np.exp(-t/maper_tau) * np.exp(1j*(2*np.pi*maper_freq*t + maper_phase)) noisy = _add_noise_at_snr(clean + maper, snr_db, rng) signals.append(noisy) return np.array(signals, dtype=np.complex128) # --- TEST 4: Receiver Dead Time --- def gen_dead_time(n, snr_db=-35, seed=42): """Zero out first 100-500μs (receiver recovery after TX pulse). Real NQR receivers cannot capture signal during dead time. Training signals start at t=0 which is physically unrealizable. """ rng = np.random.RandomState(seed) t = _time_axis() signals = [] for _ in range(n): A = rng.uniform(0.5, 1.5) sigma = rng.uniform(0.001, 0.01) T2 = rng.uniform(0.001, 0.01) nu = rng.uniform(-150, 150) phi = rng.uniform(-np.pi, np.pi) clean = A * np.exp(-t**2/(2*sigma**2) - t/T2 + 1j*(2*np.pi*nu*t + phi)) # Zero out dead time: 100-500 μs at 18μs sampling = 6-28 samples dead_samples = rng.randint(6, 28) clean[:dead_samples] = 0.0 noisy = _add_noise_at_snr(clean, snr_db, rng) signals.append(noisy) return np.array(signals, dtype=np.complex128) # --- TEST 5: Colored Noise + RFI --- def gen_colored_noise(n, snr_db=-35, seed=42): """Replace white Gaussian with realistic colored noise: - 1/f spectral shape from probe coil - Narrowband RFI spikes (power line harmonics) - Impulsive noise bursts """ rng = np.random.RandomState(seed) t = _time_axis() L = config.SIGNAL_LENGTH signals = [] for _ in range(n): A = rng.uniform(0.5, 1.5) sigma = rng.uniform(0.001, 0.01) T2 = rng.uniform(0.001, 0.01) nu = rng.uniform(-150, 150) phi = rng.uniform(-np.pi, np.pi) clean = A * np.exp(-t**2/(2*sigma**2) - t/T2 + 1j*(2*np.pi*nu*t + phi)) # Colored noise: 1/f^alpha white = rng.randn(L) + 1j * rng.randn(L) freqs = np.fft.fftfreq(L, d=config.SAMPLING_INTERVAL) alpha = rng.uniform(0.5, 1.5) # pink to brown noise color_filter = np.ones(L) nonzero = freqs != 0 color_filter[nonzero] = 1.0 / (np.abs(freqs[nonzero]) ** (alpha/2) + 0.1) colored = np.fft.ifft(np.fft.fft(white) * color_filter) # Add narrowband RFI (5-10 interferers at random frequencies) n_rfi = rng.randint(5, 10) for _ in range(n_rfi): rfi_freq = rng.uniform(-5000, 5000) rfi_amp = rng.uniform(0.5, 3.0) rfi_phase = rng.uniform(-np.pi, np.pi) colored += rfi_amp * np.exp(1j * (2*np.pi*rfi_freq*t + rfi_phase)) # Add impulsive bursts (2-5 per signal) n_impulses = rng.randint(2, 5) for _ in range(n_impulses): pos = rng.randint(0, L) width = rng.randint(1, 10) amp = rng.uniform(5, 20) colored[pos:min(pos+width, L)] += amp * (rng.randn(min(width, L-pos)) + 1j*rng.randn(min(width, L-pos))) # Scale colored noise to achieve target SNR signal_power = np.mean(np.abs(clean)**2) noise_power = np.mean(np.abs(colored)**2) target_linear = 10 ** (snr_db / 10) desired_noise_power = signal_power / target_linear if noise_power > 0: noise_scale = np.sqrt(desired_noise_power / noise_power) noisy = clean + noise_scale * colored else: noisy = clean + colored signals.append(noisy.astype(np.complex128)) return np.array(signals, dtype=np.complex128) # --- TEST 6: Pure Lorentzian Lineshape (not Voigt) --- def gen_lorentzian(n, snr_db=-35, seed=42): """Pure Lorentzian decay (no Gaussian component). Real NQR FID for single crystals is purely Lorentzian: exp(-t/T2). Training uses Voigt (Gaussian × Lorentzian). This tests if the model learned "exponential decay" or specifically "Voigt shape." """ rng = np.random.RandomState(seed) t = _time_axis() signals = [] for _ in range(n): A = rng.uniform(0.5, 1.5) T2 = rng.uniform(0.001, 0.01) nu = rng.uniform(-150, 150) phi = rng.uniform(-np.pi, np.pi) # Pure Lorentzian: NO Gaussian envelope (sigma → ∞) clean = A * np.exp(-t/T2 + 1j*(2*np.pi*nu*t + phi)) noisy = _add_noise_at_snr(clean, snr_db, rng) signals.append(noisy) return np.array(signals, dtype=np.complex128) # --- TEST 7: Multi-Polymorph (two overlapping lines) --- def gen_multiline(n, snr_db=-35, seed=42): """Two NQR lines from different polymorphs/sites. RDX has 3 inequivalent nitrogen sites. TNT has α/β forms. Training uses single-line signals only. """ rng = np.random.RandomState(seed) t = _time_axis() signals = [] for _ in range(n): # Line 1 A1 = rng.uniform(0.3, 1.0) T2_1 = rng.uniform(0.001, 0.01) nu1 = rng.uniform(-150, 0) sigma1 = rng.uniform(0.001, 0.01) phi1 = rng.uniform(-np.pi, np.pi) line1 = A1 * np.exp(-t**2/(2*sigma1**2) - t/T2_1 + 1j*(2*np.pi*nu1*t + phi1)) # Line 2 (different freq, amplitude, T2) A2 = rng.uniform(0.2, 0.8) T2_2 = rng.uniform(0.002, 0.015) nu2 = rng.uniform(50, 300) # separated from line 1 sigma2 = rng.uniform(0.001, 0.01) phi2 = rng.uniform(-np.pi, np.pi) line2 = A2 * np.exp(-t**2/(2*sigma2**2) - t/T2_2 + 1j*(2*np.pi*nu2*t + phi2)) clean = line1 + line2 noisy = _add_noise_at_snr(clean, snr_db, rng) signals.append(noisy) return np.array(signals, dtype=np.complex128) # --- TEST 8: Powder Broadening (non-Voigt lineshape) --- def gen_powder_broadened(n, snr_db=-35, seed=42): """Simulate powder broadening: sum of many crystallites at slightly different frequencies. Creates a non-Voigt, asymmetric lineshape. Real explosives are polycrystalline powders, not single crystals. """ rng = np.random.RandomState(seed) t = _time_axis() signals = [] for _ in range(n): A_total = rng.uniform(0.5, 1.5) T2 = rng.uniform(0.001, 0.01) sigma = rng.uniform(0.001, 0.01) nu_center = rng.uniform(-150, 150) phi = rng.uniform(-np.pi, np.pi) # Sum 20-50 crystallite orientations with frequency spread n_crystallites = rng.randint(20, 50) freq_spread = rng.uniform(100, 500) # Hz spread from powder # Asymmetric distribution (not Gaussian — more like Pake doublet) offsets = freq_spread * (rng.beta(2, 5, n_crystallites) - 0.3) clean = np.zeros(config.SIGNAL_LENGTH, dtype=np.complex128) for offset in offsets: nu_i = nu_center + offset amp_i = A_total / n_crystallites * rng.uniform(0.5, 1.5) clean += amp_i * np.exp(-t**2/(2*sigma**2) - t/T2 + 1j*(2*np.pi*nu_i*t + phi)) noisy = _add_noise_at_snr(clean, snr_db, rng) signals.append(noisy) return np.array(signals, dtype=np.complex128) # ═══════════════════════════════════════════════════════════════════ # NOISE-ONLY GENERATORS (matching power for fair comparison) # ═══════════════════════════════════════════════════════════════════ def gen_noise_white(n, seed=99): """Standard white noise (same as training).""" rng = np.random.RandomState(seed) return np.array([generate_noise_only_at_power("white", rng, 1.0) for _ in range(n)], dtype=np.complex128) def gen_noise_colored(n, seed=99): """Colored noise matching the colored test scenario.""" rng = np.random.RandomState(seed) t = _time_axis() L = config.SIGNAL_LENGTH signals = [] for _ in range(n): white = rng.randn(L) + 1j * rng.randn(L) freqs = np.fft.fftfreq(L, d=config.SAMPLING_INTERVAL) alpha = rng.uniform(0.5, 1.5) color_filter = np.ones(L) nonzero = freqs != 0 color_filter[nonzero] = 1.0 / (np.abs(freqs[nonzero]) ** (alpha/2) + 0.1) colored = np.fft.ifft(np.fft.fft(white) * color_filter) n_rfi = rng.randint(5, 10) for _ in range(n_rfi): colored += rng.uniform(0.5, 3.0) * np.exp(1j*(2*np.pi*rng.uniform(-5000,5000)*t + rng.uniform(-np.pi,np.pi))) power = np.mean(np.abs(colored)**2) if power > 0: colored = colored * np.sqrt(1.0 / power) signals.append(colored.astype(np.complex128)) return np.array(signals, dtype=np.complex128) # ═══════════════════════════════════════════════════════════════════ # MAIN TEST RUNNER # ═══════════════════════════════════════════════════════════════════ def main(): print("="*80) print("NQR-SNN OUT-OF-DISTRIBUTION GENERALIZATION TEST") print("="*80) print(f"Device: {DEVICE} | SpikingJelly: {SPIKINGJELLY_AVAILABLE}") print(f"Test samples: {N_PER_CLASS} per class per scenario") print(f"SNR: -35 dB (mid-range challenge)") print() # Load model print("Loading ensemble from outputs/models/ ...") if not os.path.exists(config.MODELS_DIR): print("ERROR: No trained models found. Run the pipeline first:") print(" python run_full_pipeline.py --quick") sys.exit(1) ensemble = SNNEnsemble(ensemble_size=config.ENSEMBLE_SIZE, device=DEVICE, heterogeneous=True) ensemble.load_checkpoints() encoder = DeterministicEncoder() print(f" Loaded {len(ensemble.models)} models\n") # Define test scenarios tests = [ ("1_param_extrap", "Parameter Extrapolation (outside training bounds)", gen_extrapolated_params), ("2_temp_drift", "Temperature Frequency Drift (±10-30 kHz)", gen_temp_drift), ("3_maper", "MAPER Contamination (5-50x ringing)", gen_maper), ("4_dead_time", "Receiver Dead Time (first 6-28 samples zeroed)", gen_dead_time), ("5_colored_noise", "Colored Noise + RFI (1/f + narrowband)", gen_colored_noise), ("6_lorentzian", "Pure Lorentzian (no Gaussian envelope)", gen_lorentzian), ("7_multiline", "Multi-Polymorph (2 overlapping NQR lines)", gen_multiline), ("8_powder", "Powder Broadened (non-Voigt asymmetric)", gen_powder_broadened), ] results = {} print(f"{'Test':<45} {'Acc':>7} {'AUC':>7} {'F1':>7} {'TPR':>7} {'TNR':>7}") print("-"*80) for test_id, test_name, gen_fn in tests: # Generate signal-present (label=1) signals_present = gen_fn(N_PER_CLASS, snr_db=-35, seed=42) # Generate noise-only (label=0) — use colored noise for colored test, white otherwise if "colored" in test_id: signals_noise = gen_noise_colored(N_PER_CLASS, seed=99) else: signals_noise = gen_noise_white(N_PER_CLASS, seed=99) # Combine all_signals = np.concatenate([signals_present, signals_noise], axis=0) all_labels = np.array([1]*N_PER_CLASS + [0]*N_PER_CLASS) # Predict feats = extract_features_batch(all_signals) feat_tensor = torch.from_numpy(feats) x_seq = encoder.encode(feat_tensor).to(DEVICE) with torch.no_grad(): mean_p, std_p = ensemble.predict(x_seq) mean_p_np = mean_p.cpu().numpy() # Evaluate rep = full_report(all_labels, mean_p_np) results[test_id] = { "name": test_name, "accuracy": rep["accuracy"], "auc": rep["auc"], "f1": rep["f1"], "tpr": rep["tpr"], "tnr": rep["tnr"], "mean_prob_signal": float(np.mean(mean_p_np[:N_PER_CLASS])), "mean_prob_noise": float(np.mean(mean_p_np[N_PER_CLASS:])), } print(f"{test_name:<45} {rep['accuracy']:>6.3f} {rep['auc']:>7.3f} {rep['f1']:>7.3f} " f"{rep['tpr']:>7.3f} {rep['tnr']:>7.3f}") # Also run the IN-DISTRIBUTION control test print("-"*80) from nqr_snn.data.generator import generate_signal_at_snr rng = np.random.RandomState(42) id_signals = np.array([generate_signal_at_snr(-35, "white", rng)[0] for _ in range(N_PER_CLASS)], dtype=np.complex128) id_noise = gen_noise_white(N_PER_CLASS, seed=99) id_all = np.concatenate([id_signals, id_noise]) id_labels = np.array([1]*N_PER_CLASS + [0]*N_PER_CLASS) id_feats = torch.from_numpy(extract_features_batch(id_all)) id_x = encoder.encode(id_feats).to(DEVICE) with torch.no_grad(): id_p, _ = ensemble.predict(id_x) id_rep = full_report(id_labels, id_p.cpu().numpy()) results["0_in_distribution"] = {"name": "IN-DISTRIBUTION CONTROL", **id_rep} print(f"{'IN-DISTRIBUTION CONTROL (reference)':<45} {id_rep['accuracy']:>6.3f} {id_rep['auc']:>7.3f} " f"{id_rep['f1']:>7.3f} {id_rep['tpr']:>7.3f} {id_rep['tnr']:>7.3f}") # Summary print("\n" + "="*80) print("GENERALIZATION VERDICT") print("="*80) id_acc = results["0_in_distribution"]["accuracy"] n_fail = 0 for test_id, r in sorted(results.items()): if test_id == "0_in_distribution": continue acc = r["accuracy"] drop = id_acc - acc if drop > 0.10: verdict = "FAIL (>10% drop)" n_fail += 1 elif drop > 0.05: verdict = "WEAK (5-10% drop)" n_fail += 1 else: verdict = "OK (<5% drop)" print(f" {r['name']:<50} acc={acc:.3f} drop={drop:+.3f} {verdict}") print(f"\n In-distribution accuracy: {id_acc:.3f}") print(f" Tests with >5% accuracy drop: {n_fail}/8") if n_fail >= 4: print("\n CONCLUSION: MODEL IS OVERFITTING TO THE SYNTHETIC GENERATOR.") print(" It has learned the parametric family, not NQR physics.") print(" Real-world deployment would likely fail.") elif n_fail >= 2: print("\n CONCLUSION: PARTIAL GENERALIZATION.") print(" Model handles some domain shifts but fails on others.") print(" Needs augmentation or more realistic training data.") else: print("\n CONCLUSION: GOOD GENERALIZATION.") print(" Model appears to have learned signal-vs-noise discrimination broadly.") # Save os.makedirs("outputs/results", exist_ok=True) with open("outputs/results/ood_generalization_results.json", "w") as f: json.dump(results, f, indent=2, default=str) print(f"\n Results saved: outputs/results/ood_generalization_results.json") if __name__ == "__main__": main()