|
|
import copy |
|
|
import numpy as np |
|
|
|
|
|
from typing import Any, Optional |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
|
|
|
|
|
|
from .pos_embed import get_1d_sincos_pos_embed_from_grid, get_2d_sincos_pos_embed, get_binaural_pos_embed |
|
|
from .audio_extractor import Extractor |
|
|
from .types import TransformerLayerCFG, TransformerEncoderCFG |
|
|
from .utils import normalize, calculate_padding_mask, get_timestamps |
|
|
|
|
|
class WavJEPA(nn.Module): |
|
|
""" |
|
|
Joint-Embedding Predictive Architecture (JEPA). |
|
|
|
|
|
This implementation is inspired by: |
|
|
* I-JEPA http://arxiv.org/abs/2301.08243 |
|
|
* Data2vec 2.0 http://arxiv.org/abs/2212.07525 |
|
|
""" |
|
|
|
|
|
teacher_encoder: nn.Module |
|
|
sample_rate : int = 16000 |
|
|
process_audio_seconds : float = 2.01 |
|
|
in_channels : int = 1 |
|
|
|
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
feature_extractor: Extractor, |
|
|
transformer_encoder_layers_cfg : TransformerLayerCFG, |
|
|
transformer_encoder_cfg : TransformerEncoderCFG, |
|
|
transformer_decoder_layers_cfg : TransformerLayerCFG, |
|
|
transformer_decoder_cfg : TransformerEncoderCFG, |
|
|
size : str = "base", |
|
|
**kwargs : dict[str, Any], |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
self.is_spectrogram = False |
|
|
self.target_length = int(self.sample_rate * self.process_audio_seconds) |
|
|
self.extract_audio = feature_extractor |
|
|
self.total_patches = 200 |
|
|
self.feature_norms : nn.Module = nn.LayerNorm(self.extract_audio.embedding_dim) |
|
|
|
|
|
self.n_encoder_heads = transformer_encoder_layers_cfg["nhead"] |
|
|
self.encoder_embedding_dim = transformer_encoder_layers_cfg["d_model"] |
|
|
self.n_decoder_heads = transformer_decoder_layers_cfg["nhead"] |
|
|
self.decoder_embedding_dim = transformer_decoder_layers_cfg["d_model"] |
|
|
|
|
|
encoder_layer = nn.TransformerEncoderLayer(**transformer_encoder_layers_cfg, activation=nn.GELU()) |
|
|
self.encoder = nn.TransformerEncoder(encoder_layer, norm = nn.LayerNorm(self.encoder_embedding_dim), **transformer_encoder_cfg) |
|
|
self.post_extraction_mapper : Optional[nn.Module] = nn.Linear(feature_extractor.embedding_dim, self.encoder_embedding_dim) if feature_extractor.embedding_dim != self.encoder_embedding_dim else None |
|
|
decoder_layer = nn.TransformerEncoderLayer(**transformer_decoder_layers_cfg, activation=nn.GELU()) |
|
|
self.decoder = nn.TransformerEncoder(decoder_layer, norm = nn.LayerNorm(self.decoder_embedding_dim), **transformer_decoder_cfg) |
|
|
self.decoder_to_encoder_mapper = nn.Linear(self.decoder_embedding_dim, self.encoder_embedding_dim, bias=True) |
|
|
self.encoder_to_decoder_mapper = nn.Linear(self.encoder_embedding_dim, self.decoder_embedding_dim) |
|
|
|
|
|
|
|
|
self.mask_token = nn.Parameter( |
|
|
torch.zeros(1, 1, self.decoder_embedding_dim, requires_grad=True) |
|
|
) |
|
|
self.pos_encoding_encoder = self._get_pos_embed_params(self.encoder_embedding_dim) |
|
|
self.pos_encoding_decoder = self._get_pos_embed_params(self.decoder_embedding_dim) |
|
|
self.output_steps = self.extract_audio.total_patches(self.target_length) // self.in_channels |
|
|
|
|
|
self._init_teacher() |
|
|
|
|
|
|
|
|
def _get_pos_embed_params(self, embedding_dim): |
|
|
"""Calculates the pos embedding embedding parameters and returns them.""" |
|
|
|
|
|
pos_embed = nn.Parameter( |
|
|
torch.zeros( |
|
|
1, |
|
|
self.total_patches, |
|
|
embedding_dim, |
|
|
), |
|
|
requires_grad=False, |
|
|
) |
|
|
positions = np.arange(self.total_patches, dtype=np.float64) |
|
|
if self.is_spectrogram: |
|
|
|
|
|
pos_embed_data = get_2d_sincos_pos_embed( |
|
|
embedding_dim, self.extract_audio.grid_size, cls_token_num=0 |
|
|
) |
|
|
|
|
|
elif not self.is_spectrogram and self.in_channels == 2 and (self.total_patches == 400): |
|
|
|
|
|
pos_embed_data = get_binaural_pos_embed(embedding_dim, time_steps=self.total_patches // self.in_channels |
|
|
) |
|
|
elif not self.is_spectrogram and self.in_channels == 2 and (self.total_patches == 200): |
|
|
|
|
|
pos_embed_data = get_1d_sincos_pos_embed_from_grid( |
|
|
embedding_dim, |
|
|
positions, |
|
|
) |
|
|
elif not self.is_spectrogram and self.in_channels == 1 and (self.total_patches == 200): |
|
|
|
|
|
pos_embed_data = get_1d_sincos_pos_embed_from_grid( |
|
|
embedding_dim, |
|
|
positions, |
|
|
) |
|
|
else: |
|
|
raise Exception(f"Not implemented for more in_channels, {self.in_channels}, {self.total_patches}") |
|
|
pos_embed.data.copy_(torch.from_numpy(pos_embed_data).float().unsqueeze(0)) |
|
|
return pos_embed |
|
|
|
|
|
def _init_teacher(self): |
|
|
self.teacher_encoder = copy.deepcopy(self.encoder) |
|
|
self.teacher_encoder.requires_grad_(False) |
|
|
|
|
|
|
|
|
|
|
|
@torch.inference_mode() |
|
|
def _get_segment_representation(self, audio : torch.Tensor, padding_mask : torch.tensor): |
|
|
|
|
|
local_features = self.extract_audio(audio) |
|
|
local_features = self.feature_norms(local_features) |
|
|
if self.post_extraction_mapper: |
|
|
local_features = self.post_extraction_mapper(local_features) |
|
|
local_features = local_features + self.pos_encoding_encoder |
|
|
|
|
|
contextual_features = self.encoder(local_features, src_key_padding_mask = padding_mask) |
|
|
return contextual_features |
|
|
|
|
|
@torch.inference_mode() |
|
|
def get_audio_representation(self, audio : torch.Tensor): |
|
|
B = audio.shape[0] |
|
|
input_audio_len = audio.shape[-1] |
|
|
|
|
|
if audio.ndim != 3: |
|
|
raise ValueError( |
|
|
"audio input tensor must be 2D with shape (n_sounds, n_channels, num_samples)" |
|
|
) |
|
|
cur_frames = audio.shape[-1] |
|
|
pad_frames = self.target_length - (cur_frames % self.target_length) |
|
|
if pad_frames > 0: |
|
|
|
|
|
pad_arg = ( |
|
|
0, |
|
|
pad_frames, |
|
|
) |
|
|
audio = torch.nn.functional.pad(audio, pad_arg, mode="constant") |
|
|
embeddings = [] |
|
|
padding_mask, cut_off = calculate_padding_mask(pad_frames = pad_frames, |
|
|
total_frames = audio.shape[-1], |
|
|
sr = self.sample_rate, |
|
|
output_steps = self.total_patches, |
|
|
process_seconds = self.target_length // self.sample_rate, |
|
|
device = audio.device, |
|
|
B = B) |
|
|
mask_idx = 0 |
|
|
masked_mean = torch.zeros(audio.shape, dtype = torch.bool) |
|
|
masked_mean[..., cur_frames:] = True |
|
|
mt = torch.masked.masked_tensor(audio, masked_mean) |
|
|
|
|
|
for i in range(audio.shape[-1] // self.target_length): |
|
|
mt = audio[..., i * self.target_length : (i + 1) * self.target_length] |
|
|
mask = padding_mask[...,mask_idx : mask_idx + self.output_steps] |
|
|
with torch.no_grad(): |
|
|
|
|
|
embedding = self._get_segment_representation( |
|
|
normalize(mt), |
|
|
mask |
|
|
) |
|
|
mask_idx = mask_idx + self.output_steps |
|
|
embeddings.append(embedding) |
|
|
|
|
|
x = torch.hstack(embeddings) |
|
|
x = x[:, :cut_off, :] |
|
|
ts = get_timestamps(self.sample_rate, B, input_audio_len, x) |
|
|
return x, ts |
|
|
|
|
|
|
|
|
|
|
|
|