import torch import torch.nn as nn from pathlib import Path from .transformer import Transformer from .tokenizer import STTokenizer from src.spectttra.feature import FeatureExtractor class SpecTTTra(nn.Module): """ SpecTTTra: A Spectro-Temporal Transformer model for audio representation learning. This model first tokenizes the input spectrogram into temporal and spectral tokens, then processes them with a Transformer encoder to capture spectro-temporal dependencies. """ def __init__( self, input_spec_dim, input_temp_dim, embed_dim, t_clip, f_clip, num_heads, num_layers, pre_norm=False, pe_learnable=False, pos_drop_rate=0.0, attn_drop_rate=0.0, proj_drop_rate=0.0, mlp_ratio=4.0, ): """ Initialize the SpecTTTra model. Args: input_spec_dim (int): Input spectrogram frequency dimension (F). input_temp_dim (int): Input spectrogram temporal dimension (T). embed_dim (int): Embedding dimension for tokens. t_clip (int): Temporal clip size for tokenization. f_clip (int): Spectral clip size for tokenization. num_heads (int): Number of attention heads in the transformer. num_layers (int): Number of transformer layers. pre_norm (bool, optional): Whether to apply pre-normalization. Defaults to False. pe_learnable (bool, optional): If True, use learnable positional embeddings. Defaults to False. pos_drop_rate (float, optional): Dropout rate for positional embeddings. Defaults to 0.0. attn_drop_rate (float, optional): Dropout rate for attention. Defaults to 0.0. proj_drop_rate (float, optional): Dropout rate for projection layers. Defaults to 0.0. mlp_ratio (float, optional): Expansion ratio for MLP hidden dimension. Defaults to 4.0. """ super(SpecTTTra, self).__init__() self.input_spec_dim = input_spec_dim self.input_temp_dim = input_temp_dim self.embed_dim = embed_dim self.t_clip = t_clip self.f_clip = f_clip self.num_heads = num_heads self.num_layers = num_layers self.pre_norm = ( pre_norm # Applied after tokenization before transformer (used in CLIP) ) self.pe_learnable = pe_learnable # Learned positional encoding self.pos_drop_rate = pos_drop_rate self.attn_drop_rate = attn_drop_rate self.proj_drop_rate = proj_drop_rate self.mlp_ratio = mlp_ratio # Tokenizer for spectro-temporal features self.st_tokenizer = STTokenizer( input_spec_dim, input_temp_dim, t_clip, f_clip, embed_dim, pre_norm=pre_norm, pe_learnable=pe_learnable, ) # Dropout applied after tokenization self.pos_drop = nn.Dropout(p=pos_drop_rate) # Transformer encoder self.transformer = Transformer( embed_dim, num_heads, num_layers, attn_drop=self.attn_drop_rate, proj_drop=self.proj_drop_rate, mlp_ratio=self.mlp_ratio, ) def forward(self, x): """ Forward pass of SpecTTTra. Args: x (torch.Tensor): Input spectrogram of shape - (B, 1, F, T) if channel dimension exists - (B, F, T) otherwise Returns: torch.Tensor: Transformer-encoded spectro-temporal tokens of shape (B, T/t + F/f, embed_dim) """ # Squeeze the channel dimension if it exists if x.dim() == 4: x = x.squeeze(1) # Spectro-temporal tokenization spectro_temporal_tokens = self.st_tokenizer(x) # Positional dropout spectro_temporal_tokens = self.pos_drop(spectro_temporal_tokens) # Transformer output = self.transformer(spectro_temporal_tokens) # shape: (B, T/t + F/f, dim) return output def build_spectttra_from_cfg(cfg, device): """ Constructs the SpecTTTra model and its associated FeatureExtractor from a given configuration. Args: cfg (SimpleNamespace): Configuration object containing model and feature extraction parameters. Expected attributes include: - cfg.melspec.n_mels: Number of mel frequency bins. - cfg.model: Model-specific parameters (e.g., embed_dim, t_clip, f_clip, etc.). device (torch.device): The device on which the model and feature extractor will be allocated (e.g., 'cpu' or 'cuda'). Returns: tuple: FeatureExtractor: Initialized feature extraction module moved to the specified device. SpecTTTra: Constructed SpecTTTra model moved to the specified device. """ feat_ext = FeatureExtractor(cfg).to(device) # The pre-trained model expects specific, fixed input dimensions. # Hardcoded to ensure the model architecture matches the checkpoint weights exactly. # The expected number of frames (n_frames) is taken directly from the RuntimeError message. n_mels = cfg.melspec.n_mels # n_mels should be 128 n_frames = 3744 # n_frames match the checkpoint's expectation print(f"[INFO] Initializing SpecTTTra with fixed dimensions: n_mels={n_mels}, n_frames={n_frames}") model_cfg = cfg.model model = SpecTTTra( input_spec_dim=n_mels, input_temp_dim=n_frames, embed_dim=model_cfg.embed_dim, t_clip=model_cfg.t_clip, f_clip=model_cfg.f_clip, num_heads=model_cfg.num_heads, num_layers=model_cfg.num_layers, pre_norm=model_cfg.pre_norm, pe_learnable=model_cfg.pe_learnable, pos_drop_rate=model_cfg.pos_drop_rate, attn_drop_rate=model_cfg.attn_drop_rate, proj_drop_rate=model_cfg.proj_drop_rate, mlp_ratio=model_cfg.mlp_ratio, ).to(device) return feat_ext, model def load_frozen_spectttra(model, ckpt_path, device): """ Loads pretrained SpecTTTra weights from a frozen checkpoint file. Args: model (torch.nn.Module): An initialized SpecTTTra model instance to load weights into. ckpt_path (str or Path): Path to the pretrained model checkpoint file (e.g., 'spectttra_frozen.pth'). device (torch.device): The device to map the loaded weights to (e.g., 'cpu' or 'cuda'). Returns: model (torch.nn.Module): The SpecTTTra model with loaded pretrained weights, set to evaluation mode. Raises: FileNotFoundError: If the specified checkpoint file does not exist at `ckpt_path`. """ ckpt_path = Path(ckpt_path) if not ckpt_path.exists(): raise FileNotFoundError( f"Pre-trained model not found at {ckpt_path}. " "Please download 'pytorch_model.bin', rename to 'spectttra_frozen.pth', " "and place it in the correct directory." ) print(f"[INFO] Found SpecTTTra checkpoint at {ckpt_path}. Loading weights...") state = torch.load(ckpt_path, map_location=device) new_state_dict = {} for k, v in state.items(): if k.startswith("encoder."): new_key = k[len("encoder."):] new_state_dict[new_key] = v else: new_state_dict[k] = v # Now that the shapes match, this should load without a size mismatch error. missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False) if missing_keys: # Might see a few missing keys if your SpecTTTra class is slightly different, but the core should load. print(f"[WARNING] Missing keys in model: {missing_keys}") if unexpected_keys: # Seeing 'classifier' or 'ft_extractor' keys here is NORMAL and SAFE. print(f"[INFO] Unused keys in checkpoint: {unexpected_keys}") print("[INFO] Successfully loaded pre-trained SpecTTTra weights.") model.eval() return model