AudioGAN / generator.py
SeaSky1027's picture
Add CLAP & HiFiGAN
8e60cc8
import torch
from torch import nn
from torch.nn import init
from torch.nn import functional as F
class SpectralNorm:
def __init__(self, name):
self.name = name
def compute_weight(self, module):
weight = getattr(module, self.name + '_orig')
u = getattr(module, self.name + '_u')
size = weight.size()
weight_mat = weight.contiguous().view(size[0], -1)
with torch.no_grad():
v = weight_mat.t() @ u
v = v / v.norm()
u = weight_mat @ v
u = u / u.norm()
sigma = u @ weight_mat @ v
weight_sn = weight / sigma
return weight_sn, u, sigma
@staticmethod
def apply(module, name):
fn = SpectralNorm(name)
weight = getattr(module, name)
del module._parameters[name]
module.register_parameter(name + '_orig', weight)
input_size = weight.size(0)
u = weight.new_empty(input_size).normal_()
module.register_buffer(name, weight)
module.register_buffer(name + '_u', u)
module.register_buffer(name + '_sv', torch.ones(1).squeeze())
module.register_forward_pre_hook(fn)
return fn
def __call__(self, module, input):
weight_sn, u, sigma = self.compute_weight(module)
setattr(module, self.name, weight_sn)
setattr(module, self.name + '_u', u)
setattr(module, self.name + '_sv', sigma)
def spectral_norm(module, name='weight'):
SpectralNorm.apply(module, name)
return module
def spectral_init(module, gain=1):
init.xavier_uniform_(module.weight, gain)
if module.bias is not None:
module.bias.data.zero_()
return spectral_norm(module)
class ConditionalNorm(nn.Module):
def __init__(self, in_channel, condition_dim):
super().__init__()
self.bn = nn.BatchNorm2d(in_channel, affine=False)
self.linear1 = nn.Linear(condition_dim, in_channel)
self.linear2 = nn.Linear(condition_dim, in_channel)
def forward(self, input, condition):
out = self.bn(input)
gamma, beta = self.linear1(condition), self.linear2(condition)
gamma = gamma.unsqueeze(2).unsqueeze(3)
beta = beta.unsqueeze(2).unsqueeze(3)
out = gamma * out + beta
return out
class ConvBlock(nn.Module):
def __init__(self, in_channel, out_channel, kernel_size=[3, 3],
padding=1, stride=1, condition_dim=None, bn=True,
activation=F.relu, upsample=True, downsample=False):
super().__init__()
gain = 2 ** 0.5
self.conv1 = spectral_init(nn.Conv2d(in_channel, out_channel,
kernel_size, stride, padding,
bias=False if bn else True),
gain=gain)
self.conv2 = spectral_init(nn.Conv2d(out_channel, out_channel,
kernel_size, stride, padding,
bias=False if bn else True),
gain=gain)
self.skip_proj = False
if in_channel != out_channel or upsample or downsample:
self.conv_skip = spectral_init(nn.Conv2d(in_channel, out_channel,
1, 1, 0))
self.skip_proj = True
self.upsample = upsample
self.downsample = downsample
self.activation = activation
self.bn = bn
if bn:
self.norm1 = ConditionalNorm(in_channel, condition_dim)
self.norm2 = ConditionalNorm(out_channel, condition_dim)
def forward(self, input, condition=None, condition1=None):
out = input
if self.bn:
out = self.norm1(out, condition)
out = self.activation(out)
if self.upsample:
out = F.interpolate(out, scale_factor=2, mode='nearest')
out = self.conv1(out)
if self.bn:
out = self.norm2(out, condition)
out = self.activation(out)
out = self.conv2(out)
if self.downsample:
out = F.avg_pool2d(out, 2)
if self.skip_proj:
skip = input
if self.upsample:
skip = F.interpolate(skip, scale_factor=2, mode='nearest')
skip = self.conv_skip(skip)
if self.downsample:
skip = F.avg_pool2d(skip, 2)
else:
skip = input
return out + skip
class SelfAttention(nn.Module):
def __init__(self, in_channel, embed_dim, gain=2 ** 0.5):
super().__init__()
self.query = spectral_init(nn.Conv1d(in_channel, embed_dim, 1),
gain=gain)
self.key = spectral_init(nn.Conv1d(in_channel, embed_dim, 1),
gain=gain)
self.value = spectral_init(nn.Conv1d(in_channel, in_channel, 1),
gain=gain)
self.gamma = nn.Parameter(torch.tensor(0.0))
def forward(self, input): # [bsz, channel, freq, time]
shape = input.shape
flatten = input.view(shape[0], shape[1], -1) # [bsz, channel, freq*time]
query = self.query(flatten).permute(0, 2, 1)
key = self.key(flatten)
value = self.value(flatten)
query_key = torch.bmm(query, key) # [bsz, freq*time, freq*time]
attention_map = F.softmax(query_key, 1)
out = torch.bmm(value, attention_map)
out = out.view(*shape)
out = self.gamma * out + input
return (out, attention_map)
class CrossAttention(nn.Module):
def __init__(self, in_channel, cond_channel, embed_dim, gain=2 ** 0.5):
super().__init__()
self.key = spectral_init(nn.Conv1d(cond_channel, embed_dim, 1),
gain=gain)
self.value = spectral_init(nn.Conv1d(cond_channel, in_channel, 1),
gain=gain)
self.query = spectral_init(nn.Conv1d(in_channel, embed_dim, 1),
gain=gain)
self.gamma = nn.Parameter(torch.tensor(0.0))
def forward(self, input, condition, sequence_lengths=None):
# input : mel [bsz, channel, freq, time] or sentence [bsz, channel]
# condition : sentence [bsz, channel] or word [bsz, word_num, channel]
input_shape = input.shape
if len(input.shape) == 4: # mel [bsz, channel, freq, time]
batch_size, c, w, h = input.shape
num = w * h
x = input.reshape([batch_size, c, num]) #[bsz, channel, input_num]
elif len(input.shape) == 2: # sentence [bsz, channel]
batch_size, c = input.shape
num = 1
x = input.unsqueeze(2) # [bsz, channel, input_num]
if len(condition.shape) == 2: # sentence [bsz, channel]
condition = condition.unsqueeze(2) # [bsz, channel, cond_num]
else: # word [bsz, word_num, channel]
condition = condition.permute(0, 2, 1) # [bsz, channel, cond_num]
query = self.query(x).permute(0, 2, 1) # [bsz, input_num, channel]
key = self.key(condition) # [bsz, channel, cond_num]
value = self.value(condition).permute(0, 2, 1) # [bsz, cond_num, channel]
attention_map = torch.bmm(query, key) # [bsz, input_num, cond_num]
if sequence_lengths is not None: # condition is word embedding
total_len = condition.shape[2]
mask = torch.tile(torch.arange(total_len), [batch_size, num, 1]).to(condition.device)
for i in range(batch_size):
sequence_lengths_i = sequence_lengths[i]
mask[i,:,:] = mask[i,:,:] >= sequence_lengths_i.item()
attention_map = attention_map + mask * (-1e9)
attention_map = F.softmax(attention_map, dim=-1) # [bsz, input_num, cond_num]
out = torch.bmm(attention_map, value).permute(0, 2, 1) # [bsz, input_num, channel]
out = out.permute(0, 2, 1).reshape(input_shape).squeeze()
out = self.gamma * out + input
return out, attention_map
class Spec_Attention(nn.Module):
def __init__(self, in_channel, cond_channel=None, embed_dim=64, gain=2 ** 0.5):
super().__init__()
if cond_channel is None:
cond_channel = in_channel
self.f_query = spectral_init(nn.Conv1d(in_channel, embed_dim, 1),
gain=gain)
self.t_key = spectral_init(nn.Conv1d(cond_channel, embed_dim, 1),
gain=gain)
self.t_query = spectral_init(nn.Conv1d(in_channel, embed_dim, 1),
gain=gain)
self.f_key = spectral_init(nn.Conv1d(cond_channel, embed_dim, 1),
gain=gain)
self.value = spectral_init(nn.Conv1d(cond_channel, in_channel, 1),
gain=gain)
self.gamma = nn.Parameter(torch.tensor(0.0))
def forward(self, input, condition=None, sequence_lengths=None):
# input : mel [bsz, channel, freq, time]
# condition : sentence [bsz, channel] or word [bsz, word_num, channel]
batch_size, c, f, t = input.shape
freq_embedding = input.mean(dim=3) # [bsz, channel, freq]
time_embedding = input.mean(dim=2) # [bsz, channel, time]
if condition is not None:
if len(condition.shape) == 2: # sentence [bsz, channel]
condition = condition.unsqueeze(2) # [bsz, channel, 1]
else: # word [bsz, word_num, channel]
condition = condition.permute(0, 2, 1) # [bsz, channel, cond_num]
t_condition = condition
f_condition = condition
else:
t_condition = time_embedding
f_condition = freq_embedding
f_query = self.f_query(freq_embedding).permute(0, 2, 1) # [bsz, freq, channel]
t_key = self.t_key(t_condition) # [bsz, channel, time] or [bsz, channel, cond_num]
freq_cond_map = torch.bmm(f_query, t_key) # [bsz, freq, time] or [bsz, freq, cond_num]
t_query = self.t_query(time_embedding).permute(0, 2, 1) # [bsz, time, channel]
f_key = self.f_key(f_condition) # [bsz, channel, freq] or [bsz, channel, cond_num]
time_cond_map = torch.bmm(t_query, f_key) # [bsz, time, freq] or [bsz, time, cond_num]
if sequence_lengths is not None: # condition is word embedding
total_len = condition.shape[2]
mask = torch.arange(total_len, device=condition.device)[None, None, :]
mask = mask >= sequence_lengths[:, None, None]
freq_cond_map = freq_cond_map + mask * (-1e9)
time_cond_map = time_cond_map + mask * (-1e9)
freq_cond_map = F.softmax(freq_cond_map, dim=-1) # [bsz, freq, time] or [bsz, freq, cond_num]
time_cond_map = F.softmax(time_cond_map, dim=-1) # [bsz, time, freq] or [bsz, time, cond_num]
if condition is None:
freq_time_embedding = input.reshape([batch_size, c, f*t]) # [bsz, channel, freq*time]
weight_map = torch.add(freq_cond_map, time_cond_map.permute(0, 2, 1)).reshape([batch_size, f*t]).unsqueeze(-1) # [bsz, freq*time, 1]
value = self.value(freq_time_embedding).permute(0, 2, 1) # [bsz, freq*time, channel]
out = torch.mul(value, weight_map).permute(0, 2, 1).reshape(batch_size, c, f, t) # [bsz, channel, freq, time]
else:
freq_cond_map = torch.tile(freq_cond_map.unsqueeze(2), [1, 1, t, 1]) # [bsz, freq, time, cond_num]
time_cond_map = torch.tile(time_cond_map.unsqueeze(1), [1, f, 1, 1]) # [bsz, freq, time, cond_num]
weight_map = torch.add(freq_cond_map, time_cond_map).reshape([batch_size, f*t, -1]) # [bsz, freq*time, cond_num]
value = self.value(condition).permute(0, 2, 1) # [bsz, cond_num, channel]
out = torch.bmm(weight_map, value).permute(0, 2, 1).reshape(batch_size, c, f, t) # [bsz, channel, freq, time]
out = self.gamma * out + input
return out, weight_map
class Multi_Triple_Attention(nn.Module):
def __init__(self, in_channel, sentence_embed_dim=768, word_embed_dim=768, embed_dim=64, reverse=False, gain=2 ** 0.5, n_heads=2, attention_list="self,word,sentence", spec_attention=False):
super().__init__()
self.reverse = reverse
self.n_heads = n_heads
self.attention_list = attention_list.split(",")
if "self" in self.attention_list:
if spec_attention:
self.self_attention_modules = nn.ModuleList([Spec_Attention(in_channel, embed_dim=embed_dim) for _ in range(self.n_heads)])
else:
self.self_attention_modules = nn.ModuleList([SelfAttention(in_channel, embed_dim=embed_dim) for _ in range(self.n_heads)])
if "word" in self.attention_list:
if spec_attention:
self.cross_attention_for_word_modules = nn.ModuleList([Spec_Attention(in_channel, cond_channel=word_embed_dim, embed_dim=embed_dim) for _ in range(self.n_heads)])
else:
self.cross_attention_for_word_modules = nn.ModuleList([CrossAttention(in_channel, cond_channel=word_embed_dim, embed_dim=embed_dim) for _ in range(self.n_heads)])
if "sentence" in self.attention_list:
if spec_attention:
self.cross_attention_for_sent_modules = nn.ModuleList([Spec_Attention(in_channel, cond_channel=sentence_embed_dim, embed_dim=embed_dim) for _ in range(self.n_heads)])
else:
self.cross_attention_for_sent_modules = nn.ModuleList([CrossAttention(in_channel, cond_channel=sentence_embed_dim, embed_dim=embed_dim) for _ in range(self.n_heads)])
self.gamma = [nn.Parameter(torch.tensor(0.0)) for _ in range(self.n_heads)]
self.conv_for_attention = spectral_init(nn.Conv1d(in_channel * len(self.attention_list), in_channel, 1), gain=gain)
self.out = spectral_init(nn.Conv1d(in_channel * self.n_heads, in_channel, 1), gain=gain)
def forward(self, input, sentence_embedding, word_embedding, sequence_lengths):
batch_size, c, f, t = input.shape
x = input
result = []
for head in range(self.n_heads):
out_list = []
if "self" in self.attention_list:
x_self, attention_map = self.self_attention_modules[head](x)
out_list.append(x_self)
if "word" in self.attention_list:
x_word, attention_map = self.cross_attention_for_word_modules[head](x, word_embedding, sequence_lengths)
out_list.append(x_word)
if "sentence" in self.attention_list:
x_sent, attention_map = self.cross_attention_for_sent_modules[head](x, sentence_embedding)
out_list.append(x_sent)
out = torch.cat(out_list, dim=1)
out = self.conv_for_attention(out.reshape([batch_size, c*len(out_list), f*t])).reshape([batch_size, c, f, t])
out = self.gamma[head] * out + x
result.append(out)
x = torch.cat(result, dim=1)
x = self.out(x.reshape([batch_size, c * self.n_heads, f*t])).reshape([batch_size, c, f, t])
x = input + x
return x
class Generator(nn.Module):
def __init__(self, model_config=None):
super().__init__()
if model_config is None:
model_config = {
"noise_dim":128,
"g_chaneel":128,
"n_heads":10,
"sentence_embed_dim":512,
"word_embed_dim":768,
"attention_list":["self,word,sentence", "word,sentence", "sentence"],
"spec_attention":True,
}
self.noise_dim = model_config['noise_dim']
self.channel = model_config['g_chaneel']
self.n_heads = model_config['n_heads']
self.sentence_embed_dim = model_config['sentence_embed_dim']
self.word_embed_dim = model_config['word_embed_dim']
self.attention_list = model_config['attention_list']
self.spec_attention = model_config['spec_attention']
channel_list = [self.channel, self.channel, self.channel//2, self.channel//2, self.channel//4, self.channel//4, self.channel//4, self.channel//8, self.channel//8]
self.lin_code = spectral_init(nn.Linear(self.noise_dim, channel_list[0] * 2 * 32))
self.conv1 = ConvBlock(channel_list[0], channel_list[1], condition_dim=self.sentence_embed_dim)
self.conv2 = ConvBlock(channel_list[1], channel_list[2], condition_dim=self.sentence_embed_dim)
self.multi_triple_attention_1 = Multi_Triple_Attention(channel_list[2],
sentence_embed_dim=self.sentence_embed_dim,
word_embed_dim=self.word_embed_dim,
embed_dim=channel_list[2],
reverse=False,
n_heads=self.n_heads,
attention_list=self.attention_list[0],
spec_attention=self.spec_attention)
self.conv3 = ConvBlock(channel_list[2], channel_list[3], condition_dim=self.sentence_embed_dim)
self.conv4 = ConvBlock(channel_list[3], channel_list[4], condition_dim=self.sentence_embed_dim, upsample=False)
self.multi_triple_attention_2 = Multi_Triple_Attention(channel_list[4],
sentence_embed_dim=self.sentence_embed_dim,
word_embed_dim=self.word_embed_dim,
embed_dim=channel_list[4],
reverse=False,
n_heads=self.n_heads,
attention_list=self.attention_list[1],
spec_attention=self.spec_attention)
self.conv5 = ConvBlock(channel_list[4], channel_list[5], condition_dim=self.sentence_embed_dim)
self.conv6 = ConvBlock(channel_list[5], channel_list[6], condition_dim=self.sentence_embed_dim, upsample=False)
self.multi_triple_attention_3 = Multi_Triple_Attention(channel_list[6],
sentence_embed_dim=self.sentence_embed_dim,
word_embed_dim=self.word_embed_dim,
embed_dim=channel_list[6],
reverse=False,
n_heads=self.n_heads,
attention_list=self.attention_list[2],
spec_attention=self.spec_attention)
self.conv7 = ConvBlock(channel_list[6], channel_list[7], condition_dim=self.sentence_embed_dim)
self.bn = nn.BatchNorm2d(channel_list[8])
self.colorize = spectral_init(nn.Conv1d(channel_list[8], 1, 1))
def forward(self, z, sentence_embedding, word_embedding, sequence_lengths):
batch_size = z.shape[0]
x = self.lin_code(z)
x = x.view(-1, self.channel, 2, 32) # [bsz, c, 2, 32]
x = self.conv1(x, sentence_embedding) # [bsz, c, 4, 64]
x = self.conv2(x, sentence_embedding) # [bsz, c, 8, 128]
x = self.multi_triple_attention_1(x, sentence_embedding, word_embedding, sequence_lengths) # [bsz, c, 8, 128]
x = self.conv3(x, sentence_embedding) # [bsz, c, 16, 256]
x = self.conv4(x, sentence_embedding) # [bsz, c, 16, 256]
x = self.multi_triple_attention_2(x, sentence_embedding, word_embedding, sequence_lengths) # [bsz, c, 16, 256]
x = self.conv5(x, sentence_embedding) # [bsz, c, 32, 512]
x = self.conv6(x, sentence_embedding) # [bsz, c, 32, 512]
x = self.multi_triple_attention_3(x, sentence_embedding, word_embedding, sequence_lengths) # [bsz, c, 32, 512]
x = self.conv7(x, sentence_embedding) # [bsz, c, 64, 1024]
x = self.bn(x) # [bsz, c // 8, 64, 1024]
x = F.relu(x)
x = self.colorize(x.reshape([batch_size, -1, 64*1024])).reshape([batch_size, 1, 64, 1024]) # [bsz, 1, 64, 1024]
return x