KD099 commited on
Commit
3faeada
Β·
verified Β·
1 Parent(s): 96f5e21

v3.2: wire neuron search results into ENSEMBLE_CONFIGS, use trained encoder at eval, consolidate denoise_batch

Browse files
Files changed (1) hide show
  1. run_full_pipeline.py +59 -36
run_full_pipeline.py CHANGED
@@ -1,24 +1,25 @@
1
  #!/usr/bin/env python3
2
  """
3
- NQR-SNN Framework v3.1 β€” Full Pipeline
4
  =======================================
5
  Complete end-to-end pipeline with all 6 stages:
6
 
7
  1. Denoiser Selection β€” LM vs SSA vs Wavelet (gate on low-SNR white)
8
- 2. Parallel Neuron Search β€” 5 neuron Γ— 3 surrogate = 15 configs via ProcessPoolExecutor
9
  3. Dataset Generation β€” SNR-controlled signals, denoised, 7-channel feature extraction
10
  4. Ensemble Training β€” N-member heterogeneous CNN+SNN ensemble with TET loss
11
  5. SNR-Level Evaluation β€” Accuracy / AUC / F1 at each dB level with denoising
12
  6. Noise Injection Stress β€” Gaussian / S&P / RFI / weight / OOD frequency shift
13
 
14
- v3.1 fixes: dead neuron bug (threshold+LayerNorm), differentiable encoder, dropout.
 
 
 
15
 
16
  Usage:
17
- python run_full_pipeline.py --quick # fast demo (~90s CPU)
18
- python run_full_pipeline.py # full run
19
- python run_full_pipeline.py --skip_denoise # skip denoiser stage
20
- python run_full_pipeline.py --skip_neuron_search # skip neuron search, use LIF+ATan
21
- python run_full_pipeline.py --search_workers 8 # force 8 parallel workers
22
  """
23
 
24
  import os
@@ -50,23 +51,17 @@ from nqr_snn.evaluation.plots import (
50
  plot_noise_degradation, plot_ood_uncertainty,
51
  )
52
  from nqr_snn.denoising.selector import DenoisingSelector
 
53
  from nqr_snn.noise_injection.stress_test import run_stress_test
54
 
55
 
56
  # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
57
  # Helper: apply denoiser to a batch of complex signals
 
58
  # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
59
  def denoise_batch(signals: np.ndarray, denoiser) -> np.ndarray:
60
  """Apply denoiser to every signal in a (N, 1024) complex128 array."""
61
- out = np.zeros_like(signals)
62
- for i in range(len(signals)):
63
- if hasattr(denoiser, 'fit_one'): # LMDenoiser
64
- out[i] = denoiser.fit_one(signals[i])
65
- elif hasattr(denoiser, 'denoise'): # SSA / Wavelet
66
- out[i] = denoiser.denoise(signals[i])
67
- else:
68
- out[i] = signals[i]
69
- return out
70
 
71
 
72
  # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
@@ -89,10 +84,9 @@ def make_snr_test_set(target_snr_db, n_per_class=200, seed=99):
89
  # Helper: signals β†’ features tensor (optionally denoise first)
90
  # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
91
  def signals_to_features(signals, denoiser=None):
92
- """(N,1024) complex128 β†’ (N,7,1024) float32 tensor β€” v3.0 vectorized."""
93
  if denoiser is not None:
94
  signals = denoise_batch(signals, denoiser)
95
- # v3.0: Use vectorized batch extraction
96
  feats = extract_features_batch(signals)
97
  return torch.from_numpy(feats)
98
 
@@ -102,7 +96,7 @@ def signals_to_features(signals, denoiser=None):
102
  # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
103
  def main():
104
  parser = argparse.ArgumentParser(
105
- description="NQR-SNN v2.0 β€” full pipeline: denoise β†’ neuron search β†’ train β†’ evaluate β†’ stress test"
106
  )
107
  parser.add_argument("--quick", action="store_true",
108
  help="Quick demo mode (~90 s on CPU)")
@@ -119,7 +113,7 @@ def main():
119
  parser.add_argument("--search_epochs", type=int, default=None)
120
  args = parser.parse_args()
121
 
122
- # ── Resolve sizes depending on mode ──
123
  quick = args.quick
124
  train_size = args.train_size or (500 if quick else 3000)
125
  val_size = args.val_size or (150 if quick else 1000)
@@ -133,15 +127,14 @@ def main():
133
  device = "cuda" if torch.cuda.is_available() else "cpu"
134
 
135
  print("=" * 80)
136
- print("NQR-SNN FRAMEWORK v3.1 β€” FULL PIPELINE (SNN BUG FIXED)")
137
  print("=" * 80)
138
- print(f" v3.1 fixes: dead neuron (threshold+LayerNorm), differentiable encoder, dropout")
139
  print(f" Device : {device}")
140
  print(f" Mode : {'QUICK' if quick else 'FULL'}")
141
  print(f" Train / Val : {train_size} / {val_size} per class")
142
  print(f" Ensemble members : {ensemble_size}")
143
  print(f" Max epochs : {max_epochs}")
144
- print(f" Search epochs : {search_epochs} ({len(config.NEURON_MODELS)}Γ—{len(config.SURROGATE_FUNCTIONS)} configs)")
145
  print(f" Test per SNR : {n_test}")
146
  print(f" Steps : denoise={'ON' if not args.skip_denoise else 'OFF'} | "
147
  f"search={'ON' if not args.skip_neuron_search else 'OFF'} | "
@@ -159,7 +152,7 @@ def main():
159
 
160
  if not args.skip_denoise:
161
  t1 = time.time()
162
- print(f"\n[1/6] DENOISER SELECTION (LM β†’ SSA β†’ Wavelet on {denoise_samples} low-SNR white samples)")
163
 
164
  old_vs = config.VAL_SIZE
165
  config.VAL_SIZE = denoise_samples
@@ -172,7 +165,7 @@ def main():
172
  den_data["clean"][:denoise_samples],
173
  )
174
  step_times["denoiser"] = time.time() - t1
175
- print(f" βœ“ Selected: {denoiser_name} (RΒ²={denoiser_r2:.1f}) [{step_times['denoiser']:.1f}s]")
176
  else:
177
  print("\n[1/6] DENOISER SELECTION β€” skipped")
178
 
@@ -210,12 +203,40 @@ def main():
210
 
211
  step_times["search"] = time.time() - t2
212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  # Heatmap
214
  csv = os.path.join(config.RESULTS_DIR, "neuron_search.csv")
215
  if os.path.exists(csv):
216
  plot_neuron_search_heatmap(csv)
217
 
218
- print(f" βœ“ Winner: {neuron_model} + {surrogate_fn} "
219
  f"(val_acc={search_result['val_accuracy']:.4f}) "
220
  f"[wall {step_times['search']:.1f}s]")
221
  else:
@@ -240,7 +261,7 @@ def main():
240
  val_loader = get_balanced_loader_v2(vl_ds, batch_size=config.BATCH_SIZE, shuffle=False)
241
 
242
  step_times["datagen"] = time.time() - t3
243
- print(f" βœ“ {len(tr_ds)} train, {len(vl_ds)} val "
244
  f"SNR=[{tr['snr_dbs'].min():.0f},{tr['snr_dbs'].max():.0f}]dB "
245
  f"[{step_times['datagen']:.1f}s]")
246
 
@@ -248,25 +269,26 @@ def main():
248
  # STEP 4 Ensemble Training
249
  # ──────────────────────────────────────────────────────────────────
250
  t4 = time.time()
251
- print(f"\n[4/6] ENSEMBLE TRAINING ({ensemble_size} members Γ— {max_epochs} epochs, "
252
  f"neuron={neuron_model}+{surrogate_fn})")
253
 
254
  ensemble = SNNEnsemble(
255
  neuron_model=neuron_model, surrogate_fn=surrogate_fn,
256
  ensemble_size=ensemble_size, device=device,
257
- heterogeneous=True, # v3.0: diverse neuron/surrogate/tau
258
  )
259
  ensemble.train_all(train_loader, val_loader, max_epochs=max_epochs)
260
  ensemble.load_checkpoints()
261
 
262
  step_times["train"] = time.time() - t4
263
- print(f" βœ“ Training done [{step_times['train']:.1f}s]")
264
 
265
  # ──────────────────────────────────────────────────────────────────
266
  # STEP 5 SNR-Level Evaluation
267
  # ──────────────────────────────────────────────────────────────────
268
  t5 = time.time()
269
- encoder = DeterministicEncoder()
 
270
  snr_rows = []
271
 
272
  print(f"\n[5/6] SNR EVALUATION ({n_test}/class per level, "
@@ -285,7 +307,7 @@ def main():
285
  rep["ensemble_std"] = float(std_p.cpu().mean())
286
  snr_rows.append(rep)
287
 
288
- print(f" SNR={snr_db:4d} dB β”‚ acc={rep['accuracy']:.4f} "
289
  f"auc={rep['auc']:.4f} f1={rep['f1']:.4f} "
290
  f"tpr={rep['tpr']:.4f} tnr={rep['tnr']:.4f}")
291
 
@@ -322,7 +344,7 @@ def main():
322
  os.path.join(config.PLOTS_DIR, "ood_uncertainty.png"))
323
 
324
  step_times["stress"] = time.time() - t6
325
- print(f" βœ“ Stress test done [{step_times['stress']:.1f}s]")
326
  else:
327
  print("\n[6/6] STRESS TEST β€” skipped")
328
 
@@ -331,7 +353,7 @@ def main():
331
  # ──────────────────────────────────────────────────────────────────
332
  total = time.time() - t0
333
  print("\n" + "=" * 80)
334
- print("PIPELINE COMPLETE β€” RESULTS SUMMARY (v3.1 FIXED)")
335
  print("=" * 80)
336
  print(f" Neuron / surrogate : {neuron_model} + {surrogate_fn}")
337
  print(f" Denoiser : {denoiser_name}")
@@ -348,7 +370,8 @@ def main():
348
  row = snr_df[snr_df["snr_db"] == db]
349
  if len(row):
350
  a = row["accuracy"].values[0]
351
- print(f" {'βœ…' if a >= tgt else '⚠️ '} SNR={db} dB : {a:.4f} (target β‰₯{tgt})")
 
352
 
353
  print(f"\n Step timings:")
354
  for k, v in step_times.items():
 
1
  #!/usr/bin/env python3
2
  """
3
+ NQR-SNN Framework v3.2 β€” Full Pipeline
4
  =======================================
5
  Complete end-to-end pipeline with all 6 stages:
6
 
7
  1. Denoiser Selection β€” LM vs SSA vs Wavelet (gate on low-SNR white)
8
+ 2. Parallel Neuron Search β€” 5 neuron x 3 surrogate = 15 configs via ProcessPoolExecutor
9
  3. Dataset Generation β€” SNR-controlled signals, denoised, 7-channel feature extraction
10
  4. Ensemble Training β€” N-member heterogeneous CNN+SNN ensemble with TET loss
11
  5. SNR-Level Evaluation β€” Accuracy / AUC / F1 at each dB level with denoising
12
  6. Noise Injection Stress β€” Gaussian / S&P / RFI / weight / OOD frequency shift
13
 
14
+ v3.2 fixes:
15
+ - Neuron search results now wire into ENSEMBLE_CONFIGS (was dead code)
16
+ - Uses trained encoder at evaluation (matches training)
17
+ - Consolidated denoise_batch helper (no more inline loops)
18
 
19
  Usage:
20
+ python run_full_pipeline.py --quick # fast demo (~90s CPU)
21
+ python run_full_pipeline.py # full run
22
+ python run_full_pipeline.py --skip_denoise # skip denoiser stage
 
 
23
  """
24
 
25
  import os
 
51
  plot_noise_degradation, plot_ood_uncertainty,
52
  )
53
  from nqr_snn.denoising.selector import DenoisingSelector
54
+ from nqr_snn.denoising import denoise_batch as _denoise_batch
55
  from nqr_snn.noise_injection.stress_test import run_stress_test
56
 
57
 
58
  # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
59
  # Helper: apply denoiser to a batch of complex signals
60
+ # v3.2: Now delegates to nqr_snn.denoising.denoise_batch (was duplicated)
61
  # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
62
  def denoise_batch(signals: np.ndarray, denoiser) -> np.ndarray:
63
  """Apply denoiser to every signal in a (N, 1024) complex128 array."""
64
+ return _denoise_batch(denoiser, signals)
 
 
 
 
 
 
 
 
65
 
66
 
67
  # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
 
84
  # Helper: signals β†’ features tensor (optionally denoise first)
85
  # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
86
  def signals_to_features(signals, denoiser=None):
87
+ """(N,1024) complex128 -> (N,7,1024) float32 tensor β€” vectorized."""
88
  if denoiser is not None:
89
  signals = denoise_batch(signals, denoiser)
 
90
  feats = extract_features_batch(signals)
91
  return torch.from_numpy(feats)
92
 
 
96
  # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
97
  def main():
98
  parser = argparse.ArgumentParser(
99
+ description="NQR-SNN v3.2 β€” full pipeline"
100
  )
101
  parser.add_argument("--quick", action="store_true",
102
  help="Quick demo mode (~90 s on CPU)")
 
113
  parser.add_argument("--search_epochs", type=int, default=None)
114
  args = parser.parse_args()
115
 
116
+ # -- Resolve sizes depending on mode --
117
  quick = args.quick
118
  train_size = args.train_size or (500 if quick else 3000)
119
  val_size = args.val_size or (150 if quick else 1000)
 
127
  device = "cuda" if torch.cuda.is_available() else "cpu"
128
 
129
  print("=" * 80)
130
+ print("NQR-SNN FRAMEWORK v3.2 β€” FULL PIPELINE")
131
  print("=" * 80)
 
132
  print(f" Device : {device}")
133
  print(f" Mode : {'QUICK' if quick else 'FULL'}")
134
  print(f" Train / Val : {train_size} / {val_size} per class")
135
  print(f" Ensemble members : {ensemble_size}")
136
  print(f" Max epochs : {max_epochs}")
137
+ print(f" Search epochs : {search_epochs} ({len(config.NEURON_MODELS)}x{len(config.SURROGATE_FUNCTIONS)} configs)")
138
  print(f" Test per SNR : {n_test}")
139
  print(f" Steps : denoise={'ON' if not args.skip_denoise else 'OFF'} | "
140
  f"search={'ON' if not args.skip_neuron_search else 'OFF'} | "
 
152
 
153
  if not args.skip_denoise:
154
  t1 = time.time()
155
+ print(f"\n[1/6] DENOISER SELECTION (LM -> SSA -> Wavelet on {denoise_samples} low-SNR white samples)")
156
 
157
  old_vs = config.VAL_SIZE
158
  config.VAL_SIZE = denoise_samples
 
165
  den_data["clean"][:denoise_samples],
166
  )
167
  step_times["denoiser"] = time.time() - t1
168
+ print(f" Done: {denoiser_name} (R2={denoiser_r2:.1f}) [{step_times['denoiser']:.1f}s]")
169
  else:
170
  print("\n[1/6] DENOISER SELECTION β€” skipped")
171
 
 
203
 
204
  step_times["search"] = time.time() - t2
205
 
206
+ # v3.2 FIX: Wire neuron search results into ENSEMBLE_CONFIGS
207
+ # Previously the search result was stored but ENSEMBLE_CONFIGS stayed hardcoded,
208
+ # making the entire search stage dead code. Now the top-K diverse configs from
209
+ # the search populate the ensemble, ensuring the search actually affects training.
210
+ search_df = search_result.get("all_results")
211
+ if search_df is not None and len(search_df) > 0:
212
+ valid = search_df[search_df["val_accuracy"] > 0].sort_values(
213
+ "val_accuracy", ascending=False
214
+ )
215
+ if len(valid) >= 3:
216
+ new_configs = []
217
+ for _, row in valid.iterrows():
218
+ new_configs.append((
219
+ row["neuron_model"],
220
+ row["surrogate_fn"],
221
+ config.SNN_TAU, # default tau; vary below
222
+ ))
223
+ # Add tau diversity: cycle through [1.5, 2.0, 2.5, 3.0] for top configs
224
+ tau_cycle = [1.5, 2.0, 2.5, 3.0]
225
+ ensemble_configs = []
226
+ for i, (nm, sf, _) in enumerate(new_configs):
227
+ tau = tau_cycle[i % len(tau_cycle)]
228
+ ensemble_configs.append((nm, sf, tau))
229
+ if len(ensemble_configs) >= config.ENSEMBLE_SIZE:
230
+ break
231
+ config.ENSEMBLE_CONFIGS = ensemble_configs
232
+ print(f" Updated ENSEMBLE_CONFIGS from search: {len(ensemble_configs)} configs")
233
+
234
  # Heatmap
235
  csv = os.path.join(config.RESULTS_DIR, "neuron_search.csv")
236
  if os.path.exists(csv):
237
  plot_neuron_search_heatmap(csv)
238
 
239
+ print(f" Winner: {neuron_model} + {surrogate_fn} "
240
  f"(val_acc={search_result['val_accuracy']:.4f}) "
241
  f"[wall {step_times['search']:.1f}s]")
242
  else:
 
261
  val_loader = get_balanced_loader_v2(vl_ds, batch_size=config.BATCH_SIZE, shuffle=False)
262
 
263
  step_times["datagen"] = time.time() - t3
264
+ print(f" {len(tr_ds)} train, {len(vl_ds)} val "
265
  f"SNR=[{tr['snr_dbs'].min():.0f},{tr['snr_dbs'].max():.0f}]dB "
266
  f"[{step_times['datagen']:.1f}s]")
267
 
 
269
  # STEP 4 Ensemble Training
270
  # ──────────────────────────────────────────────────────────────────
271
  t4 = time.time()
272
+ print(f"\n[4/6] ENSEMBLE TRAINING ({ensemble_size} members x {max_epochs} epochs, "
273
  f"neuron={neuron_model}+{surrogate_fn})")
274
 
275
  ensemble = SNNEnsemble(
276
  neuron_model=neuron_model, surrogate_fn=surrogate_fn,
277
  ensemble_size=ensemble_size, device=device,
278
+ heterogeneous=True,
279
  )
280
  ensemble.train_all(train_loader, val_loader, max_epochs=max_epochs)
281
  ensemble.load_checkpoints()
282
 
283
  step_times["train"] = time.time() - t4
284
+ print(f" Training done [{step_times['train']:.1f}s]")
285
 
286
  # ──────────────────────────────────────────────────────────────────
287
  # STEP 5 SNR-Level Evaluation
288
  # ──────────────────────────────────────────────────────────────────
289
  t5 = time.time()
290
+ # v3.2: Use ensemble's encoder (LearnableTemporalEncoder with trained weights)
291
+ encoder = ensemble.encoder
292
  snr_rows = []
293
 
294
  print(f"\n[5/6] SNR EVALUATION ({n_test}/class per level, "
 
307
  rep["ensemble_std"] = float(std_p.cpu().mean())
308
  snr_rows.append(rep)
309
 
310
+ print(f" SNR={snr_db:4d} dB | acc={rep['accuracy']:.4f} "
311
  f"auc={rep['auc']:.4f} f1={rep['f1']:.4f} "
312
  f"tpr={rep['tpr']:.4f} tnr={rep['tnr']:.4f}")
313
 
 
344
  os.path.join(config.PLOTS_DIR, "ood_uncertainty.png"))
345
 
346
  step_times["stress"] = time.time() - t6
347
+ print(f" Stress test done [{step_times['stress']:.1f}s]")
348
  else:
349
  print("\n[6/6] STRESS TEST β€” skipped")
350
 
 
353
  # ──────────────────────────────────────────────────────────────────
354
  total = time.time() - t0
355
  print("\n" + "=" * 80)
356
+ print("PIPELINE COMPLETE β€” RESULTS SUMMARY (v3.2)")
357
  print("=" * 80)
358
  print(f" Neuron / surrogate : {neuron_model} + {surrogate_fn}")
359
  print(f" Denoiser : {denoiser_name}")
 
370
  row = snr_df[snr_df["snr_db"] == db]
371
  if len(row):
372
  a = row["accuracy"].values[0]
373
+ status = "PASS" if a >= tgt else "FAIL"
374
+ print(f" [{status}] SNR={db} dB : {a:.4f} (target >={tgt})")
375
 
376
  print(f"\n Step timings:")
377
  for k, v in step_times.items():