Spaces:
Sleeping
Sleeping
File size: 8,107 Bytes
75d43d2 fc7b4a9 75d43d2 fc7b4a9 75d43d2 fc7b4a9 75d43d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 |
import torch
import torch.nn as nn
from pathlib import Path
from .transformer import Transformer
from .tokenizer import STTokenizer
from src.spectttra.feature import FeatureExtractor
class SpecTTTra(nn.Module):
"""
SpecTTTra: A Spectro-Temporal Transformer model for audio representation learning.
This model first tokenizes the input spectrogram into temporal and spectral tokens,
then processes them with a Transformer encoder to capture spectro-temporal dependencies.
"""
def __init__(
self,
input_spec_dim,
input_temp_dim,
embed_dim,
t_clip,
f_clip,
num_heads,
num_layers,
pre_norm=False,
pe_learnable=False,
pos_drop_rate=0.0,
attn_drop_rate=0.0,
proj_drop_rate=0.0,
mlp_ratio=4.0,
):
"""
Initialize the SpecTTTra model.
Args:
input_spec_dim (int): Input spectrogram frequency dimension (F).
input_temp_dim (int): Input spectrogram temporal dimension (T).
embed_dim (int): Embedding dimension for tokens.
t_clip (int): Temporal clip size for tokenization.
f_clip (int): Spectral clip size for tokenization.
num_heads (int): Number of attention heads in the transformer.
num_layers (int): Number of transformer layers.
pre_norm (bool, optional): Whether to apply pre-normalization. Defaults to False.
pe_learnable (bool, optional): If True, use learnable positional embeddings. Defaults to False.
pos_drop_rate (float, optional): Dropout rate for positional embeddings. Defaults to 0.0.
attn_drop_rate (float, optional): Dropout rate for attention. Defaults to 0.0.
proj_drop_rate (float, optional): Dropout rate for projection layers. Defaults to 0.0.
mlp_ratio (float, optional): Expansion ratio for MLP hidden dimension. Defaults to 4.0.
"""
super(SpecTTTra, self).__init__()
self.input_spec_dim = input_spec_dim
self.input_temp_dim = input_temp_dim
self.embed_dim = embed_dim
self.t_clip = t_clip
self.f_clip = f_clip
self.num_heads = num_heads
self.num_layers = num_layers
self.pre_norm = (
pre_norm # Applied after tokenization before transformer (used in CLIP)
)
self.pe_learnable = pe_learnable # Learned positional encoding
self.pos_drop_rate = pos_drop_rate
self.attn_drop_rate = attn_drop_rate
self.proj_drop_rate = proj_drop_rate
self.mlp_ratio = mlp_ratio
# Tokenizer for spectro-temporal features
self.st_tokenizer = STTokenizer(
input_spec_dim,
input_temp_dim,
t_clip,
f_clip,
embed_dim,
pre_norm=pre_norm,
pe_learnable=pe_learnable,
)
# Dropout applied after tokenization
self.pos_drop = nn.Dropout(p=pos_drop_rate)
# Transformer encoder
self.transformer = Transformer(
embed_dim,
num_heads,
num_layers,
attn_drop=self.attn_drop_rate,
proj_drop=self.proj_drop_rate,
mlp_ratio=self.mlp_ratio,
)
def forward(self, x):
"""
Forward pass of SpecTTTra.
Args:
x (torch.Tensor): Input spectrogram of shape
- (B, 1, F, T) if channel dimension exists
- (B, F, T) otherwise
Returns:
torch.Tensor: Transformer-encoded spectro-temporal tokens of shape
(B, T/t + F/f, embed_dim)
"""
# Squeeze the channel dimension if it exists
if x.dim() == 4:
x = x.squeeze(1)
# Spectro-temporal tokenization
spectro_temporal_tokens = self.st_tokenizer(x)
# Positional dropout
spectro_temporal_tokens = self.pos_drop(spectro_temporal_tokens)
# Transformer
output = self.transformer(spectro_temporal_tokens) # shape: (B, T/t + F/f, dim)
return output
def build_spectttra_from_cfg(cfg, device):
"""
Constructs the SpecTTTra model and its associated FeatureExtractor from a given configuration.
Args:
cfg (SimpleNamespace): Configuration object containing model and feature extraction parameters. Expected attributes include:
- cfg.melspec.n_mels: Number of mel frequency bins.
- cfg.model: Model-specific parameters (e.g., embed_dim, t_clip, f_clip, etc.).
device (torch.device): The device on which the model and feature extractor will be allocated (e.g., 'cpu' or 'cuda').
Returns:
tuple:
FeatureExtractor: Initialized feature extraction module moved to the specified device.
SpecTTTra: Constructed SpecTTTra model moved to the specified device.
"""
feat_ext = FeatureExtractor(cfg).to(device)
# The pre-trained model expects specific, fixed input dimensions.
# Hardcoded to ensure the model architecture matches the checkpoint weights exactly.
# The expected number of frames (n_frames) is taken directly from the RuntimeError message.
n_mels = cfg.melspec.n_mels # n_mels should be 128
n_frames = 3744 # n_frames match the checkpoint's expectation
print(f"[INFO] Initializing SpecTTTra with fixed dimensions: n_mels={n_mels}, n_frames={n_frames}")
model_cfg = cfg.model
model = SpecTTTra(
input_spec_dim=n_mels,
input_temp_dim=n_frames,
embed_dim=model_cfg.embed_dim,
t_clip=model_cfg.t_clip,
f_clip=model_cfg.f_clip,
num_heads=model_cfg.num_heads,
num_layers=model_cfg.num_layers,
pre_norm=model_cfg.pre_norm,
pe_learnable=model_cfg.pe_learnable,
pos_drop_rate=model_cfg.pos_drop_rate,
attn_drop_rate=model_cfg.attn_drop_rate,
proj_drop_rate=model_cfg.proj_drop_rate,
mlp_ratio=model_cfg.mlp_ratio,
).to(device)
return feat_ext, model
def load_frozen_spectttra(model, ckpt_path, device):
"""
Loads pretrained SpecTTTra weights from a frozen checkpoint file.
Args:
model (torch.nn.Module): An initialized SpecTTTra model instance to load weights into.
ckpt_path (str or Path): Path to the pretrained model checkpoint file (e.g., 'spectttra_frozen.pth').
device (torch.device): The device to map the loaded weights to (e.g., 'cpu' or 'cuda').
Returns:
model (torch.nn.Module): The SpecTTTra model with loaded pretrained weights, set to evaluation mode.
Raises:
FileNotFoundError: If the specified checkpoint file does not exist at `ckpt_path`.
"""
ckpt_path = Path(ckpt_path)
if not ckpt_path.exists():
raise FileNotFoundError(
f"Pre-trained model not found at {ckpt_path}. "
"Please download 'pytorch_model.bin', rename to 'spectttra_frozen.pth', "
"and place it in the correct directory."
)
print(f"[INFO] Found SpecTTTra checkpoint at {ckpt_path}. Loading weights...")
state = torch.load(ckpt_path, map_location=device)
new_state_dict = {}
for k, v in state.items():
if k.startswith("encoder."):
new_key = k[len("encoder."):]
new_state_dict[new_key] = v
else:
new_state_dict[k] = v
# Now that the shapes match, this should load without a size mismatch error.
missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False)
if missing_keys:
# Might see a few missing keys if your SpecTTTra class is slightly different, but the core should load.
print(f"[WARNING] Missing keys in model: {missing_keys}")
if unexpected_keys:
# Seeing 'classifier' or 'ft_extractor' keys here is NORMAL and SAFE.
print(f"[INFO] Unused keys in checkpoint: {unexpected_keys}")
print("[INFO] Successfully loaded pre-trained SpecTTTra weights.")
model.eval()
return model
|