|
|
from math import prod |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
|
|
|
from einops.layers.torch import Rearrange |
|
|
from einops import rearrange |
|
|
|
|
|
from typing import List, Optional |
|
|
|
|
|
from abc import ABC, abstractmethod |
|
|
|
|
|
class Extractor(ABC): |
|
|
"""Abstract base class for encoders.""" |
|
|
|
|
|
|
|
|
embedding_dim: int |
|
|
|
|
|
@abstractmethod |
|
|
def forward(self, x : torch.Tensor) -> torch.Tensor: |
|
|
"""Forward pass through the encoder.""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
def total_patches(self, time: int) -> int: |
|
|
"""Returns the total patches given the time dimension of the input.""" |
|
|
pass |
|
|
|
|
|
|
|
|
class ConvFeatureExtractor(Extractor, nn.Module): |
|
|
""" |
|
|
Convolutional feature encoder for EEG data. |
|
|
|
|
|
Computes successive 1D convolutions (with activations) over the time |
|
|
dimension of the audio signal. This encoder also uses different kernels for each time signal. |
|
|
Therefore, in_channels argument is necessary! |
|
|
|
|
|
Inspiration from https://github.com/facebookresearch/fairseq/blob/main/fairseq/models/wav2vec/wav2vec2.py |
|
|
and https://github.com/SPOClab-ca/BENDR/blob/main/dn3_ext.py |
|
|
|
|
|
Args: |
|
|
conv_layers_spec: list of tuples (dim, k, stride) where: |
|
|
* dim: number of output channels of the layer (unrelated to EEG channels); |
|
|
* k: temporal length of the layer's kernel; |
|
|
* stride: temporal stride of the layer's kernel. |
|
|
|
|
|
in_channels: int |
|
|
Number of audio channels. |
|
|
dropout: float |
|
|
mode: str |
|
|
Normalisation mode. Either``default`` or ``layer_norm``. |
|
|
conv_bias: bool |
|
|
depthwise: bool |
|
|
Perform depthwise convolutions rather than the full convolution. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
*args, |
|
|
conv_layers_spec: list[tuple[int, int, int]], |
|
|
in_channels : int = 2, |
|
|
dropout: float = 0.0, |
|
|
mode: str = "default", |
|
|
conv_bias: bool = False, |
|
|
depthwise : bool = False, |
|
|
**kwargs, |
|
|
): |
|
|
assert mode in {"default", "layer_norm"} |
|
|
super().__init__() |
|
|
|
|
|
def block( |
|
|
n_in : int, |
|
|
n_out : int, |
|
|
k : int, |
|
|
stride : int, |
|
|
is_layer_norm : bool =False, |
|
|
is_group_norm : bool =False, |
|
|
conv_bias : bool =False, |
|
|
depthwise : bool = True, |
|
|
): |
|
|
|
|
|
def make_conv(): |
|
|
if depthwise: |
|
|
assert n_out % n_in == 0, f"For depthwise signals we can not have non-multipler of {n_out} and {n_in}" |
|
|
conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias, groups = n_in) |
|
|
else: |
|
|
conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) |
|
|
|
|
|
nn.init.kaiming_normal_(conv.weight) |
|
|
return conv |
|
|
|
|
|
assert not ( |
|
|
is_layer_norm and is_group_norm |
|
|
), "layer norm and group norm are exclusive" |
|
|
|
|
|
if is_layer_norm: |
|
|
return nn.Sequential( |
|
|
make_conv(), |
|
|
nn.Dropout(p=dropout), |
|
|
nn.Sequential( |
|
|
Rearrange("... channels time -> ... time channels"), |
|
|
nn.LayerNorm(n_out, elementwise_affine=True), |
|
|
Rearrange("... time channels -> ... channels time"), |
|
|
), |
|
|
nn.GELU(), |
|
|
) |
|
|
elif is_group_norm: |
|
|
return nn.Sequential( |
|
|
make_conv(), |
|
|
nn.Dropout(p=dropout), |
|
|
nn.GroupNorm(n_out, n_out, affine=True), |
|
|
nn.GELU(), |
|
|
) |
|
|
else: |
|
|
return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU()) |
|
|
|
|
|
self.in_channels = in_channels |
|
|
self.depthwise = depthwise |
|
|
in_d = in_channels |
|
|
conv_layers = [] |
|
|
for i, cl in enumerate(conv_layers_spec): |
|
|
assert len(cl) == 3, "invalid conv definition: " + str(cl) |
|
|
(dim, k, stride) = cl |
|
|
conv_layers.append( |
|
|
block( |
|
|
in_d, |
|
|
dim, |
|
|
k, |
|
|
stride, |
|
|
is_layer_norm=mode == "layer_norm", |
|
|
is_group_norm=mode == "default" and i == 0, |
|
|
conv_bias=conv_bias, |
|
|
depthwise=self.depthwise |
|
|
) |
|
|
) |
|
|
in_d = dim |
|
|
self.conv_layers_spec = conv_layers_spec |
|
|
self.cnn : nn.Module = nn.Sequential(*conv_layers) |
|
|
self.embedding_dim = conv_layers_spec[-1][0] |
|
|
|
|
|
def forward(self, x : torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
x: (batch_size, n_chans, n_times) |
|
|
Batched EEG signal. |
|
|
|
|
|
Returns: |
|
|
local_features: (batch_size, emb_dim, n_times_out) |
|
|
Local features extracted from the audio signal. |
|
|
``emb_dim`` corresponds to the ``dim`` of the last element of |
|
|
``conv_layers_spec``. |
|
|
""" |
|
|
x = self.cnn(x) |
|
|
x = rearrange(x, "batch_size n_channels n_time -> batch_size n_time n_channels") |
|
|
return x |
|
|
|
|
|
def total_patches(self, time: int, device : str = "cuda") -> int: |
|
|
"""Calculate the number of output time steps for a given input length.""" |
|
|
x = torch.zeros((1, self.in_channels, time), device = next(self.cnn[0].parameters()).device) |
|
|
x = self.cnn(x) |
|
|
x : torch.Tensor = rearrange(x, "batch_size n_channels n_time -> batch_size n_time n_channels") |
|
|
return x.shape[1] |
|
|
|
|
|
@property |
|
|
def receptive_fields(self) -> List[int]: |
|
|
rf = 1 |
|
|
receptive_fields = [rf] |
|
|
for _, width, stride in reversed(self.conv_layers_spec): |
|
|
rf = (rf - 1) * stride + width |
|
|
receptive_fields.append(rf) |
|
|
return list(reversed(receptive_fields)) |
|
|
|
|
|
def description(self, sfreq : Optional[int] = None, dummy_time : Optional[int] = None) -> str: |
|
|
dims, _, strides = zip(*self.conv_layers_spec) |
|
|
receptive_fields = self.receptive_fields |
|
|
rf = receptive_fields[0] |
|
|
desc = f"Receptive field: {rf} samples" |
|
|
if sfreq is not None: |
|
|
desc += f", {rf / sfreq:.2f} seconds" |
|
|
|
|
|
ds_factor = prod(strides) |
|
|
desc += f" | Downsampled by {ds_factor}" |
|
|
if sfreq is not None: |
|
|
desc += f", new sfreq: {sfreq / ds_factor:.2f} Hz" |
|
|
desc += f" | Overlap of {rf - ds_factor} samples" |
|
|
if dummy_time is not None: |
|
|
n_times_out = self.total_patches(dummy_time) |
|
|
desc += f" | {n_times_out} encoded samples/trial" |
|
|
|
|
|
n_features = [ |
|
|
f"{dim}*{rf}" for dim, rf in zip([self.in_channels] + list(dims), receptive_fields) |
|
|
] |
|
|
desc += f" | #features/sample at each layer (n_channels*n_times): [{', '.join(n_features)}] = {[eval(x) for x in n_features]}" |
|
|
return desc |
|
|
|
|
|
|
|
|
|
|
|
|