Update modeling_wavtokenizer.py
Browse files- modeling_wavtokenizer.py +8 -9
modeling_wavtokenizer.py
CHANGED
|
@@ -39,28 +39,27 @@ def convert_audio(wav, sr, target_sr, target_channels=1):
|
|
| 39 |
|
| 40 |
|
| 41 |
# =============================================================================
|
| 42 |
-
# Weight-Normalized Conv1d (
|
| 43 |
# =============================================================================
|
| 44 |
|
| 45 |
class WNConv1d(nn.Module):
|
| 46 |
-
"""Weight-normalized Conv1d
|
| 47 |
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
|
| 48 |
super().__init__()
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
)
|
| 52 |
|
| 53 |
def forward(self, x):
|
| 54 |
return self.conv(x)
|
| 55 |
|
| 56 |
|
| 57 |
class WNConvTranspose1d(nn.Module):
|
| 58 |
-
"""Weight-normalized ConvTranspose1d."""
|
| 59 |
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True):
|
| 60 |
super().__init__()
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
)
|
| 64 |
|
| 65 |
def forward(self, x):
|
| 66 |
return self.convtr(x)
|
|
|
|
| 39 |
|
| 40 |
|
| 41 |
# =============================================================================
|
| 42 |
+
# Weight-Normalized Conv1d (using parametrizations API to match checkpoint)
|
| 43 |
# =============================================================================
|
| 44 |
|
| 45 |
class WNConv1d(nn.Module):
|
| 46 |
+
"""Weight-normalized Conv1d using parametrizations API to match checkpoint structure."""
|
| 47 |
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
|
| 48 |
super().__init__()
|
| 49 |
+
conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
|
| 50 |
+
# Use parametrizations API (PyTorch 2.0+) to match checkpoint naming
|
| 51 |
+
self.conv = nn.utils.parametrizations.weight_norm(conv)
|
| 52 |
|
| 53 |
def forward(self, x):
|
| 54 |
return self.conv(x)
|
| 55 |
|
| 56 |
|
| 57 |
class WNConvTranspose1d(nn.Module):
|
| 58 |
+
"""Weight-normalized ConvTranspose1d using parametrizations API."""
|
| 59 |
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True):
|
| 60 |
super().__init__()
|
| 61 |
+
convtr = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride, padding, output_padding, groups, bias)
|
| 62 |
+
self.convtr = nn.utils.parametrizations.weight_norm(convtr)
|
|
|
|
| 63 |
|
| 64 |
def forward(self, x):
|
| 65 |
return self.convtr(x)
|