Spaces:
Running
Running
| import copy | |
| import torch | |
| import torch.nn as nn | |
| from torch.autograd import Variable | |
| from torch.nn import functional as F | |
| from torch.nn.modules.module import Module | |
| from torch.nn.modules.activation import MultiheadAttention | |
| from torch.nn.modules.container import ModuleList | |
| from torch.nn.init import xavier_uniform_ | |
| from torch.nn.modules.dropout import Dropout | |
| from torch.nn.modules.linear import Linear | |
| from torch.nn.modules.rnn import LSTM | |
| from torch.nn.modules.normalization import LayerNorm | |
| from torch.nn.utils import weight_norm, remove_weight_norm | |
| from latent_diffusion.modules.nn import avg_pool_nd, conv_nd, linear, normalization, timestep_embedding, zero_module, checkpoint | |
| def exists(val): | |
| return val is not None | |
| class TransformerEncoderLayer(Module): | |
| def __init__(self, d_model, nhead, hidden_size, dim_feedforward, context_dim, dropout, activation="gelu"): | |
| super(TransformerEncoderLayer, self).__init__() | |
| self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout) | |
| # Implementation of improved part | |
| self.dropout = Dropout(dropout) | |
| self.linear_in = Linear(d_model, d_model*4) | |
| self.linear_out = Linear(d_model*4, d_model) | |
| self.to_k = nn.Sequential( | |
| Linear(context_dim, d_model), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| Linear(d_model, d_model, bias=False)) | |
| self.to_v = nn.Sequential( | |
| Linear(context_dim, d_model), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| Linear(d_model, d_model, bias=False)) | |
| self.dropout1 = Dropout(dropout) | |
| self.dropout2 = Dropout(dropout) | |
| self.activation = _get_activation_fn(activation) | |
| def __setstate__(self, state): | |
| if 'activation' not in state: | |
| state['activation'] = F.relu | |
| super(TransformerEncoderLayer, self).__setstate__(state) | |
| def forward(self, src, context=None): | |
| r"""Pass the input through the encoder layer. | |
| Args: | |
| src: the sequnce to the encoder layer (required). | |
| src_mask: the mask for the src sequence (optional). | |
| src_key_padding_mask: the mask for the src keys per batch (optional). | |
| Shape: | |
| see the docs in Transformer class. | |
| """ | |
| #print(context.shape, src.shape, flush=True) | |
| src2 = self.self_attn(src, self.to_k(context), self.to_v(context))[0] | |
| #src2 = self.self_attn(src, src, src)[0] | |
| src = src + self.dropout1(src2) | |
| src2 = self.linear_out(self.dropout(self.activation(self.linear_in(src)))) | |
| src = src + self.dropout2(src2) | |
| return src | |
| def _get_clones(module, N): | |
| return ModuleList([copy.deepcopy(module) for i in range(N)]) | |
| def _get_activation_fn(activation): | |
| if activation == "relu": | |
| return F.relu | |
| elif activation == "gelu": | |
| return F.gelu | |
| raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) | |
| class SingleTransformer(nn.Module): | |
| """ | |
| Container module for a single Transformer layer. | |
| args: input_size: int, dimension of the input feature. The input should have shape (batch, seq_len, input_size). | |
| """ | |
| def __init__(self, input_size, hidden_size, context_dim, dropout): | |
| super(SingleTransformer, self).__init__() | |
| self.transformer = TransformerEncoderLayer(d_model=input_size, nhead=8, hidden_size=hidden_size, | |
| dim_feedforward=hidden_size*2, context_dim=context_dim, dropout=dropout) | |
| self.emb_layers = nn.Sequential( | |
| nn.SiLU(), | |
| Linear( | |
| input_size, | |
| 2 * input_size, | |
| ), | |
| ) | |
| self.out_layers = nn.Sequential( | |
| nn.SiLU(), | |
| nn.Dropout(p=dropout), | |
| weight_norm(conv_nd(1, input_size, input_size, 3, padding=1)) | |
| ) | |
| def forward(self, x, emb, context): | |
| return checkpoint(self._forward, (x, emb, context), self.parameters(), True) | |
| def forward(self, x, emb, context=None): | |
| # input shape: batch, seq, dim | |
| x = self.transformer(x.permute(1, 0, 2).contiguous(),#).permute(1, 0, 2).contiguous() | |
| context.permute(1, 0, 2).contiguous()).permute(1, 0, 2).contiguous() | |
| emb_out = self.emb_layers(emb).unsqueeze(1) | |
| scale, shift = torch.chunk(emb_out, 2, dim=-1) | |
| x = x * (1 + scale) + shift | |
| x = self.out_layers(x.transpose(1, 2).contiguous()).transpose(1, 2).contiguous() | |
| return x | |
| # dual-path transformer | |
| class DPT(nn.Module): | |
| """ | |
| Deep dual-path transformer. | |
| args: | |
| input_size: int, dimension of the input feature. The input should have shape | |
| (batch, seq_len, input_size). | |
| hidden_size: int, dimension of the hidden state. | |
| output_size: int, dimension of the output size. | |
| num_layers: int, number of stacked Transformer layers. Default is 1. | |
| dropout: float, dropout ratio. Default is 0. | |
| """ | |
| def __init__(self, input_size, hidden_size, output_size, num_layers=1, context_dim=512, dropout=0): | |
| super(DPT, self).__init__() | |
| self.input_size = input_size | |
| self.output_size = output_size | |
| self.hidden_size = hidden_size | |
| self.num_layers = num_layers | |
| # dual-path transformer | |
| self.row_transformer = nn.ModuleList([]) | |
| self.col_transformer = nn.ModuleList([]) | |
| self.row_norm = nn.ModuleList([]) | |
| self.col_norm = nn.ModuleList([]) | |
| for i in range(num_layers): | |
| self.row_transformer.append(SingleTransformer(input_size, hidden_size, | |
| context_dim if i >= num_layers//4 and i < (num_layers*3)//4 else input_size, dropout)) | |
| self.col_transformer.append(SingleTransformer(input_size, hidden_size, | |
| context_dim if i >= num_layers//4 and i < (num_layers*3)//4 else input_size, dropout)) | |
| self.row_norm.append(nn.GroupNorm(1, input_size, eps=1e-8)) | |
| self.col_norm.append(nn.GroupNorm(1, input_size, eps=1e-8)) | |
| # output layer | |
| self.output = nn.Sequential(nn.PReLU(), nn.Conv2d(input_size, output_size, 1)) | |
| def forward(self, x, emb, context=None): | |
| # input shape: batch, N, dim1, dim2 | |
| # apply transformer on dim1 first and then dim2 | |
| # output shape: B, output_size, dim1, dim2 | |
| #input = input.to(device) | |
| batch_size, _, dim1, dim2 = x.shape | |
| output = x | |
| #print(x.shape, flush=True) | |
| for i in range(self.num_layers): | |
| row_input = output.permute(0, 3, 2, 1).contiguous().view(batch_size * dim2, dim1, -1) # B*dim2, dim1, N | |
| row_output = self.row_transformer[i](row_input, emb.repeat_interleave(dim2, 0),# None) | |
| context.repeat_interleave(dim2, 0) if i >= self.num_layers//4 and i < (self.num_layers*3)//4 else row_input) # B*dim2, dim1, H | |
| row_output = row_output.view(batch_size, dim2, dim1, -1).permute(0, 3, 2, 1).contiguous() # B, N, dim1, dim2 | |
| row_output = self.row_norm[i](row_output) | |
| output = output + row_output | |
| col_input = output.permute(0, 2, 3, 1).contiguous().view(batch_size * dim1, dim2, -1) # B*dim1, dim2, N | |
| col_output = self.col_transformer[i](col_input, emb.repeat_interleave(dim1, 0),# None) | |
| context.repeat_interleave(dim1, 0) if i >= self.num_layers//4 and i < (self.num_layers*3)//4 else col_input) # B*dim2, dim1, H | |
| col_output = col_output.view(batch_size, dim1, dim2, -1).permute(0, 3, 1, 2).contiguous() # B, N, dim1, dim2 | |
| col_output = self.col_norm[i](col_output) | |
| output = output + col_output | |
| #output = F.layer_norm(output, output.shape[1:]) | |
| output = self.output(output) # B, output_size, dim1, dim2 | |
| return output | |
| # base module for deep DPT | |
| class DPTNet(nn.Module): | |
| def __init__(self, | |
| input_dim=256, | |
| feature_dim=1024, | |
| hidden_dim=512, | |
| layer=8, | |
| segment_size=64, | |
| context_dim=512, | |
| extra_film_condition_dim=None, | |
| dropout=0, | |
| predict_xstart=False, | |
| learn_sigma=False, | |
| end2end=False, | |
| ): | |
| super(DPTNet, self).__init__() | |
| self.input_dim = input_dim | |
| self.feature_dim = feature_dim | |
| self.hidden_dim = hidden_dim | |
| self.layer = layer | |
| self.segment_size = segment_size | |
| self.predict_xstart = predict_xstart | |
| self.learn_sigma = learn_sigma | |
| self.extra_film_condition_dim = extra_film_condition_dim | |
| self.eps = 1e-8 | |
| self.time_embed = nn.Sequential( | |
| Linear(feature_dim, feature_dim), | |
| nn.SiLU(), | |
| Linear(feature_dim, feature_dim), | |
| ) | |
| if extra_film_condition_dim is not None: | |
| self.film_emb = nn.Linear(self.extra_film_condition_dim, feature_dim) | |
| print( | |
| "+ Use extra condition on UNet channel using Film. Extra condition dimension is %s. " | |
| % self.extra_film_condition_dim | |
| ) | |
| else: | |
| self.film_emb = None | |
| #self.context_embed = nn.Sequential( | |
| # Linear(context_dim, feature_dim), | |
| # nn.SiLU(), | |
| # Linear(feature_dim, feature_dim), | |
| #) | |
| # bottleneck | |
| self.BN = nn.Conv1d(self.input_dim, self.feature_dim, 1, bias=False) | |
| self.BN_gn = nn.GroupNorm(1, self.feature_dim, eps=1e-8) | |
| # DPT model | |
| self.DPT = DPT(self.feature_dim, self.hidden_dim, self.feature_dim, | |
| num_layers=layer, context_dim=context_dim, dropout=dropout) | |
| self.output_gn = nn.GroupNorm(1, self.feature_dim, eps=1e-8) | |
| self.output = nn.Conv1d(self.feature_dim, self.input_dim * 2 if learn_sigma else self.input_dim, 1) | |
| def pad_segment(self, input, segment_size): | |
| # input is the features: (B, N, T) | |
| batch_size, dim, seq_len = input.shape | |
| segment_stride = segment_size // 2 | |
| rest = segment_size - (segment_stride + seq_len % segment_size) % segment_size | |
| if rest > 0: | |
| pad = Variable(torch.zeros(batch_size, dim, rest)).type(input.type()) | |
| input = torch.cat([input, pad], 2) | |
| pad_aux = Variable(torch.zeros(batch_size, dim, segment_stride)).type(input.type()) | |
| input = torch.cat([pad_aux, input, pad_aux], 2) | |
| return input, rest | |
| def split_feature(self, input, segment_size): | |
| # split the feature into chunks of segment size | |
| # input is the features: (B, N, T) | |
| input, rest = self.pad_segment(input, segment_size) | |
| batch_size, dim, seq_len = input.shape | |
| segment_stride = segment_size // 2 | |
| segments1 = input[:, :, :-segment_stride].contiguous().view(batch_size, dim, -1, segment_size) | |
| segments2 = input[:, :, segment_stride:].contiguous().view(batch_size, dim, -1, segment_size) | |
| segments = torch.cat([segments1, segments2], 3).view(batch_size, dim, -1, segment_size).transpose(2, 3) | |
| return segments.contiguous(), rest | |
| def merge_feature(self, input, rest): | |
| # merge the splitted features into full utterance | |
| # input is the features: (B, N, L, K) | |
| batch_size, dim, segment_size, _ = input.shape | |
| segment_stride = segment_size // 2 | |
| input = input.transpose(2, 3).contiguous().view(batch_size, dim, -1, segment_size * 2) # B, N, K, L | |
| input1 = input[:, :, :, :segment_size].contiguous().view(batch_size, dim, -1)[:, :, segment_stride:] | |
| input2 = input[:, :, :, segment_size:].contiguous().view(batch_size, dim, -1)[:, :, :-segment_stride] | |
| output = input1 + input2 | |
| if rest > 0: | |
| output = output[:, :, :-rest] | |
| return output.contiguous() # B, N, T | |
| def compress(self, x): | |
| batchsize, channel, time, freq = x.size() | |
| x = x.permute(0,1,3,2) | |
| x = x.reshape(batchsize, -1, time) | |
| return x | |
| def decompress(self, x): | |
| batchsize, _, time = x.size() | |
| x = x.reshape(batchsize, 8, 16, time) | |
| return x.permute(0,1,3,2) | |
| def forward(self, x, timesteps, context=None, y=None, autoencoder=None): | |
| #input = input.to(device) | |
| if autoencoder is not None: | |
| x = autoencoder.encode(x) | |
| x = x.unsqueeze(1) | |
| x = self.compress(x) # Added by haohe | |
| batch_size, input_dim, seq_length = x.shape | |
| # input: (B, D, T) | |
| #print(timesteps, flush=True) | |
| t_emb = timestep_embedding(timesteps, self.feature_dim, repeat_only=False) | |
| emb = self.time_embed(t_emb) | |
| #emb = torch.cat([emb, self.context_embed(context).mean(1)], -1) | |
| if(self.film_emb is not None): | |
| # emb = torch.cat([emb, self.film_emb(y)], dim=-1) | |
| emb = emb + self.film_emb(y) | |
| x = self.BN(x) # (B, D, L)-->(B, N, L) | |
| x = self.BN_gn(x) | |
| #x = F.layer_norm(x, x.shape[1:]) | |
| # split the encoder output into overlapped, longer segments | |
| x, enc_rest = self.split_feature(x, self.segment_size) # B, N, L, K: L is the segment_size | |
| #print('enc_segments.shape {}'.format(enc_segments.shape)) | |
| # pass to DPT | |
| out = self.DPT(x, emb, context).view(batch_size, self.feature_dim, self.segment_size, -1) # B, N, L, K | |
| # overlap-and-add of the outputs | |
| out = self.merge_feature(out, enc_rest) # B, N, T | |
| #out = F.relu(out) | |
| out = self.output_gn(out) | |
| out = self.output(out) | |
| if autoencoder is not None: | |
| out = autoencoder.decode(out) | |
| if self.predict_xstart: | |
| out = torch.tanh(out) | |
| return out | |
| else: | |
| if self.learn_sigma: | |
| out = out.view(batch_size, 2, input_dim, seq_length) | |
| else: | |
| out = out.view(batch_size, 1, input_dim, seq_length) | |
| if self.predict_xstart: | |
| out = torch.tanh(out) | |
| else: | |
| out = F.layer_norm(out, out.shape[1:]) | |
| out = out.squeeze(1) | |
| out = self.decompress(out) # Added by haohe | |
| return out | |