File size: 8,642 Bytes
57dbe46 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 | #!/usr/bin/env python3
"""
NQR-SNN Optimization Test Suite
================================
Tests each proposed optimization against the baseline and reports
accuracy + inference speed. Run locally:
python optimization_test.py
Expected runtime: ~10-15 min on a modern CPU.
Results saved to outputs/results/optimization_results.json
"""
import os, sys, time, json, types
import numpy as np
import torch
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from nqr_snn import config
from nqr_snn.data.generator import generate_dataset_v2, generate_signal_at_snr, generate_noise_only_at_power, generate_dataset
from nqr_snn.data.dataset import NQRDatasetV2, get_balanced_loader_v2, extract_features_batch
from nqr_snn.snn.model import SpikingClassifier, SPIKINGJELLY_AVAILABLE
from nqr_snn.snn.encoder import DeterministicEncoder
from nqr_snn.snn.ensemble import SNNEnsemble
from nqr_snn.evaluation.metrics import full_report
TRAIN_SIZE, VAL_SIZE, TEST_PER_CLASS, MAX_EPOCHS = 500, 150, 150, 50
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"SpikingJelly={SPIKINGJELLY_AVAILABLE} Device={DEVICE}")
def make_test(snr, n=TEST_PER_CLASS, seed=99):
rng = np.random.RandomState(seed)
s, l = [], []
for _ in range(n):
noisy, _ = generate_signal_at_snr(snr, "white", rng); s.append(noisy); l.append(1)
for _ in range(n):
s.append(generate_noise_only_at_power("white", rng, 1.0)); l.append(0)
return np.array(s, dtype=np.complex128), np.array(l)
def evaluate(ens, enc, tag=""):
res, tt = {}, 0
for snr in config.EVAL_SNR_LEVELS:
s, l = make_test(snr, seed=abs(snr)*100)
f = torch.from_numpy(extract_features_batch(s))
t0 = time.time()
with torch.no_grad():
x = enc.encode(f).to(DEVICE) if isinstance(enc, torch.nn.Module) else enc.encode(f).to(DEVICE)
mp, _ = ens.predict(x); tt += time.time()-t0
r = full_report(l, mp.cpu().numpy()); res[snr] = r
print(f" {tag} SNR={snr:4d} | acc={r['accuracy']:.4f} auc={r['auc']:.4f}")
a = [v['accuracy'] for v in res.values()]
ma, mi = np.mean(a), np.min(a)
print(f" {tag} MEAN={ma:.4f} MIN={mi:.4f} T={tt:.3f}s")
return {"mean_acc": float(ma), "min_acc": float(mi), "inference_time": float(tt)}
def train_ens(sz, tl, vl, mdir):
od = config.MODELS_DIR; config.MODELS_DIR = mdir; os.makedirs(mdir, exist_ok=True)
e = SNNEnsemble(ensemble_size=sz, device=DEVICE, heterogeneous=True)
t0 = time.time(); e.train_all(tl, vl, max_epochs=MAX_EPOCHS); tt = time.time()-t0
e.load_checkpoints(); config.MODELS_DIR = od
return e, tt
# Data
print("\n" + "="*70 + "\nGENERATING DATA\n" + "="*70)
tr = generate_dataset_v2(TRAIN_SIZE, config.TRAIN_SNR_RANGE, "white", seed=42)
vl = generate_dataset_v2(VAL_SIZE, config.VAL_SNR_RANGE, "white", seed=100)
tr_ds = NQRDatasetV2(tr["signals"], tr["labels"])
vl_ds = NQRDatasetV2(vl["signals"], vl["labels"])
tl64 = get_balanced_loader_v2(tr_ds, batch_size=64)
vl64 = get_balanced_loader_v2(vl_ds, batch_size=64, shuffle=False)
tl128 = get_balanced_loader_v2(tr_ds, batch_size=128)
vl128 = get_balanced_loader_v2(vl_ds, batch_size=128, shuffle=False)
R = {}
# BASELINE
print("\n" + "="*70 + "\nBASELINE: 3-member, bs64\n" + "="*70)
eb, ttb = train_ens(3, tl64, vl64, "outputs/m_base")
enc = DeterministicEncoder()
R["baseline"] = {**evaluate(eb, enc, "[BASE]"), "train_time": ttb}
# OPT1: CNN ONCE
print("\n" + "="*70 + "\nOPT1: CNN SINGLE-PASS\n" + "="*70)
if SPIKINGJELLY_AVAILABLE:
from spikingjelly.activation_based import functional
def fwd1(self, x_seq, return_per_timestep=False):
if not self.use_spikingjelly: return self.snn_head(self.cnn(x_seq[0]).unsqueeze(0))
T = x_seq.shape[0]
fs = self.cnn(x_seq[0])
for sl in self._snn_layers: functional.reset_net(sl)
with torch.no_grad():
e = x_seq.pow(2).sum(dim=(-1,-2)); r = (e/(e[0:1]+1e-8)).unsqueeze(-1)
fm = fs.unsqueeze(0).expand(T,-1,-1) * r
fm = self.snn_norm(fm); x = self.fc1(fm); x = self.neuron1(x); x = self.drop1(x)
x = self.fc2(x); x = self.neuron2(x); x = self.drop2(x); x = self.fc_out(x)
if return_per_timestep: return x
return torch.sigmoid(x.mean(dim=0))
for m in eb.models: m.forward = types.MethodType(fwd1, m)
R["cnn_once"] = evaluate(eb, enc, "[CNN1x]")
from nqr_snn.snn.model import SpikingClassifier as SC
for m in eb.models: m.forward = types.MethodType(SC.forward, m)
else:
R["cnn_once"] = R["baseline"]
# OPT3: FUSED FEATURES
print("\n" + "="*70 + "\nOPT3: FUSED FEATURES (speed only)\n" + "="*70)
def fused(signals):
N, L = signals.shape
r = signals.real.astype(np.float32); i = signals.imag.astype(np.float32)
m = np.abs(signals).astype(np.float32)
fv = np.fft.fft(signals, axis=1)
lf = np.log1p(np.abs(fv)).astype(np.float32); fp = np.angle(fv).astype(np.float32)
p = np.zeros((N,2*L), dtype=signals.dtype); p[:,:L] = signals
pp = np.abs(np.fft.fft(p, axis=1))**2
ac = np.real(np.fft.ifft(pp, axis=1))[:,:L].astype(np.float32)
lc = np.arange(L,0,-1,dtype=np.float32)[None,:]; ac /= lc
an = ac[:,0:1]; ac = np.where(an>0, ac/an, ac)
pu = np.unwrap(np.angle(signals), axis=1)
inf = np.zeros((N,L), dtype=np.float32)
inf[:,1:] = np.diff(pu, axis=1).astype(np.float32)/(2*np.pi*config.SAMPLING_INTERVAL)
inf[:,0] = inf[:,1]; inf = np.clip(inf,-1000,1000)
s = np.std(inf, axis=1, keepdims=True); inf = np.where(s>0, inf/(s+1e-8), inf)
return np.stack([r,i,m,lf,fp,ac,inf], axis=1)
sb, _ = make_test(-35, 200, 77)
t0=time.time()
for _ in range(3): fo=extract_features_batch(sb)
to=(time.time()-t0)/3
t0=time.time()
for _ in range(3): ff=fused(sb)
tf=(time.time()-t0)/3
md = float(np.max(np.abs(fo-ff)))
print(f" Orig={to:.4f}s Fused={tf:.4f}s Speed={to/max(tf,1e-9):.2f}x Diff={md:.1e}")
R["features"] = {"speedup": to/max(tf,1e-9), "max_diff": md}
# OPT5: DENOISING
print("\n" + "="*70 + "\nOPT5: DENOISING ABLATION\n" + "="*70)
from nqr_snn.denoising.selector import DenoisingSelector
from nqr_snn.denoising import denoise_batch
ov = config.VAL_SIZE; config.VAL_SIZE = 50
dd = generate_dataset("low", "white", "val", seed=42); config.VAL_SIZE = ov
sel = DenoisingSelector()
dn, dr, do = sel.select(dd["noisy"][:50], dd["clean"][:50])
print(f" Denoiser: {dn} R2={dr:.1f}")
trd = NQRDatasetV2(tr["signals"], tr["labels"], denoiser=do)
vld = NQRDatasetV2(vl["signals"], vl["labels"], denoiser=do)
ed, _ = train_ens(3, get_balanced_loader_v2(trd,64), get_balanced_loader_v2(vld,64,shuffle=False), "outputs/m_den")
def eval_d(ens, enc, den, tag=""):
tt=0; res={}
for snr in config.EVAL_SNR_LEVELS:
s,l = make_test(snr, seed=abs(snr)*100)
t0=time.time(); sd=denoise_batch(den,s); f=extract_features_batch(sd)
x=enc.encode(torch.from_numpy(f)).to(DEVICE); mp,_=ens.predict(x); tt+=time.time()-t0
r=full_report(l,mp.cpu().numpy()); res[snr]=r
print(f" {tag} SNR={snr:4d} | acc={r['accuracy']:.4f}")
a=[v['accuracy'] for v in res.values()]
print(f" {tag} MEAN={np.mean(a):.4f} MIN={np.min(a):.4f} T={tt:.3f}s")
return {"mean_acc":float(np.mean(a)),"min_acc":float(np.min(a)),"inference_time":float(tt)}
R["denoised"] = eval_d(ed, enc, do, "[DEN]")
# OPT6: BS128
print("\n" + "="*70 + "\nOPT6: BS=128\n" + "="*70)
e128, tt128 = train_ens(3, tl128, vl128, "outputs/m_128")
R["bs128"] = {**evaluate(e128, enc, "[BS128]"), "train_time": tt128}
# OPT2: ENS5
print("\n" + "="*70 + "\nOPT2: ENSEMBLE 5\n" + "="*70)
e5, tt5 = train_ens(5, tl64, vl64, "outputs/m_ens5")
R["ens5"] = {**evaluate(e5, enc, "[ENS5]"), "train_time": tt5}
# SUMMARY
print("\n" + "="*70)
print("FINAL RESULTS")
print("="*70)
ba = R["baseline"]["mean_acc"]; bt = R["baseline"]["inference_time"]
print(f"{'Config':<15} {'MeanAcc':>8} {'MinAcc':>8} {'Time':>7} {'Delta':>7} {'Speed':>7} Verdict")
print("-"*70)
for n, r in R.items():
if n == "features": continue
a,mi,t = r["mean_acc"],r["min_acc"],r["inference_time"]
d=a-ba; sp=bt/max(t,1e-9)
v = "PASS" if a>=ba-0.005 else "WORSE"
if a>ba+0.005: v="BETTER"
print(f"{n:<15} {a:>7.4f} {mi:>8.4f} {t:>7.3f} {d:>+6.4f} {sp:>6.2f}x {v}")
fe=R["features"]
print(f"\nFeatures: {fe['speedup']:.2f}x speedup (diff={fe['max_diff']:.1e})")
os.makedirs("outputs/results", exist_ok=True)
with open("outputs/results/optimization_results.json","w") as f:
json.dump(R, f, indent=2, default=str)
print("\nSaved: outputs/results/optimization_results.json")
print("\nDONE. Share the output above and I will commit the winning optimizations.")
|