File size: 14,111 Bytes
827b824
 
 
 
 
5a8156a
827b824
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm
from collections import OrderedDict
import os
from .attentions import ResidualBlock1D, APTx
from .quantizer import FSQ


def sequence_mask(max_length, x_lengths):
    """
    Make a bool sequence mask
    :param max_length: Max length of sequences
    :param x_lengths: Tensor (batch,) indicating sequence lengths
    :return: Bool tensor size (batch, max_length) where True is padded and False is valid
    """
    mask = torch.arange(max_length).expand(len(x_lengths), max_length).to(x_lengths.device)
    mask = mask >= x_lengths.unsqueeze(1)
    return mask


class ConvBlock2D(nn.Module):
    """
    2-D convolutional block that supports:
      • weight-norm wrapping
      • regular or depth-wise-separable conv
      • boolean padding mask (B, 1, H, W)  – keeps padded pixels at 0

    Forward signature
    -----------------
    y = block(x, x_mask=None)

    If x_mask is provided (True = padded), the block applies
    `out = out.masked_fill(mask_expanded, 0)` right *before* the
    non-linearity.  This mirrors the masking strategy used in
    ResidualBlock2D.
    """

    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            kernel_size: int | tuple[int, int] = 3,
            stride: int | tuple[int, int] = 1,
            dilation: int | tuple[int, int] = 1,
            *,
            depthwise: bool = False,
            use_weight_norm: bool = True,
            act: str = "relu",
            dropout: float = 0.1,
            bias: bool = True,
    ):
        super().__init__()

        # ------ util ------ #
        def _make_conv(cin, cout, k, s, d, groups=1):
            padding = (
                d * (k // 2) if isinstance(k, int)
                else (d[0] * (k[0] // 2), d[1] * (k[1] // 2))
            )
            conv = nn.Conv2d(
                cin, cout, k, stride=s, padding=padding,
                dilation=d, groups=groups, bias=bias
            )
            return weight_norm(conv) if use_weight_norm else conv

        # ------ conv path ------ #
        if depthwise:
            self.dw = _make_conv(in_channels, in_channels, kernel_size, stride, dilation,
                                 groups=in_channels)  # depth-wise
            self.pw = _make_conv(in_channels, out_channels, 1, 1, 1)  # point-wise
        else:
            self.conv = _make_conv(in_channels, out_channels, kernel_size, stride, dilation)

        # ------ activation ------ #
        if act.lower() == "gelu":
            self.activation = nn.GELU()
        elif act.lower() == "aptx":
            self.activation = APTx()
        else:
            self.activation = nn.ReLU(inplace=True)

        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
        self.depthwise = depthwise  # store flag for forward()
        self.conv_out = nn.Conv2d(out_channels, 1, 1)

    # --------------------------------------------------------------------- #
    def _apply_mask(self, tensor: torch.Tensor, mask: torch.Tensor | None) -> torch.Tensor:
        if mask is not None:
            tensor = tensor.masked_fill(mask.expand_as(tensor), 0.0)
        return tensor

    # --------------------------------------------------------------------- #
    def forward(self, x: torch.Tensor, x_mask: torch.Tensor | None = None) -> torch.Tensor:
        """
        x       : (B, Cin, H, W)
        x_mask  : (B, 1, H, W) boolean, True = padding
        """
        # (B, H, W)
        x = x.unsqueeze(1)
        x_mask = x_mask.unsqueeze(1)
        if self.depthwise:
            out = self.dw(x)
            out = self._apply_mask(out, x_mask)
            out = self.pw(out)
        else:
            out = self.conv(x)

        out = self._apply_mask(out, x_mask)
        out = self.activation(out)
        out = self.dropout(out)
        out = self.conv_out(out)
        return out.squeeze(1)


class PreEncoder(nn.Module):
    def __init__(self, mel_channels, channels, kernel_sizes, fsq_levels=[8, 8, 5, 5, 5], dropout=0.1):
        """
        Spectrogram Pre-Encoder.
        ResNet-based autoencoder with configurable encoder and decoder blocks.

        Parameters:
          - mel_channels (int): number of channels in the input spectrogram.
          - channels (list of ints): list of channel dimensions for encoder blocks.
            * The first element is the projected input dimension.
            * The last element is the latent dimension.
          - kernel_sizes (list of ints): list of kernel sizes for each ResidualBlock1D.
            Length should be len(channels) - 1. The decoder will use these lists in reverse.
        """
        super(PreEncoder, self).__init__()
        # Project input from mel_channels to channels[0]
        self.proj = nn.Linear(mel_channels, channels[0])
        self.pre = ConvBlock2D(1, channels[0], kernel_size=5, depthwise=True, act="aptx")
        self.quantizer_dim = len(fsq_levels)
        # Encoder: build a sequence of ResidualBlock1D modules
        self.encoder_blocks = nn.ModuleList([
            ResidualBlock1D(channels[i], channels[i + 1], kernel_size=kernel_sizes[i], dropout=dropout, act="taptx",
                            norm="weight")
            for i in range(len(channels) - 1)
        ])

        # Quantization stage: here we use the latent dimension as the last element of channels.
        latent_dim = channels[-1]

        self.q_in_proj = nn.Linear(latent_dim, self.quantizer_dim)
        self.quantizer = FSQ(levels=fsq_levels)
        self.q_out_proj = nn.Linear(self.quantizer_dim, latent_dim)
        self.codebook_size = 8010  # TODO: dyn calculate this
        self.bos_token_id = 8001
        self.eos_token_id = 8002

        # Decoder: use the reversed lists so that the decoder mirrors the encoder.
        rev_channels = list(reversed(channels))
        rev_kernel_sizes = list(reversed(kernel_sizes))
        self.decoder_blocks = nn.ModuleList([
            ResidualBlock1D(rev_channels[i], rev_channels[i + 1], kernel_size=rev_kernel_sizes[i], dropout=dropout,
                            act="taptx", causal=True, norm="weight")
            for i in range(len(rev_channels) - 1)
        ])
        self.post = ConvBlock2D(1, channels[0], kernel_size=5, depthwise=True, act="aptx")

        # Output projection: map from the decoder’s final channel (channels[0]) back to mel_channels.
        self.out_proj = nn.Linear(channels[0], mel_channels)

    def forward(self, x, x_lengths):
        """
        Forward pass.

        Parameters:
          - x: Tensor of shape (batch, mel_len, mel_channels)
          - x_lengths: (batch,), int lengths of each thing
        Returns:
          - Reconstructed tensor of shape (batch, mel_len, mel_channels)
        """
        # Project input to channel dimension channels[0]
        x = self.proj(x)  # (batch, mel_len, channels[0])
        # Permute to (batch, channels[0], mel_len) for 1D convolutions.
        x = x.permute(0, 2, 1)

        x_mask = sequence_mask(x.size(2), x_lengths)
        x_mask = x_mask.unsqueeze(1)  # (B, 1, T)
        x = self.pre(x, x_mask)

        # Pass through the encoder blocks
        for block in self.encoder_blocks:
            x = block(x, x_mask=x_mask)

        # Permute back to (batch, mel_len, latent_dim)
        x = x.permute(0, 2, 1)
        x = self.q_in_proj(x)
        xhat, indices = self.quantizer(x)
        x = self.q_out_proj(xhat)
        # Permute for the decoder
        x = x.permute(0, 2, 1)

        # Pass through the decoder blocks
        for block in self.decoder_blocks:
            x = block(x, x_mask=x_mask)

        x = self.post(x, x_mask)
        # Permute back to (batch, mel_len, channels[0])
        x = x.permute(0, 2, 1)
        # Final projection back to mel_channels
        x = self.out_proj(x)

        return x

    def encode(self, x, x_mask=None):
        """
        Encodes the input spectrogram into discrete latent indices.

        Args:
            x (torch.Tensor): Input tensor of shape (batch, mel_len, mel_channels).
          - x_mask: Tensor of shape (batch, mel_len), bool where padded positions are True.
                   (This mask will be passed to each ResidualBlock1D, which is assumed to apply
                   .masked_fill(x_mask, 0) before its activation calls.)
        Returns:
            indices (torch.Tensor): Discrete token indices from the vector quantizer.
        """
        # Project input to latent_dim
        x = self.proj(x)
        # Permute to (batch, latent_dim, mel_len) for convolutional operations
        x = x.permute(0, 2, 1)

        if x_mask is None:
            x_mask = torch.zeros((x.size(0), 1, x.size(2)), device=x.device).bool()

        x = self.pre(x, x_mask)

        # Pass through the encoder blocks
        for block in self.encoder_blocks:
            x = block(x, x_mask=x_mask)
        # Permute back to (batch, mel_len, latent_dim)
        x = x.permute(0, 2, 1)
        # Project to quantizer input dimension (e.g. 4)
        x = self.q_in_proj(x)
        # Quantize and obtain indices
        _, indices = self.quantizer(x)
        return indices.long()  # otherwise cross entropy loss bitches later

    def decode(self, indices, x_mask=None, return_hidden=False):
        """
        Decodes discrete latent indices into a reconstructed spectrogram.

        Args:
            indices (torch.Tensor): Discrete token indices from the vector quantizer.

        Returns:
            x (torch.Tensor): Reconstructed spectrogram of shape (batch, mel_len, mel_channels).
        """
        # Convert indices to quantized latent codes (shape: (batch, mel_len, 4))
        xhat = self.quantizer.indices_to_codes(indices)
        # Project quantized representation back to latent_dim
        x = self.q_out_proj(xhat)
        # Permute to (batch, latent_dim, mel_len) for convolutional operations

        x = x.permute(0, 2, 1)

        if x_mask is None:
            x_mask = torch.zeros((x.size(0), 1, x.size(2)), device=x.device).bool()

        # Pass through the decoder blocks
        for block in self.decoder_blocks:
            x = block(x, x_mask=x_mask)

        if return_hidden:
            last_hid = x.clone()

        x = self.post(x, x_mask)
        # Permute back to (batch, mel_len, latent_dim)
        x = x.permute(0, 2, 1)
        # Project back to the original mel_channels
        x = self.out_proj(x)
        if return_hidden:
            return x, last_hid

        return x


def get_pre_encoder(model_path: str, device: str or torch.device, channels = [384, 512, 768], kernel_sizes=[7, 5, 3], mel_channels=88):
    """
    Loads a Pre-Encoder model from a checkpoint file.

    Assumes the checkpoint was saved with the training script's structure,
    containing 'model_state_dict' and 'args' (or a compatible dict).

    Args:
        model_path (str): Path to the .pth checkpoint file.
        device (str or torch.device): The device to load the model onto ('cpu', 'cuda', etc.).

    Returns:
        tuple: A tuple containing:
            - model (nn.Module): The loaded ResNetAutoencoder1D model instance,
                                 moved to the specified device and set to eval mode.
            - model_args (argparse.Namespace or dict): The configuration arguments
                                                       used to initialize the model,
                                                       loaded from the checkpoint.
    Raises:
        FileNotFoundError: If the model_path does not exist.
        KeyError: If essential keys ('args', 'model_state_dict') are missing
                  from the checkpoint.
        RuntimeError: If load_state_dict fails (e.g., architecture mismatch).
        ImportError: If the ResNetAutoencoder1D class cannot be imported/found.
    """
    if not os.path.isfile(model_path):
        raise FileNotFoundError(f"Checkpoint file not found: {model_path}")

    print(f"Loading checkpoint from: {model_path}")
    checkpoint = torch.load(model_path, map_location='cpu', weights_only=False) # Load to CPU first

    # --- 2. Instantiate Model ---
    try:
        model = PreEncoder(mel_channels=mel_channels, channels=channels, kernel_sizes=kernel_sizes,
                             dropout=0.0, fsq_levels=[8, 5, 5, 5])
    except NameError:
         raise ImportError("ResNetAutoencoder1D class definition not found. Ensure model.py is accessible or the class is defined.")
    except Exception as e:
         raise RuntimeError(f"Failed to instantiate model with loaded config: {e}")


    # --- 3. Load Weights ---
    if 'model_state_dict' in checkpoint:
        pretrained_weights = checkpoint['model_state_dict']
        print("Found weights under 'model_state_dict' key.")

        # Optional: Handle 'module.' prefix (if saved using DataParallel/DDP)
        clean_weights = OrderedDict()
        has_module_prefix = False
        for k, v in pretrained_weights.items():
            if k.startswith('module.'):
                has_module_prefix = True
                clean_weights[k[7:]] = v # remove `module.`
            else:
                clean_weights[k] = v
        if has_module_prefix:
            print("Removed 'module.' prefix from weight keys.")
        pretrained_weights = clean_weights # Use the cleaned dictionary

        # Load the weights using strict=True (assumes exact match)
        try:
            model.load_state_dict(pretrained_weights, strict=True)
            print("Successfully loaded model weights.")
        except RuntimeError as e:
            print(f"Error loading state_dict (likely architecture mismatch): {e}")
            raise e # Re-raise the error

    else:
        raise KeyError(f"Checkpoint missing 'model_state_dict' key containing weights.")

    # --- 4. Final Steps ---
    model.to(device) # Move model to the target device
    model.eval()     # Set model to evaluation mode
    print(f"Model loaded onto {device} and set to evaluation mode.")

    return model