File size: 4,765 Bytes
8a06b33 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | import torch
from .commons import fused_add_tanh_sigmoid_multiply
from packaging import version
is_pytorch2_1 = version.parse(torch.__version__) >= version.parse("2.1.0")
class WaveNet(torch.nn.Module):
def __init__(
self,
hidden_channels: int,
kernel_size: int,
dilation_rate,
n_layers: int,
gin_channels: int = 0,
p_dropout: int = 0,
):
super().__init__()
assert kernel_size % 2 == 1, "Kernel size must be odd for proper padding."
self.hidden_channels = hidden_channels
self.kernel_size = (kernel_size,)
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.gin_channels = gin_channels
self.p_dropout = p_dropout
self.n_channels_tensor = torch.IntTensor([hidden_channels])
self.in_layers = torch.nn.ModuleList()
self.res_skip_layers = torch.nn.ModuleList()
self.drop = torch.nn.Dropout(p_dropout)
if gin_channels:
if is_pytorch2_1:
self.cond_layer = torch.nn.utils.parametrizations.weight_norm(
torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1),
name="weight",
)
else:
self.cond_layer = torch.nn.utils.weight_norm(
torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1),
name="weight",
)
dilations = [dilation_rate**i for i in range(n_layers)]
paddings = [(kernel_size * d - d) // 2 for d in dilations]
for i in range(n_layers):
if is_pytorch2_1:
self.in_layers.append(
torch.nn.utils.parametrizations.weight_norm(
torch.nn.Conv1d(
hidden_channels,
2 * hidden_channels,
kernel_size,
dilation=dilations[i],
padding=paddings[i],
),
name="weight",
)
)
else:
self.in_layers.append(
torch.nn.utils.weight_norm(
torch.nn.Conv1d(
hidden_channels,
2 * hidden_channels,
kernel_size,
dilation=dilations[i],
padding=paddings[i],
),
name="weight",
)
)
res_skip_channels = (
hidden_channels if i == n_layers - 1 else 2 * hidden_channels
)
if is_pytorch2_1:
self.res_skip_layers.append(
torch.nn.utils.parametrizations.weight_norm(
torch.nn.Conv1d(hidden_channels, res_skip_channels, 1),
name="weight",
)
)
else:
self.res_skip_layers.append(
torch.nn.utils.weight_norm(
torch.nn.Conv1d(hidden_channels, res_skip_channels, 1),
name="weight",
)
)
def forward(self, x, x_mask, g=None):
output = x.clone().zero_()
g = self.cond_layer(g) if g is not None else None
for i in range(self.n_layers):
x_in = self.in_layers[i](x)
g_l = (
g[
:,
i * 2 * self.hidden_channels : (i + 1) * 2 * self.hidden_channels,
:,
]
if g is not None
else 0
)
acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, self.n_channels_tensor)
acts = self.drop(acts)
res_skip_acts = self.res_skip_layers[i](acts)
if i < self.n_layers - 1:
res_acts = res_skip_acts[:, : self.hidden_channels, :]
x = (x + res_acts) * x_mask
output = output + res_skip_acts[:, self.hidden_channels :, :]
else:
output = output + res_skip_acts
return output * x_mask
def remove_weight_norm(self):
if self.gin_channels:
torch.nn.utils.remove_weight_norm(self.cond_layer)
for layer in self.in_layers:
torch.nn.utils.remove_weight_norm(layer)
for layer in self.res_skip_layers:
torch.nn.utils.remove_weight_norm(layer)
|