#!/usr/bin/env python3 """ NQR-SNN Optimization Test Suite ================================ Tests each proposed optimization against the baseline and reports accuracy + inference speed. Run locally: python optimization_test.py Expected runtime: ~10-15 min on a modern CPU. Results saved to outputs/results/optimization_results.json """ import os, sys, time, json, types import numpy as np import torch sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from nqr_snn import config from nqr_snn.data.generator import generate_dataset_v2, generate_signal_at_snr, generate_noise_only_at_power, generate_dataset from nqr_snn.data.dataset import NQRDatasetV2, get_balanced_loader_v2, extract_features_batch from nqr_snn.snn.model import SpikingClassifier, SPIKINGJELLY_AVAILABLE from nqr_snn.snn.encoder import DeterministicEncoder from nqr_snn.snn.ensemble import SNNEnsemble from nqr_snn.evaluation.metrics import full_report TRAIN_SIZE, VAL_SIZE, TEST_PER_CLASS, MAX_EPOCHS = 500, 150, 150, 50 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" print(f"SpikingJelly={SPIKINGJELLY_AVAILABLE} Device={DEVICE}") def make_test(snr, n=TEST_PER_CLASS, seed=99): rng = np.random.RandomState(seed) s, l = [], [] for _ in range(n): noisy, _ = generate_signal_at_snr(snr, "white", rng); s.append(noisy); l.append(1) for _ in range(n): s.append(generate_noise_only_at_power("white", rng, 1.0)); l.append(0) return np.array(s, dtype=np.complex128), np.array(l) def evaluate(ens, enc, tag=""): res, tt = {}, 0 for snr in config.EVAL_SNR_LEVELS: s, l = make_test(snr, seed=abs(snr)*100) f = torch.from_numpy(extract_features_batch(s)) t0 = time.time() with torch.no_grad(): x = enc.encode(f).to(DEVICE) if isinstance(enc, torch.nn.Module) else enc.encode(f).to(DEVICE) mp, _ = ens.predict(x); tt += time.time()-t0 r = full_report(l, mp.cpu().numpy()); res[snr] = r print(f" {tag} SNR={snr:4d} | acc={r['accuracy']:.4f} auc={r['auc']:.4f}") a = [v['accuracy'] for v in res.values()] ma, mi = np.mean(a), np.min(a) print(f" {tag} MEAN={ma:.4f} MIN={mi:.4f} T={tt:.3f}s") return {"mean_acc": float(ma), "min_acc": float(mi), "inference_time": float(tt)} def train_ens(sz, tl, vl, mdir): od = config.MODELS_DIR; config.MODELS_DIR = mdir; os.makedirs(mdir, exist_ok=True) e = SNNEnsemble(ensemble_size=sz, device=DEVICE, heterogeneous=True) t0 = time.time(); e.train_all(tl, vl, max_epochs=MAX_EPOCHS); tt = time.time()-t0 e.load_checkpoints(); config.MODELS_DIR = od return e, tt # Data print("\n" + "="*70 + "\nGENERATING DATA\n" + "="*70) tr = generate_dataset_v2(TRAIN_SIZE, config.TRAIN_SNR_RANGE, "white", seed=42) vl = generate_dataset_v2(VAL_SIZE, config.VAL_SNR_RANGE, "white", seed=100) tr_ds = NQRDatasetV2(tr["signals"], tr["labels"]) vl_ds = NQRDatasetV2(vl["signals"], vl["labels"]) tl64 = get_balanced_loader_v2(tr_ds, batch_size=64) vl64 = get_balanced_loader_v2(vl_ds, batch_size=64, shuffle=False) tl128 = get_balanced_loader_v2(tr_ds, batch_size=128) vl128 = get_balanced_loader_v2(vl_ds, batch_size=128, shuffle=False) R = {} # BASELINE print("\n" + "="*70 + "\nBASELINE: 3-member, bs64\n" + "="*70) eb, ttb = train_ens(3, tl64, vl64, "outputs/m_base") enc = DeterministicEncoder() R["baseline"] = {**evaluate(eb, enc, "[BASE]"), "train_time": ttb} # OPT1: CNN ONCE print("\n" + "="*70 + "\nOPT1: CNN SINGLE-PASS\n" + "="*70) if SPIKINGJELLY_AVAILABLE: from spikingjelly.activation_based import functional def fwd1(self, x_seq, return_per_timestep=False): if not self.use_spikingjelly: return self.snn_head(self.cnn(x_seq[0]).unsqueeze(0)) T = x_seq.shape[0] fs = self.cnn(x_seq[0]) for sl in self._snn_layers: functional.reset_net(sl) with torch.no_grad(): e = x_seq.pow(2).sum(dim=(-1,-2)); r = (e/(e[0:1]+1e-8)).unsqueeze(-1) fm = fs.unsqueeze(0).expand(T,-1,-1) * r fm = self.snn_norm(fm); x = self.fc1(fm); x = self.neuron1(x); x = self.drop1(x) x = self.fc2(x); x = self.neuron2(x); x = self.drop2(x); x = self.fc_out(x) if return_per_timestep: return x return torch.sigmoid(x.mean(dim=0)) for m in eb.models: m.forward = types.MethodType(fwd1, m) R["cnn_once"] = evaluate(eb, enc, "[CNN1x]") from nqr_snn.snn.model import SpikingClassifier as SC for m in eb.models: m.forward = types.MethodType(SC.forward, m) else: R["cnn_once"] = R["baseline"] # OPT3: FUSED FEATURES print("\n" + "="*70 + "\nOPT3: FUSED FEATURES (speed only)\n" + "="*70) def fused(signals): N, L = signals.shape r = signals.real.astype(np.float32); i = signals.imag.astype(np.float32) m = np.abs(signals).astype(np.float32) fv = np.fft.fft(signals, axis=1) lf = np.log1p(np.abs(fv)).astype(np.float32); fp = np.angle(fv).astype(np.float32) p = np.zeros((N,2*L), dtype=signals.dtype); p[:,:L] = signals pp = np.abs(np.fft.fft(p, axis=1))**2 ac = np.real(np.fft.ifft(pp, axis=1))[:,:L].astype(np.float32) lc = np.arange(L,0,-1,dtype=np.float32)[None,:]; ac /= lc an = ac[:,0:1]; ac = np.where(an>0, ac/an, ac) pu = np.unwrap(np.angle(signals), axis=1) inf = np.zeros((N,L), dtype=np.float32) inf[:,1:] = np.diff(pu, axis=1).astype(np.float32)/(2*np.pi*config.SAMPLING_INTERVAL) inf[:,0] = inf[:,1]; inf = np.clip(inf,-1000,1000) s = np.std(inf, axis=1, keepdims=True); inf = np.where(s>0, inf/(s+1e-8), inf) return np.stack([r,i,m,lf,fp,ac,inf], axis=1) sb, _ = make_test(-35, 200, 77) t0=time.time() for _ in range(3): fo=extract_features_batch(sb) to=(time.time()-t0)/3 t0=time.time() for _ in range(3): ff=fused(sb) tf=(time.time()-t0)/3 md = float(np.max(np.abs(fo-ff))) print(f" Orig={to:.4f}s Fused={tf:.4f}s Speed={to/max(tf,1e-9):.2f}x Diff={md:.1e}") R["features"] = {"speedup": to/max(tf,1e-9), "max_diff": md} # OPT5: DENOISING print("\n" + "="*70 + "\nOPT5: DENOISING ABLATION\n" + "="*70) from nqr_snn.denoising.selector import DenoisingSelector from nqr_snn.denoising import denoise_batch ov = config.VAL_SIZE; config.VAL_SIZE = 50 dd = generate_dataset("low", "white", "val", seed=42); config.VAL_SIZE = ov sel = DenoisingSelector() dn, dr, do = sel.select(dd["noisy"][:50], dd["clean"][:50]) print(f" Denoiser: {dn} R2={dr:.1f}") trd = NQRDatasetV2(tr["signals"], tr["labels"], denoiser=do) vld = NQRDatasetV2(vl["signals"], vl["labels"], denoiser=do) ed, _ = train_ens(3, get_balanced_loader_v2(trd,64), get_balanced_loader_v2(vld,64,shuffle=False), "outputs/m_den") def eval_d(ens, enc, den, tag=""): tt=0; res={} for snr in config.EVAL_SNR_LEVELS: s,l = make_test(snr, seed=abs(snr)*100) t0=time.time(); sd=denoise_batch(den,s); f=extract_features_batch(sd) x=enc.encode(torch.from_numpy(f)).to(DEVICE); mp,_=ens.predict(x); tt+=time.time()-t0 r=full_report(l,mp.cpu().numpy()); res[snr]=r print(f" {tag} SNR={snr:4d} | acc={r['accuracy']:.4f}") a=[v['accuracy'] for v in res.values()] print(f" {tag} MEAN={np.mean(a):.4f} MIN={np.min(a):.4f} T={tt:.3f}s") return {"mean_acc":float(np.mean(a)),"min_acc":float(np.min(a)),"inference_time":float(tt)} R["denoised"] = eval_d(ed, enc, do, "[DEN]") # OPT6: BS128 print("\n" + "="*70 + "\nOPT6: BS=128\n" + "="*70) e128, tt128 = train_ens(3, tl128, vl128, "outputs/m_128") R["bs128"] = {**evaluate(e128, enc, "[BS128]"), "train_time": tt128} # OPT2: ENS5 print("\n" + "="*70 + "\nOPT2: ENSEMBLE 5\n" + "="*70) e5, tt5 = train_ens(5, tl64, vl64, "outputs/m_ens5") R["ens5"] = {**evaluate(e5, enc, "[ENS5]"), "train_time": tt5} # SUMMARY print("\n" + "="*70) print("FINAL RESULTS") print("="*70) ba = R["baseline"]["mean_acc"]; bt = R["baseline"]["inference_time"] print(f"{'Config':<15} {'MeanAcc':>8} {'MinAcc':>8} {'Time':>7} {'Delta':>7} {'Speed':>7} Verdict") print("-"*70) for n, r in R.items(): if n == "features": continue a,mi,t = r["mean_acc"],r["min_acc"],r["inference_time"] d=a-ba; sp=bt/max(t,1e-9) v = "PASS" if a>=ba-0.005 else "WORSE" if a>ba+0.005: v="BETTER" print(f"{n:<15} {a:>7.4f} {mi:>8.4f} {t:>7.3f} {d:>+6.4f} {sp:>6.2f}x {v}") fe=R["features"] print(f"\nFeatures: {fe['speedup']:.2f}x speedup (diff={fe['max_diff']:.1e})") os.makedirs("outputs/results", exist_ok=True) with open("outputs/results/optimization_results.json","w") as f: json.dump(R, f, indent=2, default=str) print("\nSaved: outputs/results/optimization_results.json") print("\nDONE. Share the output above and I will commit the winning optimizations.")