catalyst-n1 / sdk /benchmarks /shd_loader.py
mrwabbit's picture
Initial upload: Catalyst N1 open source neuromorphic processor RTL
e4cdd5f verified
"""SHD (Spiking Heidelberg Digits) dataset loader for neuromorphic benchmarks.
Downloads HDF5 files from Zenodo, converts variable-length spike events
to fixed-size dense binary tensors suitable for PyTorch training.
700 input channels (cochlea model), 20 classes (digits 0-9 in German+English).
"""
import os
import urllib.request
import gzip
import shutil
import numpy as np
try:
import h5py
except ImportError:
raise ImportError("h5py required: pip install h5py")
try:
import torch
from torch.utils.data import Dataset
except ImportError:
raise ImportError("PyTorch required: pip install torch")
SHD_URLS = {
"train": "https://compneuro.net/datasets/shd_train.h5.gz",
"test": "https://compneuro.net/datasets/shd_test.h5.gz",
}
N_CHANNELS = 700 # SHD cochlea channels
N_CLASSES = 20 # spoken digits 0-9 in German + English
def download_shd(data_dir="data/shd"):
"""Download SHD train/test HDF5 files from Zenodo if not present."""
os.makedirs(data_dir, exist_ok=True)
for split, url in SHD_URLS.items():
h5_path = os.path.join(data_dir, f"shd_{split}.h5")
gz_path = h5_path + ".gz"
if os.path.exists(h5_path):
continue
print(f"Downloading SHD {split} set from {url} ...")
try:
urllib.request.urlretrieve(url, gz_path)
except Exception as e:
raise RuntimeError(
f"Failed to download {url}: {e}\n"
f"Download manually from https://zenodo.org/records/4319560 "
f"and place shd_train.h5 / shd_test.h5 in {data_dir}/")
print(f"Extracting {gz_path} ...")
with gzip.open(gz_path, 'rb') as f_in:
with open(h5_path, 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)
os.remove(gz_path)
print(f" Saved to {h5_path}")
return data_dir
def spikes_to_dense(times, units, n_channels=N_CHANNELS, dt=4e-3, max_time=1.0):
"""Convert spike event lists to a dense binary tensor.
Args:
times: array of spike times in seconds
units: array of channel indices (0 to n_channels-1)
n_channels: number of input channels (700 for SHD)
dt: time bin width in seconds (4ms -> 250 bins)
max_time: maximum time window in seconds
Returns:
dense: (T, n_channels) float32 array with 1.0 at spike locations
"""
n_bins = int(max_time / dt)
dense = np.zeros((n_bins, n_channels), dtype=np.float32)
if not times:
return dense
bin_indices = np.clip((times / dt).astype(int), 0, n_bins - 1)
unit_indices = np.clip(units.astype(int), 0, n_channels - 1)
dense[bin_indices, unit_indices] = 1.0
return dense
class SHDDataset(Dataset):
"""PyTorch Dataset for Spiking Heidelberg Digits.
Each sample is converted to a dense binary tensor (T, 700) on first access.
"""
def __init__(self, data_dir="data/shd", split="train", dt=4e-3, max_time=1.0):
h5_path = os.path.join(data_dir, f"shd_{split}.h5")
if not os.path.exists(h5_path):
download_shd(data_dir)
with h5py.File(h5_path, 'r') as f:
self.times = [np.array(t) for t in f['spikes']['times']]
self.units = [np.array(u) for u in f['spikes']['units']]
self.labels = np.array(f['labels'])
self.dt = dt
self.max_time = max_time
self.n_bins = int(max_time / dt)
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
dense = spikes_to_dense(
self.times[idx], self.units[idx],
dt=self.dt, max_time=self.max_time,
)
return torch.from_numpy(dense), int(self.labels[idx])
def collate_fn(batch):
"""Collate with uniform time length (all samples use same max_time)."""
inputs, labels = zip(*batch)
return torch.stack(inputs), torch.tensor(labels, dtype=torch.long)