GokseninYuksel commited on
Commit
4642fbb
·
verified ·
1 Parent(s): 493cb1f

Upload model

Browse files
Files changed (9) hide show
  1. audio_extractor.py +235 -0
  2. config.json +53 -0
  3. configuration_wavjepa_nat.py +83 -0
  4. model.py +181 -0
  5. model.safetensors +3 -0
  6. modeling_wavjepa_nat.py +34 -0
  7. pos_embed.py +267 -0
  8. types.py +51 -0
  9. utils.py +31 -0
audio_extractor.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import prod
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from einops.layers.torch import Rearrange
7
+ from einops import rearrange
8
+
9
+ from typing import List, Optional
10
+
11
+ from abc import ABC, abstractmethod
12
+
13
+ class Extractor(ABC):
14
+ """Abstract base class for encoders."""
15
+
16
+ # Just declare that implementers should have this attribute
17
+ embedding_dim: int
18
+
19
+ @abstractmethod
20
+ def forward(self, x : torch.Tensor) -> torch.Tensor:
21
+ """Forward pass through the encoder."""
22
+ pass
23
+
24
+ @abstractmethod
25
+ def total_patches(self, time: int) -> int:
26
+ """Returns the total patches given the time dimension of the input."""
27
+ pass
28
+
29
+
30
+
31
+ class ConvChannelFeatureExtractor(Extractor, nn.Module):
32
+ """
33
+ Convolutional feature encoder for the audio data.
34
+
35
+ Computes successive 1D convolutions (with activations) over the time
36
+ dimension of the audio signal. This encoder also uses different kernels for each time signal.
37
+ Therefore, in_channels argument is necessary!
38
+
39
+ Inspiration from https://github.com/facebookresearch/fairseq/blob/main/fairseq/models/wav2vec/wav2vec2.py
40
+ and https://github.com/SPOClab-ca/BENDR/blob/main/dn3_ext.py
41
+
42
+ Args:
43
+ conv_layers_spec: list of tuples (dim, k, stride) where:
44
+ * dim: number of output channels of the layer (unrelated to EEG channels);
45
+ * k: temporal length of the layer's kernel;
46
+ * stride: temporal stride of the layer's kernel.
47
+
48
+ in_channels: int
49
+ Number of audio channels.
50
+ dropout: float
51
+ mode: str
52
+ Normalisation mode. Either``default`` or ``layer_norm``.
53
+ conv_bias: bool
54
+ depthwise: bool
55
+ Perform depthwise convolutions rather than the full convolution.
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ *args,
61
+ conv_layers_spec: list[tuple[int, int, int]],
62
+ in_channels : int = 2,
63
+ dropout: float = 0.0,
64
+ mode: str = "default",
65
+ conv_bias: bool = False,
66
+ depthwise : bool = False,
67
+ share_weights_over_channels : bool = False,
68
+ **kwargs,
69
+ ):
70
+ assert mode in {"default", "layer_norm"}
71
+ super().__init__() # type: ignore
72
+
73
+ def block(
74
+ n_in : int,
75
+ n_out : int,
76
+ k : int,
77
+ stride : int,
78
+ is_layer_norm : bool =False,
79
+ is_group_norm : bool =False,
80
+ conv_bias : bool =False,
81
+ depthwise : bool = True,
82
+ ):
83
+
84
+ def make_conv():
85
+ if depthwise:
86
+ assert n_out % n_in == 0, f"For depthwise signals we can not have non-multipler of {n_out} and {n_in}"
87
+ conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias, groups = n_in)
88
+ else:
89
+ conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
90
+
91
+ nn.init.kaiming_normal_(conv.weight)
92
+ return conv
93
+
94
+ assert not (
95
+ is_layer_norm and is_group_norm
96
+ ), "layer norm and group norm are exclusive"
97
+
98
+ if is_layer_norm:
99
+ return nn.Sequential(
100
+ make_conv(),
101
+ nn.Dropout(p=dropout),
102
+ nn.Sequential(
103
+ Rearrange("... channels time -> ... time channels"),
104
+ nn.LayerNorm(n_out, elementwise_affine=True), # Fixed: use n_out instead of dim
105
+ Rearrange("... time channels -> ... channels time"),
106
+ ),
107
+ nn.GELU(),
108
+ )
109
+ elif is_group_norm:
110
+ return nn.Sequential(
111
+ make_conv(),
112
+ nn.Dropout(p=dropout),
113
+ nn.GroupNorm(n_out, n_out, affine=True), # Fixed: use n_out instead of dim
114
+ nn.GELU(),
115
+ )
116
+ else:
117
+ return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
118
+
119
+ self.in_channels = in_channels
120
+ self.depthwise = depthwise
121
+ self.conv_layers_spec = conv_layers_spec
122
+ self.cnns = nn.ModuleList()
123
+
124
+ if share_weights_over_channels:
125
+ in_d = 1
126
+ conv_layers = []
127
+ for i, cl in enumerate(conv_layers_spec):
128
+ assert len(cl) == 3, "invalid conv definition: " + str(cl)
129
+ (dim, k, stride) = cl
130
+ conv_layers.append( # type: ignore
131
+ block(
132
+ in_d,
133
+ dim,
134
+ k,
135
+ stride,
136
+ is_layer_norm=mode == "layer_norm",
137
+ is_group_norm=mode == "default" and i == 0,
138
+ conv_bias=conv_bias,
139
+ depthwise=self.depthwise
140
+ )
141
+ )
142
+ in_d = dim
143
+ cnn : nn.Module = nn.Sequential(*conv_layers) # type: ignore
144
+ self.embedding_dim = conv_layers_spec[-1][0]
145
+ self.cnns.append(cnn)
146
+ else:
147
+ for channels in range(self.in_channels):
148
+ in_d = 1
149
+ conv_layers = []
150
+ for i, cl in enumerate(conv_layers_spec):
151
+ assert len(cl) == 3, "invalid conv definition: " + str(cl)
152
+ (dim, k, stride) = cl
153
+ conv_layers.append( # type: ignore
154
+ block(
155
+ in_d,
156
+ dim,
157
+ k,
158
+ stride,
159
+ is_layer_norm=mode == "layer_norm",
160
+ is_group_norm=mode == "default" and i == 0,
161
+ conv_bias=conv_bias,
162
+ depthwise=self.depthwise
163
+ )
164
+ )
165
+ in_d = dim
166
+ cnn : nn.Module = nn.Sequential(*conv_layers) # type: ignore
167
+ self.cnns.append(cnn)
168
+
169
+ self.embedding_dim = self.conv_layers_spec[-1][0]
170
+ self.weight_sharing = share_weights_over_channels
171
+
172
+ def forward(self, x : torch.Tensor) -> torch.Tensor:
173
+ """
174
+ Args:
175
+ x: (batch_size, n_chans, n_times)
176
+ Batched EEG signal.
177
+
178
+ Returns:
179
+ local_features: (batch_size, emb_dim, n_times_out)
180
+ Local features extracted from the audio signal.
181
+ ``emb_dim`` corresponds to the ``dim`` of the last element of
182
+ ``conv_layers_spec``.
183
+ """
184
+
185
+ out = []
186
+ for channel_index in range(self.in_channels):
187
+ # If we are sharing weights over the channels, use one CNN for all the channel dimensions
188
+ if self.weight_sharing:
189
+ module = self.cnns[0]
190
+ else:
191
+ module = self.cnns[channel_index]
192
+ processed = module(x[:, [channel_index], ...])
193
+ processed = rearrange(processed, "batch_size n_channels n_time -> batch_size n_time n_channels")
194
+ out.append(processed)
195
+ processed = torch.stack(out, dim = 1)
196
+ processed = torch.flatten(processed, start_dim = 1, end_dim = 2)
197
+ return processed
198
+
199
+ def total_patches(self, time: int) -> int:
200
+ """Calculate the number of output time steps for a given input length."""
201
+ x = torch.zeros((1, self.in_channels, time))
202
+ processed = self.forward(x)
203
+ return processed.shape[1] # Return time dimension size
204
+
205
+ @property
206
+ def receptive_fields(self) -> List[int]:
207
+ rf = 1
208
+ receptive_fields = [rf]
209
+ for _, width, stride in reversed(self.conv_layers_spec):
210
+ rf = (rf - 1) * stride + width # assumes no padding and no dilation
211
+ receptive_fields.append(rf)
212
+ return list(reversed(receptive_fields))
213
+
214
+ def description(self, sfreq : Optional[int] = None, dummy_time : Optional[int] = None) -> str:
215
+ dims, _, strides = zip(*self.conv_layers_spec)
216
+ receptive_fields = self.receptive_fields
217
+ rf = receptive_fields[0]
218
+ desc = f"Receptive field: {rf} samples"
219
+ if sfreq is not None:
220
+ desc += f", {rf / sfreq:.2f} seconds"
221
+
222
+ ds_factor = prod(strides)
223
+ desc += f" | Downsampled by {ds_factor}"
224
+ if sfreq is not None:
225
+ desc += f", new sfreq: {sfreq / ds_factor:.2f} Hz"
226
+ desc += f" | Overlap of {rf - ds_factor} samples"
227
+ if dummy_time is not None:
228
+ n_times_out = self.total_patches(dummy_time)
229
+ desc += f" | {n_times_out} encoded samples/trial"
230
+
231
+ n_features = [
232
+ f"{dim}*{rf}" for dim, rf in zip([self.in_channels] + list(dims), receptive_fields)
233
+ ]
234
+ desc += f" | #features/sample at each layer (n_channels*n_times): [{', '.join(n_features)}] = {[eval(x) for x in n_features]}"
235
+ return desc
config.json ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "WavJEPANatModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_wavjepa_nat.WavJEPANatConfig",
7
+ "AutoModel": "modeling_wavjepa_nat.WavJEPANatModel"
8
+ },
9
+ "decoder_cfg": {
10
+ "enable_nested_tensor": false,
11
+ "mask_check": true,
12
+ "num_layers": 12
13
+ },
14
+ "decoder_layers_cfg": {
15
+ "activation": "gelu",
16
+ "batch_first": true,
17
+ "bias": true,
18
+ "d_model": 384,
19
+ "dim_feedforward": 1536,
20
+ "dropout": 0.0,
21
+ "layer_norm_eps": 1e-06,
22
+ "nhead": 12,
23
+ "norm_first": true
24
+ },
25
+ "dtype": "float32",
26
+ "encoder_cfg": {
27
+ "enable_nested_tensor": false,
28
+ "mask_check": true,
29
+ "num_layers": 12
30
+ },
31
+ "encoder_layers_cfg": {
32
+ "activation": "gelu",
33
+ "batch_first": true,
34
+ "bias": true,
35
+ "d_model": 768,
36
+ "dim_feedforward": 3072,
37
+ "dropout": 0.0,
38
+ "layer_norm_eps": 1e-06,
39
+ "nhead": 12,
40
+ "norm_first": true
41
+ },
42
+ "extractor_config": {
43
+ "conv_bias": false,
44
+ "conv_layers_spec": "[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)]",
45
+ "depthwise": false,
46
+ "dropout": 0.0,
47
+ "in_channels": 2,
48
+ "mode": "default",
49
+ "share_weights_over_channels": false
50
+ },
51
+ "model_type": "wavjepa-nat-base",
52
+ "transformers_version": "4.57.1"
53
+ }
configuration_wavjepa_nat.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from .types import TransformerLayerCFG, TransformerEncoderCFG
3
+
4
+
5
+
6
+ class WavJEPANatConfig(PretrainedConfig):
7
+ model_type = "wavjepa-nat-base"
8
+ model_size = "base"
9
+ in_channels: int = 2
10
+
11
+ def __init__(
12
+ self,
13
+ extractor_layers_spec: str = "[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)]",
14
+ extractor_dropout : float = 0.0,
15
+ extractor_mode : str = "default",
16
+ extractor_conv_bias : bool = False,
17
+ extractor_depthwise: bool = False,
18
+ encoder_d_model: int = 768,
19
+ encoder_nhead : int = 12,
20
+ encoder_batch_first = True,
21
+ encoder_norm_first = True,
22
+ encoder_bias = True,
23
+ encoder_mlp_ratio = 4.0,
24
+ encoder_dropout = 0.0,
25
+ encoder_num_layers: int = 12,
26
+ encoder_enable_nested_tensor = False,
27
+ encoder_mask_check = True,
28
+ encoder_activation = 'gelu',
29
+ decoder_d_model: int = 384,
30
+ decoder_nhead : int = 12,
31
+ decoder_batch_first = True,
32
+ decoder_norm_first = True,
33
+ decoder_bias = True,
34
+ decoder_mlp_ratio = 4.0,
35
+ decoder_dropout = 0.0,
36
+ decoder_num_layers: int = 12,
37
+ decoder_enable_nested_tensor = False,
38
+ decoder_mask_check = True,
39
+ decoder_activation = 'gelu',
40
+ **kwargs
41
+ ):
42
+ self.encoder_cfg = TransformerEncoderCFG.create(
43
+ num_layers = encoder_num_layers,
44
+ enable_nested_tensor = encoder_enable_nested_tensor,
45
+ mask_check = encoder_mask_check,
46
+ )
47
+ self.decoder_cfg = TransformerEncoderCFG.create(
48
+ num_layers = decoder_num_layers,
49
+ enable_nested_tensor = decoder_enable_nested_tensor,
50
+ mask_check = decoder_mask_check,
51
+ )
52
+ self.encoder_layers_cfg = TransformerLayerCFG.create(
53
+ d_model = encoder_d_model,
54
+ nhead = encoder_nhead,
55
+ batch_first = encoder_batch_first,
56
+ norm_first = encoder_norm_first,
57
+ bias = encoder_bias,
58
+ mlp_ratio = encoder_mlp_ratio,
59
+ dropout = encoder_dropout,
60
+ activation = encoder_activation,
61
+ layer_norm_eps = 1e-6
62
+ )
63
+ self.decoder_layers_cfg = TransformerLayerCFG.create(
64
+ d_model = decoder_d_model,
65
+ nhead = decoder_nhead,
66
+ batch_first = decoder_batch_first,
67
+ norm_first = decoder_norm_first,
68
+ bias = decoder_bias,
69
+ mlp_ratio = decoder_mlp_ratio,
70
+ dropout = decoder_dropout,
71
+ activation = decoder_activation,
72
+ layer_norm_eps = 1e-6
73
+ )
74
+ self.extractor_config = dict(
75
+ conv_layers_spec = extractor_layers_spec,
76
+ in_channels = self.in_channels,
77
+ dropout = extractor_dropout,
78
+ mode = extractor_mode,
79
+ conv_bias = extractor_conv_bias,
80
+ depthwise = extractor_depthwise,
81
+ share_weights_over_channels=False)
82
+
83
+ super().__init__(**kwargs)
model.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import numpy as np
3
+
4
+ from typing import Any, Optional
5
+
6
+ import torch
7
+ from torch import nn
8
+
9
+ from einops import repeat, rearrange
10
+
11
+ from .pos_embed import get_1d_sincos_pos_embed_from_grid, get_2d_sincos_pos_embed, get_binaural_pos_embed
12
+ from .audio_extractor import Extractor
13
+ from .types import TransformerLayerCFG, TransformerEncoderCFG
14
+ from .utils import normalize, calculate_padding_mask, get_timestamps
15
+
16
+ class WavJEPANat(nn.Module):
17
+ """
18
+ Joint-Embedding Predictive Architecture (JEPA).
19
+
20
+ This implementation is inspired by:
21
+ * I-JEPA http://arxiv.org/abs/2301.08243
22
+ * Data2vec 2.0 http://arxiv.org/abs/2212.07525
23
+ """
24
+
25
+ teacher_encoder: nn.Module
26
+ sample_rate : int = 16000
27
+ process_audio_seconds : float = 2.01
28
+ in_channels : int = 2
29
+
30
+
31
+ def __init__(
32
+ self,
33
+ feature_extractor: Extractor,
34
+ transformer_encoder_layers_cfg : TransformerLayerCFG,
35
+ transformer_encoder_cfg : TransformerEncoderCFG,
36
+ transformer_decoder_layers_cfg : TransformerLayerCFG,
37
+ transformer_decoder_cfg : TransformerEncoderCFG,
38
+ size : str = "base",
39
+ **kwargs : dict[str, Any],
40
+ ):
41
+ super().__init__(**kwargs)
42
+
43
+ self.is_spectrogram = False
44
+ self.target_length = int(self.sample_rate * self.process_audio_seconds)
45
+ self.extract_audio = feature_extractor
46
+ self.total_patches = 400
47
+ self.output_steps = self.total_patches // self.in_channels
48
+ self.feature_norms : nn.Module = nn.LayerNorm(self.extract_audio.embedding_dim)
49
+
50
+ self.n_encoder_heads = transformer_encoder_layers_cfg["nhead"]
51
+ self.encoder_embedding_dim = transformer_encoder_layers_cfg["d_model"]
52
+ self.n_decoder_heads = transformer_decoder_layers_cfg["nhead"]
53
+ self.decoder_embedding_dim = transformer_decoder_layers_cfg["d_model"]
54
+
55
+ encoder_layer = nn.TransformerEncoderLayer(**transformer_encoder_layers_cfg)
56
+ self.encoder = nn.TransformerEncoder(encoder_layer, norm = nn.LayerNorm(self.encoder_embedding_dim), **transformer_encoder_cfg)
57
+ self.post_extraction_mapper : Optional[nn.Module] = nn.Linear(feature_extractor.embedding_dim, self.encoder_embedding_dim) if feature_extractor.embedding_dim != self.encoder_embedding_dim else None
58
+ decoder_layer = nn.TransformerEncoderLayer(**transformer_decoder_layers_cfg)
59
+ self.decoder = nn.TransformerEncoder(decoder_layer, norm = nn.LayerNorm(self.decoder_embedding_dim), **transformer_decoder_cfg)
60
+ self.decoder_to_encoder_mapper = nn.Linear(self.decoder_embedding_dim, self.encoder_embedding_dim, bias=True)
61
+ self.encoder_to_decoder_mapper = nn.Linear(self.encoder_embedding_dim, self.decoder_embedding_dim)
62
+
63
+ # For the autocast add batch dimensions.
64
+ self.mask_token = nn.Parameter(
65
+ torch.zeros(1, 1, self.decoder_embedding_dim, requires_grad=True)
66
+ )
67
+ torch.nn.init.normal_(self.mask_token, std=0.02)
68
+ self.pos_encoding_encoder = self._get_pos_embed_params(self.encoder_embedding_dim)
69
+ self.pos_encoding_decoder = self._get_pos_embed_params(self.decoder_embedding_dim)
70
+ self._init_teacher()
71
+
72
+
73
+ def _get_pos_embed_params(self, embedding_dim):
74
+ """Calculates the pos embedding embedding parameters and returns them."""
75
+ # Update positional embedding
76
+ pos_embed = nn.Parameter(
77
+ torch.zeros(
78
+ 1,
79
+ self.total_patches,
80
+ embedding_dim,
81
+ ),
82
+ requires_grad=False,
83
+ )
84
+ positions = np.arange(self.total_patches, dtype=np.float64)
85
+ if self.is_spectrogram:
86
+ # If it is a spectrogram, we use 2d sincos embeddings.
87
+ pos_embed_data = get_2d_sincos_pos_embed(
88
+ embedding_dim, self.extract_audio.grid_size, cls_token_num=0
89
+ )
90
+ #TODO! Remove this total patches later.
91
+ elif not self.is_spectrogram and self.in_channels == 2 and (self.total_patches == 400):
92
+ # We use 1D sincos embeddings with channel number indicated on the last 384 dimensions.
93
+ pos_embed_data = get_binaural_pos_embed(embedding_dim, time_steps=self.total_patches // self.in_channels)
94
+ elif not self.is_spectrogram and self.in_channels == 2 and (self.total_patches == 200):
95
+ #Use 1D pos_embeddings if channel-mixing feature extractor
96
+ pos_embed_data = get_1d_sincos_pos_embed_from_grid(
97
+ embedding_dim,
98
+ positions,
99
+ )
100
+ elif not self.is_spectrogram and self.in_channels == 1 and (self.total_patches == 200):
101
+ # IF it is plain audio, we used 1d sincos embeddings
102
+ pos_embed_data = get_1d_sincos_pos_embed_from_grid(
103
+ embedding_dim,
104
+ positions,
105
+ )
106
+ else:
107
+ raise Exception(f"Not implemented for more in_channels, {self.in_channels}, {self.total_patches}")
108
+ pos_embed.data.copy_(torch.from_numpy(pos_embed_data).float().unsqueeze(0))
109
+ return pos_embed
110
+
111
+ def _init_teacher(self):
112
+ self.teacher_encoder = copy.deepcopy(self.encoder)
113
+ self.teacher_encoder.requires_grad_(False)
114
+
115
+
116
+
117
+ @torch.inference_mode()
118
+ def _get_segment_representation(self, audio : torch.Tensor, padding_mask : torch.tensor):
119
+ # Get the audio representatin of waveform x.
120
+ self.eval()
121
+ local_features = self.extract_audio(audio)
122
+ local_features = self.feature_norms(local_features)
123
+ if self.post_extraction_mapper:
124
+ local_features = self.post_extraction_mapper(local_features)
125
+ local_features = local_features + self.pos_encoding_encoder
126
+ # Encoder and decoder forward
127
+ contextual_features = self.encoder(local_features, src_key_padding_mask = padding_mask)
128
+ return contextual_features
129
+
130
+ @torch.inference_mode()
131
+ def get_audio_representation(self, audio : torch.Tensor):
132
+ B = audio.shape[0]
133
+ input_audio_len = audio.shape[-1]
134
+ # Assert audio is of correct shape
135
+ if audio.ndim != 3:
136
+ raise ValueError(
137
+ "audio input tensor must be 2D with shape (n_sounds, n_channels, num_samples)"
138
+ )
139
+ cur_frames = audio.shape[-1]
140
+ pad_frames = self.target_length - (cur_frames % self.target_length)
141
+
142
+ if pad_frames > 0:
143
+ # Padding with constant 0s
144
+ pad_arg = (
145
+ 0,
146
+ pad_frames,
147
+ ) # (channel, channel, height, height, width, width)
148
+ audio = torch.nn.functional.pad(audio, pad_arg, mode="constant")
149
+ embeddings = []
150
+ padding_mask, cut_off = calculate_padding_mask(pad_frames = pad_frames,
151
+ total_frames = audio.shape[-1],
152
+ sr = self.sample_rate,
153
+ output_steps = self.output_steps,
154
+ process_seconds = self.target_length // self.sample_rate,
155
+ device = audio.device,
156
+ B = B)
157
+ mask_idx = 0
158
+ masked_mean = torch.zeros(audio.shape, dtype = torch.bool)
159
+ masked_mean[..., cur_frames:] = True
160
+ mt = torch.masked.masked_tensor(audio, masked_mean)
161
+ # Now get the embeddings o the model.
162
+ for i in range(audio.shape[-1] // self.target_length):
163
+ mt = audio[..., i * self.target_length : (i + 1) * self.target_length]
164
+ mask = padding_mask[..., mask_idx : mask_idx + self.output_steps]
165
+ with torch.no_grad():
166
+ # We do not include padding tokens in the mean and std calculation.
167
+ mask = repeat(mask, "B E -> B (C E)", C = self.in_channels)
168
+ embedding = self._get_segment_representation(
169
+ normalize(mt),
170
+ mask
171
+ )
172
+ embedding = rearrange(embedding, "B (C S) E -> B C S E", C = self.in_channels)
173
+ mask_idx = mask_idx + self.output_steps
174
+ embeddings.append(embedding)
175
+
176
+
177
+ x = torch.concatenate(embeddings, axis = 2)
178
+ x = x[:, :, :cut_off, :]
179
+ ts = get_timestamps(self.sample_rate, B, input_audio_len, x)
180
+ assert ts.shape[-1] == x.shape[2]
181
+ return x, ts
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a37a5c507ca1d75d093895c840e3b225ca41076d017bbafd0be3850ccf97267
3
+ size 800875480
modeling_wavjepa_nat.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+
3
+ from .model import WavJEPANat
4
+ from .configuration_wavjepa_nat import WavJEPANatConfig
5
+ from .audio_extractor import ConvChannelFeatureExtractor
6
+ import torch
7
+ from typing import Union
8
+
9
+ class WavJEPANatModel(PreTrainedModel):
10
+ config_class = WavJEPANatConfig
11
+
12
+ def __init__(self, config):
13
+ super().__init__(config)
14
+
15
+ self.model = WavJEPANat(
16
+ feature_extractor = ConvChannelFeatureExtractor(
17
+ conv_layers_spec = eval(config.extractor_config['conv_layers_spec']),
18
+ in_channels = config.extractor_config['in_channels'],
19
+ dropout = config.extractor_config['dropout'],
20
+ mode = config.extractor_config['mode'],
21
+ conv_bias = config.extractor_config['conv_bias'],
22
+ depthwise = config.extractor_config['depthwise'],
23
+ share_weights_over_channels = config.extractor_config['share_weights_over_channels']
24
+ ),
25
+ transformer_encoder_layers_cfg = config.encoder_layers_cfg,
26
+ transformer_encoder_cfg = config.encoder_cfg,
27
+ transformer_decoder_layers_cfg = config.decoder_layers_cfg,
28
+ transformer_decoder_cfg = config.decoder_cfg,
29
+ size = config.model_size,
30
+ )
31
+
32
+ def forward(self, tensor) -> Union[torch.Tensor, torch.Tensor]:
33
+ return self.model.get_audio_representation(tensor)
34
+
pos_embed.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # Position embedding utils
8
+ # --------------------------------------------------------
9
+
10
+
11
+ # https://github.com/facebookresearch/AudioMAE/blob/main/util/pos_embed.py
12
+ import numpy as np
13
+ import torch
14
+
15
+
16
+ # --------------------------------------------------------
17
+ # 2D sine-cosine position embedding
18
+ # References:
19
+ # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
20
+ # MoCo v3: https://github.com/facebookresearch/moco-v3
21
+ # --------------------------------------------------------
22
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token_num):
23
+ """
24
+ grid_size: int of the grid height and width
25
+ return:
26
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
27
+ """
28
+ if grid_size is int:
29
+ gH = grid_size
30
+ gW = grid_size
31
+ else:
32
+ gH = grid_size[0]
33
+ gW = grid_size[1]
34
+ grid_h = np.arange(gH, dtype=np.float64)
35
+ grid_w = np.arange(gW, dtype=np.float64)
36
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
37
+ grid = np.stack(grid, axis=0)
38
+
39
+ grid = grid.reshape([2, 1, gH, gW])
40
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
41
+ for _ in range(cls_token_num):
42
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
43
+ return pos_embed
44
+
45
+
46
+ def get_2d_sincos_pos_embed_flexible(embed_dim, grid_size, cls_token=False):
47
+ """
48
+ grid_size: int of the grid height and width
49
+ return:
50
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
51
+ """
52
+ grid_h = np.arange(grid_size[0], dtype=np.float64)
53
+ grid_w = np.arange(grid_size[1], dtype=np.float64)
54
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
55
+ grid = np.stack(grid, axis=0)
56
+
57
+ grid = grid.reshape([2, 1, grid_size[0], grid_size[1]])
58
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
59
+ if cls_token:
60
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
61
+ return pos_embed
62
+
63
+
64
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
65
+ assert embed_dim % 2 == 0
66
+
67
+ # use half of dimensions to encode grid_h
68
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
69
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
70
+
71
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
72
+ return emb
73
+
74
+
75
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
76
+ """
77
+ embed_dim: output dimension for each position
78
+ pos: a list of positions to be encoded: size (M,)
79
+ out: (M, D)
80
+ """
81
+ assert embed_dim % 2 == 0
82
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
83
+ omega /= embed_dim / 2.0
84
+ omega = 1.0 / 10000**omega # (D/2,)
85
+
86
+ pos = pos.reshape(-1) # (M,)
87
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
88
+
89
+ emb_sin = np.sin(out) # (M, D/2)
90
+ emb_cos = np.cos(out) # (M, D/2)
91
+
92
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
93
+ return emb
94
+
95
+
96
+ def get_1d_sincos_pos_embed(embed_dim, length):
97
+ """
98
+ Create 1D sinusoidal positional embeddings.
99
+
100
+ Args:
101
+ embed_dim: embedding dimension
102
+ length: sequence length
103
+
104
+ Returns:
105
+ pos_embed: [length, embed_dim]
106
+ """
107
+ assert embed_dim % 2 == 0
108
+
109
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
110
+ omega /= embed_dim / 2.0
111
+ omega = 1.0 / 10000**omega # (D/2,)
112
+
113
+ pos = np.arange(length, dtype=np.float64) # (length,)
114
+ out = np.einsum("m,d->md", pos, omega) # (length, D/2)
115
+
116
+ emb_sin = np.sin(out) # (length, D/2)
117
+ emb_cos = np.cos(out) # (length, D/2)
118
+
119
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (length, D)
120
+ return emb
121
+
122
+ def get_binaural_pos_embed(embed_dim, time_steps=100):
123
+ """
124
+ Create positional embeddings for binaural audio.
125
+ Same time encoding, different channel encoding.
126
+
127
+ Args:
128
+ embed_dim: embedding dimension
129
+ time_steps: number of time steps per channel
130
+
131
+ Returns:
132
+ pos_embed: [2*time_steps, embed_dim] - for concatenated L+R channels
133
+ """
134
+ assert embed_dim % 2 == 0
135
+
136
+ # Time dimension encoding (same for both channels)
137
+ time_embed = get_1d_sincos_pos_embed(embed_dim // 2, time_steps)
138
+
139
+ # Channel dimension encoding (different for L and R)
140
+ channel_embed_left = np.zeros((time_steps, embed_dim // 2)) # Left channel = 0
141
+ channel_embed_right = get_1d_sincos_pos_embed(embed_dim // 2, 1) # Right channel = different
142
+ channel_embed_right = np.tile(channel_embed_right, (time_steps, 1))
143
+
144
+ # Combine time and channel embeddings
145
+ left_pos_embed = np.concatenate([time_embed, channel_embed_left], axis=1)
146
+ right_pos_embed = np.concatenate([time_embed, channel_embed_right], axis=1)
147
+
148
+ # Concatenate left and right channel embeddings
149
+ binaural_pos_embed = np.concatenate([left_pos_embed, right_pos_embed], axis=0)
150
+
151
+ return binaural_pos_embed
152
+
153
+ # --------------------------------------------------------
154
+ # Interpolate position embeddings for high-resolution
155
+ # References:
156
+ # DeiT: https://github.com/facebookresearch/deit
157
+ # --------------------------------------------------------
158
+ def interpolate_pos_embed(model, checkpoint_model):
159
+ if "pos_embed" in checkpoint_model:
160
+ pos_embed_checkpoint = checkpoint_model["pos_embed"]
161
+ embedding_size = pos_embed_checkpoint.shape[-1]
162
+ num_patches = model.patch_embed.num_patches
163
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
164
+ # height (== width) for the checkpoint position embedding
165
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
166
+ # height (== width) for the new position embedding
167
+ new_size = int(num_patches**0.5)
168
+ # class_token and dist_token are kept unchanged
169
+ if orig_size != new_size:
170
+ print(
171
+ "Position interpolate from %dx%d to %dx%d"
172
+ % (orig_size, orig_size, new_size, new_size)
173
+ )
174
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
175
+ # only the position tokens are interpolated
176
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
177
+ pos_tokens = pos_tokens.reshape(
178
+ -1, orig_size, orig_size, embedding_size
179
+ ).permute(0, 3, 1, 2)
180
+ pos_tokens = torch.nn.functional.interpolate(
181
+ pos_tokens,
182
+ size=(new_size, new_size),
183
+ mode="bicubic",
184
+ align_corners=False,
185
+ )
186
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
187
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
188
+ checkpoint_model["pos_embed"] = new_pos_embed
189
+
190
+
191
+ def interpolate_pos_embed_img2audio(model, checkpoint_model, orig_size, new_size):
192
+ if "pos_embed" in checkpoint_model:
193
+ pos_embed_checkpoint = checkpoint_model["pos_embed"]
194
+ embedding_size = pos_embed_checkpoint.shape[-1]
195
+ num_patches = model.patch_embed.num_patches
196
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
197
+ # height (== width) for the checkpoint position embedding
198
+ # orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
199
+ # height (== width) for the new position embedding
200
+ # new_size = int(num_patches ** 0.5)
201
+ # class_token and dist_token are kept unchanged
202
+ if orig_size != new_size:
203
+ print(
204
+ "Position interpolate from %dx%d to %dx%d"
205
+ % (orig_size[0], orig_size[1], new_size[0], new_size[1])
206
+ )
207
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
208
+ # only the position tokens are interpolated
209
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
210
+ pos_tokens = pos_tokens.reshape(
211
+ -1, orig_size[0], orig_size[1], embedding_size
212
+ ).permute(0, 3, 1, 2)
213
+ pos_tokens = torch.nn.functional.interpolate(
214
+ pos_tokens,
215
+ size=(new_size[0], new_size[1]),
216
+ mode="bicubic",
217
+ align_corners=False,
218
+ )
219
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
220
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
221
+ checkpoint_model["pos_embed"] = new_pos_embed
222
+
223
+
224
+ def interpolate_pos_embed_audio(model, checkpoint_model, orig_size, new_size):
225
+ if "pos_embed" in checkpoint_model:
226
+ pos_embed_checkpoint = checkpoint_model["pos_embed"]
227
+ embedding_size = pos_embed_checkpoint.shape[-1]
228
+ if orig_size != new_size:
229
+ print(
230
+ "Position interpolate from %dx%d to %dx%d"
231
+ % (orig_size[0], orig_size[1], new_size[0], new_size[1])
232
+ )
233
+ # extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
234
+ # only the position tokens are interpolated
235
+ cls_token = pos_embed_checkpoint[:, 0, :].unsqueeze(1)
236
+ pos_tokens = pos_embed_checkpoint[:, 1:, :] # remove
237
+ pos_tokens = pos_tokens.reshape(
238
+ -1, orig_size[0], orig_size[1], embedding_size
239
+ ) # .permute(0, 3, 1, 2)
240
+ # pos_tokens = torch.nn.functional.interpolate(
241
+ # pos_tokens, size=(new_size[0], new_size[1]), mode='bicubic', align_corners=False)
242
+
243
+ # pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
244
+ pos_tokens = pos_tokens[:, :, : new_size[1], :] # assume only time diff
245
+ pos_tokens = pos_tokens.flatten(1, 2)
246
+ new_pos_embed = torch.cat((cls_token, pos_tokens), dim=1)
247
+ checkpoint_model["pos_embed"] = new_pos_embed
248
+
249
+
250
+ def interpolate_patch_embed_audio(
251
+ model,
252
+ checkpoint_model,
253
+ orig_channel,
254
+ new_channel=1,
255
+ kernel_size=(16, 16),
256
+ stride=(16, 16),
257
+ padding=(0, 0),
258
+ ):
259
+ if orig_channel != new_channel:
260
+ if "patch_embed.proj.weight" in checkpoint_model:
261
+ # aggregate 3 channels in rgb ckpt to 1 channel for audio
262
+ new_proj_weight = torch.nn.Parameter(
263
+ torch.sum(checkpoint_model["patch_embed.proj.weight"], dim=1).unsqueeze(
264
+ 1
265
+ )
266
+ )
267
+ checkpoint_model["patch_embed.proj.weight"] = new_proj_weight
types.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TypedDict
2
+ from torch import nn
3
+
4
+
5
+ class TransformerLayerCFG(TypedDict):
6
+ d_model : int
7
+ nhead : int
8
+ batch_first : bool
9
+ norm_first : bool
10
+ bias : bool
11
+ dim_feedforward : int
12
+ dropout : float
13
+ activation : nn.Module
14
+ layer_norm_eps : float
15
+
16
+ @classmethod
17
+ def create(cls,
18
+ d_model : int = 768,
19
+ nhead : int = 12,
20
+ batch_first : bool = True,
21
+ norm_first : bool = False,
22
+ bias : bool = True,
23
+ mlp_ratio : float = 4.0,
24
+ dropout : float = 0.0,
25
+ activation : nn.Module = nn.GELU(),
26
+ layer_norm_eps : float = 1e-6) -> 'TransformerLayerCFG':
27
+ return TransformerLayerCFG(d_model = d_model,
28
+ nhead = nhead,
29
+ batch_first = batch_first,
30
+ norm_first = norm_first,
31
+ bias = bias,
32
+ dim_feedforward = int(d_model * mlp_ratio),
33
+ dropout = dropout,
34
+ activation = activation,
35
+ layer_norm_eps = layer_norm_eps)
36
+
37
+
38
+ # Norm needs to be defined by the user!
39
+ class TransformerEncoderCFG(TypedDict):
40
+ num_layers : int
41
+ enable_nested_tensor: bool
42
+ mask_check: bool
43
+
44
+ @classmethod
45
+ def create(cls,
46
+ num_layers : int = 12,
47
+ enable_nested_tensor: bool = False,
48
+ mask_check: bool = True) -> 'TransformerEncoderCFG':
49
+ return TransformerEncoderCFG(num_layers=num_layers,
50
+ enable_nested_tensor = enable_nested_tensor,
51
+ mask_check = mask_check)
utils.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def get_timestamps(sample_rate, B, input_audio_len, x):
4
+ audio_len = input_audio_len
5
+ sec = audio_len / sample_rate
6
+ x_len = x.shape[2]
7
+ step = sec / x_len * 1000 # sec -> ms
8
+ ts = torch.tensor([step * i for i in range(x_len)]).unsqueeze(0)
9
+ ts = ts.repeat(B, 1)
10
+ return ts
11
+
12
+ def normalize(audio):
13
+ mean = audio.mean(dim=(-2, -1), keepdim=True)
14
+ std = audio.std(dim=(-2, -1), keepdim=True)
15
+ audio = (audio - mean) / (std + 1e-5) # Add epsilon for stability
16
+ return audio
17
+
18
+ def calculate_padding_mask(pad_frames, total_frames, sr, output_steps, process_seconds, device, B):
19
+ # How many 2 seconds chunks does this audio have?
20
+ # Find it and then multiply by the output_steps.
21
+ total_frames = int((total_frames / sr) / process_seconds)
22
+ total_output_steps = output_steps * total_frames
23
+ mask = torch.zeros((B, total_output_steps), dtype = torch.bool, device = device)
24
+
25
+ # Check the number of padding tokens that we have in the audio.
26
+ output_sr = int(output_steps / process_seconds)
27
+ pad_seconds = pad_frames / sr
28
+ pad_steps = int(pad_seconds * output_sr)
29
+ # Create the mask
30
+ mask[..., total_output_steps - pad_steps:] = True
31
+ return mask, total_output_steps - pad_steps