Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import nn | |
| from torch.nn import init | |
| from torch.nn import functional as F | |
| class SpectralNorm: | |
| def __init__(self, name): | |
| self.name = name | |
| def compute_weight(self, module): | |
| weight = getattr(module, self.name + '_orig') | |
| u = getattr(module, self.name + '_u') | |
| size = weight.size() | |
| weight_mat = weight.contiguous().view(size[0], -1) | |
| with torch.no_grad(): | |
| v = weight_mat.t() @ u | |
| v = v / v.norm() | |
| u = weight_mat @ v | |
| u = u / u.norm() | |
| sigma = u @ weight_mat @ v | |
| weight_sn = weight / sigma | |
| return weight_sn, u, sigma | |
| def apply(module, name): | |
| fn = SpectralNorm(name) | |
| weight = getattr(module, name) | |
| del module._parameters[name] | |
| module.register_parameter(name + '_orig', weight) | |
| input_size = weight.size(0) | |
| u = weight.new_empty(input_size).normal_() | |
| module.register_buffer(name, weight) | |
| module.register_buffer(name + '_u', u) | |
| module.register_buffer(name + '_sv', torch.ones(1).squeeze()) | |
| module.register_forward_pre_hook(fn) | |
| return fn | |
| def __call__(self, module, input): | |
| weight_sn, u, sigma = self.compute_weight(module) | |
| setattr(module, self.name, weight_sn) | |
| setattr(module, self.name + '_u', u) | |
| setattr(module, self.name + '_sv', sigma) | |
| def spectral_norm(module, name='weight'): | |
| SpectralNorm.apply(module, name) | |
| return module | |
| def spectral_init(module, gain=1): | |
| init.xavier_uniform_(module.weight, gain) | |
| if module.bias is not None: | |
| module.bias.data.zero_() | |
| return spectral_norm(module) | |
| class ConditionalNorm(nn.Module): | |
| def __init__(self, in_channel, condition_dim): | |
| super().__init__() | |
| self.bn = nn.BatchNorm2d(in_channel, affine=False) | |
| self.linear1 = nn.Linear(condition_dim, in_channel) | |
| self.linear2 = nn.Linear(condition_dim, in_channel) | |
| def forward(self, input, condition): | |
| out = self.bn(input) | |
| gamma, beta = self.linear1(condition), self.linear2(condition) | |
| gamma = gamma.unsqueeze(2).unsqueeze(3) | |
| beta = beta.unsqueeze(2).unsqueeze(3) | |
| out = gamma * out + beta | |
| return out | |
| class ConvBlock(nn.Module): | |
| def __init__(self, in_channel, out_channel, kernel_size=[3, 3], | |
| padding=1, stride=1, condition_dim=None, bn=True, | |
| activation=F.relu, upsample=True, downsample=False): | |
| super().__init__() | |
| gain = 2 ** 0.5 | |
| self.conv1 = spectral_init(nn.Conv2d(in_channel, out_channel, | |
| kernel_size, stride, padding, | |
| bias=False if bn else True), | |
| gain=gain) | |
| self.conv2 = spectral_init(nn.Conv2d(out_channel, out_channel, | |
| kernel_size, stride, padding, | |
| bias=False if bn else True), | |
| gain=gain) | |
| self.skip_proj = False | |
| if in_channel != out_channel or upsample or downsample: | |
| self.conv_skip = spectral_init(nn.Conv2d(in_channel, out_channel, | |
| 1, 1, 0)) | |
| self.skip_proj = True | |
| self.upsample = upsample | |
| self.downsample = downsample | |
| self.activation = activation | |
| self.bn = bn | |
| if bn: | |
| self.norm1 = ConditionalNorm(in_channel, condition_dim) | |
| self.norm2 = ConditionalNorm(out_channel, condition_dim) | |
| def forward(self, input, condition=None, condition1=None): | |
| out = input | |
| if self.bn: | |
| out = self.norm1(out, condition) | |
| out = self.activation(out) | |
| if self.upsample: | |
| out = F.interpolate(out, scale_factor=2, mode='nearest') | |
| out = self.conv1(out) | |
| if self.bn: | |
| out = self.norm2(out, condition) | |
| out = self.activation(out) | |
| out = self.conv2(out) | |
| if self.downsample: | |
| out = F.avg_pool2d(out, 2) | |
| if self.skip_proj: | |
| skip = input | |
| if self.upsample: | |
| skip = F.interpolate(skip, scale_factor=2, mode='nearest') | |
| skip = self.conv_skip(skip) | |
| if self.downsample: | |
| skip = F.avg_pool2d(skip, 2) | |
| else: | |
| skip = input | |
| return out + skip | |
| class SelfAttention(nn.Module): | |
| def __init__(self, in_channel, embed_dim, gain=2 ** 0.5): | |
| super().__init__() | |
| self.query = spectral_init(nn.Conv1d(in_channel, embed_dim, 1), | |
| gain=gain) | |
| self.key = spectral_init(nn.Conv1d(in_channel, embed_dim, 1), | |
| gain=gain) | |
| self.value = spectral_init(nn.Conv1d(in_channel, in_channel, 1), | |
| gain=gain) | |
| self.gamma = nn.Parameter(torch.tensor(0.0)) | |
| def forward(self, input): # [bsz, channel, freq, time] | |
| shape = input.shape | |
| flatten = input.view(shape[0], shape[1], -1) # [bsz, channel, freq*time] | |
| query = self.query(flatten).permute(0, 2, 1) | |
| key = self.key(flatten) | |
| value = self.value(flatten) | |
| query_key = torch.bmm(query, key) # [bsz, freq*time, freq*time] | |
| attention_map = F.softmax(query_key, 1) | |
| out = torch.bmm(value, attention_map) | |
| out = out.view(*shape) | |
| out = self.gamma * out + input | |
| return (out, attention_map) | |
| class CrossAttention(nn.Module): | |
| def __init__(self, in_channel, cond_channel, embed_dim, gain=2 ** 0.5): | |
| super().__init__() | |
| self.key = spectral_init(nn.Conv1d(cond_channel, embed_dim, 1), | |
| gain=gain) | |
| self.value = spectral_init(nn.Conv1d(cond_channel, in_channel, 1), | |
| gain=gain) | |
| self.query = spectral_init(nn.Conv1d(in_channel, embed_dim, 1), | |
| gain=gain) | |
| self.gamma = nn.Parameter(torch.tensor(0.0)) | |
| def forward(self, input, condition, sequence_lengths=None): | |
| # input : mel [bsz, channel, freq, time] or sentence [bsz, channel] | |
| # condition : sentence [bsz, channel] or word [bsz, word_num, channel] | |
| input_shape = input.shape | |
| if len(input.shape) == 4: # mel [bsz, channel, freq, time] | |
| batch_size, c, w, h = input.shape | |
| num = w * h | |
| x = input.reshape([batch_size, c, num]) #[bsz, channel, input_num] | |
| elif len(input.shape) == 2: # sentence [bsz, channel] | |
| batch_size, c = input.shape | |
| num = 1 | |
| x = input.unsqueeze(2) # [bsz, channel, input_num] | |
| if len(condition.shape) == 2: # sentence [bsz, channel] | |
| condition = condition.unsqueeze(2) # [bsz, channel, cond_num] | |
| else: # word [bsz, word_num, channel] | |
| condition = condition.permute(0, 2, 1) # [bsz, channel, cond_num] | |
| query = self.query(x).permute(0, 2, 1) # [bsz, input_num, channel] | |
| key = self.key(condition) # [bsz, channel, cond_num] | |
| value = self.value(condition).permute(0, 2, 1) # [bsz, cond_num, channel] | |
| attention_map = torch.bmm(query, key) # [bsz, input_num, cond_num] | |
| if sequence_lengths is not None: # condition is word embedding | |
| total_len = condition.shape[2] | |
| mask = torch.tile(torch.arange(total_len), [batch_size, num, 1]).to(condition.device) | |
| for i in range(batch_size): | |
| sequence_lengths_i = sequence_lengths[i] | |
| mask[i,:,:] = mask[i,:,:] >= sequence_lengths_i.item() | |
| attention_map = attention_map + mask * (-1e9) | |
| attention_map = F.softmax(attention_map, dim=-1) # [bsz, input_num, cond_num] | |
| out = torch.bmm(attention_map, value).permute(0, 2, 1) # [bsz, input_num, channel] | |
| out = out.permute(0, 2, 1).reshape(input_shape).squeeze() | |
| out = self.gamma * out + input | |
| return out, attention_map | |
| class Spec_Attention(nn.Module): | |
| def __init__(self, in_channel, cond_channel=None, embed_dim=64, gain=2 ** 0.5): | |
| super().__init__() | |
| if cond_channel is None: | |
| cond_channel = in_channel | |
| self.f_query = spectral_init(nn.Conv1d(in_channel, embed_dim, 1), | |
| gain=gain) | |
| self.t_key = spectral_init(nn.Conv1d(cond_channel, embed_dim, 1), | |
| gain=gain) | |
| self.t_query = spectral_init(nn.Conv1d(in_channel, embed_dim, 1), | |
| gain=gain) | |
| self.f_key = spectral_init(nn.Conv1d(cond_channel, embed_dim, 1), | |
| gain=gain) | |
| self.value = spectral_init(nn.Conv1d(cond_channel, in_channel, 1), | |
| gain=gain) | |
| self.gamma = nn.Parameter(torch.tensor(0.0)) | |
| def forward(self, input, condition=None, sequence_lengths=None): | |
| # input : mel [bsz, channel, freq, time] | |
| # condition : sentence [bsz, channel] or word [bsz, word_num, channel] | |
| batch_size, c, f, t = input.shape | |
| freq_embedding = input.mean(dim=3) # [bsz, channel, freq] | |
| time_embedding = input.mean(dim=2) # [bsz, channel, time] | |
| if condition is not None: | |
| if len(condition.shape) == 2: # sentence [bsz, channel] | |
| condition = condition.unsqueeze(2) # [bsz, channel, 1] | |
| else: # word [bsz, word_num, channel] | |
| condition = condition.permute(0, 2, 1) # [bsz, channel, cond_num] | |
| t_condition = condition | |
| f_condition = condition | |
| else: | |
| t_condition = time_embedding | |
| f_condition = freq_embedding | |
| f_query = self.f_query(freq_embedding).permute(0, 2, 1) # [bsz, freq, channel] | |
| t_key = self.t_key(t_condition) # [bsz, channel, time] or [bsz, channel, cond_num] | |
| freq_cond_map = torch.bmm(f_query, t_key) # [bsz, freq, time] or [bsz, freq, cond_num] | |
| t_query = self.t_query(time_embedding).permute(0, 2, 1) # [bsz, time, channel] | |
| f_key = self.f_key(f_condition) # [bsz, channel, freq] or [bsz, channel, cond_num] | |
| time_cond_map = torch.bmm(t_query, f_key) # [bsz, time, freq] or [bsz, time, cond_num] | |
| if sequence_lengths is not None: # condition is word embedding | |
| total_len = condition.shape[2] | |
| mask = torch.arange(total_len, device=condition.device)[None, None, :] | |
| mask = mask >= sequence_lengths[:, None, None] | |
| freq_cond_map = freq_cond_map + mask * (-1e9) | |
| time_cond_map = time_cond_map + mask * (-1e9) | |
| freq_cond_map = F.softmax(freq_cond_map, dim=-1) # [bsz, freq, time] or [bsz, freq, cond_num] | |
| time_cond_map = F.softmax(time_cond_map, dim=-1) # [bsz, time, freq] or [bsz, time, cond_num] | |
| if condition is None: | |
| freq_time_embedding = input.reshape([batch_size, c, f*t]) # [bsz, channel, freq*time] | |
| weight_map = torch.add(freq_cond_map, time_cond_map.permute(0, 2, 1)).reshape([batch_size, f*t]).unsqueeze(-1) # [bsz, freq*time, 1] | |
| value = self.value(freq_time_embedding).permute(0, 2, 1) # [bsz, freq*time, channel] | |
| out = torch.mul(value, weight_map).permute(0, 2, 1).reshape(batch_size, c, f, t) # [bsz, channel, freq, time] | |
| else: | |
| freq_cond_map = torch.tile(freq_cond_map.unsqueeze(2), [1, 1, t, 1]) # [bsz, freq, time, cond_num] | |
| time_cond_map = torch.tile(time_cond_map.unsqueeze(1), [1, f, 1, 1]) # [bsz, freq, time, cond_num] | |
| weight_map = torch.add(freq_cond_map, time_cond_map).reshape([batch_size, f*t, -1]) # [bsz, freq*time, cond_num] | |
| value = self.value(condition).permute(0, 2, 1) # [bsz, cond_num, channel] | |
| out = torch.bmm(weight_map, value).permute(0, 2, 1).reshape(batch_size, c, f, t) # [bsz, channel, freq, time] | |
| out = self.gamma * out + input | |
| return out, weight_map | |
| class Multi_Triple_Attention(nn.Module): | |
| def __init__(self, in_channel, sentence_embed_dim=768, word_embed_dim=768, embed_dim=64, reverse=False, gain=2 ** 0.5, n_heads=2, attention_list="self,word,sentence", spec_attention=False): | |
| super().__init__() | |
| self.reverse = reverse | |
| self.n_heads = n_heads | |
| self.attention_list = attention_list.split(",") | |
| if "self" in self.attention_list: | |
| if spec_attention: | |
| self.self_attention_modules = nn.ModuleList([Spec_Attention(in_channel, embed_dim=embed_dim) for _ in range(self.n_heads)]) | |
| else: | |
| self.self_attention_modules = nn.ModuleList([SelfAttention(in_channel, embed_dim=embed_dim) for _ in range(self.n_heads)]) | |
| if "word" in self.attention_list: | |
| if spec_attention: | |
| self.cross_attention_for_word_modules = nn.ModuleList([Spec_Attention(in_channel, cond_channel=word_embed_dim, embed_dim=embed_dim) for _ in range(self.n_heads)]) | |
| else: | |
| self.cross_attention_for_word_modules = nn.ModuleList([CrossAttention(in_channel, cond_channel=word_embed_dim, embed_dim=embed_dim) for _ in range(self.n_heads)]) | |
| if "sentence" in self.attention_list: | |
| if spec_attention: | |
| self.cross_attention_for_sent_modules = nn.ModuleList([Spec_Attention(in_channel, cond_channel=sentence_embed_dim, embed_dim=embed_dim) for _ in range(self.n_heads)]) | |
| else: | |
| self.cross_attention_for_sent_modules = nn.ModuleList([CrossAttention(in_channel, cond_channel=sentence_embed_dim, embed_dim=embed_dim) for _ in range(self.n_heads)]) | |
| self.gamma = [nn.Parameter(torch.tensor(0.0)) for _ in range(self.n_heads)] | |
| self.conv_for_attention = spectral_init(nn.Conv1d(in_channel * len(self.attention_list), in_channel, 1), gain=gain) | |
| self.out = spectral_init(nn.Conv1d(in_channel * self.n_heads, in_channel, 1), gain=gain) | |
| def forward(self, input, sentence_embedding, word_embedding, sequence_lengths): | |
| batch_size, c, f, t = input.shape | |
| x = input | |
| result = [] | |
| for head in range(self.n_heads): | |
| out_list = [] | |
| if "self" in self.attention_list: | |
| x_self, attention_map = self.self_attention_modules[head](x) | |
| out_list.append(x_self) | |
| if "word" in self.attention_list: | |
| x_word, attention_map = self.cross_attention_for_word_modules[head](x, word_embedding, sequence_lengths) | |
| out_list.append(x_word) | |
| if "sentence" in self.attention_list: | |
| x_sent, attention_map = self.cross_attention_for_sent_modules[head](x, sentence_embedding) | |
| out_list.append(x_sent) | |
| out = torch.cat(out_list, dim=1) | |
| out = self.conv_for_attention(out.reshape([batch_size, c*len(out_list), f*t])).reshape([batch_size, c, f, t]) | |
| out = self.gamma[head] * out + x | |
| result.append(out) | |
| x = torch.cat(result, dim=1) | |
| x = self.out(x.reshape([batch_size, c * self.n_heads, f*t])).reshape([batch_size, c, f, t]) | |
| x = input + x | |
| return x | |
| class Generator(nn.Module): | |
| def __init__(self, model_config=None): | |
| super().__init__() | |
| if model_config is None: | |
| model_config = { | |
| "noise_dim":128, | |
| "g_chaneel":128, | |
| "n_heads":10, | |
| "sentence_embed_dim":512, | |
| "word_embed_dim":768, | |
| "attention_list":["self,word,sentence", "word,sentence", "sentence"], | |
| "spec_attention":True, | |
| } | |
| self.noise_dim = model_config['noise_dim'] | |
| self.channel = model_config['g_chaneel'] | |
| self.n_heads = model_config['n_heads'] | |
| self.sentence_embed_dim = model_config['sentence_embed_dim'] | |
| self.word_embed_dim = model_config['word_embed_dim'] | |
| self.attention_list = model_config['attention_list'] | |
| self.spec_attention = model_config['spec_attention'] | |
| channel_list = [self.channel, self.channel, self.channel//2, self.channel//2, self.channel//4, self.channel//4, self.channel//4, self.channel//8, self.channel//8] | |
| self.lin_code = spectral_init(nn.Linear(self.noise_dim, channel_list[0] * 2 * 32)) | |
| self.conv1 = ConvBlock(channel_list[0], channel_list[1], condition_dim=self.sentence_embed_dim) | |
| self.conv2 = ConvBlock(channel_list[1], channel_list[2], condition_dim=self.sentence_embed_dim) | |
| self.multi_triple_attention_1 = Multi_Triple_Attention(channel_list[2], | |
| sentence_embed_dim=self.sentence_embed_dim, | |
| word_embed_dim=self.word_embed_dim, | |
| embed_dim=channel_list[2], | |
| reverse=False, | |
| n_heads=self.n_heads, | |
| attention_list=self.attention_list[0], | |
| spec_attention=self.spec_attention) | |
| self.conv3 = ConvBlock(channel_list[2], channel_list[3], condition_dim=self.sentence_embed_dim) | |
| self.conv4 = ConvBlock(channel_list[3], channel_list[4], condition_dim=self.sentence_embed_dim, upsample=False) | |
| self.multi_triple_attention_2 = Multi_Triple_Attention(channel_list[4], | |
| sentence_embed_dim=self.sentence_embed_dim, | |
| word_embed_dim=self.word_embed_dim, | |
| embed_dim=channel_list[4], | |
| reverse=False, | |
| n_heads=self.n_heads, | |
| attention_list=self.attention_list[1], | |
| spec_attention=self.spec_attention) | |
| self.conv5 = ConvBlock(channel_list[4], channel_list[5], condition_dim=self.sentence_embed_dim) | |
| self.conv6 = ConvBlock(channel_list[5], channel_list[6], condition_dim=self.sentence_embed_dim, upsample=False) | |
| self.multi_triple_attention_3 = Multi_Triple_Attention(channel_list[6], | |
| sentence_embed_dim=self.sentence_embed_dim, | |
| word_embed_dim=self.word_embed_dim, | |
| embed_dim=channel_list[6], | |
| reverse=False, | |
| n_heads=self.n_heads, | |
| attention_list=self.attention_list[2], | |
| spec_attention=self.spec_attention) | |
| self.conv7 = ConvBlock(channel_list[6], channel_list[7], condition_dim=self.sentence_embed_dim) | |
| self.bn = nn.BatchNorm2d(channel_list[8]) | |
| self.colorize = spectral_init(nn.Conv1d(channel_list[8], 1, 1)) | |
| def forward(self, z, sentence_embedding, word_embedding, sequence_lengths): | |
| batch_size = z.shape[0] | |
| x = self.lin_code(z) | |
| x = x.view(-1, self.channel, 2, 32) # [bsz, c, 2, 32] | |
| x = self.conv1(x, sentence_embedding) # [bsz, c, 4, 64] | |
| x = self.conv2(x, sentence_embedding) # [bsz, c, 8, 128] | |
| x = self.multi_triple_attention_1(x, sentence_embedding, word_embedding, sequence_lengths) # [bsz, c, 8, 128] | |
| x = self.conv3(x, sentence_embedding) # [bsz, c, 16, 256] | |
| x = self.conv4(x, sentence_embedding) # [bsz, c, 16, 256] | |
| x = self.multi_triple_attention_2(x, sentence_embedding, word_embedding, sequence_lengths) # [bsz, c, 16, 256] | |
| x = self.conv5(x, sentence_embedding) # [bsz, c, 32, 512] | |
| x = self.conv6(x, sentence_embedding) # [bsz, c, 32, 512] | |
| x = self.multi_triple_attention_3(x, sentence_embedding, word_embedding, sequence_lengths) # [bsz, c, 32, 512] | |
| x = self.conv7(x, sentence_embedding) # [bsz, c, 64, 1024] | |
| x = self.bn(x) # [bsz, c // 8, 64, 1024] | |
| x = F.relu(x) | |
| x = self.colorize(x.reshape([batch_size, -1, 64*1024])).reshape([batch_size, 1, 64, 1024]) # [bsz, 1, 64, 1024] | |
| return x | |