File size: 5,707 Bytes
297d76f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
"""
spike_decoder.py β€” Decode temporal spike trains into feature vectors.

Takes the rich temporal spike patterns from the encoder and extracts
meaningful features that a downstream classifier can use.

Three decoding strategies, each capturing different temporal information:

1. RateDecoder     β€” baseline, ignores timing (same as Part 1)
2. TemporalDecoder β€” uses first spike times (when neurons fire)
3. SyncDecoder     β€” uses synchrony patterns (which neurons fire together)

TBC's insight: timing carries MORE information than rate alone.
We prove this by comparing classifier accuracy across all three decoders.
"""

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class RateDecoder:
    """
    Baseline decoder β€” ignores spike timing, just counts spikes.
    This is equivalent to what Part 1 did with the reservoir.
    Used as the control condition.
    """

    def decode(self, spikes: np.ndarray) -> np.ndarray:
        """
        Args:
            spikes: (n_timesteps, n_neurons) binary array
        Returns:
            features: (n_neurons,) mean firing rate per neuron
        """
        return spikes.mean(axis=0)  # (n_neurons,)

    def decode_batch(self, spike_batch: np.ndarray) -> np.ndarray:
        """Args: (N, T, n_neurons) β†’ Returns: (N, n_neurons)"""
        return spike_batch.mean(axis=1)


class TemporalDecoder:
    """
    Temporal decoder β€” uses WHEN neurons fire, not just how often.

    Features extracted:
    1. First spike time β€” encodes intensity (bright β†’ early)
    2. Mean inter-spike interval β€” encodes regularity
    3. Spike count in early window β€” captures onset response
    4. Spike count in late window β€” captures sustained response
    """

    def __init__(self, n_timesteps: int = 100, early_window: float = 0.3):
        self.n_timesteps = n_timesteps
        self.early_cutoff = int(n_timesteps * early_window)

    def decode(self, spikes: np.ndarray) -> np.ndarray:
        """
        Args:
            spikes: (n_timesteps, n_neurons)
        Returns:
            features: (4 * n_neurons,) temporal feature vector
        """
        n_timesteps, n_neurons = spikes.shape

        # Feature 1: First spike time (normalized)
        first_spike = np.full(n_neurons, 1.0)
        for t in range(n_timesteps):
            fired = spikes[t] > 0
            never_fired = first_spike == 1.0
            first_spike[fired & never_fired] = t / n_timesteps

        # Feature 2: Total spike count (normalized)
        spike_count = spikes.sum(axis=0) / n_timesteps

        # Feature 3: Early window spike density
        early_spikes = spikes[:self.early_cutoff].mean(axis=0)

        # Feature 4: Late window spike density
        late_spikes = spikes[self.early_cutoff:].mean(axis=0)

        # Concatenate all temporal features
        features = np.concatenate([
            first_spike,
            spike_count,
            early_spikes,
            late_spikes,
        ])

        return features.astype(np.float32)

    def decode_batch(self, spike_batch: np.ndarray) -> np.ndarray:
        """Args: (N, T, n_neurons) β†’ Returns: (N, 4*n_neurons)"""
        return np.array([self.decode(s) for s in spike_batch])


class SyncDecoder:
    """
    Synchrony decoder β€” captures which neurons fire TOGETHER.

    Synchronous firing between neurons is a key biological signal.
    Two neurons firing at the same timestep = correlated activity = shared input.
    This captures spatial relationships that rate coding misses.

    Implementation: pairwise synchrony for neuron groups (not all pairs β€” too slow)
    We group neurons into bins and measure within-bin synchrony.
    """

    def __init__(self, n_timesteps: int = 100, n_groups: int = 16):
        self.n_timesteps = n_timesteps
        self.n_groups = n_groups

    def decode(self, spikes: np.ndarray) -> np.ndarray:
        """
        Args:
            spikes: (n_timesteps, n_neurons)
        Returns:
            features: (n_neurons + n_groups,)
                      rate features + synchrony features per group
        """
        n_timesteps, n_neurons = spikes.shape
        group_size = n_neurons // self.n_groups

        # Base rate features
        rate_features = spikes.mean(axis=0)

        # Synchrony features per group
        sync_features = np.zeros(self.n_groups)
        for g in range(self.n_groups):
            start = g * group_size
            end = start + group_size
            group_spikes = spikes[:, start:end]  # (T, group_size)

            # Synchrony = mean pairwise coincidence within group
            # Simplified: variance of group activity over time
            group_activity = group_spikes.sum(axis=1)  # (T,) total spikes per timestep
            sync_features[g] = group_activity.std() / (group_activity.mean() + 1e-8)

        features = np.concatenate([rate_features, sync_features])
        return features.astype(np.float32)

    def decode_batch(self, spike_batch: np.ndarray) -> np.ndarray:
        return np.array([self.decode(s) for s in spike_batch])


class SpikeClassifier(nn.Module):
    """
    Simple MLP classifier that works with any decoder's output.
    Kept identical across all decoders to isolate the effect of decoding strategy.
    """

    def __init__(self, input_dim: int, n_classes: int = 10):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, 256)
        self.fc2 = nn.Linear(256, 128)
        self.classifier = nn.Linear(128, n_classes)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=0.3, training=self.training)
        x = F.relu(self.fc2(x))
        return self.classifier(x)