nqr-snn-framework / ood_generalization_test.py
KD099's picture
Add ood_generalization_test.py — tests model against realistic OOD scenarios to expose overfitting
e93a3de verified
#!/usr/bin/env python3
"""
NQR-SNN Out-of-Distribution Generalization Test
=================================================
Tests whether the model has learned NQR physics or just memorized the
synthetic generator's parametric family.
Run locally:
python ood_generalization_test.py
Tests 8 realistic domain shifts that REAL NQR signals would have:
1. Parameter extrapolation (signal params outside training bounds)
2. Temperature-induced frequency drift (±10-30 kHz)
3. MAPER contamination (magneto-acoustic ringing)
4. Receiver dead time (first 100-500μs zeroed)
5. Colored/correlated noise (1/f + narrowband RFI)
6. Lorentzian lineshape (not Voigt — the real NQR shape)
7. Multi-polymorph signals (two overlapping lines)
8. Powder-broadened signals (inhomogeneous lineshape)
If accuracy drops significantly (<90%) on these tests, the model is
overfitting to the synthetic generator and would fail on real hardware data.
Expected runtime: ~5-10 min (uses pre-trained model from outputs/models/)
"""
import os
import sys
import time
import json
import numpy as np
import torch
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from nqr_snn import config
from nqr_snn.data.dataset import extract_features_batch
from nqr_snn.data.generator import generate_noise_only_at_power
from nqr_snn.snn.model import SpikingClassifier, SPIKINGJELLY_AVAILABLE
from nqr_snn.snn.encoder import DeterministicEncoder
from nqr_snn.snn.ensemble import SNNEnsemble
from nqr_snn.evaluation.metrics import full_report
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
N_PER_CLASS = 200 # signals per class per test
# ═══════════════════════════════════════════════════════════════════
# SIGNAL GENERATORS — Realistic OOD Scenarios
# ═══════════════════════════════════════════════════════════════════
def _time_axis():
return np.arange(config.SIGNAL_LENGTH) * config.SAMPLING_INTERVAL
def _add_noise_at_snr(clean, target_snr_db, rng, noise_type="white"):
"""Add noise to achieve target SNR."""
from nqr_snn.data.noise_sources import generate_noise_sample
noise_raw = generate_noise_sample(noise_type, rng)
signal_power = np.mean(np.abs(clean)**2)
noise_power = np.mean(np.abs(noise_raw)**2)
target_linear = 10 ** (target_snr_db / 10)
desired_noise_power = signal_power / target_linear
if noise_power > 0:
noise_scale = np.sqrt(desired_noise_power / noise_power)
return clean + noise_scale * noise_raw
return clean + noise_raw
# --- TEST 1: Parameter Extrapolation ---
def gen_extrapolated_params(n, snr_db=-35, seed=42):
"""Signal params OUTSIDE training ranges.
Training uses: A=[0.5,1.5], sigma=[0.001,0.01], T2=[0.001,0.01], nu=[-150,150]
This uses: A=[0.05,0.3] or [2.0,5.0], sigma=[0.02,0.05], T2=[0.02,0.05], nu=[-500,-200] or [200,500]
"""
rng = np.random.RandomState(seed)
t = _time_axis()
signals = []
for _ in range(n):
# Parameters outside training bounds
A = rng.choice([rng.uniform(0.05, 0.3), rng.uniform(2.0, 5.0)])
sigma = rng.uniform(0.02, 0.05) # wider than training [0.001, 0.01]
T2 = rng.uniform(0.02, 0.05) # longer than training
nu = rng.choice([rng.uniform(-500, -200), rng.uniform(200, 500)]) # beyond ±150
phi = rng.uniform(-np.pi, np.pi)
clean = A * np.exp(-t**2/(2*sigma**2) - t/T2 + 1j*(2*np.pi*nu*t + phi))
noisy = _add_noise_at_snr(clean, snr_db, rng)
signals.append(noisy)
return np.array(signals, dtype=np.complex128)
# --- TEST 2: Temperature Frequency Drift ---
def gen_temp_drift(n, snr_db=-35, seed=42):
"""Simulate temperature-induced frequency shift (±10-30 kHz).
Real TNT shifts ~800 Hz/°C. Over -20 to +50°C range = ±28 kHz.
Training uses nu=[-150, 150] Hz. This adds ±10000-30000 Hz shift.
"""
rng = np.random.RandomState(seed)
t = _time_axis()
signals = []
for _ in range(n):
A = rng.uniform(0.5, 1.5)
sigma = rng.uniform(0.001, 0.01)
T2 = rng.uniform(0.001, 0.01)
# Normal frequency PLUS large temperature drift
nu_base = rng.uniform(-150, 150)
temp_shift = rng.choice([-1, 1]) * rng.uniform(10000, 30000) # ±10-30 kHz
nu = nu_base + temp_shift
phi = rng.uniform(-np.pi, np.pi)
clean = A * np.exp(-t**2/(2*sigma**2) - t/T2 + 1j*(2*np.pi*nu*t + phi))
noisy = _add_noise_at_snr(clean, snr_db, rng)
signals.append(noisy)
return np.array(signals, dtype=np.complex128)
# --- TEST 3: MAPER Contamination ---
def gen_maper(n, snr_db=-35, seed=42):
"""Add Magneto-Acoustic/PiezoElectric Ringing (MAPER).
MAPER is the #1 real-world artifact: coherent exponentially-decaying
oscillation caused by metal objects in the RF field. Can be 10-100x
stronger than the NQR signal.
"""
rng = np.random.RandomState(seed)
t = _time_axis()
signals = []
for _ in range(n):
# Normal NQR signal
A = rng.uniform(0.5, 1.5)
sigma = rng.uniform(0.001, 0.01)
T2 = rng.uniform(0.001, 0.01)
nu = rng.uniform(-150, 150)
phi = rng.uniform(-np.pi, np.pi)
clean = A * np.exp(-t**2/(2*sigma**2) - t/T2 + 1j*(2*np.pi*nu*t + phi))
# MAPER: exponentially decaying sinusoid at different frequency
maper_amp = rng.uniform(5, 50) * A # 5-50x signal amplitude
maper_freq = rng.uniform(500, 5000) # different from NQR freq
maper_tau = rng.uniform(0.0005, 0.005) # 0.5-5 ms decay
maper_phase = rng.uniform(-np.pi, np.pi)
maper = maper_amp * np.exp(-t/maper_tau) * np.exp(1j*(2*np.pi*maper_freq*t + maper_phase))
noisy = _add_noise_at_snr(clean + maper, snr_db, rng)
signals.append(noisy)
return np.array(signals, dtype=np.complex128)
# --- TEST 4: Receiver Dead Time ---
def gen_dead_time(n, snr_db=-35, seed=42):
"""Zero out first 100-500μs (receiver recovery after TX pulse).
Real NQR receivers cannot capture signal during dead time.
Training signals start at t=0 which is physically unrealizable.
"""
rng = np.random.RandomState(seed)
t = _time_axis()
signals = []
for _ in range(n):
A = rng.uniform(0.5, 1.5)
sigma = rng.uniform(0.001, 0.01)
T2 = rng.uniform(0.001, 0.01)
nu = rng.uniform(-150, 150)
phi = rng.uniform(-np.pi, np.pi)
clean = A * np.exp(-t**2/(2*sigma**2) - t/T2 + 1j*(2*np.pi*nu*t + phi))
# Zero out dead time: 100-500 μs at 18μs sampling = 6-28 samples
dead_samples = rng.randint(6, 28)
clean[:dead_samples] = 0.0
noisy = _add_noise_at_snr(clean, snr_db, rng)
signals.append(noisy)
return np.array(signals, dtype=np.complex128)
# --- TEST 5: Colored Noise + RFI ---
def gen_colored_noise(n, snr_db=-35, seed=42):
"""Replace white Gaussian with realistic colored noise:
- 1/f spectral shape from probe coil
- Narrowband RFI spikes (power line harmonics)
- Impulsive noise bursts
"""
rng = np.random.RandomState(seed)
t = _time_axis()
L = config.SIGNAL_LENGTH
signals = []
for _ in range(n):
A = rng.uniform(0.5, 1.5)
sigma = rng.uniform(0.001, 0.01)
T2 = rng.uniform(0.001, 0.01)
nu = rng.uniform(-150, 150)
phi = rng.uniform(-np.pi, np.pi)
clean = A * np.exp(-t**2/(2*sigma**2) - t/T2 + 1j*(2*np.pi*nu*t + phi))
# Colored noise: 1/f^alpha
white = rng.randn(L) + 1j * rng.randn(L)
freqs = np.fft.fftfreq(L, d=config.SAMPLING_INTERVAL)
alpha = rng.uniform(0.5, 1.5) # pink to brown noise
color_filter = np.ones(L)
nonzero = freqs != 0
color_filter[nonzero] = 1.0 / (np.abs(freqs[nonzero]) ** (alpha/2) + 0.1)
colored = np.fft.ifft(np.fft.fft(white) * color_filter)
# Add narrowband RFI (5-10 interferers at random frequencies)
n_rfi = rng.randint(5, 10)
for _ in range(n_rfi):
rfi_freq = rng.uniform(-5000, 5000)
rfi_amp = rng.uniform(0.5, 3.0)
rfi_phase = rng.uniform(-np.pi, np.pi)
colored += rfi_amp * np.exp(1j * (2*np.pi*rfi_freq*t + rfi_phase))
# Add impulsive bursts (2-5 per signal)
n_impulses = rng.randint(2, 5)
for _ in range(n_impulses):
pos = rng.randint(0, L)
width = rng.randint(1, 10)
amp = rng.uniform(5, 20)
colored[pos:min(pos+width, L)] += amp * (rng.randn(min(width, L-pos)) + 1j*rng.randn(min(width, L-pos)))
# Scale colored noise to achieve target SNR
signal_power = np.mean(np.abs(clean)**2)
noise_power = np.mean(np.abs(colored)**2)
target_linear = 10 ** (snr_db / 10)
desired_noise_power = signal_power / target_linear
if noise_power > 0:
noise_scale = np.sqrt(desired_noise_power / noise_power)
noisy = clean + noise_scale * colored
else:
noisy = clean + colored
signals.append(noisy.astype(np.complex128))
return np.array(signals, dtype=np.complex128)
# --- TEST 6: Pure Lorentzian Lineshape (not Voigt) ---
def gen_lorentzian(n, snr_db=-35, seed=42):
"""Pure Lorentzian decay (no Gaussian component).
Real NQR FID for single crystals is purely Lorentzian: exp(-t/T2).
Training uses Voigt (Gaussian × Lorentzian). This tests if the model
learned "exponential decay" or specifically "Voigt shape."
"""
rng = np.random.RandomState(seed)
t = _time_axis()
signals = []
for _ in range(n):
A = rng.uniform(0.5, 1.5)
T2 = rng.uniform(0.001, 0.01)
nu = rng.uniform(-150, 150)
phi = rng.uniform(-np.pi, np.pi)
# Pure Lorentzian: NO Gaussian envelope (sigma → ∞)
clean = A * np.exp(-t/T2 + 1j*(2*np.pi*nu*t + phi))
noisy = _add_noise_at_snr(clean, snr_db, rng)
signals.append(noisy)
return np.array(signals, dtype=np.complex128)
# --- TEST 7: Multi-Polymorph (two overlapping lines) ---
def gen_multiline(n, snr_db=-35, seed=42):
"""Two NQR lines from different polymorphs/sites.
RDX has 3 inequivalent nitrogen sites. TNT has α/β forms.
Training uses single-line signals only.
"""
rng = np.random.RandomState(seed)
t = _time_axis()
signals = []
for _ in range(n):
# Line 1
A1 = rng.uniform(0.3, 1.0)
T2_1 = rng.uniform(0.001, 0.01)
nu1 = rng.uniform(-150, 0)
sigma1 = rng.uniform(0.001, 0.01)
phi1 = rng.uniform(-np.pi, np.pi)
line1 = A1 * np.exp(-t**2/(2*sigma1**2) - t/T2_1 + 1j*(2*np.pi*nu1*t + phi1))
# Line 2 (different freq, amplitude, T2)
A2 = rng.uniform(0.2, 0.8)
T2_2 = rng.uniform(0.002, 0.015)
nu2 = rng.uniform(50, 300) # separated from line 1
sigma2 = rng.uniform(0.001, 0.01)
phi2 = rng.uniform(-np.pi, np.pi)
line2 = A2 * np.exp(-t**2/(2*sigma2**2) - t/T2_2 + 1j*(2*np.pi*nu2*t + phi2))
clean = line1 + line2
noisy = _add_noise_at_snr(clean, snr_db, rng)
signals.append(noisy)
return np.array(signals, dtype=np.complex128)
# --- TEST 8: Powder Broadening (non-Voigt lineshape) ---
def gen_powder_broadened(n, snr_db=-35, seed=42):
"""Simulate powder broadening: sum of many crystallites at slightly
different frequencies. Creates a non-Voigt, asymmetric lineshape.
Real explosives are polycrystalline powders, not single crystals.
"""
rng = np.random.RandomState(seed)
t = _time_axis()
signals = []
for _ in range(n):
A_total = rng.uniform(0.5, 1.5)
T2 = rng.uniform(0.001, 0.01)
sigma = rng.uniform(0.001, 0.01)
nu_center = rng.uniform(-150, 150)
phi = rng.uniform(-np.pi, np.pi)
# Sum 20-50 crystallite orientations with frequency spread
n_crystallites = rng.randint(20, 50)
freq_spread = rng.uniform(100, 500) # Hz spread from powder
# Asymmetric distribution (not Gaussian — more like Pake doublet)
offsets = freq_spread * (rng.beta(2, 5, n_crystallites) - 0.3)
clean = np.zeros(config.SIGNAL_LENGTH, dtype=np.complex128)
for offset in offsets:
nu_i = nu_center + offset
amp_i = A_total / n_crystallites * rng.uniform(0.5, 1.5)
clean += amp_i * np.exp(-t**2/(2*sigma**2) - t/T2 + 1j*(2*np.pi*nu_i*t + phi))
noisy = _add_noise_at_snr(clean, snr_db, rng)
signals.append(noisy)
return np.array(signals, dtype=np.complex128)
# ═══════════════════════════════════════════════════════════════════
# NOISE-ONLY GENERATORS (matching power for fair comparison)
# ═══════════════════════════════════════════════════════════════════
def gen_noise_white(n, seed=99):
"""Standard white noise (same as training)."""
rng = np.random.RandomState(seed)
return np.array([generate_noise_only_at_power("white", rng, 1.0) for _ in range(n)], dtype=np.complex128)
def gen_noise_colored(n, seed=99):
"""Colored noise matching the colored test scenario."""
rng = np.random.RandomState(seed)
t = _time_axis()
L = config.SIGNAL_LENGTH
signals = []
for _ in range(n):
white = rng.randn(L) + 1j * rng.randn(L)
freqs = np.fft.fftfreq(L, d=config.SAMPLING_INTERVAL)
alpha = rng.uniform(0.5, 1.5)
color_filter = np.ones(L)
nonzero = freqs != 0
color_filter[nonzero] = 1.0 / (np.abs(freqs[nonzero]) ** (alpha/2) + 0.1)
colored = np.fft.ifft(np.fft.fft(white) * color_filter)
n_rfi = rng.randint(5, 10)
for _ in range(n_rfi):
colored += rng.uniform(0.5, 3.0) * np.exp(1j*(2*np.pi*rng.uniform(-5000,5000)*t + rng.uniform(-np.pi,np.pi)))
power = np.mean(np.abs(colored)**2)
if power > 0:
colored = colored * np.sqrt(1.0 / power)
signals.append(colored.astype(np.complex128))
return np.array(signals, dtype=np.complex128)
# ═══════════════════════════════════════════════════════════════════
# MAIN TEST RUNNER
# ═══════════════════════════════════════════════════════════════════
def main():
print("="*80)
print("NQR-SNN OUT-OF-DISTRIBUTION GENERALIZATION TEST")
print("="*80)
print(f"Device: {DEVICE} | SpikingJelly: {SPIKINGJELLY_AVAILABLE}")
print(f"Test samples: {N_PER_CLASS} per class per scenario")
print(f"SNR: -35 dB (mid-range challenge)")
print()
# Load model
print("Loading ensemble from outputs/models/ ...")
if not os.path.exists(config.MODELS_DIR):
print("ERROR: No trained models found. Run the pipeline first:")
print(" python run_full_pipeline.py --quick")
sys.exit(1)
ensemble = SNNEnsemble(ensemble_size=config.ENSEMBLE_SIZE, device=DEVICE, heterogeneous=True)
ensemble.load_checkpoints()
encoder = DeterministicEncoder()
print(f" Loaded {len(ensemble.models)} models\n")
# Define test scenarios
tests = [
("1_param_extrap", "Parameter Extrapolation (outside training bounds)", gen_extrapolated_params),
("2_temp_drift", "Temperature Frequency Drift (±10-30 kHz)", gen_temp_drift),
("3_maper", "MAPER Contamination (5-50x ringing)", gen_maper),
("4_dead_time", "Receiver Dead Time (first 6-28 samples zeroed)", gen_dead_time),
("5_colored_noise", "Colored Noise + RFI (1/f + narrowband)", gen_colored_noise),
("6_lorentzian", "Pure Lorentzian (no Gaussian envelope)", gen_lorentzian),
("7_multiline", "Multi-Polymorph (2 overlapping NQR lines)", gen_multiline),
("8_powder", "Powder Broadened (non-Voigt asymmetric)", gen_powder_broadened),
]
results = {}
print(f"{'Test':<45} {'Acc':>7} {'AUC':>7} {'F1':>7} {'TPR':>7} {'TNR':>7}")
print("-"*80)
for test_id, test_name, gen_fn in tests:
# Generate signal-present (label=1)
signals_present = gen_fn(N_PER_CLASS, snr_db=-35, seed=42)
# Generate noise-only (label=0) — use colored noise for colored test, white otherwise
if "colored" in test_id:
signals_noise = gen_noise_colored(N_PER_CLASS, seed=99)
else:
signals_noise = gen_noise_white(N_PER_CLASS, seed=99)
# Combine
all_signals = np.concatenate([signals_present, signals_noise], axis=0)
all_labels = np.array([1]*N_PER_CLASS + [0]*N_PER_CLASS)
# Predict
feats = extract_features_batch(all_signals)
feat_tensor = torch.from_numpy(feats)
x_seq = encoder.encode(feat_tensor).to(DEVICE)
with torch.no_grad():
mean_p, std_p = ensemble.predict(x_seq)
mean_p_np = mean_p.cpu().numpy()
# Evaluate
rep = full_report(all_labels, mean_p_np)
results[test_id] = {
"name": test_name,
"accuracy": rep["accuracy"],
"auc": rep["auc"],
"f1": rep["f1"],
"tpr": rep["tpr"],
"tnr": rep["tnr"],
"mean_prob_signal": float(np.mean(mean_p_np[:N_PER_CLASS])),
"mean_prob_noise": float(np.mean(mean_p_np[N_PER_CLASS:])),
}
print(f"{test_name:<45} {rep['accuracy']:>6.3f} {rep['auc']:>7.3f} {rep['f1']:>7.3f} "
f"{rep['tpr']:>7.3f} {rep['tnr']:>7.3f}")
# Also run the IN-DISTRIBUTION control test
print("-"*80)
from nqr_snn.data.generator import generate_signal_at_snr
rng = np.random.RandomState(42)
id_signals = np.array([generate_signal_at_snr(-35, "white", rng)[0] for _ in range(N_PER_CLASS)], dtype=np.complex128)
id_noise = gen_noise_white(N_PER_CLASS, seed=99)
id_all = np.concatenate([id_signals, id_noise])
id_labels = np.array([1]*N_PER_CLASS + [0]*N_PER_CLASS)
id_feats = torch.from_numpy(extract_features_batch(id_all))
id_x = encoder.encode(id_feats).to(DEVICE)
with torch.no_grad():
id_p, _ = ensemble.predict(id_x)
id_rep = full_report(id_labels, id_p.cpu().numpy())
results["0_in_distribution"] = {"name": "IN-DISTRIBUTION CONTROL", **id_rep}
print(f"{'IN-DISTRIBUTION CONTROL (reference)':<45} {id_rep['accuracy']:>6.3f} {id_rep['auc']:>7.3f} "
f"{id_rep['f1']:>7.3f} {id_rep['tpr']:>7.3f} {id_rep['tnr']:>7.3f}")
# Summary
print("\n" + "="*80)
print("GENERALIZATION VERDICT")
print("="*80)
id_acc = results["0_in_distribution"]["accuracy"]
n_fail = 0
for test_id, r in sorted(results.items()):
if test_id == "0_in_distribution":
continue
acc = r["accuracy"]
drop = id_acc - acc
if drop > 0.10:
verdict = "FAIL (>10% drop)"
n_fail += 1
elif drop > 0.05:
verdict = "WEAK (5-10% drop)"
n_fail += 1
else:
verdict = "OK (<5% drop)"
print(f" {r['name']:<50} acc={acc:.3f} drop={drop:+.3f} {verdict}")
print(f"\n In-distribution accuracy: {id_acc:.3f}")
print(f" Tests with >5% accuracy drop: {n_fail}/8")
if n_fail >= 4:
print("\n CONCLUSION: MODEL IS OVERFITTING TO THE SYNTHETIC GENERATOR.")
print(" It has learned the parametric family, not NQR physics.")
print(" Real-world deployment would likely fail.")
elif n_fail >= 2:
print("\n CONCLUSION: PARTIAL GENERALIZATION.")
print(" Model handles some domain shifts but fails on others.")
print(" Needs augmentation or more realistic training data.")
else:
print("\n CONCLUSION: GOOD GENERALIZATION.")
print(" Model appears to have learned signal-vs-noise discrimination broadly.")
# Save
os.makedirs("outputs/results", exist_ok=True)
with open("outputs/results/ood_generalization_results.json", "w") as f:
json.dump(results, f, indent=2, default=str)
print(f"\n Results saved: outputs/results/ood_generalization_results.json")
if __name__ == "__main__":
main()