bach-or-bot / src /spectttra /tokenizer.py
Acelle Krislette Rosales
Initial commit: Added application code
fc7b4a9
import math
import torch
import torch.nn as nn
from .embedding import (
SinusoidPositionalEncoding,
LearnedPositionalEncoding,
)
class STTokenizer(nn.Module):
"""
Spectro-temporal tokenizer that converts mel-spectrograms into a sequence of tokens.
Both temporal and spectral dimensions are tokenized separately and then
concatenated to form spectro-temporal tokens.
Args:
input_spec_dim (int): Number of frequency bins in the spectrogram.
input_temp_dim (int): Number of time frames in the spectrogram.
t_clip (int): Temporal clip size (stride for temporal tokenization).
f_clip (int): Spectral clip size (stride for spectral tokenization).
embed_dim (int): Dimensionality of each token embedding.
pre_norm (bool, optional): Whether to apply pre-normalization with LayerNorm. Defaults to False.
pe_learnable (bool, optional): Whether to use learnable positional encodings. Defaults to False.
"""
def __init__(
self,
input_spec_dim,
input_temp_dim,
t_clip,
f_clip,
embed_dim,
pre_norm=False,
pe_learnable=False,
):
super(STTokenizer, self).__init__()
self.input_spec_dim = input_spec_dim
self.input_temp_dim = input_temp_dim
self.t_clip = t_clip
self.f_clip = f_clip
self.embed_dim = embed_dim
self.pre_norm = pre_norm
self.pe_learnable = pe_learnable
# Compute number of tokens
self.num_temporal_tokens = math.floor(
(input_temp_dim - t_clip) / t_clip + 1
) # e.g., floor((1280 - 5) / 5 + 1) = 256
self.num_spectral_tokens = math.floor(
(input_spec_dim - f_clip) / f_clip + 1
) # e.g., floor((128 - 3) / 3 + 1) = 42
self.num_tokens = (
self.num_temporal_tokens + self.num_spectral_tokens
)
# Temporal and spectral tokenizers
self.temporal_tokenizer = Tokenizer1D(
input_spec_dim,
embed_dim,
clip_size=t_clip,
num_clips=self.num_temporal_tokens,
pre_norm=pre_norm,
pe_learnable=pe_learnable,
)
self.spectral_tokenizer = Tokenizer1D(
input_temp_dim,
embed_dim,
clip_size=f_clip,
num_clips=self.num_spectral_tokens,
pre_norm=pre_norm,
pe_learnable=pe_learnable,
)
def forward(self, x):
"""
Forward pass of spectro-temporal tokenizer.
Args:
x (torch.Tensor): Input mel-spectrogram of shape (batch_size, freq_bins, time_frames).
Returns:
torch.Tensor: Spectro-temporal tokens of shape
(batch_size, num_temporal_tokens + num_spectral_tokens, embed_dim).
"""
# Temporal tokenization
temporal_input = x # shape: (B, F, T)
temporal_tokens = self.temporal_tokenizer(
temporal_input
) # shape: (B, T/t, dim)
# Spectral tokenization
spectral_input = x.permute(0, 2, 1) # shape: (batch_size, T, F)
spectral_tokens = self.spectral_tokenizer(
spectral_input
) # shape: (B, F/f, dim)
# Concatenate along token dimension
spectro_temporal_tokens = torch.cat(
(temporal_tokens, spectral_tokens), dim=1
) # shape: (B, T/t + F/f, dim)
return spectro_temporal_tokens
class Tokenizer1D(nn.Module):
"""
One-dimensional tokenizer for either temporal or spectral dimension.
Applies a 1D convolution with stride equal to the clip size, followed by
GELU activation, positional encoding, and optional LayerNorm.
Args:
input_dim (int): Input dimension size (frequency for temporal, time for spectral).
token_dim (int): Output token embedding dimension.
clip_size (int): Window/stride size for tokenization.
num_clips (int): Number of tokens produced.
pre_norm (bool, optional): Whether to apply pre-normalization with LayerNorm. Defaults to False.
pe_learnable (bool, optional): Whether to use learnable positional encodings. Defaults to False.
"""
def __init__(
self,
input_dim,
token_dim,
clip_size,
num_clips,
pre_norm=False,
pe_learnable=False,
):
super(Tokenizer1D, self).__init__()
self.conv1d = nn.Conv1d(
input_dim,
token_dim,
clip_size,
stride=clip_size,
bias=not pre_norm, # Disable bias if pre-norm is used (e.g. CLIP)
)
self.act = nn.GELU()
self.pos_encoder = (
SinusoidPositionalEncoding(token_dim)
if not pe_learnable
else LearnedPositionalEncoding(token_dim, num_clips)
)
self.norm_pre = nn.LayerNorm(token_dim, eps=1e-6) if pre_norm else nn.Identity()
def forward(self, x):
"""
Forward pass of 1D tokenizer.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, input_dim, length).
Returns:
torch.Tensor: Sequence of tokens with shape (batch_size, num_clips, token_dim).
"""
x = x # (F, T)
x = self.conv1d(x) # (F, T) -> (dim, T/t)
x = self.act(x)
x = x.transpose(1, 2) # (dim, T/t) -> (T/t, dim)
x = self.pos_encoder(x) # Add position embeddings
x = self.norm_pre(x)
return x