import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.utils import weight_norm from collections import OrderedDict import os from .attentions import ResidualBlock1D, APTx from .quantizer import FSQ def sequence_mask(max_length, x_lengths): """ Make a bool sequence mask :param max_length: Max length of sequences :param x_lengths: Tensor (batch,) indicating sequence lengths :return: Bool tensor size (batch, max_length) where True is padded and False is valid """ mask = torch.arange(max_length).expand(len(x_lengths), max_length).to(x_lengths.device) mask = mask >= x_lengths.unsqueeze(1) return mask class ConvBlock2D(nn.Module): """ 2-D convolutional block that supports: • weight-norm wrapping • regular or depth-wise-separable conv • boolean padding mask (B, 1, H, W) – keeps padded pixels at 0 Forward signature ----------------- y = block(x, x_mask=None) If x_mask is provided (True = padded), the block applies `out = out.masked_fill(mask_expanded, 0)` right *before* the non-linearity. This mirrors the masking strategy used in ResidualBlock2D. """ def __init__( self, in_channels: int, out_channels: int, kernel_size: int | tuple[int, int] = 3, stride: int | tuple[int, int] = 1, dilation: int | tuple[int, int] = 1, *, depthwise: bool = False, use_weight_norm: bool = True, act: str = "relu", dropout: float = 0.1, bias: bool = True, ): super().__init__() # ------ util ------ # def _make_conv(cin, cout, k, s, d, groups=1): padding = ( d * (k // 2) if isinstance(k, int) else (d[0] * (k[0] // 2), d[1] * (k[1] // 2)) ) conv = nn.Conv2d( cin, cout, k, stride=s, padding=padding, dilation=d, groups=groups, bias=bias ) return weight_norm(conv) if use_weight_norm else conv # ------ conv path ------ # if depthwise: self.dw = _make_conv(in_channels, in_channels, kernel_size, stride, dilation, groups=in_channels) # depth-wise self.pw = _make_conv(in_channels, out_channels, 1, 1, 1) # point-wise else: self.conv = _make_conv(in_channels, out_channels, kernel_size, stride, dilation) # ------ activation ------ # if act.lower() == "gelu": self.activation = nn.GELU() elif act.lower() == "aptx": self.activation = APTx() else: self.activation = nn.ReLU(inplace=True) self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() self.depthwise = depthwise # store flag for forward() self.conv_out = nn.Conv2d(out_channels, 1, 1) # --------------------------------------------------------------------- # def _apply_mask(self, tensor: torch.Tensor, mask: torch.Tensor | None) -> torch.Tensor: if mask is not None: tensor = tensor.masked_fill(mask.expand_as(tensor), 0.0) return tensor # --------------------------------------------------------------------- # def forward(self, x: torch.Tensor, x_mask: torch.Tensor | None = None) -> torch.Tensor: """ x : (B, Cin, H, W) x_mask : (B, 1, H, W) boolean, True = padding """ # (B, H, W) x = x.unsqueeze(1) x_mask = x_mask.unsqueeze(1) if self.depthwise: out = self.dw(x) out = self._apply_mask(out, x_mask) out = self.pw(out) else: out = self.conv(x) out = self._apply_mask(out, x_mask) out = self.activation(out) out = self.dropout(out) out = self.conv_out(out) return out.squeeze(1) class PreEncoder(nn.Module): def __init__(self, mel_channels, channels, kernel_sizes, fsq_levels=[8, 8, 5, 5, 5], dropout=0.1): """ Spectrogram Pre-Encoder. ResNet-based autoencoder with configurable encoder and decoder blocks. Parameters: - mel_channels (int): number of channels in the input spectrogram. - channels (list of ints): list of channel dimensions for encoder blocks. * The first element is the projected input dimension. * The last element is the latent dimension. - kernel_sizes (list of ints): list of kernel sizes for each ResidualBlock1D. Length should be len(channels) - 1. The decoder will use these lists in reverse. """ super(PreEncoder, self).__init__() # Project input from mel_channels to channels[0] self.proj = nn.Linear(mel_channels, channels[0]) self.pre = ConvBlock2D(1, channels[0], kernel_size=5, depthwise=True, act="aptx") self.quantizer_dim = len(fsq_levels) # Encoder: build a sequence of ResidualBlock1D modules self.encoder_blocks = nn.ModuleList([ ResidualBlock1D(channels[i], channels[i + 1], kernel_size=kernel_sizes[i], dropout=dropout, act="taptx", norm="weight") for i in range(len(channels) - 1) ]) # Quantization stage: here we use the latent dimension as the last element of channels. latent_dim = channels[-1] self.q_in_proj = nn.Linear(latent_dim, self.quantizer_dim) self.quantizer = FSQ(levels=fsq_levels) self.q_out_proj = nn.Linear(self.quantizer_dim, latent_dim) self.codebook_size = 8010 # TODO: dyn calculate this self.bos_token_id = 8001 self.eos_token_id = 8002 # Decoder: use the reversed lists so that the decoder mirrors the encoder. rev_channels = list(reversed(channels)) rev_kernel_sizes = list(reversed(kernel_sizes)) self.decoder_blocks = nn.ModuleList([ ResidualBlock1D(rev_channels[i], rev_channels[i + 1], kernel_size=rev_kernel_sizes[i], dropout=dropout, act="taptx", causal=True, norm="weight") for i in range(len(rev_channels) - 1) ]) self.post = ConvBlock2D(1, channels[0], kernel_size=5, depthwise=True, act="aptx") # Output projection: map from the decoder’s final channel (channels[0]) back to mel_channels. self.out_proj = nn.Linear(channels[0], mel_channels) def forward(self, x, x_lengths): """ Forward pass. Parameters: - x: Tensor of shape (batch, mel_len, mel_channels) - x_lengths: (batch,), int lengths of each thing Returns: - Reconstructed tensor of shape (batch, mel_len, mel_channels) """ # Project input to channel dimension channels[0] x = self.proj(x) # (batch, mel_len, channels[0]) # Permute to (batch, channels[0], mel_len) for 1D convolutions. x = x.permute(0, 2, 1) x_mask = sequence_mask(x.size(2), x_lengths) x_mask = x_mask.unsqueeze(1) # (B, 1, T) x = self.pre(x, x_mask) # Pass through the encoder blocks for block in self.encoder_blocks: x = block(x, x_mask=x_mask) # Permute back to (batch, mel_len, latent_dim) x = x.permute(0, 2, 1) x = self.q_in_proj(x) xhat, indices = self.quantizer(x) x = self.q_out_proj(xhat) # Permute for the decoder x = x.permute(0, 2, 1) # Pass through the decoder blocks for block in self.decoder_blocks: x = block(x, x_mask=x_mask) x = self.post(x, x_mask) # Permute back to (batch, mel_len, channels[0]) x = x.permute(0, 2, 1) # Final projection back to mel_channels x = self.out_proj(x) return x def encode(self, x, x_mask=None): """ Encodes the input spectrogram into discrete latent indices. Args: x (torch.Tensor): Input tensor of shape (batch, mel_len, mel_channels). - x_mask: Tensor of shape (batch, mel_len), bool where padded positions are True. (This mask will be passed to each ResidualBlock1D, which is assumed to apply .masked_fill(x_mask, 0) before its activation calls.) Returns: indices (torch.Tensor): Discrete token indices from the vector quantizer. """ # Project input to latent_dim x = self.proj(x) # Permute to (batch, latent_dim, mel_len) for convolutional operations x = x.permute(0, 2, 1) if x_mask is None: x_mask = torch.zeros((x.size(0), 1, x.size(2)), device=x.device).bool() x = self.pre(x, x_mask) # Pass through the encoder blocks for block in self.encoder_blocks: x = block(x, x_mask=x_mask) # Permute back to (batch, mel_len, latent_dim) x = x.permute(0, 2, 1) # Project to quantizer input dimension (e.g. 4) x = self.q_in_proj(x) # Quantize and obtain indices _, indices = self.quantizer(x) return indices.long() # otherwise cross entropy loss bitches later def decode(self, indices, x_mask=None, return_hidden=False): """ Decodes discrete latent indices into a reconstructed spectrogram. Args: indices (torch.Tensor): Discrete token indices from the vector quantizer. Returns: x (torch.Tensor): Reconstructed spectrogram of shape (batch, mel_len, mel_channels). """ # Convert indices to quantized latent codes (shape: (batch, mel_len, 4)) xhat = self.quantizer.indices_to_codes(indices) # Project quantized representation back to latent_dim x = self.q_out_proj(xhat) # Permute to (batch, latent_dim, mel_len) for convolutional operations x = x.permute(0, 2, 1) if x_mask is None: x_mask = torch.zeros((x.size(0), 1, x.size(2)), device=x.device).bool() # Pass through the decoder blocks for block in self.decoder_blocks: x = block(x, x_mask=x_mask) if return_hidden: last_hid = x.clone() x = self.post(x, x_mask) # Permute back to (batch, mel_len, latent_dim) x = x.permute(0, 2, 1) # Project back to the original mel_channels x = self.out_proj(x) if return_hidden: return x, last_hid return x def get_pre_encoder(model_path: str, device: str or torch.device, channels = [384, 512, 768], kernel_sizes=[7, 5, 3], mel_channels=88): """ Loads a Pre-Encoder model from a checkpoint file. Assumes the checkpoint was saved with the training script's structure, containing 'model_state_dict' and 'args' (or a compatible dict). Args: model_path (str): Path to the .pth checkpoint file. device (str or torch.device): The device to load the model onto ('cpu', 'cuda', etc.). Returns: tuple: A tuple containing: - model (nn.Module): The loaded ResNetAutoencoder1D model instance, moved to the specified device and set to eval mode. - model_args (argparse.Namespace or dict): The configuration arguments used to initialize the model, loaded from the checkpoint. Raises: FileNotFoundError: If the model_path does not exist. KeyError: If essential keys ('args', 'model_state_dict') are missing from the checkpoint. RuntimeError: If load_state_dict fails (e.g., architecture mismatch). ImportError: If the ResNetAutoencoder1D class cannot be imported/found. """ if not os.path.isfile(model_path): raise FileNotFoundError(f"Checkpoint file not found: {model_path}") print(f"Loading checkpoint from: {model_path}") checkpoint = torch.load(model_path, map_location='cpu', weights_only=False) # Load to CPU first # --- 2. Instantiate Model --- try: model = PreEncoder(mel_channels=mel_channels, channels=channels, kernel_sizes=kernel_sizes, dropout=0.0, fsq_levels=[8, 5, 5, 5]) except NameError: raise ImportError("ResNetAutoencoder1D class definition not found. Ensure model.py is accessible or the class is defined.") except Exception as e: raise RuntimeError(f"Failed to instantiate model with loaded config: {e}") # --- 3. Load Weights --- if 'model_state_dict' in checkpoint: pretrained_weights = checkpoint['model_state_dict'] print("Found weights under 'model_state_dict' key.") # Optional: Handle 'module.' prefix (if saved using DataParallel/DDP) clean_weights = OrderedDict() has_module_prefix = False for k, v in pretrained_weights.items(): if k.startswith('module.'): has_module_prefix = True clean_weights[k[7:]] = v # remove `module.` else: clean_weights[k] = v if has_module_prefix: print("Removed 'module.' prefix from weight keys.") pretrained_weights = clean_weights # Use the cleaned dictionary # Load the weights using strict=True (assumes exact match) try: model.load_state_dict(pretrained_weights, strict=True) print("Successfully loaded model weights.") except RuntimeError as e: print(f"Error loading state_dict (likely architecture mismatch): {e}") raise e # Re-raise the error else: raise KeyError(f"Checkpoint missing 'model_state_dict' key containing weights.") # --- 4. Final Steps --- model.to(device) # Move model to the target device model.eval() # Set model to evaluation mode print(f"Model loaded onto {device} and set to evaluation mode.") return model