| | """SHD (Spiking Heidelberg Digits) dataset loader for neuromorphic benchmarks.
|
| |
|
| | Downloads HDF5 files from Zenodo, converts variable-length spike events
|
| | to fixed-size dense binary tensors suitable for PyTorch training.
|
| |
|
| | 700 input channels (cochlea model), 20 classes (digits 0-9 in German+English).
|
| | """
|
| |
|
| | import os
|
| | import urllib.request
|
| | import gzip
|
| | import shutil
|
| | import numpy as np
|
| |
|
| | try:
|
| | import h5py
|
| | except ImportError:
|
| | raise ImportError("h5py required: pip install h5py")
|
| |
|
| | try:
|
| | import torch
|
| | from torch.utils.data import Dataset
|
| | except ImportError:
|
| | raise ImportError("PyTorch required: pip install torch")
|
| |
|
| |
|
| | SHD_URLS = {
|
| | "train": "https://compneuro.net/datasets/shd_train.h5.gz",
|
| | "test": "https://compneuro.net/datasets/shd_test.h5.gz",
|
| | }
|
| |
|
| | N_CHANNELS = 700
|
| | N_CLASSES = 20
|
| |
|
| |
|
| | def download_shd(data_dir="data/shd"):
|
| | """Download SHD train/test HDF5 files from Zenodo if not present."""
|
| | os.makedirs(data_dir, exist_ok=True)
|
| |
|
| | for split, url in SHD_URLS.items():
|
| | h5_path = os.path.join(data_dir, f"shd_{split}.h5")
|
| | gz_path = h5_path + ".gz"
|
| |
|
| | if os.path.exists(h5_path):
|
| | continue
|
| |
|
| | print(f"Downloading SHD {split} set from {url} ...")
|
| | try:
|
| | urllib.request.urlretrieve(url, gz_path)
|
| | except Exception as e:
|
| | raise RuntimeError(
|
| | f"Failed to download {url}: {e}\n"
|
| | f"Download manually from https://zenodo.org/records/4319560 "
|
| | f"and place shd_train.h5 / shd_test.h5 in {data_dir}/")
|
| |
|
| | print(f"Extracting {gz_path} ...")
|
| | with gzip.open(gz_path, 'rb') as f_in:
|
| | with open(h5_path, 'wb') as f_out:
|
| | shutil.copyfileobj(f_in, f_out)
|
| | os.remove(gz_path)
|
| | print(f" Saved to {h5_path}")
|
| |
|
| | return data_dir
|
| |
|
| |
|
| | def spikes_to_dense(times, units, n_channels=N_CHANNELS, dt=4e-3, max_time=1.0):
|
| | """Convert spike event lists to a dense binary tensor.
|
| |
|
| | Args:
|
| | times: array of spike times in seconds
|
| | units: array of channel indices (0 to n_channels-1)
|
| | n_channels: number of input channels (700 for SHD)
|
| | dt: time bin width in seconds (4ms -> 250 bins)
|
| | max_time: maximum time window in seconds
|
| |
|
| | Returns:
|
| | dense: (T, n_channels) float32 array with 1.0 at spike locations
|
| | """
|
| | n_bins = int(max_time / dt)
|
| | dense = np.zeros((n_bins, n_channels), dtype=np.float32)
|
| |
|
| | if not times:
|
| | return dense
|
| |
|
| | bin_indices = np.clip((times / dt).astype(int), 0, n_bins - 1)
|
| | unit_indices = np.clip(units.astype(int), 0, n_channels - 1)
|
| | dense[bin_indices, unit_indices] = 1.0
|
| | return dense
|
| |
|
| |
|
| | class SHDDataset(Dataset):
|
| | """PyTorch Dataset for Spiking Heidelberg Digits.
|
| |
|
| | Each sample is converted to a dense binary tensor (T, 700) on first access.
|
| | """
|
| |
|
| | def __init__(self, data_dir="data/shd", split="train", dt=4e-3, max_time=1.0):
|
| | h5_path = os.path.join(data_dir, f"shd_{split}.h5")
|
| | if not os.path.exists(h5_path):
|
| | download_shd(data_dir)
|
| |
|
| | with h5py.File(h5_path, 'r') as f:
|
| | self.times = [np.array(t) for t in f['spikes']['times']]
|
| | self.units = [np.array(u) for u in f['spikes']['units']]
|
| | self.labels = np.array(f['labels'])
|
| |
|
| | self.dt = dt
|
| | self.max_time = max_time
|
| | self.n_bins = int(max_time / dt)
|
| |
|
| | def __len__(self):
|
| | return len(self.labels)
|
| |
|
| | def __getitem__(self, idx):
|
| | dense = spikes_to_dense(
|
| | self.times[idx], self.units[idx],
|
| | dt=self.dt, max_time=self.max_time,
|
| | )
|
| | return torch.from_numpy(dense), int(self.labels[idx])
|
| |
|
| |
|
| | def collate_fn(batch):
|
| | """Collate with uniform time length (all samples use same max_time)."""
|
| | inputs, labels = zip(*batch)
|
| | return torch.stack(inputs), torch.tensor(labels, dtype=torch.long)
|
| |
|