MusicLSTMDemo / MQGAN /preencoder.py
ZDisket
fix preenc
5a8156a
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm
from collections import OrderedDict
import os
from .attentions import ResidualBlock1D, APTx
from .quantizer import FSQ
def sequence_mask(max_length, x_lengths):
"""
Make a bool sequence mask
:param max_length: Max length of sequences
:param x_lengths: Tensor (batch,) indicating sequence lengths
:return: Bool tensor size (batch, max_length) where True is padded and False is valid
"""
mask = torch.arange(max_length).expand(len(x_lengths), max_length).to(x_lengths.device)
mask = mask >= x_lengths.unsqueeze(1)
return mask
class ConvBlock2D(nn.Module):
"""
2-D convolutional block that supports:
• weight-norm wrapping
• regular or depth-wise-separable conv
• boolean padding mask (B, 1, H, W) – keeps padded pixels at 0
Forward signature
-----------------
y = block(x, x_mask=None)
If x_mask is provided (True = padded), the block applies
`out = out.masked_fill(mask_expanded, 0)` right *before* the
non-linearity. This mirrors the masking strategy used in
ResidualBlock2D.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int | tuple[int, int] = 3,
stride: int | tuple[int, int] = 1,
dilation: int | tuple[int, int] = 1,
*,
depthwise: bool = False,
use_weight_norm: bool = True,
act: str = "relu",
dropout: float = 0.1,
bias: bool = True,
):
super().__init__()
# ------ util ------ #
def _make_conv(cin, cout, k, s, d, groups=1):
padding = (
d * (k // 2) if isinstance(k, int)
else (d[0] * (k[0] // 2), d[1] * (k[1] // 2))
)
conv = nn.Conv2d(
cin, cout, k, stride=s, padding=padding,
dilation=d, groups=groups, bias=bias
)
return weight_norm(conv) if use_weight_norm else conv
# ------ conv path ------ #
if depthwise:
self.dw = _make_conv(in_channels, in_channels, kernel_size, stride, dilation,
groups=in_channels) # depth-wise
self.pw = _make_conv(in_channels, out_channels, 1, 1, 1) # point-wise
else:
self.conv = _make_conv(in_channels, out_channels, kernel_size, stride, dilation)
# ------ activation ------ #
if act.lower() == "gelu":
self.activation = nn.GELU()
elif act.lower() == "aptx":
self.activation = APTx()
else:
self.activation = nn.ReLU(inplace=True)
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.depthwise = depthwise # store flag for forward()
self.conv_out = nn.Conv2d(out_channels, 1, 1)
# --------------------------------------------------------------------- #
def _apply_mask(self, tensor: torch.Tensor, mask: torch.Tensor | None) -> torch.Tensor:
if mask is not None:
tensor = tensor.masked_fill(mask.expand_as(tensor), 0.0)
return tensor
# --------------------------------------------------------------------- #
def forward(self, x: torch.Tensor, x_mask: torch.Tensor | None = None) -> torch.Tensor:
"""
x : (B, Cin, H, W)
x_mask : (B, 1, H, W) boolean, True = padding
"""
# (B, H, W)
x = x.unsqueeze(1)
x_mask = x_mask.unsqueeze(1)
if self.depthwise:
out = self.dw(x)
out = self._apply_mask(out, x_mask)
out = self.pw(out)
else:
out = self.conv(x)
out = self._apply_mask(out, x_mask)
out = self.activation(out)
out = self.dropout(out)
out = self.conv_out(out)
return out.squeeze(1)
class PreEncoder(nn.Module):
def __init__(self, mel_channels, channels, kernel_sizes, fsq_levels=[8, 8, 5, 5, 5], dropout=0.1):
"""
Spectrogram Pre-Encoder.
ResNet-based autoencoder with configurable encoder and decoder blocks.
Parameters:
- mel_channels (int): number of channels in the input spectrogram.
- channels (list of ints): list of channel dimensions for encoder blocks.
* The first element is the projected input dimension.
* The last element is the latent dimension.
- kernel_sizes (list of ints): list of kernel sizes for each ResidualBlock1D.
Length should be len(channels) - 1. The decoder will use these lists in reverse.
"""
super(PreEncoder, self).__init__()
# Project input from mel_channels to channels[0]
self.proj = nn.Linear(mel_channels, channels[0])
self.pre = ConvBlock2D(1, channels[0], kernel_size=5, depthwise=True, act="aptx")
self.quantizer_dim = len(fsq_levels)
# Encoder: build a sequence of ResidualBlock1D modules
self.encoder_blocks = nn.ModuleList([
ResidualBlock1D(channels[i], channels[i + 1], kernel_size=kernel_sizes[i], dropout=dropout, act="taptx",
norm="weight")
for i in range(len(channels) - 1)
])
# Quantization stage: here we use the latent dimension as the last element of channels.
latent_dim = channels[-1]
self.q_in_proj = nn.Linear(latent_dim, self.quantizer_dim)
self.quantizer = FSQ(levels=fsq_levels)
self.q_out_proj = nn.Linear(self.quantizer_dim, latent_dim)
self.codebook_size = 8010 # TODO: dyn calculate this
self.bos_token_id = 8001
self.eos_token_id = 8002
# Decoder: use the reversed lists so that the decoder mirrors the encoder.
rev_channels = list(reversed(channels))
rev_kernel_sizes = list(reversed(kernel_sizes))
self.decoder_blocks = nn.ModuleList([
ResidualBlock1D(rev_channels[i], rev_channels[i + 1], kernel_size=rev_kernel_sizes[i], dropout=dropout,
act="taptx", causal=True, norm="weight")
for i in range(len(rev_channels) - 1)
])
self.post = ConvBlock2D(1, channels[0], kernel_size=5, depthwise=True, act="aptx")
# Output projection: map from the decoder’s final channel (channels[0]) back to mel_channels.
self.out_proj = nn.Linear(channels[0], mel_channels)
def forward(self, x, x_lengths):
"""
Forward pass.
Parameters:
- x: Tensor of shape (batch, mel_len, mel_channels)
- x_lengths: (batch,), int lengths of each thing
Returns:
- Reconstructed tensor of shape (batch, mel_len, mel_channels)
"""
# Project input to channel dimension channels[0]
x = self.proj(x) # (batch, mel_len, channels[0])
# Permute to (batch, channels[0], mel_len) for 1D convolutions.
x = x.permute(0, 2, 1)
x_mask = sequence_mask(x.size(2), x_lengths)
x_mask = x_mask.unsqueeze(1) # (B, 1, T)
x = self.pre(x, x_mask)
# Pass through the encoder blocks
for block in self.encoder_blocks:
x = block(x, x_mask=x_mask)
# Permute back to (batch, mel_len, latent_dim)
x = x.permute(0, 2, 1)
x = self.q_in_proj(x)
xhat, indices = self.quantizer(x)
x = self.q_out_proj(xhat)
# Permute for the decoder
x = x.permute(0, 2, 1)
# Pass through the decoder blocks
for block in self.decoder_blocks:
x = block(x, x_mask=x_mask)
x = self.post(x, x_mask)
# Permute back to (batch, mel_len, channels[0])
x = x.permute(0, 2, 1)
# Final projection back to mel_channels
x = self.out_proj(x)
return x
def encode(self, x, x_mask=None):
"""
Encodes the input spectrogram into discrete latent indices.
Args:
x (torch.Tensor): Input tensor of shape (batch, mel_len, mel_channels).
- x_mask: Tensor of shape (batch, mel_len), bool where padded positions are True.
(This mask will be passed to each ResidualBlock1D, which is assumed to apply
.masked_fill(x_mask, 0) before its activation calls.)
Returns:
indices (torch.Tensor): Discrete token indices from the vector quantizer.
"""
# Project input to latent_dim
x = self.proj(x)
# Permute to (batch, latent_dim, mel_len) for convolutional operations
x = x.permute(0, 2, 1)
if x_mask is None:
x_mask = torch.zeros((x.size(0), 1, x.size(2)), device=x.device).bool()
x = self.pre(x, x_mask)
# Pass through the encoder blocks
for block in self.encoder_blocks:
x = block(x, x_mask=x_mask)
# Permute back to (batch, mel_len, latent_dim)
x = x.permute(0, 2, 1)
# Project to quantizer input dimension (e.g. 4)
x = self.q_in_proj(x)
# Quantize and obtain indices
_, indices = self.quantizer(x)
return indices.long() # otherwise cross entropy loss bitches later
def decode(self, indices, x_mask=None, return_hidden=False):
"""
Decodes discrete latent indices into a reconstructed spectrogram.
Args:
indices (torch.Tensor): Discrete token indices from the vector quantizer.
Returns:
x (torch.Tensor): Reconstructed spectrogram of shape (batch, mel_len, mel_channels).
"""
# Convert indices to quantized latent codes (shape: (batch, mel_len, 4))
xhat = self.quantizer.indices_to_codes(indices)
# Project quantized representation back to latent_dim
x = self.q_out_proj(xhat)
# Permute to (batch, latent_dim, mel_len) for convolutional operations
x = x.permute(0, 2, 1)
if x_mask is None:
x_mask = torch.zeros((x.size(0), 1, x.size(2)), device=x.device).bool()
# Pass through the decoder blocks
for block in self.decoder_blocks:
x = block(x, x_mask=x_mask)
if return_hidden:
last_hid = x.clone()
x = self.post(x, x_mask)
# Permute back to (batch, mel_len, latent_dim)
x = x.permute(0, 2, 1)
# Project back to the original mel_channels
x = self.out_proj(x)
if return_hidden:
return x, last_hid
return x
def get_pre_encoder(model_path: str, device: str or torch.device, channels = [384, 512, 768], kernel_sizes=[7, 5, 3], mel_channels=88):
"""
Loads a Pre-Encoder model from a checkpoint file.
Assumes the checkpoint was saved with the training script's structure,
containing 'model_state_dict' and 'args' (or a compatible dict).
Args:
model_path (str): Path to the .pth checkpoint file.
device (str or torch.device): The device to load the model onto ('cpu', 'cuda', etc.).
Returns:
tuple: A tuple containing:
- model (nn.Module): The loaded ResNetAutoencoder1D model instance,
moved to the specified device and set to eval mode.
- model_args (argparse.Namespace or dict): The configuration arguments
used to initialize the model,
loaded from the checkpoint.
Raises:
FileNotFoundError: If the model_path does not exist.
KeyError: If essential keys ('args', 'model_state_dict') are missing
from the checkpoint.
RuntimeError: If load_state_dict fails (e.g., architecture mismatch).
ImportError: If the ResNetAutoencoder1D class cannot be imported/found.
"""
if not os.path.isfile(model_path):
raise FileNotFoundError(f"Checkpoint file not found: {model_path}")
print(f"Loading checkpoint from: {model_path}")
checkpoint = torch.load(model_path, map_location='cpu', weights_only=False) # Load to CPU first
# --- 2. Instantiate Model ---
try:
model = PreEncoder(mel_channels=mel_channels, channels=channels, kernel_sizes=kernel_sizes,
dropout=0.0, fsq_levels=[8, 5, 5, 5])
except NameError:
raise ImportError("ResNetAutoencoder1D class definition not found. Ensure model.py is accessible or the class is defined.")
except Exception as e:
raise RuntimeError(f"Failed to instantiate model with loaded config: {e}")
# --- 3. Load Weights ---
if 'model_state_dict' in checkpoint:
pretrained_weights = checkpoint['model_state_dict']
print("Found weights under 'model_state_dict' key.")
# Optional: Handle 'module.' prefix (if saved using DataParallel/DDP)
clean_weights = OrderedDict()
has_module_prefix = False
for k, v in pretrained_weights.items():
if k.startswith('module.'):
has_module_prefix = True
clean_weights[k[7:]] = v # remove `module.`
else:
clean_weights[k] = v
if has_module_prefix:
print("Removed 'module.' prefix from weight keys.")
pretrained_weights = clean_weights # Use the cleaned dictionary
# Load the weights using strict=True (assumes exact match)
try:
model.load_state_dict(pretrained_weights, strict=True)
print("Successfully loaded model weights.")
except RuntimeError as e:
print(f"Error loading state_dict (likely architecture mismatch): {e}")
raise e # Re-raise the error
else:
raise KeyError(f"Checkpoint missing 'model_state_dict' key containing weights.")
# --- 4. Final Steps ---
model.to(device) # Move model to the target device
model.eval() # Set model to evaluation mode
print(f"Model loaded onto {device} and set to evaluation mode.")
return model