shd-snn-benchmark / shd_train.py
mrwabbit's picture
Upload shd_train.py with huggingface_hub
faada08 verified
"""Surrogate gradient SNN training for the SHD benchmark.
Trains a recurrent SNN (700 -> hidden -> 20) using backpropagation through
time with a fast-sigmoid surrogate gradient.
Supports two neuron models:
- LIF: multiplicative decay (v = beta * v + (1-beta) * I). Default.
- adLIF: Adaptive LIF with Symplectic Euler discretization.
Updates adaptation BEFORE threshold computation for richer temporal dynamics.
Published: 95.81% on SHD (SE-adLIF, 2025).
Hardware mapping (CUBA neuron, P22A):
decay_u = round(alpha * 4096) (12-bit fractional)
Usage:
python shd_train.py --data-dir data/shd --epochs 200 --hidden 512
python shd_train.py --neuron-type adlif --dropout 0.15 --epochs 200
"""
import os
import sys
import random
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
# Add benchmarks dir to path for shd_loader import
sys.path.insert(0, os.path.dirname(__file__))
from shd_loader import SHDDataset, collate_fn, N_CHANNELS, N_CLASSES
# ---------------------------------------------------------------------------
# Surrogate gradient
# ---------------------------------------------------------------------------
class SurrogateSpikeFunction(torch.autograd.Function):
"""Heaviside forward, fast-sigmoid backward (surrogate gradient)."""
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return (x >= 0).float()
@staticmethod
def backward(ctx, grad_output):
x, = ctx.saved_tensors
# Fast sigmoid surrogate: 1 / (1 + scale*|x|)^2
scale = 25.0
grad = grad_output / (scale * torch.abs(x) + 1.0) ** 2
return grad
surrogate_spike = SurrogateSpikeFunction.apply
# ---------------------------------------------------------------------------
# Neuron model — multiplicative decay LIF (maps to CUBA hardware neuron)
# ---------------------------------------------------------------------------
class LIFNeuron(nn.Module):
"""Leaky Integrate-and-Fire with multiplicative (exponential) decay.
Dynamics per timestep:
v = beta * v_prev + (1 - beta) * I # exponential decay + scaled input
spike = Heaviside(v - threshold) # surrogate in backward
v = v * (1 - spike) # hard reset
Hardware mapping (CUBA neuron, P22A):
decay_u = round(beta * 4096) (12-bit fractional)
"""
def __init__(self, size, beta_init=0.95, threshold=1.0, learn_beta=True):
super().__init__()
self.size = size
self.threshold = threshold
# Learnable time constant via sigmoid-mapped beta
if learn_beta:
# Initialize so sigmoid(x) = beta_init
init_val = np.log(beta_init / (1.0 - beta_init))
self.beta_raw = nn.Parameter(torch.full((size,), init_val))
else:
self.register_buffer('beta_raw',
torch.full((size,), np.log(beta_init / (1.0 - beta_init))))
@property
def beta(self):
return torch.sigmoid(self.beta_raw)
def forward(self, input_current, v_prev):
beta = self.beta
v = beta * v_prev + (1.0 - beta) * input_current
spikes = surrogate_spike(v - self.threshold)
v = v * (1.0 - spikes) # hard reset to 0
return v, spikes
# ---------------------------------------------------------------------------
# Adaptive LIF neuron — Symplectic Euler discretization
# ---------------------------------------------------------------------------
class AdaptiveLIFNeuron(nn.Module):
"""Adaptive LIF with Symplectic Euler (SE) discretization.
Key: adaptation is updated BEFORE threshold computation, so the neuron
can anticipate its own spike — greatly improves temporal coding.
Dynamics per timestep (SE order):
a = rho * a_prev + spike_prev # 1. adaptation update FIRST
theta = threshold_base + beta_a * a # 2. adaptive threshold
v = alpha * v_prev + (1-alpha) * I # 3. membrane update
spike = Heaviside(v - theta) # 4. spike decision
v = v * (1 - spike) # 5. hard reset
Hardware note: adaptation is training-only. Only alpha (membrane decay)
deploys to CUBA hardware as decay_v = round(alpha * 4096).
"""
def __init__(self, size, alpha_init=0.90, rho_init=0.85, beta_a_init=1.8,
threshold=1.0):
super().__init__()
self.size = size
self.threshold_base = nn.Parameter(torch.full((size,), threshold))
# Membrane decay (learnable via sigmoid)
init_alpha = np.log(alpha_init / (1.0 - alpha_init))
self.alpha_raw = nn.Parameter(torch.full((size,), init_alpha))
# Adaptation decay (learnable via sigmoid)
init_rho = np.log(rho_init / (1.0 - rho_init))
self.rho_raw = nn.Parameter(torch.full((size,), init_rho))
# Adaptation strength (learnable, softplus to keep positive)
# softplus^{-1}(beta_a_init) = log(exp(beta_a_init) - 1)
init_beta_a = np.log(np.exp(beta_a_init) - 1.0)
self.beta_a_raw = nn.Parameter(torch.full((size,), init_beta_a))
@property
def alpha(self):
return torch.sigmoid(self.alpha_raw)
def forward(self, input_current, v_prev, a_prev, spike_prev):
alpha = torch.sigmoid(self.alpha_raw)
rho = torch.sigmoid(self.rho_raw)
beta_a = F.softplus(self.beta_a_raw)
# SE discretization: adaptation FIRST
a_new = rho * a_prev + spike_prev
theta = self.threshold_base + beta_a * a_new
# Membrane dynamics
v = alpha * v_prev + (1.0 - alpha) * input_current
spikes = surrogate_spike(v - theta)
v = v * (1.0 - spikes) # hard reset
return v, spikes, a_new
# ---------------------------------------------------------------------------
# Event-drop data augmentation
# ---------------------------------------------------------------------------
def event_drop_augment(spikes_batch, drop_time_prob=0.1, drop_neuron_prob=0.05):
"""Randomly drop entire time bins or channels for regularization.
Operates on full batch (B, T, C) for efficiency. ~1% accuracy boost.
"""
if random.random() < 0.5:
# Drop-by-time: zero out random time bins (shared across batch)
B, T, C = spikes_batch.shape
mask = (torch.rand(1, T, 1, device=spikes_batch.device)
> drop_time_prob).float()
return spikes_batch * mask
else:
# Drop-by-neuron: zero out random input channels (shared across batch)
B, T, C = spikes_batch.shape
mask = (torch.rand(1, 1, C, device=spikes_batch.device)
> drop_neuron_prob).float()
return spikes_batch * mask
# ---------------------------------------------------------------------------
# SNN model
# ---------------------------------------------------------------------------
class SHDSNN(nn.Module):
"""Recurrent SNN for SHD classification.
700 (input spikes) -> hidden (recurrent LIF/adLIF) -> 20 (non-spiking readout)
Readout: time-summed membrane potential of output layer -> softmax.
"""
def __init__(self, n_input=N_CHANNELS, n_hidden=256, n_output=N_CLASSES,
beta_hidden=0.95, beta_out=0.9, threshold=1.0, dropout=0.3,
neuron_type='lif', alpha_init=0.90, rho_init=0.85,
beta_a_init=1.8):
super().__init__()
self.n_hidden = n_hidden
self.n_output = n_output
self.dropout_p = dropout
self.neuron_type = neuron_type
# Synaptic weight matrices
self.fc1 = nn.Linear(n_input, n_hidden, bias=False)
self.fc2 = nn.Linear(n_hidden, n_output, bias=False)
# Recurrent connection in hidden layer
self.fc_rec = nn.Linear(n_hidden, n_hidden, bias=False)
# Hidden layer neuron
if neuron_type == 'adlif':
self.lif1 = AdaptiveLIFNeuron(
n_hidden, alpha_init=alpha_init, rho_init=rho_init,
beta_a_init=beta_a_init, threshold=threshold)
else:
self.lif1 = LIFNeuron(n_hidden, beta_init=beta_hidden,
threshold=threshold, learn_beta=True)
# Output layer always standard LIF (readout doesn't need adaptation)
self.lif2 = LIFNeuron(n_output, beta_init=beta_out,
threshold=threshold, learn_beta=True)
# Dropout for regularization
self.dropout = nn.Dropout(p=dropout)
# Weight init
nn.init.xavier_uniform_(self.fc1.weight, gain=0.5)
nn.init.xavier_uniform_(self.fc2.weight, gain=0.5)
nn.init.orthogonal_(self.fc_rec.weight, gain=0.2)
def forward(self, x):
"""Forward pass unrolled through T timesteps.
Args:
x: (batch, T, n_input) dense spike input
Returns:
output: (batch, n_output) averaged membrane for classification
"""
batch, T, _ = x.shape
device = x.device
v1 = torch.zeros(batch, self.n_hidden, device=device)
v2 = torch.zeros(batch, self.n_output, device=device)
spk1 = torch.zeros(batch, self.n_hidden, device=device)
out_sum = torch.zeros(batch, self.n_output, device=device)
# adLIF needs adaptation state
if self.neuron_type == 'adlif':
a1 = torch.zeros(batch, self.n_hidden, device=device)
for t in range(T):
# Hidden layer: feedforward + recurrent
I1 = self.fc1(x[:, t]) + self.fc_rec(spk1)
if self.neuron_type == 'adlif':
v1, spk1, a1 = self.lif1(I1, v1, a1, spk1)
else:
v1, spk1 = self.lif1(I1, v1)
# Apply dropout to hidden spikes
spk1_drop = self.dropout(spk1) if self.training else spk1
# Output layer (non-spiking readout: integrate with decay)
I2 = self.fc2(spk1_drop)
beta_out = self.lif2.beta
v2 = beta_out * v2 + (1.0 - beta_out) * I2
out_sum = out_sum + v2
# Normalize by timesteps
return out_sum / T
# ---------------------------------------------------------------------------
# Training loop
# ---------------------------------------------------------------------------
def train_epoch(model, loader, optimizer, device, use_event_drop=False,
label_smoothing=0.0):
model.train()
total_loss = 0.0
correct = 0
total = 0
for inputs, labels in loader:
inputs, labels = inputs.to(device), labels.to(device)
# Event-drop augmentation (batch-level for efficiency)
if use_event_drop:
inputs = event_drop_augment(inputs)
optimizer.zero_grad()
output = model(inputs)
loss = F.cross_entropy(output, labels, label_smoothing=label_smoothing)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
total_loss += loss.item() * inputs.size(0)
correct += (output.argmax(1) == labels).sum().item()
total += inputs.size(0)
return total_loss / total, correct / total
@torch.no_grad()
def evaluate(model, loader, device):
model.eval()
total_loss = 0.0
correct = 0
total = 0
for inputs, labels in loader:
inputs, labels = inputs.to(device), labels.to(device)
output = model(inputs)
loss = F.cross_entropy(output, labels)
total_loss += loss.item() * inputs.size(0)
correct += (output.argmax(1) == labels).sum().item()
total += inputs.size(0)
return total_loss / total, correct / total
def main():
parser = argparse.ArgumentParser(description="Train SNN on SHD benchmark")
parser.add_argument("--data-dir", default="data/shd")
parser.add_argument("--epochs", type=int, default=200)
parser.add_argument("--batch-size", type=int, default=128)
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--weight-decay", type=float, default=1e-4)
parser.add_argument("--hidden", type=int, default=512)
parser.add_argument("--threshold", type=float, default=1.0)
parser.add_argument("--beta-hidden", type=float, default=0.95,
help="Initial membrane decay factor for hidden layer")
parser.add_argument("--beta-out", type=float, default=0.9,
help="Initial membrane decay factor for output layer")
parser.add_argument("--dropout", type=float, default=0.3)
parser.add_argument("--dt", type=float, default=4e-3,
help="Time bin width in seconds (4ms -> 250 bins)")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--save", default="shd_model.pt")
parser.add_argument("--no-recurrent", action="store_true",
help="Disable recurrent hidden connection")
parser.add_argument("--neuron-type", choices=["lif", "adlif"], default="lif",
help="Neuron model: lif (standard) or adlif (adaptive, SE)")
parser.add_argument("--alpha-init", type=float, default=0.90,
help="Initial membrane decay for adLIF (default: 0.90)")
parser.add_argument("--rho-init", type=float, default=0.85,
help="Initial adaptation decay for adLIF (default: 0.85)")
parser.add_argument("--beta-a-init", type=float, default=1.8,
help="Initial adaptation strength for adLIF (default: 1.8)")
parser.add_argument("--event-drop", action="store_true", default=None,
help="Enable event-drop augmentation (auto-enabled for adlif)")
parser.add_argument("--label-smoothing", type=float, default=0.0,
help="Label smoothing factor (0.0=off, 0.1=recommended)")
args = parser.parse_args()
# Auto-enable event-drop for adLIF if not explicitly set
if args.event_drop is None:
args.event_drop = (args.neuron_type == 'adlif')
torch.manual_seed(args.seed)
np.random.seed(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
# Dataset
print("Loading SHD dataset...")
train_ds = SHDDataset(args.data_dir, "train", dt=args.dt)
test_ds = SHDDataset(args.data_dir, "test", dt=args.dt)
train_loader = DataLoader(
train_ds, batch_size=args.batch_size, shuffle=True,
collate_fn=collate_fn, num_workers=0, pin_memory=True)
test_loader = DataLoader(
test_ds, batch_size=args.batch_size, shuffle=False,
collate_fn=collate_fn, num_workers=0, pin_memory=True)
print(f"Train: {len(train_ds)}, Test: {len(test_ds)}, "
f"Time bins: {train_ds.n_bins} (dt={args.dt*1000:.1f}ms)")
# Model
model = SHDSNN(
n_hidden=args.hidden,
threshold=args.threshold,
beta_hidden=args.beta_hidden,
beta_out=args.beta_out,
dropout=args.dropout,
neuron_type=args.neuron_type,
alpha_init=args.alpha_init,
rho_init=args.rho_init,
beta_a_init=args.beta_a_init,
).to(device)
if args.no_recurrent:
model.fc_rec.weight.data.zero_()
model.fc_rec.weight.requires_grad = False
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
neuron_info = args.neuron_type.upper()
if args.neuron_type == 'adlif':
neuron_info += f" (alpha={args.alpha_init}, rho={args.rho_init}, beta_a={args.beta_a_init})"
print(f"Model: {N_CHANNELS}->{args.hidden}->{N_CLASSES}, "
f"{n_params:,} params ({neuron_info}, "
f"recurrent={'off' if args.no_recurrent else 'on'}, "
f"dropout={args.dropout}, event_drop={args.event_drop})")
# Optimizer with weight decay
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr,
weight_decay=args.weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs,
eta_min=1e-5)
best_acc = 0.0
for epoch in range(args.epochs):
train_loss, train_acc = train_epoch(model, train_loader, optimizer, device,
use_event_drop=args.event_drop,
label_smoothing=args.label_smoothing)
test_loss, test_acc = evaluate(model, test_loader, device)
scheduler.step()
if test_acc > best_acc:
best_acc = test_acc
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'test_acc': test_acc,
'args': vars(args),
}, args.save)
lr = optimizer.param_groups[0]['lr']
print(f"Epoch {epoch+1:3d}/{args.epochs} | "
f"Train: {train_loss:.4f} / {train_acc*100:.1f}% | "
f"Test: {test_loss:.4f} / {test_acc*100:.1f}% | "
f"LR={lr:.2e} | Best={best_acc*100:.1f}%")
print(f"\nDone. Best test accuracy: {best_acc*100:.1f}%")
print(f"Model saved to {args.save}")
if __name__ == "__main__":
main()