| |
| """ |
| NQR-SNN Optimization Test Suite |
| ================================ |
| Tests each proposed optimization against the baseline and reports |
| accuracy + inference speed. Run locally: |
| |
| python optimization_test.py |
| |
| Expected runtime: ~10-15 min on a modern CPU. |
| Results saved to outputs/results/optimization_results.json |
| """ |
| import os, sys, time, json, types |
| import numpy as np |
| import torch |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
| from nqr_snn import config |
| from nqr_snn.data.generator import generate_dataset_v2, generate_signal_at_snr, generate_noise_only_at_power, generate_dataset |
| from nqr_snn.data.dataset import NQRDatasetV2, get_balanced_loader_v2, extract_features_batch |
| from nqr_snn.snn.model import SpikingClassifier, SPIKINGJELLY_AVAILABLE |
| from nqr_snn.snn.encoder import DeterministicEncoder |
| from nqr_snn.snn.ensemble import SNNEnsemble |
| from nqr_snn.evaluation.metrics import full_report |
|
|
| TRAIN_SIZE, VAL_SIZE, TEST_PER_CLASS, MAX_EPOCHS = 500, 150, 150, 50 |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"SpikingJelly={SPIKINGJELLY_AVAILABLE} Device={DEVICE}") |
|
|
| def make_test(snr, n=TEST_PER_CLASS, seed=99): |
| rng = np.random.RandomState(seed) |
| s, l = [], [] |
| for _ in range(n): |
| noisy, _ = generate_signal_at_snr(snr, "white", rng); s.append(noisy); l.append(1) |
| for _ in range(n): |
| s.append(generate_noise_only_at_power("white", rng, 1.0)); l.append(0) |
| return np.array(s, dtype=np.complex128), np.array(l) |
|
|
| def evaluate(ens, enc, tag=""): |
| res, tt = {}, 0 |
| for snr in config.EVAL_SNR_LEVELS: |
| s, l = make_test(snr, seed=abs(snr)*100) |
| f = torch.from_numpy(extract_features_batch(s)) |
| t0 = time.time() |
| with torch.no_grad(): |
| x = enc.encode(f).to(DEVICE) if isinstance(enc, torch.nn.Module) else enc.encode(f).to(DEVICE) |
| mp, _ = ens.predict(x); tt += time.time()-t0 |
| r = full_report(l, mp.cpu().numpy()); res[snr] = r |
| print(f" {tag} SNR={snr:4d} | acc={r['accuracy']:.4f} auc={r['auc']:.4f}") |
| a = [v['accuracy'] for v in res.values()] |
| ma, mi = np.mean(a), np.min(a) |
| print(f" {tag} MEAN={ma:.4f} MIN={mi:.4f} T={tt:.3f}s") |
| return {"mean_acc": float(ma), "min_acc": float(mi), "inference_time": float(tt)} |
|
|
| def train_ens(sz, tl, vl, mdir): |
| od = config.MODELS_DIR; config.MODELS_DIR = mdir; os.makedirs(mdir, exist_ok=True) |
| e = SNNEnsemble(ensemble_size=sz, device=DEVICE, heterogeneous=True) |
| t0 = time.time(); e.train_all(tl, vl, max_epochs=MAX_EPOCHS); tt = time.time()-t0 |
| e.load_checkpoints(); config.MODELS_DIR = od |
| return e, tt |
|
|
| |
| print("\n" + "="*70 + "\nGENERATING DATA\n" + "="*70) |
| tr = generate_dataset_v2(TRAIN_SIZE, config.TRAIN_SNR_RANGE, "white", seed=42) |
| vl = generate_dataset_v2(VAL_SIZE, config.VAL_SNR_RANGE, "white", seed=100) |
| tr_ds = NQRDatasetV2(tr["signals"], tr["labels"]) |
| vl_ds = NQRDatasetV2(vl["signals"], vl["labels"]) |
| tl64 = get_balanced_loader_v2(tr_ds, batch_size=64) |
| vl64 = get_balanced_loader_v2(vl_ds, batch_size=64, shuffle=False) |
| tl128 = get_balanced_loader_v2(tr_ds, batch_size=128) |
| vl128 = get_balanced_loader_v2(vl_ds, batch_size=128, shuffle=False) |
| R = {} |
|
|
| |
| print("\n" + "="*70 + "\nBASELINE: 3-member, bs64\n" + "="*70) |
| eb, ttb = train_ens(3, tl64, vl64, "outputs/m_base") |
| enc = DeterministicEncoder() |
| R["baseline"] = {**evaluate(eb, enc, "[BASE]"), "train_time": ttb} |
|
|
| |
| print("\n" + "="*70 + "\nOPT1: CNN SINGLE-PASS\n" + "="*70) |
| if SPIKINGJELLY_AVAILABLE: |
| from spikingjelly.activation_based import functional |
| def fwd1(self, x_seq, return_per_timestep=False): |
| if not self.use_spikingjelly: return self.snn_head(self.cnn(x_seq[0]).unsqueeze(0)) |
| T = x_seq.shape[0] |
| fs = self.cnn(x_seq[0]) |
| for sl in self._snn_layers: functional.reset_net(sl) |
| with torch.no_grad(): |
| e = x_seq.pow(2).sum(dim=(-1,-2)); r = (e/(e[0:1]+1e-8)).unsqueeze(-1) |
| fm = fs.unsqueeze(0).expand(T,-1,-1) * r |
| fm = self.snn_norm(fm); x = self.fc1(fm); x = self.neuron1(x); x = self.drop1(x) |
| x = self.fc2(x); x = self.neuron2(x); x = self.drop2(x); x = self.fc_out(x) |
| if return_per_timestep: return x |
| return torch.sigmoid(x.mean(dim=0)) |
| for m in eb.models: m.forward = types.MethodType(fwd1, m) |
| R["cnn_once"] = evaluate(eb, enc, "[CNN1x]") |
| from nqr_snn.snn.model import SpikingClassifier as SC |
| for m in eb.models: m.forward = types.MethodType(SC.forward, m) |
| else: |
| R["cnn_once"] = R["baseline"] |
|
|
| |
| print("\n" + "="*70 + "\nOPT3: FUSED FEATURES (speed only)\n" + "="*70) |
| def fused(signals): |
| N, L = signals.shape |
| r = signals.real.astype(np.float32); i = signals.imag.astype(np.float32) |
| m = np.abs(signals).astype(np.float32) |
| fv = np.fft.fft(signals, axis=1) |
| lf = np.log1p(np.abs(fv)).astype(np.float32); fp = np.angle(fv).astype(np.float32) |
| p = np.zeros((N,2*L), dtype=signals.dtype); p[:,:L] = signals |
| pp = np.abs(np.fft.fft(p, axis=1))**2 |
| ac = np.real(np.fft.ifft(pp, axis=1))[:,:L].astype(np.float32) |
| lc = np.arange(L,0,-1,dtype=np.float32)[None,:]; ac /= lc |
| an = ac[:,0:1]; ac = np.where(an>0, ac/an, ac) |
| pu = np.unwrap(np.angle(signals), axis=1) |
| inf = np.zeros((N,L), dtype=np.float32) |
| inf[:,1:] = np.diff(pu, axis=1).astype(np.float32)/(2*np.pi*config.SAMPLING_INTERVAL) |
| inf[:,0] = inf[:,1]; inf = np.clip(inf,-1000,1000) |
| s = np.std(inf, axis=1, keepdims=True); inf = np.where(s>0, inf/(s+1e-8), inf) |
| return np.stack([r,i,m,lf,fp,ac,inf], axis=1) |
| sb, _ = make_test(-35, 200, 77) |
| t0=time.time() |
| for _ in range(3): fo=extract_features_batch(sb) |
| to=(time.time()-t0)/3 |
| t0=time.time() |
| for _ in range(3): ff=fused(sb) |
| tf=(time.time()-t0)/3 |
| md = float(np.max(np.abs(fo-ff))) |
| print(f" Orig={to:.4f}s Fused={tf:.4f}s Speed={to/max(tf,1e-9):.2f}x Diff={md:.1e}") |
| R["features"] = {"speedup": to/max(tf,1e-9), "max_diff": md} |
|
|
| |
| print("\n" + "="*70 + "\nOPT5: DENOISING ABLATION\n" + "="*70) |
| from nqr_snn.denoising.selector import DenoisingSelector |
| from nqr_snn.denoising import denoise_batch |
| ov = config.VAL_SIZE; config.VAL_SIZE = 50 |
| dd = generate_dataset("low", "white", "val", seed=42); config.VAL_SIZE = ov |
| sel = DenoisingSelector() |
| dn, dr, do = sel.select(dd["noisy"][:50], dd["clean"][:50]) |
| print(f" Denoiser: {dn} R2={dr:.1f}") |
| trd = NQRDatasetV2(tr["signals"], tr["labels"], denoiser=do) |
| vld = NQRDatasetV2(vl["signals"], vl["labels"], denoiser=do) |
| ed, _ = train_ens(3, get_balanced_loader_v2(trd,64), get_balanced_loader_v2(vld,64,shuffle=False), "outputs/m_den") |
| def eval_d(ens, enc, den, tag=""): |
| tt=0; res={} |
| for snr in config.EVAL_SNR_LEVELS: |
| s,l = make_test(snr, seed=abs(snr)*100) |
| t0=time.time(); sd=denoise_batch(den,s); f=extract_features_batch(sd) |
| x=enc.encode(torch.from_numpy(f)).to(DEVICE); mp,_=ens.predict(x); tt+=time.time()-t0 |
| r=full_report(l,mp.cpu().numpy()); res[snr]=r |
| print(f" {tag} SNR={snr:4d} | acc={r['accuracy']:.4f}") |
| a=[v['accuracy'] for v in res.values()] |
| print(f" {tag} MEAN={np.mean(a):.4f} MIN={np.min(a):.4f} T={tt:.3f}s") |
| return {"mean_acc":float(np.mean(a)),"min_acc":float(np.min(a)),"inference_time":float(tt)} |
| R["denoised"] = eval_d(ed, enc, do, "[DEN]") |
|
|
| |
| print("\n" + "="*70 + "\nOPT6: BS=128\n" + "="*70) |
| e128, tt128 = train_ens(3, tl128, vl128, "outputs/m_128") |
| R["bs128"] = {**evaluate(e128, enc, "[BS128]"), "train_time": tt128} |
|
|
| |
| print("\n" + "="*70 + "\nOPT2: ENSEMBLE 5\n" + "="*70) |
| e5, tt5 = train_ens(5, tl64, vl64, "outputs/m_ens5") |
| R["ens5"] = {**evaluate(e5, enc, "[ENS5]"), "train_time": tt5} |
|
|
| |
| print("\n" + "="*70) |
| print("FINAL RESULTS") |
| print("="*70) |
| ba = R["baseline"]["mean_acc"]; bt = R["baseline"]["inference_time"] |
| print(f"{'Config':<15} {'MeanAcc':>8} {'MinAcc':>8} {'Time':>7} {'Delta':>7} {'Speed':>7} Verdict") |
| print("-"*70) |
| for n, r in R.items(): |
| if n == "features": continue |
| a,mi,t = r["mean_acc"],r["min_acc"],r["inference_time"] |
| d=a-ba; sp=bt/max(t,1e-9) |
| v = "PASS" if a>=ba-0.005 else "WORSE" |
| if a>ba+0.005: v="BETTER" |
| print(f"{n:<15} {a:>7.4f} {mi:>8.4f} {t:>7.3f} {d:>+6.4f} {sp:>6.2f}x {v}") |
| fe=R["features"] |
| print(f"\nFeatures: {fe['speedup']:.2f}x speedup (diff={fe['max_diff']:.1e})") |
|
|
| os.makedirs("outputs/results", exist_ok=True) |
| with open("outputs/results/optimization_results.json","w") as f: |
| json.dump(R, f, indent=2, default=str) |
| print("\nSaved: outputs/results/optimization_results.json") |
| print("\nDONE. Share the output above and I will commit the winning optimizations.") |
|
|