RE-USE / models /codec_module_time_d4.py
szuweifu's picture
Upload 11 files
d78f08c verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import torch
import torch.nn as nn
import numpy as np
from einops import rearrange
def get_padding_2d(kernel_size, dilation=(1, 1)):
"""
Calculate the padding size for a 2D convolutional layer.
Args:
- kernel_size (tuple): Size of the convolutional kernel (height, width).
- dilation (tuple, optional): Dilation rate of the convolution (height, width). Defaults to (1, 1).
Returns:
- tuple: Calculated padding size (height, width).
"""
return (int((kernel_size[0] * dilation[0] - dilation[0]) / 2),
int((kernel_size[1] * dilation[1] - dilation[1]) / 2))
class SPConvTranspose2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, r=1):
super(SPConvTranspose2d, self).__init__()
self.pad1 = nn.ConstantPad2d((1, 1, 0, 0), value=0.)
self.out_channels = out_channels
self.conv = nn.Conv2d(in_channels, out_channels * r, kernel_size=kernel_size, stride=(1, 1))
self.r = r
def forward(self, x):
x = self.pad1(x)
out = self.conv(x)
batch_size, nchannels, H, W = out.shape
out = out.view((batch_size, self.r, nchannels // self.r, H, W))
out = out.permute(0, 2, 3, 4, 1)
out = out.contiguous().view((batch_size, nchannels // self.r, H, -1))
return out
class DenseBlock(nn.Module):
"""
DenseBlock module consisting of multiple convolutional layers with dilation.
"""
def __init__(self, cfg, kernel_size=(3, 3), depth=4):
super(DenseBlock, self).__init__()
self.cfg = cfg
self.depth = depth
self.dense_block = nn.ModuleList()
self.hid_feature = cfg['model_cfg']['hid_feature']
for i in range(depth):
dil = 2 ** i
dense_conv = nn.Sequential(
nn.Conv2d(self.hid_feature * (i + 1), self.hid_feature, kernel_size,
dilation=(dil, 1), padding=get_padding_2d(kernel_size, (dil, 1))),
nn.InstanceNorm2d(self.hid_feature, affine=True),
nn.PReLU(self.hid_feature)
)
self.dense_block.append(dense_conv)
def forward(self, x):
skip = x
for i in range(self.depth):
x = self.dense_block[i](skip)
skip = torch.cat([x, skip], dim=1)
return x
class DenseEncoder(nn.Module):
"""
DenseEncoder module consisting of initial convolution, dense block, and a final convolution.
"""
def __init__(self, cfg):
super(DenseEncoder, self).__init__()
self.cfg = cfg
self.input_channel = cfg['model_cfg']['input_channel']
self.hid_feature = cfg['model_cfg']['hid_feature']
self.dense_conv_1 = nn.Sequential(
nn.Conv2d(self.input_channel, self.hid_feature, (1, 1)),
nn.InstanceNorm2d(self.hid_feature, affine=True),
nn.PReLU(self.hid_feature)
)
self.dense_block = DenseBlock(cfg, depth=4)
self.dense_conv_2 = nn.Sequential(
nn.Conv2d(self.hid_feature, self.hid_feature, (1, 3), stride=(4, 2)),
nn.InstanceNorm2d(self.hid_feature, affine=True),
nn.PReLU(self.hid_feature)
)
def forward(self, x):
x = self.dense_conv_1(x) # [batch, hid_feature, time, freq]
x = self.dense_block(x) # [batch, hid_feature, time, freq]
x = self.dense_conv_2(x) # [batch, hid_feature, time, freq//2]
return x
class MagDecoder(nn.Module):
"""
MagDecoder module for decoding magnitude information.
"""
def __init__(self, cfg):
super(MagDecoder, self).__init__()
self.dense_block = DenseBlock(cfg, depth=4)
self.hid_feature = cfg['model_cfg']['hid_feature']
self.output_channel = cfg['model_cfg']['output_channel']
self.n_fft = cfg['stft_cfg']['n_fft']
self.beta = cfg['model_cfg']['beta']
self.up_conv1 = nn.Sequential(
SPConvTranspose2d(self.hid_feature, self.hid_feature, (1, 3), 2),
nn.InstanceNorm2d(self.hid_feature, affine=True),
nn.PReLU(self.hid_feature)
)
self.up_conv2 = nn.Sequential(
SPConvTranspose2d(self.hid_feature, self.hid_feature, (1, 3), 4),
nn.InstanceNorm2d(self.hid_feature, affine=True),
nn.PReLU(self.hid_feature)
)
self.final_conv = nn.Conv2d(self.hid_feature, self.output_channel, (1, 1))
def forward(self, x):
x = self.dense_block(x)
x = self.up_conv1(x)
x = self.up_conv2(x.permute(0,1,3,2)).permute(0,1,3,2)
x = self.final_conv(x)
return x
class PhaseDecoder(nn.Module):
"""
PhaseDecoder module for decoding phase information.
"""
def __init__(self, cfg):
super(PhaseDecoder, self).__init__()
self.dense_block = DenseBlock(cfg, depth=4)
self.hid_feature = cfg['model_cfg']['hid_feature']
self.output_channel = cfg['model_cfg']['output_channel']
self.up_conv1 = nn.Sequential(
SPConvTranspose2d(self.hid_feature, self.hid_feature, (1, 3), 2),
nn.InstanceNorm2d(self.hid_feature, affine=True),
nn.PReLU(self.hid_feature)
)
self.up_conv2 = nn.Sequential(
SPConvTranspose2d(self.hid_feature, self.hid_feature, (1, 3), 4),
nn.InstanceNorm2d(self.hid_feature, affine=True),
nn.PReLU(self.hid_feature)
)
self.phase_conv_r = nn.Conv2d(self.hid_feature, self.output_channel, (1, 1))
self.phase_conv_i = nn.Conv2d(self.hid_feature, self.output_channel, (1, 1))
def forward(self, x):
x = self.dense_block(x)
x = self.up_conv1(x)
x = self.up_conv2(x.permute(0,1,3,2)).permute(0,1,3,2)
x_r = self.phase_conv_r(x)
x_i = self.phase_conv_i(x)
x = torch.atan2(x_i, x_r)
return x