"""Surrogate gradient SNN training for DVS128 Gesture benchmark. Trains a 2-layer feedforward SNN (2048 -> hidden -> 11) using the same SubtractiveLIF neuron model from shd_train.py. Usage: python dvs_train.py --data-dir data/dvs_gesture --epochs 80 --hidden 512 """ import os import sys import argparse import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader sys.path.insert(0, os.path.dirname(__file__)) from dvs_loader import DVSGestureDataset, collate_fn, N_CHANNELS, N_CLASSES from shd_train import SubtractiveLIF, surrogate_spike class DVSSNN(nn.Module): """2-layer SNN for DVS Gesture classification. 2048 (32x32x2 input) -> hidden (LIF) -> 11 (output integrator) """ def __init__(self, n_input=N_CHANNELS, n_hidden=512, n_output=N_CLASSES, threshold=1.0, leak=0.003): super().__init__() self.n_hidden = n_hidden self.n_output = n_output self.fc1 = nn.Linear(n_input, n_hidden, bias=False) self.fc2 = nn.Linear(n_hidden, n_output, bias=False) self.fc_rec = nn.Linear(n_hidden, n_hidden, bias=False) self.lif1 = SubtractiveLIF(n_hidden, threshold=threshold, leak=leak) self.output_leak = leak * 0.5 nn.init.xavier_uniform_(self.fc1.weight, gain=0.1) nn.init.xavier_uniform_(self.fc2.weight, gain=0.3) nn.init.orthogonal_(self.fc_rec.weight, gain=0.1) def forward(self, x): batch, T, _ = x.shape device = x.device v1 = torch.zeros(batch, self.n_hidden, device=device) v2 = torch.zeros(batch, self.n_output, device=device) spk1 = torch.zeros(batch, self.n_hidden, device=device) out_sum = torch.zeros(batch, self.n_output, device=device) for t in range(T): I1 = self.fc1(x[:, t]) + self.fc_rec(spk1) v1, spk1 = self.lif1(I1, v1) I2 = self.fc2(spk1) v2 = v2 + I2 - self.output_leak v2 = torch.clamp(v2, min=0.0) out_sum = out_sum + v2 return out_sum / T def train_epoch(model, loader, optimizer, device): model.train() total_loss = 0.0 correct = 0 total = 0 for inputs, labels in loader: inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() output = model(inputs) loss = F.cross_entropy(output, labels) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() total_loss += loss.item() * inputs.size(0) correct += (output.argmax(1) == labels).sum().item() total += inputs.size(0) return total_loss / total, correct / total @torch.no_grad() def evaluate(model, loader, device): model.eval() total_loss = 0.0 correct = 0 total = 0 for inputs, labels in loader: inputs, labels = inputs.to(device), labels.to(device) output = model(inputs) loss = F.cross_entropy(output, labels) total_loss += loss.item() * inputs.size(0) correct += (output.argmax(1) == labels).sum().item() total += inputs.size(0) return total_loss / total, correct / total def main(): parser = argparse.ArgumentParser(description="Train SNN on DVS Gesture") parser.add_argument("--data-dir", default="data/dvs_gesture") parser.add_argument("--epochs", type=int, default=80) parser.add_argument("--batch-size", type=int, default=32) parser.add_argument("--lr", type=float, default=5e-4) parser.add_argument("--hidden", type=int, default=512) parser.add_argument("--threshold", type=float, default=1.0) parser.add_argument("--leak", type=float, default=0.003) parser.add_argument("--dt", type=float, default=10e-3, help="Time bin width (10ms -> 150 bins for 1.5s)") parser.add_argument("--duration", type=float, default=1.5) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--save", default="dvs_model.pt") args = parser.parse_args() torch.manual_seed(args.seed) np.random.seed(args.seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Device: {device}") print("Loading DVS Gesture dataset (first load downloads ~1.5GB)...") train_ds = DVSGestureDataset(args.data_dir, train=True, dt=args.dt, duration=args.duration) test_ds = DVSGestureDataset(args.data_dir, train=False, dt=args.dt, duration=args.duration) train_loader = DataLoader( train_ds, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn, num_workers=0, pin_memory=True) test_loader = DataLoader( test_ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0, pin_memory=True) print(f"Train: {len(train_ds)}, Test: {len(test_ds)}, " f"Time bins: {train_ds.n_bins} (dt={args.dt*1000:.1f}ms)") model = DVSSNN( n_hidden=args.hidden, threshold=args.threshold, leak=args.leak, ).to(device) n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Model: {N_CHANNELS}->{args.hidden}->{N_CLASSES}, {n_params:,} params") optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs) best_acc = 0.0 for epoch in range(args.epochs): train_loss, train_acc = train_epoch(model, train_loader, optimizer, device) test_loss, test_acc = evaluate(model, test_loader, device) scheduler.step() if test_acc > best_acc: best_acc = test_acc torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'test_acc': test_acc, 'args': vars(args), }, args.save) lr = optimizer.param_groups[0]['lr'] print(f"Epoch {epoch+1:3d}/{args.epochs} | " f"Train: {train_loss:.4f} / {train_acc*100:.1f}% | " f"Test: {test_loss:.4f} / {test_acc*100:.1f}% | " f"LR={lr:.2e} | Best={best_acc*100:.1f}%") print(f"\nDone. Best test accuracy: {best_acc*100:.1f}%") print(f"Model saved to {args.save}") if __name__ == "__main__": main()