NeMo / nemo /collections /tts /modules /spectrogram_enhancer.py
camenduru's picture
thanks to NVIDIA ❤
7934b29
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# MIT License
#
# Copyright (c) 2020 Phil Wang
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# The following is largely based on code from https://github.com/lucidrains/stylegan2-pytorch
import math
from functools import partial
from math import log2
from typing import List
import torch
import torch.nn.functional as F
from einops import rearrange
from kornia.filters import filter2d
from nemo.collections.tts.parts.utils.helpers import mask_sequence_tensor
class Blur(torch.nn.Module):
def __init__(self):
super().__init__()
f = torch.Tensor([1, 2, 1])
self.register_buffer("f", f)
def forward(self, x):
f = self.f
f = f[None, None, :] * f[None, :, None]
return filter2d(x, f, normalized=True)
class EqualLinear(torch.nn.Module):
def __init__(self, in_dim, out_dim, lr_mul=1, bias=True):
super().__init__()
self.weight = torch.nn.Parameter(torch.randn(out_dim, in_dim))
if bias:
self.bias = torch.nn.Parameter(torch.zeros(out_dim))
self.lr_mul = lr_mul
def forward(self, input):
return F.linear(input, self.weight * self.lr_mul, bias=self.bias * self.lr_mul)
class StyleMapping(torch.nn.Module):
def __init__(self, emb, depth, lr_mul=0.1):
super().__init__()
layers = []
for _ in range(depth):
layers.extend([EqualLinear(emb, emb, lr_mul), torch.nn.LeakyReLU(0.2, inplace=True)])
self.net = torch.nn.Sequential(*layers)
def forward(self, x):
x = F.normalize(x, dim=1)
return self.net(x)
class RGBBlock(torch.nn.Module):
def __init__(self, latent_dim, input_channel, upsample, channels=3):
super().__init__()
self.input_channel = input_channel
self.to_style = torch.nn.Linear(latent_dim, input_channel)
out_filters = channels
self.conv = Conv2DModulated(input_channel, out_filters, 1, demod=False)
self.upsample = (
torch.nn.Sequential(torch.nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), Blur(),)
if upsample
else None
)
def forward(self, x, prev_rgb, istyle):
style = self.to_style(istyle)
x = self.conv(x, style)
if prev_rgb is not None:
x = x + prev_rgb
if self.upsample is not None:
x = self.upsample(x)
return x
class Conv2DModulated(torch.nn.Module):
"""
Modulated convolution.
For details refer to [1]
[1] Karras et. al. - Analyzing and Improving the Image Quality of StyleGAN (https://arxiv.org/abs/1912.04958)
"""
def __init__(
self, in_chan, out_chan, kernel, demod=True, stride=1, dilation=1, eps=1e-8, **kwargs,
):
super().__init__()
self.filters = out_chan
self.demod = demod
self.kernel = kernel
self.stride = stride
self.dilation = dilation
self.weight = torch.nn.Parameter(torch.randn((out_chan, in_chan, kernel, kernel)))
self.eps = eps
torch.nn.init.kaiming_normal_(self.weight, a=0, mode="fan_in", nonlinearity="leaky_relu")
def _get_same_padding(self, size, kernel, dilation, stride):
return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2
def forward(self, x, y):
b, c, h, w = x.shape
w1 = y[:, None, :, None, None]
w2 = self.weight[None, :, :, :, :]
weights = w2 * (w1 + 1)
if self.demod:
d = torch.rsqrt((weights ** 2).sum(dim=(2, 3, 4), keepdim=True) + self.eps)
weights = weights * d
x = x.reshape(1, -1, h, w)
_, _, *ws = weights.shape
weights = weights.reshape(b * self.filters, *ws)
padding = self._get_same_padding(h, self.kernel, self.dilation, self.stride)
x = F.conv2d(x, weights, padding=padding, groups=b)
x = x.reshape(-1, self.filters, h, w)
return x
class GeneratorBlock(torch.nn.Module):
def __init__(
self, latent_dim, input_channels, filters, upsample=True, upsample_rgb=True, channels=1,
):
super().__init__()
self.upsample = torch.nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False) if upsample else None
self.to_style1 = torch.nn.Linear(latent_dim, input_channels)
self.to_noise1 = torch.nn.Linear(1, filters)
self.conv1 = Conv2DModulated(input_channels, filters, 3)
self.to_style2 = torch.nn.Linear(latent_dim, filters)
self.to_noise2 = torch.nn.Linear(1, filters)
self.conv2 = Conv2DModulated(filters, filters, 3)
self.activation = torch.nn.LeakyReLU(0.2, inplace=True)
self.to_rgb = RGBBlock(latent_dim, filters, upsample_rgb, channels)
def forward(self, x, prev_rgb, istyle, inoise):
if self.upsample is not None:
x = self.upsample(x)
inoise = inoise[:, : x.shape[2], : x.shape[3], :]
noise1 = self.to_noise1(inoise).permute((0, 3, 1, 2))
noise2 = self.to_noise2(inoise).permute((0, 3, 1, 2))
style1 = self.to_style1(istyle)
x = self.conv1(x, style1)
x = self.activation(x + noise1)
style2 = self.to_style2(istyle)
x = self.conv2(x, style2)
x = self.activation(x + noise2)
rgb = self.to_rgb(x, prev_rgb, istyle)
return x, rgb
class DiscriminatorBlock(torch.nn.Module):
def __init__(self, input_channels, filters, downsample=True):
super().__init__()
self.conv_res = torch.nn.Conv2d(input_channels, filters, 1, stride=(2 if downsample else 1))
self.net = torch.nn.Sequential(
torch.nn.Conv2d(input_channels, filters, 3, padding=1),
torch.nn.LeakyReLU(0.2, inplace=True),
torch.nn.Conv2d(filters, filters, 3, padding=1),
torch.nn.LeakyReLU(0.2, inplace=True),
)
self.downsample = (
torch.nn.Sequential(Blur(), torch.nn.Conv2d(filters, filters, 3, padding=1, stride=2))
if downsample
else None
)
def forward(self, x):
res = self.conv_res(x)
x = self.net(x)
if self.downsample is not None:
x = self.downsample(x)
x = (x + res) * (1 / math.sqrt(2))
return x
class Generator(torch.nn.Module):
def __init__(
self, n_bands, latent_dim, style_depth, network_capacity=16, channels=1, fmap_max=512, start_from_zero=True
):
super().__init__()
self.image_size = n_bands
self.latent_dim = latent_dim
self.num_layers = int(log2(n_bands) - 1)
self.style_depth = style_depth
self.style_mapping = StyleMapping(self.latent_dim, self.style_depth, lr_mul=0.1)
filters = [network_capacity * (2 ** (i + 1)) for i in range(self.num_layers)][::-1]
set_fmap_max = partial(min, fmap_max)
filters = list(map(set_fmap_max, filters))
init_channels = filters[0]
filters = [init_channels, *filters]
in_out_pairs = zip(filters[:-1], filters[1:])
self.initial_conv = torch.nn.Conv2d(filters[0], filters[0], 3, padding=1)
self.blocks = torch.nn.ModuleList([])
for ind, (in_chan, out_chan) in enumerate(in_out_pairs):
not_first = ind != 0
not_last = ind != (self.num_layers - 1)
block = GeneratorBlock(
latent_dim, in_chan, out_chan, upsample=not_first, upsample_rgb=not_last, channels=channels,
)
self.blocks.append(block)
for m in self.modules():
if isinstance(m, (torch.nn.Conv2d, torch.nn.Linear)):
torch.nn.init.kaiming_normal_(m.weight, a=0, mode="fan_in", nonlinearity="leaky_relu")
for block in self.blocks:
torch.nn.init.zeros_(block.to_noise1.weight)
torch.nn.init.zeros_(block.to_noise1.bias)
torch.nn.init.zeros_(block.to_noise2.weight)
torch.nn.init.zeros_(block.to_noise2.bias)
initial_block_size = n_bands // self.upsample_factor, 1
self.initial_block = torch.nn.Parameter(
torch.randn((1, init_channels, *initial_block_size)), requires_grad=False
)
if start_from_zero:
self.initial_block.data.zero_()
def add_scaled_condition(self, target: torch.Tensor, condition: torch.Tensor, condition_lengths: torch.Tensor):
*_, target_height, _ = target.shape
*_, height, _ = condition.shape
scale = height // target_height
# scale appropriately
condition = F.interpolate(condition, size=target.shape[-2:], mode="bilinear")
# add and mask
result = (target + condition) / 2
result = mask_sequence_tensor(result, (condition_lengths / scale).ceil().long())
return result
@property
def upsample_factor(self):
return 2 ** sum(1 for block in self.blocks if block.upsample)
def forward(self, condition: torch.Tensor, lengths: torch.Tensor, ws: List[torch.Tensor], noise: torch.Tensor):
batch_size, _, _, max_length = condition.shape
x = self.initial_block.expand(batch_size, -1, -1, max_length // self.upsample_factor)
rgb = None
x = self.initial_conv(x)
for style, block in zip(ws, self.blocks):
x, rgb = block(x, rgb, style, noise)
x = self.add_scaled_condition(x, condition, lengths)
rgb = self.add_scaled_condition(rgb, condition, lengths)
return rgb
class Discriminator(torch.nn.Module):
def __init__(
self, n_bands, network_capacity=16, channels=1, fmap_max=512,
):
super().__init__()
num_layers = int(log2(n_bands) - 1)
num_init_filters = channels
blocks = []
filters = [num_init_filters] + [(network_capacity * 4) * (2 ** i) for i in range(num_layers + 1)]
set_fmap_max = partial(min, fmap_max)
filters = list(map(set_fmap_max, filters))
chan_in_out = list(zip(filters[:-1], filters[1:]))
blocks = []
for ind, (in_chan, out_chan) in enumerate(chan_in_out):
is_not_last = ind != (len(chan_in_out) - 1)
block = DiscriminatorBlock(in_chan, out_chan, downsample=is_not_last)
blocks.append(block)
self.blocks = torch.nn.ModuleList(blocks)
channel_last = filters[-1]
latent_dim = channel_last
self.final_conv = torch.nn.Conv2d(channel_last, channel_last, 3, padding=1)
self.to_logit = torch.nn.Linear(latent_dim, 1)
for m in self.modules():
if isinstance(m, (torch.nn.Conv2d, torch.nn.Linear)):
torch.nn.init.kaiming_normal_(m.weight, a=0, mode="fan_in", nonlinearity="leaky_relu")
def forward(self, x, condition: torch.Tensor, lengths: torch.Tensor):
for block in self.blocks:
x = block(x)
scale = condition.shape[-1] // x.shape[-1]
x = mask_sequence_tensor(x, (lengths / scale).ceil().long())
x = self.final_conv(x)
scale = condition.shape[-1] // x.shape[-1]
x = mask_sequence_tensor(x, (lengths / scale).ceil().long())
x = x.mean(axis=-2)
x = (x / rearrange(lengths / scale, "b -> b 1 1")).sum(axis=-1)
x = self.to_logit(x)
return x.squeeze()