File size: 2,963 Bytes
e4cdd5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
"""DVS128 Gesture dataset loader for neuromorphic benchmarks.



Uses the `tonic` library for event camera data loading and transforms.

128x128 pixels x 2 polarities -> downsampled to 32x32 = 2048 input channels.

11 gesture classes.



Requires: pip install tonic

"""

import os
import numpy as np

try:
    import torch
    from torch.utils.data import Dataset
except ImportError:
    raise ImportError("PyTorch required: pip install torch")

try:
    import tonic
    import tonic.transforms as transforms
except ImportError:
    raise ImportError("tonic required: pip install tonic")


N_CHANNELS = 2048  # 32x32x2 (downsampled from 128x128x2)
N_CLASSES = 11      # gesture classes
SENSOR_SIZE = (128, 128, 2)
DS_FACTOR = 4       # downsample 128->32
DS_SIZE = (32, 32, 2)


def get_dvs_transform(dt=10e-3, duration=1.5):
    """Build tonic transform pipeline: downsample -> bin to frames."""
    n_bins = int(duration / dt)
    return transforms.Compose([
        transforms.Downsample(spatial_factor=1.0 / DS_FACTOR),
        transforms.ToFrame(
            sensor_size=DS_SIZE,
            n_time_bins=n_bins,
        ),
    ])


class DVSGestureDataset(Dataset):
    """PyTorch Dataset wrapper for DVS128 Gesture.



    Each sample is converted to a dense frame tensor (T, 2048) via tonic transforms.

    """

    def __init__(self, data_dir="data/dvs_gesture", train=True, dt=10e-3, duration=1.5):
        transform = get_dvs_transform(dt=dt, duration=duration)

        self._tonic_ds = tonic.datasets.DVSGesture(
            save_to=data_dir,
            train=train,
            transform=transform,
        )

        self.n_bins = int(duration / dt)
        self.dt = dt
        self.duration = duration

    def __len__(self):
        return len(self._tonic_ds)

    def __getitem__(self, idx):
        frames, label = self._tonic_ds[idx]
        # frames shape from tonic: (T, 2, 32, 32) or (T, C, H, W)
        # Flatten spatial dims: (T, 2*32*32) = (T, 2048)
        frames = np.array(frames, dtype=np.float32)

        if frames.ndim == 4:
            T = frames.shape[0]
            frames = frames.reshape(T, -1)
        elif frames.ndim == 3:
            T = frames.shape[0]
            frames = frames.reshape(T, -1)

        # Clip to n_bins
        if frames.shape[0] > self.n_bins:
            frames = frames[:self.n_bins]
        elif frames.shape[0] < self.n_bins:
            pad = np.zeros((self.n_bins - frames.shape[0], frames.shape[1]), dtype=np.float32)
            frames = np.concatenate([frames, pad], axis=0)

        # Binarize (any event count > 0 = spike)
        frames = (frames > 0).astype(np.float32)

        return torch.from_numpy(frames), int(label)


def collate_fn(batch):
    """Collate with uniform time length."""
    inputs, labels = zip(*batch)
    return torch.stack(inputs), torch.tensor(labels, dtype=torch.long)