| |
| """ |
| NQR-SNN Out-of-Distribution Generalization Test |
| ================================================= |
| Tests whether the model has learned NQR physics or just memorized the |
| synthetic generator's parametric family. |
| |
| Run locally: |
| python ood_generalization_test.py |
| |
| Tests 8 realistic domain shifts that REAL NQR signals would have: |
| 1. Parameter extrapolation (signal params outside training bounds) |
| 2. Temperature-induced frequency drift (±10-30 kHz) |
| 3. MAPER contamination (magneto-acoustic ringing) |
| 4. Receiver dead time (first 100-500μs zeroed) |
| 5. Colored/correlated noise (1/f + narrowband RFI) |
| 6. Lorentzian lineshape (not Voigt — the real NQR shape) |
| 7. Multi-polymorph signals (two overlapping lines) |
| 8. Powder-broadened signals (inhomogeneous lineshape) |
| |
| If accuracy drops significantly (<90%) on these tests, the model is |
| overfitting to the synthetic generator and would fail on real hardware data. |
| |
| Expected runtime: ~5-10 min (uses pre-trained model from outputs/models/) |
| """ |
|
|
| import os |
| import sys |
| import time |
| import json |
| import numpy as np |
| import torch |
|
|
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
|
|
| from nqr_snn import config |
| from nqr_snn.data.dataset import extract_features_batch |
| from nqr_snn.data.generator import generate_noise_only_at_power |
| from nqr_snn.snn.model import SpikingClassifier, SPIKINGJELLY_AVAILABLE |
| from nqr_snn.snn.encoder import DeterministicEncoder |
| from nqr_snn.snn.ensemble import SNNEnsemble |
| from nqr_snn.evaluation.metrics import full_report |
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| N_PER_CLASS = 200 |
|
|
|
|
| |
| |
| |
|
|
| def _time_axis(): |
| return np.arange(config.SIGNAL_LENGTH) * config.SAMPLING_INTERVAL |
|
|
|
|
| def _add_noise_at_snr(clean, target_snr_db, rng, noise_type="white"): |
| """Add noise to achieve target SNR.""" |
| from nqr_snn.data.noise_sources import generate_noise_sample |
| noise_raw = generate_noise_sample(noise_type, rng) |
| signal_power = np.mean(np.abs(clean)**2) |
| noise_power = np.mean(np.abs(noise_raw)**2) |
| target_linear = 10 ** (target_snr_db / 10) |
| desired_noise_power = signal_power / target_linear |
| if noise_power > 0: |
| noise_scale = np.sqrt(desired_noise_power / noise_power) |
| return clean + noise_scale * noise_raw |
| return clean + noise_raw |
|
|
|
|
| |
| def gen_extrapolated_params(n, snr_db=-35, seed=42): |
| """Signal params OUTSIDE training ranges. |
| Training uses: A=[0.5,1.5], sigma=[0.001,0.01], T2=[0.001,0.01], nu=[-150,150] |
| This uses: A=[0.05,0.3] or [2.0,5.0], sigma=[0.02,0.05], T2=[0.02,0.05], nu=[-500,-200] or [200,500] |
| """ |
| rng = np.random.RandomState(seed) |
| t = _time_axis() |
| signals = [] |
| for _ in range(n): |
| |
| A = rng.choice([rng.uniform(0.05, 0.3), rng.uniform(2.0, 5.0)]) |
| sigma = rng.uniform(0.02, 0.05) |
| T2 = rng.uniform(0.02, 0.05) |
| nu = rng.choice([rng.uniform(-500, -200), rng.uniform(200, 500)]) |
| phi = rng.uniform(-np.pi, np.pi) |
| clean = A * np.exp(-t**2/(2*sigma**2) - t/T2 + 1j*(2*np.pi*nu*t + phi)) |
| noisy = _add_noise_at_snr(clean, snr_db, rng) |
| signals.append(noisy) |
| return np.array(signals, dtype=np.complex128) |
|
|
|
|
| |
| def gen_temp_drift(n, snr_db=-35, seed=42): |
| """Simulate temperature-induced frequency shift (±10-30 kHz). |
| Real TNT shifts ~800 Hz/°C. Over -20 to +50°C range = ±28 kHz. |
| Training uses nu=[-150, 150] Hz. This adds ±10000-30000 Hz shift. |
| """ |
| rng = np.random.RandomState(seed) |
| t = _time_axis() |
| signals = [] |
| for _ in range(n): |
| A = rng.uniform(0.5, 1.5) |
| sigma = rng.uniform(0.001, 0.01) |
| T2 = rng.uniform(0.001, 0.01) |
| |
| nu_base = rng.uniform(-150, 150) |
| temp_shift = rng.choice([-1, 1]) * rng.uniform(10000, 30000) |
| nu = nu_base + temp_shift |
| phi = rng.uniform(-np.pi, np.pi) |
| clean = A * np.exp(-t**2/(2*sigma**2) - t/T2 + 1j*(2*np.pi*nu*t + phi)) |
| noisy = _add_noise_at_snr(clean, snr_db, rng) |
| signals.append(noisy) |
| return np.array(signals, dtype=np.complex128) |
|
|
|
|
| |
| def gen_maper(n, snr_db=-35, seed=42): |
| """Add Magneto-Acoustic/PiezoElectric Ringing (MAPER). |
| MAPER is the #1 real-world artifact: coherent exponentially-decaying |
| oscillation caused by metal objects in the RF field. Can be 10-100x |
| stronger than the NQR signal. |
| """ |
| rng = np.random.RandomState(seed) |
| t = _time_axis() |
| signals = [] |
| for _ in range(n): |
| |
| A = rng.uniform(0.5, 1.5) |
| sigma = rng.uniform(0.001, 0.01) |
| T2 = rng.uniform(0.001, 0.01) |
| nu = rng.uniform(-150, 150) |
| phi = rng.uniform(-np.pi, np.pi) |
| clean = A * np.exp(-t**2/(2*sigma**2) - t/T2 + 1j*(2*np.pi*nu*t + phi)) |
|
|
| |
| maper_amp = rng.uniform(5, 50) * A |
| maper_freq = rng.uniform(500, 5000) |
| maper_tau = rng.uniform(0.0005, 0.005) |
| maper_phase = rng.uniform(-np.pi, np.pi) |
| maper = maper_amp * np.exp(-t/maper_tau) * np.exp(1j*(2*np.pi*maper_freq*t + maper_phase)) |
|
|
| noisy = _add_noise_at_snr(clean + maper, snr_db, rng) |
| signals.append(noisy) |
| return np.array(signals, dtype=np.complex128) |
|
|
|
|
| |
| def gen_dead_time(n, snr_db=-35, seed=42): |
| """Zero out first 100-500μs (receiver recovery after TX pulse). |
| Real NQR receivers cannot capture signal during dead time. |
| Training signals start at t=0 which is physically unrealizable. |
| """ |
| rng = np.random.RandomState(seed) |
| t = _time_axis() |
| signals = [] |
| for _ in range(n): |
| A = rng.uniform(0.5, 1.5) |
| sigma = rng.uniform(0.001, 0.01) |
| T2 = rng.uniform(0.001, 0.01) |
| nu = rng.uniform(-150, 150) |
| phi = rng.uniform(-np.pi, np.pi) |
| clean = A * np.exp(-t**2/(2*sigma**2) - t/T2 + 1j*(2*np.pi*nu*t + phi)) |
|
|
| |
| dead_samples = rng.randint(6, 28) |
| clean[:dead_samples] = 0.0 |
|
|
| noisy = _add_noise_at_snr(clean, snr_db, rng) |
| signals.append(noisy) |
| return np.array(signals, dtype=np.complex128) |
|
|
|
|
| |
| def gen_colored_noise(n, snr_db=-35, seed=42): |
| """Replace white Gaussian with realistic colored noise: |
| - 1/f spectral shape from probe coil |
| - Narrowband RFI spikes (power line harmonics) |
| - Impulsive noise bursts |
| """ |
| rng = np.random.RandomState(seed) |
| t = _time_axis() |
| L = config.SIGNAL_LENGTH |
| signals = [] |
| for _ in range(n): |
| A = rng.uniform(0.5, 1.5) |
| sigma = rng.uniform(0.001, 0.01) |
| T2 = rng.uniform(0.001, 0.01) |
| nu = rng.uniform(-150, 150) |
| phi = rng.uniform(-np.pi, np.pi) |
| clean = A * np.exp(-t**2/(2*sigma**2) - t/T2 + 1j*(2*np.pi*nu*t + phi)) |
|
|
| |
| white = rng.randn(L) + 1j * rng.randn(L) |
| freqs = np.fft.fftfreq(L, d=config.SAMPLING_INTERVAL) |
| alpha = rng.uniform(0.5, 1.5) |
| color_filter = np.ones(L) |
| nonzero = freqs != 0 |
| color_filter[nonzero] = 1.0 / (np.abs(freqs[nonzero]) ** (alpha/2) + 0.1) |
| colored = np.fft.ifft(np.fft.fft(white) * color_filter) |
|
|
| |
| n_rfi = rng.randint(5, 10) |
| for _ in range(n_rfi): |
| rfi_freq = rng.uniform(-5000, 5000) |
| rfi_amp = rng.uniform(0.5, 3.0) |
| rfi_phase = rng.uniform(-np.pi, np.pi) |
| colored += rfi_amp * np.exp(1j * (2*np.pi*rfi_freq*t + rfi_phase)) |
|
|
| |
| n_impulses = rng.randint(2, 5) |
| for _ in range(n_impulses): |
| pos = rng.randint(0, L) |
| width = rng.randint(1, 10) |
| amp = rng.uniform(5, 20) |
| colored[pos:min(pos+width, L)] += amp * (rng.randn(min(width, L-pos)) + 1j*rng.randn(min(width, L-pos))) |
|
|
| |
| signal_power = np.mean(np.abs(clean)**2) |
| noise_power = np.mean(np.abs(colored)**2) |
| target_linear = 10 ** (snr_db / 10) |
| desired_noise_power = signal_power / target_linear |
| if noise_power > 0: |
| noise_scale = np.sqrt(desired_noise_power / noise_power) |
| noisy = clean + noise_scale * colored |
| else: |
| noisy = clean + colored |
| signals.append(noisy.astype(np.complex128)) |
| return np.array(signals, dtype=np.complex128) |
|
|
|
|
| |
| def gen_lorentzian(n, snr_db=-35, seed=42): |
| """Pure Lorentzian decay (no Gaussian component). |
| Real NQR FID for single crystals is purely Lorentzian: exp(-t/T2). |
| Training uses Voigt (Gaussian × Lorentzian). This tests if the model |
| learned "exponential decay" or specifically "Voigt shape." |
| """ |
| rng = np.random.RandomState(seed) |
| t = _time_axis() |
| signals = [] |
| for _ in range(n): |
| A = rng.uniform(0.5, 1.5) |
| T2 = rng.uniform(0.001, 0.01) |
| nu = rng.uniform(-150, 150) |
| phi = rng.uniform(-np.pi, np.pi) |
| |
| clean = A * np.exp(-t/T2 + 1j*(2*np.pi*nu*t + phi)) |
| noisy = _add_noise_at_snr(clean, snr_db, rng) |
| signals.append(noisy) |
| return np.array(signals, dtype=np.complex128) |
|
|
|
|
| |
| def gen_multiline(n, snr_db=-35, seed=42): |
| """Two NQR lines from different polymorphs/sites. |
| RDX has 3 inequivalent nitrogen sites. TNT has α/β forms. |
| Training uses single-line signals only. |
| """ |
| rng = np.random.RandomState(seed) |
| t = _time_axis() |
| signals = [] |
| for _ in range(n): |
| |
| A1 = rng.uniform(0.3, 1.0) |
| T2_1 = rng.uniform(0.001, 0.01) |
| nu1 = rng.uniform(-150, 0) |
| sigma1 = rng.uniform(0.001, 0.01) |
| phi1 = rng.uniform(-np.pi, np.pi) |
| line1 = A1 * np.exp(-t**2/(2*sigma1**2) - t/T2_1 + 1j*(2*np.pi*nu1*t + phi1)) |
|
|
| |
| A2 = rng.uniform(0.2, 0.8) |
| T2_2 = rng.uniform(0.002, 0.015) |
| nu2 = rng.uniform(50, 300) |
| sigma2 = rng.uniform(0.001, 0.01) |
| phi2 = rng.uniform(-np.pi, np.pi) |
| line2 = A2 * np.exp(-t**2/(2*sigma2**2) - t/T2_2 + 1j*(2*np.pi*nu2*t + phi2)) |
|
|
| clean = line1 + line2 |
| noisy = _add_noise_at_snr(clean, snr_db, rng) |
| signals.append(noisy) |
| return np.array(signals, dtype=np.complex128) |
|
|
|
|
| |
| def gen_powder_broadened(n, snr_db=-35, seed=42): |
| """Simulate powder broadening: sum of many crystallites at slightly |
| different frequencies. Creates a non-Voigt, asymmetric lineshape. |
| Real explosives are polycrystalline powders, not single crystals. |
| """ |
| rng = np.random.RandomState(seed) |
| t = _time_axis() |
| signals = [] |
| for _ in range(n): |
| A_total = rng.uniform(0.5, 1.5) |
| T2 = rng.uniform(0.001, 0.01) |
| sigma = rng.uniform(0.001, 0.01) |
| nu_center = rng.uniform(-150, 150) |
| phi = rng.uniform(-np.pi, np.pi) |
|
|
| |
| n_crystallites = rng.randint(20, 50) |
| freq_spread = rng.uniform(100, 500) |
| |
| offsets = freq_spread * (rng.beta(2, 5, n_crystallites) - 0.3) |
|
|
| clean = np.zeros(config.SIGNAL_LENGTH, dtype=np.complex128) |
| for offset in offsets: |
| nu_i = nu_center + offset |
| amp_i = A_total / n_crystallites * rng.uniform(0.5, 1.5) |
| clean += amp_i * np.exp(-t**2/(2*sigma**2) - t/T2 + 1j*(2*np.pi*nu_i*t + phi)) |
|
|
| noisy = _add_noise_at_snr(clean, snr_db, rng) |
| signals.append(noisy) |
| return np.array(signals, dtype=np.complex128) |
|
|
|
|
| |
| |
| |
|
|
| def gen_noise_white(n, seed=99): |
| """Standard white noise (same as training).""" |
| rng = np.random.RandomState(seed) |
| return np.array([generate_noise_only_at_power("white", rng, 1.0) for _ in range(n)], dtype=np.complex128) |
|
|
| def gen_noise_colored(n, seed=99): |
| """Colored noise matching the colored test scenario.""" |
| rng = np.random.RandomState(seed) |
| t = _time_axis() |
| L = config.SIGNAL_LENGTH |
| signals = [] |
| for _ in range(n): |
| white = rng.randn(L) + 1j * rng.randn(L) |
| freqs = np.fft.fftfreq(L, d=config.SAMPLING_INTERVAL) |
| alpha = rng.uniform(0.5, 1.5) |
| color_filter = np.ones(L) |
| nonzero = freqs != 0 |
| color_filter[nonzero] = 1.0 / (np.abs(freqs[nonzero]) ** (alpha/2) + 0.1) |
| colored = np.fft.ifft(np.fft.fft(white) * color_filter) |
| n_rfi = rng.randint(5, 10) |
| for _ in range(n_rfi): |
| colored += rng.uniform(0.5, 3.0) * np.exp(1j*(2*np.pi*rng.uniform(-5000,5000)*t + rng.uniform(-np.pi,np.pi))) |
| power = np.mean(np.abs(colored)**2) |
| if power > 0: |
| colored = colored * np.sqrt(1.0 / power) |
| signals.append(colored.astype(np.complex128)) |
| return np.array(signals, dtype=np.complex128) |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| print("="*80) |
| print("NQR-SNN OUT-OF-DISTRIBUTION GENERALIZATION TEST") |
| print("="*80) |
| print(f"Device: {DEVICE} | SpikingJelly: {SPIKINGJELLY_AVAILABLE}") |
| print(f"Test samples: {N_PER_CLASS} per class per scenario") |
| print(f"SNR: -35 dB (mid-range challenge)") |
| print() |
|
|
| |
| print("Loading ensemble from outputs/models/ ...") |
| if not os.path.exists(config.MODELS_DIR): |
| print("ERROR: No trained models found. Run the pipeline first:") |
| print(" python run_full_pipeline.py --quick") |
| sys.exit(1) |
|
|
| ensemble = SNNEnsemble(ensemble_size=config.ENSEMBLE_SIZE, device=DEVICE, heterogeneous=True) |
| ensemble.load_checkpoints() |
| encoder = DeterministicEncoder() |
| print(f" Loaded {len(ensemble.models)} models\n") |
|
|
| |
| tests = [ |
| ("1_param_extrap", "Parameter Extrapolation (outside training bounds)", gen_extrapolated_params), |
| ("2_temp_drift", "Temperature Frequency Drift (±10-30 kHz)", gen_temp_drift), |
| ("3_maper", "MAPER Contamination (5-50x ringing)", gen_maper), |
| ("4_dead_time", "Receiver Dead Time (first 6-28 samples zeroed)", gen_dead_time), |
| ("5_colored_noise", "Colored Noise + RFI (1/f + narrowband)", gen_colored_noise), |
| ("6_lorentzian", "Pure Lorentzian (no Gaussian envelope)", gen_lorentzian), |
| ("7_multiline", "Multi-Polymorph (2 overlapping NQR lines)", gen_multiline), |
| ("8_powder", "Powder Broadened (non-Voigt asymmetric)", gen_powder_broadened), |
| ] |
|
|
| results = {} |
| print(f"{'Test':<45} {'Acc':>7} {'AUC':>7} {'F1':>7} {'TPR':>7} {'TNR':>7}") |
| print("-"*80) |
|
|
| for test_id, test_name, gen_fn in tests: |
| |
| signals_present = gen_fn(N_PER_CLASS, snr_db=-35, seed=42) |
|
|
| |
| if "colored" in test_id: |
| signals_noise = gen_noise_colored(N_PER_CLASS, seed=99) |
| else: |
| signals_noise = gen_noise_white(N_PER_CLASS, seed=99) |
|
|
| |
| all_signals = np.concatenate([signals_present, signals_noise], axis=0) |
| all_labels = np.array([1]*N_PER_CLASS + [0]*N_PER_CLASS) |
|
|
| |
| feats = extract_features_batch(all_signals) |
| feat_tensor = torch.from_numpy(feats) |
| x_seq = encoder.encode(feat_tensor).to(DEVICE) |
| with torch.no_grad(): |
| mean_p, std_p = ensemble.predict(x_seq) |
| mean_p_np = mean_p.cpu().numpy() |
|
|
| |
| rep = full_report(all_labels, mean_p_np) |
| results[test_id] = { |
| "name": test_name, |
| "accuracy": rep["accuracy"], |
| "auc": rep["auc"], |
| "f1": rep["f1"], |
| "tpr": rep["tpr"], |
| "tnr": rep["tnr"], |
| "mean_prob_signal": float(np.mean(mean_p_np[:N_PER_CLASS])), |
| "mean_prob_noise": float(np.mean(mean_p_np[N_PER_CLASS:])), |
| } |
|
|
| print(f"{test_name:<45} {rep['accuracy']:>6.3f} {rep['auc']:>7.3f} {rep['f1']:>7.3f} " |
| f"{rep['tpr']:>7.3f} {rep['tnr']:>7.3f}") |
|
|
| |
| print("-"*80) |
| from nqr_snn.data.generator import generate_signal_at_snr |
| rng = np.random.RandomState(42) |
| id_signals = np.array([generate_signal_at_snr(-35, "white", rng)[0] for _ in range(N_PER_CLASS)], dtype=np.complex128) |
| id_noise = gen_noise_white(N_PER_CLASS, seed=99) |
| id_all = np.concatenate([id_signals, id_noise]) |
| id_labels = np.array([1]*N_PER_CLASS + [0]*N_PER_CLASS) |
| id_feats = torch.from_numpy(extract_features_batch(id_all)) |
| id_x = encoder.encode(id_feats).to(DEVICE) |
| with torch.no_grad(): |
| id_p, _ = ensemble.predict(id_x) |
| id_rep = full_report(id_labels, id_p.cpu().numpy()) |
| results["0_in_distribution"] = {"name": "IN-DISTRIBUTION CONTROL", **id_rep} |
| print(f"{'IN-DISTRIBUTION CONTROL (reference)':<45} {id_rep['accuracy']:>6.3f} {id_rep['auc']:>7.3f} " |
| f"{id_rep['f1']:>7.3f} {id_rep['tpr']:>7.3f} {id_rep['tnr']:>7.3f}") |
|
|
| |
| print("\n" + "="*80) |
| print("GENERALIZATION VERDICT") |
| print("="*80) |
| id_acc = results["0_in_distribution"]["accuracy"] |
| n_fail = 0 |
| for test_id, r in sorted(results.items()): |
| if test_id == "0_in_distribution": |
| continue |
| acc = r["accuracy"] |
| drop = id_acc - acc |
| if drop > 0.10: |
| verdict = "FAIL (>10% drop)" |
| n_fail += 1 |
| elif drop > 0.05: |
| verdict = "WEAK (5-10% drop)" |
| n_fail += 1 |
| else: |
| verdict = "OK (<5% drop)" |
| print(f" {r['name']:<50} acc={acc:.3f} drop={drop:+.3f} {verdict}") |
|
|
| print(f"\n In-distribution accuracy: {id_acc:.3f}") |
| print(f" Tests with >5% accuracy drop: {n_fail}/8") |
|
|
| if n_fail >= 4: |
| print("\n CONCLUSION: MODEL IS OVERFITTING TO THE SYNTHETIC GENERATOR.") |
| print(" It has learned the parametric family, not NQR physics.") |
| print(" Real-world deployment would likely fail.") |
| elif n_fail >= 2: |
| print("\n CONCLUSION: PARTIAL GENERALIZATION.") |
| print(" Model handles some domain shifts but fails on others.") |
| print(" Needs augmentation or more realistic training data.") |
| else: |
| print("\n CONCLUSION: GOOD GENERALIZATION.") |
| print(" Model appears to have learned signal-vs-noise discrimination broadly.") |
|
|
| |
| os.makedirs("outputs/results", exist_ok=True) |
| with open("outputs/results/ood_generalization_results.json", "w") as f: |
| json.dump(results, f, indent=2, default=str) |
| print(f"\n Results saved: outputs/results/ood_generalization_results.json") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|