catalyst-n1 / sdk /benchmarks /dvs_train.py
mrwabbit's picture
Initial upload: Catalyst N1 open source neuromorphic processor RTL
e4cdd5f verified
"""Surrogate gradient SNN training for DVS128 Gesture benchmark.
Trains a 2-layer feedforward SNN (2048 -> hidden -> 11) using the same
SubtractiveLIF neuron model from shd_train.py.
Usage:
python dvs_train.py --data-dir data/dvs_gesture --epochs 80 --hidden 512
"""
import os
import sys
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
sys.path.insert(0, os.path.dirname(__file__))
from dvs_loader import DVSGestureDataset, collate_fn, N_CHANNELS, N_CLASSES
from shd_train import SubtractiveLIF, surrogate_spike
class DVSSNN(nn.Module):
"""2-layer SNN for DVS Gesture classification.
2048 (32x32x2 input) -> hidden (LIF) -> 11 (output integrator)
"""
def __init__(self, n_input=N_CHANNELS, n_hidden=512, n_output=N_CLASSES,
threshold=1.0, leak=0.003):
super().__init__()
self.n_hidden = n_hidden
self.n_output = n_output
self.fc1 = nn.Linear(n_input, n_hidden, bias=False)
self.fc2 = nn.Linear(n_hidden, n_output, bias=False)
self.fc_rec = nn.Linear(n_hidden, n_hidden, bias=False)
self.lif1 = SubtractiveLIF(n_hidden, threshold=threshold, leak=leak)
self.output_leak = leak * 0.5
nn.init.xavier_uniform_(self.fc1.weight, gain=0.1)
nn.init.xavier_uniform_(self.fc2.weight, gain=0.3)
nn.init.orthogonal_(self.fc_rec.weight, gain=0.1)
def forward(self, x):
batch, T, _ = x.shape
device = x.device
v1 = torch.zeros(batch, self.n_hidden, device=device)
v2 = torch.zeros(batch, self.n_output, device=device)
spk1 = torch.zeros(batch, self.n_hidden, device=device)
out_sum = torch.zeros(batch, self.n_output, device=device)
for t in range(T):
I1 = self.fc1(x[:, t]) + self.fc_rec(spk1)
v1, spk1 = self.lif1(I1, v1)
I2 = self.fc2(spk1)
v2 = v2 + I2 - self.output_leak
v2 = torch.clamp(v2, min=0.0)
out_sum = out_sum + v2
return out_sum / T
def train_epoch(model, loader, optimizer, device):
model.train()
total_loss = 0.0
correct = 0
total = 0
for inputs, labels in loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
output = model(inputs)
loss = F.cross_entropy(output, labels)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
total_loss += loss.item() * inputs.size(0)
correct += (output.argmax(1) == labels).sum().item()
total += inputs.size(0)
return total_loss / total, correct / total
@torch.no_grad()
def evaluate(model, loader, device):
model.eval()
total_loss = 0.0
correct = 0
total = 0
for inputs, labels in loader:
inputs, labels = inputs.to(device), labels.to(device)
output = model(inputs)
loss = F.cross_entropy(output, labels)
total_loss += loss.item() * inputs.size(0)
correct += (output.argmax(1) == labels).sum().item()
total += inputs.size(0)
return total_loss / total, correct / total
def main():
parser = argparse.ArgumentParser(description="Train SNN on DVS Gesture")
parser.add_argument("--data-dir", default="data/dvs_gesture")
parser.add_argument("--epochs", type=int, default=80)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--lr", type=float, default=5e-4)
parser.add_argument("--hidden", type=int, default=512)
parser.add_argument("--threshold", type=float, default=1.0)
parser.add_argument("--leak", type=float, default=0.003)
parser.add_argument("--dt", type=float, default=10e-3,
help="Time bin width (10ms -> 150 bins for 1.5s)")
parser.add_argument("--duration", type=float, default=1.5)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--save", default="dvs_model.pt")
args = parser.parse_args()
torch.manual_seed(args.seed)
np.random.seed(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
print("Loading DVS Gesture dataset (first load downloads ~1.5GB)...")
train_ds = DVSGestureDataset(args.data_dir, train=True,
dt=args.dt, duration=args.duration)
test_ds = DVSGestureDataset(args.data_dir, train=False,
dt=args.dt, duration=args.duration)
train_loader = DataLoader(
train_ds, batch_size=args.batch_size, shuffle=True,
collate_fn=collate_fn, num_workers=0, pin_memory=True)
test_loader = DataLoader(
test_ds, batch_size=args.batch_size, shuffle=False,
collate_fn=collate_fn, num_workers=0, pin_memory=True)
print(f"Train: {len(train_ds)}, Test: {len(test_ds)}, "
f"Time bins: {train_ds.n_bins} (dt={args.dt*1000:.1f}ms)")
model = DVSSNN(
n_hidden=args.hidden,
threshold=args.threshold,
leak=args.leak,
).to(device)
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model: {N_CHANNELS}->{args.hidden}->{N_CLASSES}, {n_params:,} params")
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)
best_acc = 0.0
for epoch in range(args.epochs):
train_loss, train_acc = train_epoch(model, train_loader, optimizer, device)
test_loss, test_acc = evaluate(model, test_loader, device)
scheduler.step()
if test_acc > best_acc:
best_acc = test_acc
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'test_acc': test_acc,
'args': vars(args),
}, args.save)
lr = optimizer.param_groups[0]['lr']
print(f"Epoch {epoch+1:3d}/{args.epochs} | "
f"Train: {train_loss:.4f} / {train_acc*100:.1f}% | "
f"Test: {test_loss:.4f} / {test_acc*100:.1f}% | "
f"LR={lr:.2e} | Best={best_acc*100:.1f}%")
print(f"\nDone. Best test accuracy: {best_acc*100:.1f}%")
print(f"Model saved to {args.save}")
if __name__ == "__main__":
main()