nqr-snn-framework / optimization_test.py
KD099's picture
Add optimization_test.py — run locally to validate all optimizations before committing
57dbe46 verified
#!/usr/bin/env python3
"""
NQR-SNN Optimization Test Suite
================================
Tests each proposed optimization against the baseline and reports
accuracy + inference speed. Run locally:
python optimization_test.py
Expected runtime: ~10-15 min on a modern CPU.
Results saved to outputs/results/optimization_results.json
"""
import os, sys, time, json, types
import numpy as np
import torch
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from nqr_snn import config
from nqr_snn.data.generator import generate_dataset_v2, generate_signal_at_snr, generate_noise_only_at_power, generate_dataset
from nqr_snn.data.dataset import NQRDatasetV2, get_balanced_loader_v2, extract_features_batch
from nqr_snn.snn.model import SpikingClassifier, SPIKINGJELLY_AVAILABLE
from nqr_snn.snn.encoder import DeterministicEncoder
from nqr_snn.snn.ensemble import SNNEnsemble
from nqr_snn.evaluation.metrics import full_report
TRAIN_SIZE, VAL_SIZE, TEST_PER_CLASS, MAX_EPOCHS = 500, 150, 150, 50
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"SpikingJelly={SPIKINGJELLY_AVAILABLE} Device={DEVICE}")
def make_test(snr, n=TEST_PER_CLASS, seed=99):
rng = np.random.RandomState(seed)
s, l = [], []
for _ in range(n):
noisy, _ = generate_signal_at_snr(snr, "white", rng); s.append(noisy); l.append(1)
for _ in range(n):
s.append(generate_noise_only_at_power("white", rng, 1.0)); l.append(0)
return np.array(s, dtype=np.complex128), np.array(l)
def evaluate(ens, enc, tag=""):
res, tt = {}, 0
for snr in config.EVAL_SNR_LEVELS:
s, l = make_test(snr, seed=abs(snr)*100)
f = torch.from_numpy(extract_features_batch(s))
t0 = time.time()
with torch.no_grad():
x = enc.encode(f).to(DEVICE) if isinstance(enc, torch.nn.Module) else enc.encode(f).to(DEVICE)
mp, _ = ens.predict(x); tt += time.time()-t0
r = full_report(l, mp.cpu().numpy()); res[snr] = r
print(f" {tag} SNR={snr:4d} | acc={r['accuracy']:.4f} auc={r['auc']:.4f}")
a = [v['accuracy'] for v in res.values()]
ma, mi = np.mean(a), np.min(a)
print(f" {tag} MEAN={ma:.4f} MIN={mi:.4f} T={tt:.3f}s")
return {"mean_acc": float(ma), "min_acc": float(mi), "inference_time": float(tt)}
def train_ens(sz, tl, vl, mdir):
od = config.MODELS_DIR; config.MODELS_DIR = mdir; os.makedirs(mdir, exist_ok=True)
e = SNNEnsemble(ensemble_size=sz, device=DEVICE, heterogeneous=True)
t0 = time.time(); e.train_all(tl, vl, max_epochs=MAX_EPOCHS); tt = time.time()-t0
e.load_checkpoints(); config.MODELS_DIR = od
return e, tt
# Data
print("\n" + "="*70 + "\nGENERATING DATA\n" + "="*70)
tr = generate_dataset_v2(TRAIN_SIZE, config.TRAIN_SNR_RANGE, "white", seed=42)
vl = generate_dataset_v2(VAL_SIZE, config.VAL_SNR_RANGE, "white", seed=100)
tr_ds = NQRDatasetV2(tr["signals"], tr["labels"])
vl_ds = NQRDatasetV2(vl["signals"], vl["labels"])
tl64 = get_balanced_loader_v2(tr_ds, batch_size=64)
vl64 = get_balanced_loader_v2(vl_ds, batch_size=64, shuffle=False)
tl128 = get_balanced_loader_v2(tr_ds, batch_size=128)
vl128 = get_balanced_loader_v2(vl_ds, batch_size=128, shuffle=False)
R = {}
# BASELINE
print("\n" + "="*70 + "\nBASELINE: 3-member, bs64\n" + "="*70)
eb, ttb = train_ens(3, tl64, vl64, "outputs/m_base")
enc = DeterministicEncoder()
R["baseline"] = {**evaluate(eb, enc, "[BASE]"), "train_time": ttb}
# OPT1: CNN ONCE
print("\n" + "="*70 + "\nOPT1: CNN SINGLE-PASS\n" + "="*70)
if SPIKINGJELLY_AVAILABLE:
from spikingjelly.activation_based import functional
def fwd1(self, x_seq, return_per_timestep=False):
if not self.use_spikingjelly: return self.snn_head(self.cnn(x_seq[0]).unsqueeze(0))
T = x_seq.shape[0]
fs = self.cnn(x_seq[0])
for sl in self._snn_layers: functional.reset_net(sl)
with torch.no_grad():
e = x_seq.pow(2).sum(dim=(-1,-2)); r = (e/(e[0:1]+1e-8)).unsqueeze(-1)
fm = fs.unsqueeze(0).expand(T,-1,-1) * r
fm = self.snn_norm(fm); x = self.fc1(fm); x = self.neuron1(x); x = self.drop1(x)
x = self.fc2(x); x = self.neuron2(x); x = self.drop2(x); x = self.fc_out(x)
if return_per_timestep: return x
return torch.sigmoid(x.mean(dim=0))
for m in eb.models: m.forward = types.MethodType(fwd1, m)
R["cnn_once"] = evaluate(eb, enc, "[CNN1x]")
from nqr_snn.snn.model import SpikingClassifier as SC
for m in eb.models: m.forward = types.MethodType(SC.forward, m)
else:
R["cnn_once"] = R["baseline"]
# OPT3: FUSED FEATURES
print("\n" + "="*70 + "\nOPT3: FUSED FEATURES (speed only)\n" + "="*70)
def fused(signals):
N, L = signals.shape
r = signals.real.astype(np.float32); i = signals.imag.astype(np.float32)
m = np.abs(signals).astype(np.float32)
fv = np.fft.fft(signals, axis=1)
lf = np.log1p(np.abs(fv)).astype(np.float32); fp = np.angle(fv).astype(np.float32)
p = np.zeros((N,2*L), dtype=signals.dtype); p[:,:L] = signals
pp = np.abs(np.fft.fft(p, axis=1))**2
ac = np.real(np.fft.ifft(pp, axis=1))[:,:L].astype(np.float32)
lc = np.arange(L,0,-1,dtype=np.float32)[None,:]; ac /= lc
an = ac[:,0:1]; ac = np.where(an>0, ac/an, ac)
pu = np.unwrap(np.angle(signals), axis=1)
inf = np.zeros((N,L), dtype=np.float32)
inf[:,1:] = np.diff(pu, axis=1).astype(np.float32)/(2*np.pi*config.SAMPLING_INTERVAL)
inf[:,0] = inf[:,1]; inf = np.clip(inf,-1000,1000)
s = np.std(inf, axis=1, keepdims=True); inf = np.where(s>0, inf/(s+1e-8), inf)
return np.stack([r,i,m,lf,fp,ac,inf], axis=1)
sb, _ = make_test(-35, 200, 77)
t0=time.time()
for _ in range(3): fo=extract_features_batch(sb)
to=(time.time()-t0)/3
t0=time.time()
for _ in range(3): ff=fused(sb)
tf=(time.time()-t0)/3
md = float(np.max(np.abs(fo-ff)))
print(f" Orig={to:.4f}s Fused={tf:.4f}s Speed={to/max(tf,1e-9):.2f}x Diff={md:.1e}")
R["features"] = {"speedup": to/max(tf,1e-9), "max_diff": md}
# OPT5: DENOISING
print("\n" + "="*70 + "\nOPT5: DENOISING ABLATION\n" + "="*70)
from nqr_snn.denoising.selector import DenoisingSelector
from nqr_snn.denoising import denoise_batch
ov = config.VAL_SIZE; config.VAL_SIZE = 50
dd = generate_dataset("low", "white", "val", seed=42); config.VAL_SIZE = ov
sel = DenoisingSelector()
dn, dr, do = sel.select(dd["noisy"][:50], dd["clean"][:50])
print(f" Denoiser: {dn} R2={dr:.1f}")
trd = NQRDatasetV2(tr["signals"], tr["labels"], denoiser=do)
vld = NQRDatasetV2(vl["signals"], vl["labels"], denoiser=do)
ed, _ = train_ens(3, get_balanced_loader_v2(trd,64), get_balanced_loader_v2(vld,64,shuffle=False), "outputs/m_den")
def eval_d(ens, enc, den, tag=""):
tt=0; res={}
for snr in config.EVAL_SNR_LEVELS:
s,l = make_test(snr, seed=abs(snr)*100)
t0=time.time(); sd=denoise_batch(den,s); f=extract_features_batch(sd)
x=enc.encode(torch.from_numpy(f)).to(DEVICE); mp,_=ens.predict(x); tt+=time.time()-t0
r=full_report(l,mp.cpu().numpy()); res[snr]=r
print(f" {tag} SNR={snr:4d} | acc={r['accuracy']:.4f}")
a=[v['accuracy'] for v in res.values()]
print(f" {tag} MEAN={np.mean(a):.4f} MIN={np.min(a):.4f} T={tt:.3f}s")
return {"mean_acc":float(np.mean(a)),"min_acc":float(np.min(a)),"inference_time":float(tt)}
R["denoised"] = eval_d(ed, enc, do, "[DEN]")
# OPT6: BS128
print("\n" + "="*70 + "\nOPT6: BS=128\n" + "="*70)
e128, tt128 = train_ens(3, tl128, vl128, "outputs/m_128")
R["bs128"] = {**evaluate(e128, enc, "[BS128]"), "train_time": tt128}
# OPT2: ENS5
print("\n" + "="*70 + "\nOPT2: ENSEMBLE 5\n" + "="*70)
e5, tt5 = train_ens(5, tl64, vl64, "outputs/m_ens5")
R["ens5"] = {**evaluate(e5, enc, "[ENS5]"), "train_time": tt5}
# SUMMARY
print("\n" + "="*70)
print("FINAL RESULTS")
print("="*70)
ba = R["baseline"]["mean_acc"]; bt = R["baseline"]["inference_time"]
print(f"{'Config':<15} {'MeanAcc':>8} {'MinAcc':>8} {'Time':>7} {'Delta':>7} {'Speed':>7} Verdict")
print("-"*70)
for n, r in R.items():
if n == "features": continue
a,mi,t = r["mean_acc"],r["min_acc"],r["inference_time"]
d=a-ba; sp=bt/max(t,1e-9)
v = "PASS" if a>=ba-0.005 else "WORSE"
if a>ba+0.005: v="BETTER"
print(f"{n:<15} {a:>7.4f} {mi:>8.4f} {t:>7.3f} {d:>+6.4f} {sp:>6.2f}x {v}")
fe=R["features"]
print(f"\nFeatures: {fe['speedup']:.2f}x speedup (diff={fe['max_diff']:.1e})")
os.makedirs("outputs/results", exist_ok=True)
with open("outputs/results/optimization_results.json","w") as f:
json.dump(R, f, indent=2, default=str)
print("\nSaved: outputs/results/optimization_results.json")
print("\nDONE. Share the output above and I will commit the winning optimizations.")