File size: 7,212 Bytes
fefd7ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
from math import prod
import torch
from torch import nn
from einops.layers.torch import Rearrange
from einops import rearrange
from typing import List, Optional
from abc import ABC, abstractmethod
class Extractor(ABC):
"""Abstract base class for encoders."""
# Just declare that implementers should have this attribute
embedding_dim: int
@abstractmethod
def forward(self, x : torch.Tensor) -> torch.Tensor:
"""Forward pass through the encoder."""
pass
@abstractmethod
def total_patches(self, time: int) -> int:
"""Returns the total patches given the time dimension of the input."""
pass
class ConvFeatureExtractor(Extractor, nn.Module):
"""
Convolutional feature encoder for EEG data.
Computes successive 1D convolutions (with activations) over the time
dimension of the audio signal. This encoder also uses different kernels for each time signal.
Therefore, in_channels argument is necessary!
Inspiration from https://github.com/facebookresearch/fairseq/blob/main/fairseq/models/wav2vec/wav2vec2.py
and https://github.com/SPOClab-ca/BENDR/blob/main/dn3_ext.py
Args:
conv_layers_spec: list of tuples (dim, k, stride) where:
* dim: number of output channels of the layer (unrelated to EEG channels);
* k: temporal length of the layer's kernel;
* stride: temporal stride of the layer's kernel.
in_channels: int
Number of audio channels.
dropout: float
mode: str
Normalisation mode. Either``default`` or ``layer_norm``.
conv_bias: bool
depthwise: bool
Perform depthwise convolutions rather than the full convolution.
"""
def __init__(
self,
*args,
conv_layers_spec: list[tuple[int, int, int]],
in_channels : int = 2,
dropout: float = 0.0,
mode: str = "default",
conv_bias: bool = False,
depthwise : bool = False,
**kwargs,
):
assert mode in {"default", "layer_norm"}
super().__init__() # type: ignore
def block(
n_in : int,
n_out : int,
k : int,
stride : int,
is_layer_norm : bool =False,
is_group_norm : bool =False,
conv_bias : bool =False,
depthwise : bool = True,
):
def make_conv():
if depthwise:
assert n_out % n_in == 0, f"For depthwise signals we can not have non-multipler of {n_out} and {n_in}"
conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias, groups = n_in)
else:
conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
nn.init.kaiming_normal_(conv.weight)
return conv
assert not (
is_layer_norm and is_group_norm
), "layer norm and group norm are exclusive"
if is_layer_norm:
return nn.Sequential(
make_conv(),
nn.Dropout(p=dropout),
nn.Sequential(
Rearrange("... channels time -> ... time channels"),
nn.LayerNorm(n_out, elementwise_affine=True), # Fixed: use n_out instead of dim
Rearrange("... time channels -> ... channels time"),
),
nn.GELU(),
)
elif is_group_norm:
return nn.Sequential(
make_conv(),
nn.Dropout(p=dropout),
nn.GroupNorm(n_out, n_out, affine=True), # Fixed: use n_out instead of dim
nn.GELU(),
)
else:
return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
self.in_channels = in_channels
self.depthwise = depthwise
in_d = in_channels
conv_layers = []
for i, cl in enumerate(conv_layers_spec):
assert len(cl) == 3, "invalid conv definition: " + str(cl)
(dim, k, stride) = cl
conv_layers.append( # type: ignore
block(
in_d,
dim,
k,
stride,
is_layer_norm=mode == "layer_norm",
is_group_norm=mode == "default" and i == 0,
conv_bias=conv_bias,
depthwise=self.depthwise
)
)
in_d = dim
self.conv_layers_spec = conv_layers_spec
self.cnn : nn.Module = nn.Sequential(*conv_layers) # type: ignore
self.embedding_dim = conv_layers_spec[-1][0]
def forward(self, x : torch.Tensor) -> torch.Tensor:
"""
Args:
x: (batch_size, n_chans, n_times)
Batched EEG signal.
Returns:
local_features: (batch_size, emb_dim, n_times_out)
Local features extracted from the audio signal.
``emb_dim`` corresponds to the ``dim`` of the last element of
``conv_layers_spec``.
"""
x = self.cnn(x)
x = rearrange(x, "batch_size n_channels n_time -> batch_size n_time n_channels")
return x
def total_patches(self, time: int, device : str = "cuda") -> int:
"""Calculate the number of output time steps for a given input length."""
x = torch.zeros((1, self.in_channels, time), device = next(self.cnn[0].parameters()).device)
x = self.cnn(x)
x : torch.Tensor = rearrange(x, "batch_size n_channels n_time -> batch_size n_time n_channels")
return x.shape[1] # Return time dimension size
@property
def receptive_fields(self) -> List[int]:
rf = 1
receptive_fields = [rf]
for _, width, stride in reversed(self.conv_layers_spec):
rf = (rf - 1) * stride + width # assumes no padding and no dilation
receptive_fields.append(rf)
return list(reversed(receptive_fields))
def description(self, sfreq : Optional[int] = None, dummy_time : Optional[int] = None) -> str:
dims, _, strides = zip(*self.conv_layers_spec)
receptive_fields = self.receptive_fields
rf = receptive_fields[0]
desc = f"Receptive field: {rf} samples"
if sfreq is not None:
desc += f", {rf / sfreq:.2f} seconds"
ds_factor = prod(strides)
desc += f" | Downsampled by {ds_factor}"
if sfreq is not None:
desc += f", new sfreq: {sfreq / ds_factor:.2f} Hz"
desc += f" | Overlap of {rf - ds_factor} samples"
if dummy_time is not None:
n_times_out = self.total_patches(dummy_time)
desc += f" | {n_times_out} encoded samples/trial"
n_features = [
f"{dim}*{rf}" for dim, rf in zip([self.in_channels] + list(dims), receptive_fields)
]
desc += f" | #features/sample at each layer (n_channels*n_times): [{', '.join(n_features)}] = {[eval(x) for x in n_features]}"
return desc
|