import math import torch import torch.nn as nn from .embedding import ( SinusoidPositionalEncoding, LearnedPositionalEncoding, ) class STTokenizer(nn.Module): """ Spectro-temporal tokenizer that converts mel-spectrograms into a sequence of tokens. Both temporal and spectral dimensions are tokenized separately and then concatenated to form spectro-temporal tokens. Args: input_spec_dim (int): Number of frequency bins in the spectrogram. input_temp_dim (int): Number of time frames in the spectrogram. t_clip (int): Temporal clip size (stride for temporal tokenization). f_clip (int): Spectral clip size (stride for spectral tokenization). embed_dim (int): Dimensionality of each token embedding. pre_norm (bool, optional): Whether to apply pre-normalization with LayerNorm. Defaults to False. pe_learnable (bool, optional): Whether to use learnable positional encodings. Defaults to False. """ def __init__( self, input_spec_dim, input_temp_dim, t_clip, f_clip, embed_dim, pre_norm=False, pe_learnable=False, ): super(STTokenizer, self).__init__() self.input_spec_dim = input_spec_dim self.input_temp_dim = input_temp_dim self.t_clip = t_clip self.f_clip = f_clip self.embed_dim = embed_dim self.pre_norm = pre_norm self.pe_learnable = pe_learnable # Compute number of tokens self.num_temporal_tokens = math.floor( (input_temp_dim - t_clip) / t_clip + 1 ) # e.g., floor((1280 - 5) / 5 + 1) = 256 self.num_spectral_tokens = math.floor( (input_spec_dim - f_clip) / f_clip + 1 ) # e.g., floor((128 - 3) / 3 + 1) = 42 self.num_tokens = ( self.num_temporal_tokens + self.num_spectral_tokens ) # Temporal and spectral tokenizers self.temporal_tokenizer = Tokenizer1D( input_spec_dim, embed_dim, clip_size=t_clip, num_clips=self.num_temporal_tokens, pre_norm=pre_norm, pe_learnable=pe_learnable, ) self.spectral_tokenizer = Tokenizer1D( input_temp_dim, embed_dim, clip_size=f_clip, num_clips=self.num_spectral_tokens, pre_norm=pre_norm, pe_learnable=pe_learnable, ) def forward(self, x): """ Forward pass of spectro-temporal tokenizer. Args: x (torch.Tensor): Input mel-spectrogram of shape (batch_size, freq_bins, time_frames). Returns: torch.Tensor: Spectro-temporal tokens of shape (batch_size, num_temporal_tokens + num_spectral_tokens, embed_dim). """ # Temporal tokenization temporal_input = x # shape: (B, F, T) temporal_tokens = self.temporal_tokenizer( temporal_input ) # shape: (B, T/t, dim) # Spectral tokenization spectral_input = x.permute(0, 2, 1) # shape: (batch_size, T, F) spectral_tokens = self.spectral_tokenizer( spectral_input ) # shape: (B, F/f, dim) # Concatenate along token dimension spectro_temporal_tokens = torch.cat( (temporal_tokens, spectral_tokens), dim=1 ) # shape: (B, T/t + F/f, dim) return spectro_temporal_tokens class Tokenizer1D(nn.Module): """ One-dimensional tokenizer for either temporal or spectral dimension. Applies a 1D convolution with stride equal to the clip size, followed by GELU activation, positional encoding, and optional LayerNorm. Args: input_dim (int): Input dimension size (frequency for temporal, time for spectral). token_dim (int): Output token embedding dimension. clip_size (int): Window/stride size for tokenization. num_clips (int): Number of tokens produced. pre_norm (bool, optional): Whether to apply pre-normalization with LayerNorm. Defaults to False. pe_learnable (bool, optional): Whether to use learnable positional encodings. Defaults to False. """ def __init__( self, input_dim, token_dim, clip_size, num_clips, pre_norm=False, pe_learnable=False, ): super(Tokenizer1D, self).__init__() self.conv1d = nn.Conv1d( input_dim, token_dim, clip_size, stride=clip_size, bias=not pre_norm, # Disable bias if pre-norm is used (e.g. CLIP) ) self.act = nn.GELU() self.pos_encoder = ( SinusoidPositionalEncoding(token_dim) if not pe_learnable else LearnedPositionalEncoding(token_dim, num_clips) ) self.norm_pre = nn.LayerNorm(token_dim, eps=1e-6) if pre_norm else nn.Identity() def forward(self, x): """ Forward pass of 1D tokenizer. Args: x (torch.Tensor): Input tensor of shape (batch_size, input_dim, length). Returns: torch.Tensor: Sequence of tokens with shape (batch_size, num_clips, token_dim). """ x = x # (F, T) x = self.conv1d(x) # (F, T) -> (dim, T/t) x = self.act(x) x = x.transpose(1, 2) # (dim, T/t) -> (T/t, dim) x = self.pos_encoder(x) # Add position embeddings x = self.norm_pre(x) return x