Spaces:
Running
Running
File size: 5,560 Bytes
fc7b4a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import math
import torch
import torch.nn as nn
from .embedding import (
SinusoidPositionalEncoding,
LearnedPositionalEncoding,
)
class STTokenizer(nn.Module):
"""
Spectro-temporal tokenizer that converts mel-spectrograms into a sequence of tokens.
Both temporal and spectral dimensions are tokenized separately and then
concatenated to form spectro-temporal tokens.
Args:
input_spec_dim (int): Number of frequency bins in the spectrogram.
input_temp_dim (int): Number of time frames in the spectrogram.
t_clip (int): Temporal clip size (stride for temporal tokenization).
f_clip (int): Spectral clip size (stride for spectral tokenization).
embed_dim (int): Dimensionality of each token embedding.
pre_norm (bool, optional): Whether to apply pre-normalization with LayerNorm. Defaults to False.
pe_learnable (bool, optional): Whether to use learnable positional encodings. Defaults to False.
"""
def __init__(
self,
input_spec_dim,
input_temp_dim,
t_clip,
f_clip,
embed_dim,
pre_norm=False,
pe_learnable=False,
):
super(STTokenizer, self).__init__()
self.input_spec_dim = input_spec_dim
self.input_temp_dim = input_temp_dim
self.t_clip = t_clip
self.f_clip = f_clip
self.embed_dim = embed_dim
self.pre_norm = pre_norm
self.pe_learnable = pe_learnable
# Compute number of tokens
self.num_temporal_tokens = math.floor(
(input_temp_dim - t_clip) / t_clip + 1
) # e.g., floor((1280 - 5) / 5 + 1) = 256
self.num_spectral_tokens = math.floor(
(input_spec_dim - f_clip) / f_clip + 1
) # e.g., floor((128 - 3) / 3 + 1) = 42
self.num_tokens = (
self.num_temporal_tokens + self.num_spectral_tokens
)
# Temporal and spectral tokenizers
self.temporal_tokenizer = Tokenizer1D(
input_spec_dim,
embed_dim,
clip_size=t_clip,
num_clips=self.num_temporal_tokens,
pre_norm=pre_norm,
pe_learnable=pe_learnable,
)
self.spectral_tokenizer = Tokenizer1D(
input_temp_dim,
embed_dim,
clip_size=f_clip,
num_clips=self.num_spectral_tokens,
pre_norm=pre_norm,
pe_learnable=pe_learnable,
)
def forward(self, x):
"""
Forward pass of spectro-temporal tokenizer.
Args:
x (torch.Tensor): Input mel-spectrogram of shape (batch_size, freq_bins, time_frames).
Returns:
torch.Tensor: Spectro-temporal tokens of shape
(batch_size, num_temporal_tokens + num_spectral_tokens, embed_dim).
"""
# Temporal tokenization
temporal_input = x # shape: (B, F, T)
temporal_tokens = self.temporal_tokenizer(
temporal_input
) # shape: (B, T/t, dim)
# Spectral tokenization
spectral_input = x.permute(0, 2, 1) # shape: (batch_size, T, F)
spectral_tokens = self.spectral_tokenizer(
spectral_input
) # shape: (B, F/f, dim)
# Concatenate along token dimension
spectro_temporal_tokens = torch.cat(
(temporal_tokens, spectral_tokens), dim=1
) # shape: (B, T/t + F/f, dim)
return spectro_temporal_tokens
class Tokenizer1D(nn.Module):
"""
One-dimensional tokenizer for either temporal or spectral dimension.
Applies a 1D convolution with stride equal to the clip size, followed by
GELU activation, positional encoding, and optional LayerNorm.
Args:
input_dim (int): Input dimension size (frequency for temporal, time for spectral).
token_dim (int): Output token embedding dimension.
clip_size (int): Window/stride size for tokenization.
num_clips (int): Number of tokens produced.
pre_norm (bool, optional): Whether to apply pre-normalization with LayerNorm. Defaults to False.
pe_learnable (bool, optional): Whether to use learnable positional encodings. Defaults to False.
"""
def __init__(
self,
input_dim,
token_dim,
clip_size,
num_clips,
pre_norm=False,
pe_learnable=False,
):
super(Tokenizer1D, self).__init__()
self.conv1d = nn.Conv1d(
input_dim,
token_dim,
clip_size,
stride=clip_size,
bias=not pre_norm, # Disable bias if pre-norm is used (e.g. CLIP)
)
self.act = nn.GELU()
self.pos_encoder = (
SinusoidPositionalEncoding(token_dim)
if not pe_learnable
else LearnedPositionalEncoding(token_dim, num_clips)
)
self.norm_pre = nn.LayerNorm(token_dim, eps=1e-6) if pre_norm else nn.Identity()
def forward(self, x):
"""
Forward pass of 1D tokenizer.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, input_dim, length).
Returns:
torch.Tensor: Sequence of tokens with shape (batch_size, num_clips, token_dim).
"""
x = x # (F, T)
x = self.conv1d(x) # (F, T) -> (dim, T/t)
x = self.act(x)
x = x.transpose(1, 2) # (dim, T/t) -> (T/t, dim)
x = self.pos_encoder(x) # Add position embeddings
x = self.norm_pre(x)
return x
|