mrwabbit commited on
Commit
faada08
·
verified ·
1 Parent(s): 7ee06fa

Upload shd_train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. shd_train.py +449 -0
shd_train.py ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Surrogate gradient SNN training for the SHD benchmark.
2
+
3
+ Trains a recurrent SNN (700 -> hidden -> 20) using backpropagation through
4
+ time with a fast-sigmoid surrogate gradient.
5
+
6
+ Supports two neuron models:
7
+ - LIF: multiplicative decay (v = beta * v + (1-beta) * I). Default.
8
+ - adLIF: Adaptive LIF with Symplectic Euler discretization.
9
+ Updates adaptation BEFORE threshold computation for richer temporal dynamics.
10
+ Published: 95.81% on SHD (SE-adLIF, 2025).
11
+
12
+ Hardware mapping (CUBA neuron, P22A):
13
+ decay_u = round(alpha * 4096) (12-bit fractional)
14
+
15
+ Usage:
16
+ python shd_train.py --data-dir data/shd --epochs 200 --hidden 512
17
+ python shd_train.py --neuron-type adlif --dropout 0.15 --epochs 200
18
+ """
19
+
20
+ import os
21
+ import sys
22
+ import random
23
+ import argparse
24
+ import numpy as np
25
+ import torch
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+ from torch.utils.data import DataLoader
29
+
30
+ # Add benchmarks dir to path for shd_loader import
31
+ sys.path.insert(0, os.path.dirname(__file__))
32
+ from shd_loader import SHDDataset, collate_fn, N_CHANNELS, N_CLASSES
33
+
34
+
35
+ # ---------------------------------------------------------------------------
36
+ # Surrogate gradient
37
+ # ---------------------------------------------------------------------------
38
+
39
+ class SurrogateSpikeFunction(torch.autograd.Function):
40
+ """Heaviside forward, fast-sigmoid backward (surrogate gradient)."""
41
+
42
+ @staticmethod
43
+ def forward(ctx, x):
44
+ ctx.save_for_backward(x)
45
+ return (x >= 0).float()
46
+
47
+ @staticmethod
48
+ def backward(ctx, grad_output):
49
+ x, = ctx.saved_tensors
50
+ # Fast sigmoid surrogate: 1 / (1 + scale*|x|)^2
51
+ scale = 25.0
52
+ grad = grad_output / (scale * torch.abs(x) + 1.0) ** 2
53
+ return grad
54
+
55
+
56
+ surrogate_spike = SurrogateSpikeFunction.apply
57
+
58
+
59
+ # ---------------------------------------------------------------------------
60
+ # Neuron model — multiplicative decay LIF (maps to CUBA hardware neuron)
61
+ # ---------------------------------------------------------------------------
62
+
63
+ class LIFNeuron(nn.Module):
64
+ """Leaky Integrate-and-Fire with multiplicative (exponential) decay.
65
+
66
+ Dynamics per timestep:
67
+ v = beta * v_prev + (1 - beta) * I # exponential decay + scaled input
68
+ spike = Heaviside(v - threshold) # surrogate in backward
69
+ v = v * (1 - spike) # hard reset
70
+
71
+ Hardware mapping (CUBA neuron, P22A):
72
+ decay_u = round(beta * 4096) (12-bit fractional)
73
+ """
74
+
75
+ def __init__(self, size, beta_init=0.95, threshold=1.0, learn_beta=True):
76
+ super().__init__()
77
+ self.size = size
78
+ self.threshold = threshold
79
+ # Learnable time constant via sigmoid-mapped beta
80
+ if learn_beta:
81
+ # Initialize so sigmoid(x) = beta_init
82
+ init_val = np.log(beta_init / (1.0 - beta_init))
83
+ self.beta_raw = nn.Parameter(torch.full((size,), init_val))
84
+ else:
85
+ self.register_buffer('beta_raw',
86
+ torch.full((size,), np.log(beta_init / (1.0 - beta_init))))
87
+
88
+ @property
89
+ def beta(self):
90
+ return torch.sigmoid(self.beta_raw)
91
+
92
+ def forward(self, input_current, v_prev):
93
+ beta = self.beta
94
+ v = beta * v_prev + (1.0 - beta) * input_current
95
+ spikes = surrogate_spike(v - self.threshold)
96
+ v = v * (1.0 - spikes) # hard reset to 0
97
+ return v, spikes
98
+
99
+
100
+ # ---------------------------------------------------------------------------
101
+ # Adaptive LIF neuron — Symplectic Euler discretization
102
+ # ---------------------------------------------------------------------------
103
+
104
+ class AdaptiveLIFNeuron(nn.Module):
105
+ """Adaptive LIF with Symplectic Euler (SE) discretization.
106
+
107
+ Key: adaptation is updated BEFORE threshold computation, so the neuron
108
+ can anticipate its own spike — greatly improves temporal coding.
109
+
110
+ Dynamics per timestep (SE order):
111
+ a = rho * a_prev + spike_prev # 1. adaptation update FIRST
112
+ theta = threshold_base + beta_a * a # 2. adaptive threshold
113
+ v = alpha * v_prev + (1-alpha) * I # 3. membrane update
114
+ spike = Heaviside(v - theta) # 4. spike decision
115
+ v = v * (1 - spike) # 5. hard reset
116
+
117
+ Hardware note: adaptation is training-only. Only alpha (membrane decay)
118
+ deploys to CUBA hardware as decay_v = round(alpha * 4096).
119
+ """
120
+
121
+ def __init__(self, size, alpha_init=0.90, rho_init=0.85, beta_a_init=1.8,
122
+ threshold=1.0):
123
+ super().__init__()
124
+ self.size = size
125
+ self.threshold_base = nn.Parameter(torch.full((size,), threshold))
126
+
127
+ # Membrane decay (learnable via sigmoid)
128
+ init_alpha = np.log(alpha_init / (1.0 - alpha_init))
129
+ self.alpha_raw = nn.Parameter(torch.full((size,), init_alpha))
130
+
131
+ # Adaptation decay (learnable via sigmoid)
132
+ init_rho = np.log(rho_init / (1.0 - rho_init))
133
+ self.rho_raw = nn.Parameter(torch.full((size,), init_rho))
134
+
135
+ # Adaptation strength (learnable, softplus to keep positive)
136
+ # softplus^{-1}(beta_a_init) = log(exp(beta_a_init) - 1)
137
+ init_beta_a = np.log(np.exp(beta_a_init) - 1.0)
138
+ self.beta_a_raw = nn.Parameter(torch.full((size,), init_beta_a))
139
+
140
+ @property
141
+ def alpha(self):
142
+ return torch.sigmoid(self.alpha_raw)
143
+
144
+ def forward(self, input_current, v_prev, a_prev, spike_prev):
145
+ alpha = torch.sigmoid(self.alpha_raw)
146
+ rho = torch.sigmoid(self.rho_raw)
147
+ beta_a = F.softplus(self.beta_a_raw)
148
+
149
+ # SE discretization: adaptation FIRST
150
+ a_new = rho * a_prev + spike_prev
151
+ theta = self.threshold_base + beta_a * a_new
152
+
153
+ # Membrane dynamics
154
+ v = alpha * v_prev + (1.0 - alpha) * input_current
155
+ spikes = surrogate_spike(v - theta)
156
+ v = v * (1.0 - spikes) # hard reset
157
+
158
+ return v, spikes, a_new
159
+
160
+
161
+ # ---------------------------------------------------------------------------
162
+ # Event-drop data augmentation
163
+ # ---------------------------------------------------------------------------
164
+
165
+ def event_drop_augment(spikes_batch, drop_time_prob=0.1, drop_neuron_prob=0.05):
166
+ """Randomly drop entire time bins or channels for regularization.
167
+
168
+ Operates on full batch (B, T, C) for efficiency. ~1% accuracy boost.
169
+ """
170
+ if random.random() < 0.5:
171
+ # Drop-by-time: zero out random time bins (shared across batch)
172
+ B, T, C = spikes_batch.shape
173
+ mask = (torch.rand(1, T, 1, device=spikes_batch.device)
174
+ > drop_time_prob).float()
175
+ return spikes_batch * mask
176
+ else:
177
+ # Drop-by-neuron: zero out random input channels (shared across batch)
178
+ B, T, C = spikes_batch.shape
179
+ mask = (torch.rand(1, 1, C, device=spikes_batch.device)
180
+ > drop_neuron_prob).float()
181
+ return spikes_batch * mask
182
+
183
+
184
+ # ---------------------------------------------------------------------------
185
+ # SNN model
186
+ # ---------------------------------------------------------------------------
187
+
188
+ class SHDSNN(nn.Module):
189
+ """Recurrent SNN for SHD classification.
190
+
191
+ 700 (input spikes) -> hidden (recurrent LIF/adLIF) -> 20 (non-spiking readout)
192
+ Readout: time-summed membrane potential of output layer -> softmax.
193
+ """
194
+
195
+ def __init__(self, n_input=N_CHANNELS, n_hidden=256, n_output=N_CLASSES,
196
+ beta_hidden=0.95, beta_out=0.9, threshold=1.0, dropout=0.3,
197
+ neuron_type='lif', alpha_init=0.90, rho_init=0.85,
198
+ beta_a_init=1.8):
199
+ super().__init__()
200
+ self.n_hidden = n_hidden
201
+ self.n_output = n_output
202
+ self.dropout_p = dropout
203
+ self.neuron_type = neuron_type
204
+
205
+ # Synaptic weight matrices
206
+ self.fc1 = nn.Linear(n_input, n_hidden, bias=False)
207
+ self.fc2 = nn.Linear(n_hidden, n_output, bias=False)
208
+
209
+ # Recurrent connection in hidden layer
210
+ self.fc_rec = nn.Linear(n_hidden, n_hidden, bias=False)
211
+
212
+ # Hidden layer neuron
213
+ if neuron_type == 'adlif':
214
+ self.lif1 = AdaptiveLIFNeuron(
215
+ n_hidden, alpha_init=alpha_init, rho_init=rho_init,
216
+ beta_a_init=beta_a_init, threshold=threshold)
217
+ else:
218
+ self.lif1 = LIFNeuron(n_hidden, beta_init=beta_hidden,
219
+ threshold=threshold, learn_beta=True)
220
+
221
+ # Output layer always standard LIF (readout doesn't need adaptation)
222
+ self.lif2 = LIFNeuron(n_output, beta_init=beta_out,
223
+ threshold=threshold, learn_beta=True)
224
+
225
+ # Dropout for regularization
226
+ self.dropout = nn.Dropout(p=dropout)
227
+
228
+ # Weight init
229
+ nn.init.xavier_uniform_(self.fc1.weight, gain=0.5)
230
+ nn.init.xavier_uniform_(self.fc2.weight, gain=0.5)
231
+ nn.init.orthogonal_(self.fc_rec.weight, gain=0.2)
232
+
233
+ def forward(self, x):
234
+ """Forward pass unrolled through T timesteps.
235
+
236
+ Args:
237
+ x: (batch, T, n_input) dense spike input
238
+
239
+ Returns:
240
+ output: (batch, n_output) averaged membrane for classification
241
+ """
242
+ batch, T, _ = x.shape
243
+ device = x.device
244
+
245
+ v1 = torch.zeros(batch, self.n_hidden, device=device)
246
+ v2 = torch.zeros(batch, self.n_output, device=device)
247
+ spk1 = torch.zeros(batch, self.n_hidden, device=device)
248
+
249
+ out_sum = torch.zeros(batch, self.n_output, device=device)
250
+
251
+ # adLIF needs adaptation state
252
+ if self.neuron_type == 'adlif':
253
+ a1 = torch.zeros(batch, self.n_hidden, device=device)
254
+
255
+ for t in range(T):
256
+ # Hidden layer: feedforward + recurrent
257
+ I1 = self.fc1(x[:, t]) + self.fc_rec(spk1)
258
+
259
+ if self.neuron_type == 'adlif':
260
+ v1, spk1, a1 = self.lif1(I1, v1, a1, spk1)
261
+ else:
262
+ v1, spk1 = self.lif1(I1, v1)
263
+
264
+ # Apply dropout to hidden spikes
265
+ spk1_drop = self.dropout(spk1) if self.training else spk1
266
+
267
+ # Output layer (non-spiking readout: integrate with decay)
268
+ I2 = self.fc2(spk1_drop)
269
+ beta_out = self.lif2.beta
270
+ v2 = beta_out * v2 + (1.0 - beta_out) * I2
271
+
272
+ out_sum = out_sum + v2
273
+
274
+ # Normalize by timesteps
275
+ return out_sum / T
276
+
277
+
278
+ # ---------------------------------------------------------------------------
279
+ # Training loop
280
+ # ---------------------------------------------------------------------------
281
+
282
+ def train_epoch(model, loader, optimizer, device, use_event_drop=False,
283
+ label_smoothing=0.0):
284
+ model.train()
285
+ total_loss = 0.0
286
+ correct = 0
287
+ total = 0
288
+
289
+ for inputs, labels in loader:
290
+ inputs, labels = inputs.to(device), labels.to(device)
291
+
292
+ # Event-drop augmentation (batch-level for efficiency)
293
+ if use_event_drop:
294
+ inputs = event_drop_augment(inputs)
295
+
296
+ optimizer.zero_grad()
297
+ output = model(inputs)
298
+ loss = F.cross_entropy(output, labels, label_smoothing=label_smoothing)
299
+ loss.backward()
300
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
301
+ optimizer.step()
302
+
303
+ total_loss += loss.item() * inputs.size(0)
304
+ correct += (output.argmax(1) == labels).sum().item()
305
+ total += inputs.size(0)
306
+
307
+ return total_loss / total, correct / total
308
+
309
+
310
+ @torch.no_grad()
311
+ def evaluate(model, loader, device):
312
+ model.eval()
313
+ total_loss = 0.0
314
+ correct = 0
315
+ total = 0
316
+
317
+ for inputs, labels in loader:
318
+ inputs, labels = inputs.to(device), labels.to(device)
319
+
320
+ output = model(inputs)
321
+ loss = F.cross_entropy(output, labels)
322
+
323
+ total_loss += loss.item() * inputs.size(0)
324
+ correct += (output.argmax(1) == labels).sum().item()
325
+ total += inputs.size(0)
326
+
327
+ return total_loss / total, correct / total
328
+
329
+
330
+ def main():
331
+ parser = argparse.ArgumentParser(description="Train SNN on SHD benchmark")
332
+ parser.add_argument("--data-dir", default="data/shd")
333
+ parser.add_argument("--epochs", type=int, default=200)
334
+ parser.add_argument("--batch-size", type=int, default=128)
335
+ parser.add_argument("--lr", type=float, default=1e-3)
336
+ parser.add_argument("--weight-decay", type=float, default=1e-4)
337
+ parser.add_argument("--hidden", type=int, default=512)
338
+ parser.add_argument("--threshold", type=float, default=1.0)
339
+ parser.add_argument("--beta-hidden", type=float, default=0.95,
340
+ help="Initial membrane decay factor for hidden layer")
341
+ parser.add_argument("--beta-out", type=float, default=0.9,
342
+ help="Initial membrane decay factor for output layer")
343
+ parser.add_argument("--dropout", type=float, default=0.3)
344
+ parser.add_argument("--dt", type=float, default=4e-3,
345
+ help="Time bin width in seconds (4ms -> 250 bins)")
346
+ parser.add_argument("--seed", type=int, default=42)
347
+ parser.add_argument("--save", default="shd_model.pt")
348
+ parser.add_argument("--no-recurrent", action="store_true",
349
+ help="Disable recurrent hidden connection")
350
+ parser.add_argument("--neuron-type", choices=["lif", "adlif"], default="lif",
351
+ help="Neuron model: lif (standard) or adlif (adaptive, SE)")
352
+ parser.add_argument("--alpha-init", type=float, default=0.90,
353
+ help="Initial membrane decay for adLIF (default: 0.90)")
354
+ parser.add_argument("--rho-init", type=float, default=0.85,
355
+ help="Initial adaptation decay for adLIF (default: 0.85)")
356
+ parser.add_argument("--beta-a-init", type=float, default=1.8,
357
+ help="Initial adaptation strength for adLIF (default: 1.8)")
358
+ parser.add_argument("--event-drop", action="store_true", default=None,
359
+ help="Enable event-drop augmentation (auto-enabled for adlif)")
360
+ parser.add_argument("--label-smoothing", type=float, default=0.0,
361
+ help="Label smoothing factor (0.0=off, 0.1=recommended)")
362
+ args = parser.parse_args()
363
+
364
+ # Auto-enable event-drop for adLIF if not explicitly set
365
+ if args.event_drop is None:
366
+ args.event_drop = (args.neuron_type == 'adlif')
367
+
368
+ torch.manual_seed(args.seed)
369
+ np.random.seed(args.seed)
370
+
371
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
372
+ print(f"Device: {device}")
373
+
374
+ # Dataset
375
+ print("Loading SHD dataset...")
376
+ train_ds = SHDDataset(args.data_dir, "train", dt=args.dt)
377
+ test_ds = SHDDataset(args.data_dir, "test", dt=args.dt)
378
+
379
+ train_loader = DataLoader(
380
+ train_ds, batch_size=args.batch_size, shuffle=True,
381
+ collate_fn=collate_fn, num_workers=0, pin_memory=True)
382
+ test_loader = DataLoader(
383
+ test_ds, batch_size=args.batch_size, shuffle=False,
384
+ collate_fn=collate_fn, num_workers=0, pin_memory=True)
385
+
386
+ print(f"Train: {len(train_ds)}, Test: {len(test_ds)}, "
387
+ f"Time bins: {train_ds.n_bins} (dt={args.dt*1000:.1f}ms)")
388
+
389
+ # Model
390
+ model = SHDSNN(
391
+ n_hidden=args.hidden,
392
+ threshold=args.threshold,
393
+ beta_hidden=args.beta_hidden,
394
+ beta_out=args.beta_out,
395
+ dropout=args.dropout,
396
+ neuron_type=args.neuron_type,
397
+ alpha_init=args.alpha_init,
398
+ rho_init=args.rho_init,
399
+ beta_a_init=args.beta_a_init,
400
+ ).to(device)
401
+
402
+ if args.no_recurrent:
403
+ model.fc_rec.weight.data.zero_()
404
+ model.fc_rec.weight.requires_grad = False
405
+
406
+ n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
407
+ neuron_info = args.neuron_type.upper()
408
+ if args.neuron_type == 'adlif':
409
+ neuron_info += f" (alpha={args.alpha_init}, rho={args.rho_init}, beta_a={args.beta_a_init})"
410
+ print(f"Model: {N_CHANNELS}->{args.hidden}->{N_CLASSES}, "
411
+ f"{n_params:,} params ({neuron_info}, "
412
+ f"recurrent={'off' if args.no_recurrent else 'on'}, "
413
+ f"dropout={args.dropout}, event_drop={args.event_drop})")
414
+
415
+ # Optimizer with weight decay
416
+ optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr,
417
+ weight_decay=args.weight_decay)
418
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs,
419
+ eta_min=1e-5)
420
+
421
+ best_acc = 0.0
422
+ for epoch in range(args.epochs):
423
+ train_loss, train_acc = train_epoch(model, train_loader, optimizer, device,
424
+ use_event_drop=args.event_drop,
425
+ label_smoothing=args.label_smoothing)
426
+ test_loss, test_acc = evaluate(model, test_loader, device)
427
+ scheduler.step()
428
+
429
+ if test_acc > best_acc:
430
+ best_acc = test_acc
431
+ torch.save({
432
+ 'epoch': epoch,
433
+ 'model_state_dict': model.state_dict(),
434
+ 'test_acc': test_acc,
435
+ 'args': vars(args),
436
+ }, args.save)
437
+
438
+ lr = optimizer.param_groups[0]['lr']
439
+ print(f"Epoch {epoch+1:3d}/{args.epochs} | "
440
+ f"Train: {train_loss:.4f} / {train_acc*100:.1f}% | "
441
+ f"Test: {test_loss:.4f} / {test_acc*100:.1f}% | "
442
+ f"LR={lr:.2e} | Best={best_acc*100:.1f}%")
443
+
444
+ print(f"\nDone. Best test accuracy: {best_acc*100:.1f}%")
445
+ print(f"Model saved to {args.save}")
446
+
447
+
448
+ if __name__ == "__main__":
449
+ main()