File size: 7,212 Bytes
fefd7ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
from math import prod

import torch
from torch import nn

from einops.layers.torch import Rearrange
from einops import rearrange

from typing import List, Optional

from abc import ABC, abstractmethod

class Extractor(ABC):
    """Abstract base class for encoders."""
    
    # Just declare that implementers should have this attribute
    embedding_dim: int
    
    @abstractmethod
    def forward(self, x : torch.Tensor) -> torch.Tensor:
        """Forward pass through the encoder."""
        pass
    
    @abstractmethod
    def total_patches(self, time: int) -> int:
        """Returns the total patches given the time dimension of the input."""
        pass


class ConvFeatureExtractor(Extractor, nn.Module):
    """
    Convolutional feature encoder for EEG data.

    Computes successive 1D convolutions (with activations) over the time
    dimension of the audio signal. This encoder also uses different kernels for each time signal. 
    Therefore, in_channels argument is necessary!

    Inspiration from https://github.com/facebookresearch/fairseq/blob/main/fairseq/models/wav2vec/wav2vec2.py
    and https://github.com/SPOClab-ca/BENDR/blob/main/dn3_ext.py

    Args:
        conv_layers_spec: list of tuples (dim, k, stride) where:
            * dim: number of output channels of the layer (unrelated to EEG channels);
            * k: temporal length of the layer's kernel;
            * stride: temporal stride of the layer's kernel.

        in_channels: int
            Number of audio channels.
        dropout: float
        mode: str
            Normalisation mode. Either``default`` or ``layer_norm``.
        conv_bias: bool
        depthwise: bool 
            Perform depthwise convolutions rather than the full convolution.
    """

    def __init__(
        self,
        *args,
        conv_layers_spec: list[tuple[int, int, int]],
        in_channels : int = 2,
        dropout: float = 0.0,
        mode: str = "default",
        conv_bias: bool = False,
        depthwise : bool = False,
        **kwargs,
    ):
        assert mode in {"default", "layer_norm"}
        super().__init__() # type: ignore

        def block(
            n_in : int,
            n_out : int,
            k : int,
            stride : int,
            is_layer_norm : bool =False,
            is_group_norm : bool =False,
            conv_bias : bool =False,
            depthwise : bool = True,
        ):

            def make_conv():
                if depthwise:
                    assert n_out % n_in == 0, f"For depthwise signals we can not have non-multipler of {n_out} and {n_in}"
                    conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias, groups = n_in)
                else:
                    conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)

                nn.init.kaiming_normal_(conv.weight)
                return conv

            assert not (
                is_layer_norm and is_group_norm
            ), "layer norm and group norm are exclusive"

            if is_layer_norm:
                return nn.Sequential(
                    make_conv(),
                    nn.Dropout(p=dropout),
                    nn.Sequential(
                        Rearrange("... channels time -> ... time channels"),
                        nn.LayerNorm(n_out, elementwise_affine=True),  # Fixed: use n_out instead of dim
                        Rearrange("... time channels -> ... channels time"),
                    ),
                    nn.GELU(),
                )
            elif is_group_norm:
                return nn.Sequential(
                    make_conv(),
                    nn.Dropout(p=dropout),
                    nn.GroupNorm(n_out, n_out, affine=True),  # Fixed: use n_out instead of dim
                    nn.GELU(),
                )
            else:
                return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())

        self.in_channels = in_channels
        self.depthwise = depthwise
        in_d = in_channels
        conv_layers = []
        for i, cl in enumerate(conv_layers_spec):
            assert len(cl) == 3, "invalid conv definition: " + str(cl)
            (dim, k, stride) = cl
            conv_layers.append( # type: ignore
                block(
                    in_d,
                    dim,
                    k,
                    stride,
                    is_layer_norm=mode == "layer_norm",
                    is_group_norm=mode == "default" and i == 0,
                    conv_bias=conv_bias,
                    depthwise=self.depthwise
                )
            )
            in_d = dim
        self.conv_layers_spec = conv_layers_spec
        self.cnn : nn.Module = nn.Sequential(*conv_layers) # type: ignore
        self.embedding_dim = conv_layers_spec[-1][0]

    def forward(self, x : torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (batch_size, n_chans, n_times)
                    Batched EEG signal.

        Returns:
            local_features: (batch_size, emb_dim, n_times_out)
                Local features extracted from the audio signal.
                ``emb_dim`` corresponds to the ``dim`` of the last element of
                ``conv_layers_spec``.
        """
        x = self.cnn(x)
        x = rearrange(x, "batch_size n_channels n_time -> batch_size n_time n_channels")
        return x

    def total_patches(self, time: int, device : str = "cuda") -> int:
        """Calculate the number of output time steps for a given input length."""
        x = torch.zeros((1, self.in_channels, time), device = next(self.cnn[0].parameters()).device)
        x = self.cnn(x)
        x : torch.Tensor = rearrange(x, "batch_size n_channels n_time -> batch_size n_time n_channels")
        return x.shape[1]  # Return time dimension size

    @property
    def receptive_fields(self) -> List[int]:
        rf = 1
        receptive_fields = [rf]
        for _, width, stride in reversed(self.conv_layers_spec):
            rf = (rf - 1) * stride + width  # assumes no padding and no dilation
            receptive_fields.append(rf)
        return list(reversed(receptive_fields))

    def description(self, sfreq : Optional[int] = None, dummy_time : Optional[int] = None) -> str:
        dims, _, strides = zip(*self.conv_layers_spec)
        receptive_fields = self.receptive_fields
        rf = receptive_fields[0]
        desc = f"Receptive field: {rf} samples"
        if sfreq is not None:
            desc += f", {rf / sfreq:.2f} seconds"

        ds_factor = prod(strides)
        desc += f" | Downsampled by {ds_factor}"
        if sfreq is not None:
            desc += f", new sfreq: {sfreq / ds_factor:.2f} Hz"
        desc += f" | Overlap of {rf - ds_factor} samples"
        if dummy_time is not None:
            n_times_out = self.total_patches(dummy_time)
            desc += f" | {n_times_out} encoded samples/trial"

        n_features = [
            f"{dim}*{rf}" for dim, rf in zip([self.in_channels] + list(dims), receptive_fields)
        ]
        desc += f" | #features/sample at each layer (n_channels*n_times): [{', '.join(n_features)}] = {[eval(x) for x in n_features]}"
        return desc