klemenk commited on
Commit
0ef180e
·
verified ·
1 Parent(s): 505e173

Update modeling_wavtokenizer.py

Browse files
Files changed (1) hide show
  1. modeling_wavtokenizer.py +8 -9
modeling_wavtokenizer.py CHANGED
@@ -39,28 +39,27 @@ def convert_audio(wav, sr, target_sr, target_channels=1):
39
 
40
 
41
  # =============================================================================
42
- # Weight-Normalized Conv1d (matching checkpoint's weight_g/weight_v structure)
43
  # =============================================================================
44
 
45
  class WNConv1d(nn.Module):
46
- """Weight-normalized Conv1d matching checkpoint structure with weight_g/weight_v."""
47
  def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
48
  super().__init__()
49
- self.conv = nn.utils.weight_norm(
50
- nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
51
- )
52
 
53
  def forward(self, x):
54
  return self.conv(x)
55
 
56
 
57
  class WNConvTranspose1d(nn.Module):
58
- """Weight-normalized ConvTranspose1d."""
59
  def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True):
60
  super().__init__()
61
- self.convtr = nn.utils.weight_norm(
62
- nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride, padding, output_padding, groups, bias)
63
- )
64
 
65
  def forward(self, x):
66
  return self.convtr(x)
 
39
 
40
 
41
  # =============================================================================
42
+ # Weight-Normalized Conv1d (using parametrizations API to match checkpoint)
43
  # =============================================================================
44
 
45
  class WNConv1d(nn.Module):
46
+ """Weight-normalized Conv1d using parametrizations API to match checkpoint structure."""
47
  def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
48
  super().__init__()
49
+ conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
50
+ # Use parametrizations API (PyTorch 2.0+) to match checkpoint naming
51
+ self.conv = nn.utils.parametrizations.weight_norm(conv)
52
 
53
  def forward(self, x):
54
  return self.conv(x)
55
 
56
 
57
  class WNConvTranspose1d(nn.Module):
58
+ """Weight-normalized ConvTranspose1d using parametrizations API."""
59
  def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True):
60
  super().__init__()
61
+ convtr = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride, padding, output_padding, groups, bias)
62
+ self.convtr = nn.utils.parametrizations.weight_norm(convtr)
 
63
 
64
  def forward(self, x):
65
  return self.convtr(x)