Spaces:
Sleeping
Sleeping
| import os | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision import datasets, transforms | |
| import matplotlib.pyplot as plt | |
| class CasualCT(nn.Module): | |
| def __init__(self, | |
| in_channels, | |
| out_channels, | |
| leaky=False, | |
| p=0, | |
| groups=1, ): | |
| ''' | |
| conv-bn-relu | |
| ''' | |
| super(CasualCT, self).__init__() | |
| padding = 0 | |
| kernel_size = 2 | |
| stride = 2 | |
| in_channels = in_channels * groups | |
| out_channels = out_channels * groups | |
| self.conv = nn.ConvTranspose1d(in_channels=in_channels, out_channels=out_channels, | |
| kernel_size=kernel_size, stride=stride, padding=padding, | |
| groups=groups) | |
| self.norm = nn.BatchNorm1d(out_channels) | |
| self.dropout = nn.Dropout(p=p) | |
| if leaky: | |
| self.relu = nn.LeakyReLU(negative_slope=0.2) | |
| else: | |
| self.relu = nn.ReLU() | |
| def forward(self, x, **kwargs): | |
| out = self.norm(self.dropout(self.conv(x))) | |
| return self.relu(out) | |
| class CasualConv(nn.Module): | |
| def __init__(self, | |
| in_channels, | |
| out_channels, | |
| leaky=False, | |
| p=0, | |
| groups=1, | |
| downsample=False): | |
| ''' | |
| conv-bn-relu | |
| ''' | |
| super(CasualConv, self).__init__() | |
| padding = 0 | |
| kernel_size = 2 | |
| stride = 1 | |
| self.downsample = downsample | |
| if self.downsample: | |
| kernel_size = 2 | |
| stride = 2 | |
| in_channels = in_channels * groups | |
| out_channels = out_channels * groups | |
| self.conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, | |
| kernel_size=kernel_size, stride=stride, padding=padding, | |
| groups=groups) | |
| self.norm = nn.BatchNorm1d(out_channels) | |
| self.dropout = nn.Dropout(p=p) | |
| if leaky: | |
| self.relu = nn.LeakyReLU(negative_slope=0.2) | |
| else: | |
| self.relu = nn.ReLU() | |
| def forward(self, x, pre_state=None): | |
| if not self.downsample: | |
| if pre_state is not None: | |
| x = torch.cat([pre_state, x], dim=-1) | |
| else: | |
| zeros = torch.zeros([x.shape[0], x.shape[1], 1], device=x.device) | |
| x = torch.cat([zeros, x], dim=-1) | |
| out = self.norm(self.dropout(self.conv(x))) | |
| return self.relu(out) | |
| class ConvNormRelu(nn.Module): | |
| ''' | |
| (B,C_in,H,W) -> (B, C_out, H, W) | |
| there exist some kernel size that makes the result is not H/s | |
| #TODO: there might some problems with residual | |
| ''' | |
| def __init__(self, | |
| in_channels, | |
| out_channels, | |
| leaky=False, | |
| sample='none', | |
| p=0, | |
| groups=1, | |
| residual=False, | |
| norm='bn'): | |
| ''' | |
| conv-bn-relu | |
| ''' | |
| super(ConvNormRelu, self).__init__() | |
| self.residual = residual | |
| self.norm_type = norm | |
| padding = 1 | |
| if sample == 'none': | |
| kernel_size = 3 | |
| stride = 1 | |
| elif sample == 'one': | |
| padding = 0 | |
| kernel_size = stride = 1 | |
| else: | |
| kernel_size = 4 | |
| stride = 2 | |
| if self.residual: | |
| if sample == 'down': | |
| self.residual_layer = nn.Conv1d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=padding) | |
| elif sample == 'up': | |
| self.residual_layer = nn.ConvTranspose1d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=padding) | |
| else: | |
| if in_channels == out_channels: | |
| self.residual_layer = nn.Identity() | |
| else: | |
| self.residual_layer = nn.Sequential( | |
| nn.Conv1d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=padding | |
| ) | |
| ) | |
| in_channels = in_channels * groups | |
| out_channels = out_channels * groups | |
| if sample == 'up': | |
| self.conv = nn.ConvTranspose1d(in_channels=in_channels, out_channels=out_channels, | |
| kernel_size=kernel_size, stride=stride, padding=padding, | |
| groups=groups) | |
| else: | |
| self.conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, | |
| kernel_size=kernel_size, stride=stride, padding=padding, | |
| groups=groups) | |
| self.norm = nn.BatchNorm1d(out_channels) | |
| self.dropout = nn.Dropout(p=p) | |
| if leaky: | |
| self.relu = nn.LeakyReLU(negative_slope=0.2) | |
| else: | |
| self.relu = nn.ReLU() | |
| def forward(self, x, **kwargs): | |
| out = self.norm(self.dropout(self.conv(x))) | |
| if self.residual: | |
| residual = self.residual_layer(x) | |
| out += residual | |
| return self.relu(out) | |
| class Res_CNR_Stack(nn.Module): | |
| def __init__(self, | |
| channels, | |
| layers, | |
| sample='none', | |
| leaky=False, | |
| casual=False, | |
| ): | |
| super(Res_CNR_Stack, self).__init__() | |
| if casual: | |
| kernal_size = 1 | |
| padding = 0 | |
| conv = CasualConv | |
| else: | |
| kernal_size = 3 | |
| padding = 1 | |
| conv = ConvNormRelu | |
| if sample == 'one': | |
| kernal_size = 1 | |
| padding = 0 | |
| self._layers = nn.ModuleList() | |
| for i in range(layers): | |
| self._layers.append(conv(channels, channels, leaky=leaky, sample=sample)) | |
| self.conv = nn.Conv1d(channels, channels, kernal_size, 1, padding) | |
| self.norm = nn.BatchNorm1d(channels) | |
| self.relu = nn.ReLU() | |
| def forward(self, x, pre_state=None): | |
| # cur_state = [] | |
| h = x | |
| for i in range(self._layers.__len__()): | |
| # cur_state.append(h[..., -1:]) | |
| h = self._layers[i](h, pre_state=pre_state[i] if pre_state is not None else None) | |
| h = self.norm(self.conv(h)) | |
| return self.relu(h + x) | |
| class ExponentialMovingAverage(nn.Module): | |
| """Maintains an exponential moving average for a value. | |
| This module keeps track of a hidden exponential moving average that is | |
| initialized as a vector of zeros which is then normalized to give the average. | |
| This gives us a moving average which isn't biased towards either zero or the | |
| initial value. Reference (https://arxiv.org/pdf/1412.6980.pdf) | |
| Initially: | |
| hidden_0 = 0 | |
| Then iteratively: | |
| hidden_i = hidden_{i-1} - (hidden_{i-1} - value) * (1 - decay) | |
| average_i = hidden_i / (1 - decay^i) | |
| """ | |
| def __init__(self, init_value, decay): | |
| super().__init__() | |
| self.decay = decay | |
| self.counter = 0 | |
| self.register_buffer("hidden", torch.zeros_like(init_value)) | |
| def forward(self, value): | |
| self.counter += 1 | |
| self.hidden.sub_((self.hidden - value) * (1 - self.decay)) | |
| average = self.hidden / (1 - self.decay ** self.counter) | |
| return average | |
| class VectorQuantizerEMA(nn.Module): | |
| """ | |
| VQ-VAE layer: Input any tensor to be quantized. Use EMA to update embeddings. | |
| Args: | |
| embedding_dim (int): the dimensionality of the tensors in the | |
| quantized space. Inputs to the modules must be in this format as well. | |
| num_embeddings (int): the number of vectors in the quantized space. | |
| commitment_cost (float): scalar which controls the weighting of the loss terms (see | |
| equation 4 in the paper - this variable is Beta). | |
| decay (float): decay for the moving averages. | |
| epsilon (float): small float constant to avoid numerical instability. | |
| """ | |
| def __init__(self, embedding_dim, num_embeddings, commitment_cost, decay, | |
| epsilon=1e-5): | |
| super().__init__() | |
| self.embedding_dim = embedding_dim | |
| self.num_embeddings = num_embeddings | |
| self.commitment_cost = commitment_cost | |
| self.epsilon = epsilon | |
| # initialize embeddings as buffers | |
| embeddings = torch.empty(self.num_embeddings, self.embedding_dim) | |
| nn.init.xavier_uniform_(embeddings) | |
| self.register_buffer("embeddings", embeddings) | |
| self.ema_dw = ExponentialMovingAverage(self.embeddings, decay) | |
| # also maintain ema_cluster_size, which record the size of each embedding | |
| self.ema_cluster_size = ExponentialMovingAverage(torch.zeros((self.num_embeddings,)), decay) | |
| def forward(self, x): | |
| # [B, C, H, W] -> [B, H, W, C] | |
| x = x.permute(0, 2, 1).contiguous() | |
| # [B, H, W, C] -> [BHW, C] | |
| flat_x = x.reshape(-1, self.embedding_dim) | |
| encoding_indices = self.get_code_indices(flat_x) | |
| quantized = self.quantize(encoding_indices) | |
| quantized = quantized.view_as(x) # [B, W, C] | |
| if not self.training: | |
| quantized = quantized.permute(0, 2, 1).contiguous() | |
| return quantized, encoding_indices.view(quantized.shape[0], quantized.shape[2]) | |
| # update embeddings with EMA | |
| with torch.no_grad(): | |
| encodings = F.one_hot(encoding_indices, self.num_embeddings).float() | |
| updated_ema_cluster_size = self.ema_cluster_size(torch.sum(encodings, dim=0)) | |
| n = torch.sum(updated_ema_cluster_size) | |
| updated_ema_cluster_size = ((updated_ema_cluster_size + self.epsilon) / | |
| (n + self.num_embeddings * self.epsilon) * n) | |
| dw = torch.matmul(encodings.t(), flat_x) # sum encoding vectors of each cluster | |
| updated_ema_dw = self.ema_dw(dw) | |
| normalised_updated_ema_w = ( | |
| updated_ema_dw / updated_ema_cluster_size.reshape(-1, 1)) | |
| self.embeddings.data = normalised_updated_ema_w | |
| # commitment loss | |
| e_latent_loss = F.mse_loss(x, quantized.detach()) | |
| loss = self.commitment_cost * e_latent_loss | |
| # Straight Through Estimator | |
| quantized = x + (quantized - x).detach() | |
| quantized = quantized.permute(0, 2, 1).contiguous() | |
| return quantized, loss | |
| def get_code_indices(self, flat_x): | |
| # compute L2 distance | |
| distances = ( | |
| torch.sum(flat_x ** 2, dim=1, keepdim=True) + | |
| torch.sum(self.embeddings ** 2, dim=1) - | |
| 2. * torch.matmul(flat_x, self.embeddings.t()) | |
| ) # [N, M] | |
| encoding_indices = torch.argmin(distances, dim=1) # [N,] | |
| return encoding_indices | |
| def quantize(self, encoding_indices): | |
| """Returns embedding tensor for a batch of indices.""" | |
| return F.embedding(encoding_indices, self.embeddings) | |
| class Casual_Encoder(nn.Module): | |
| def __init__(self, in_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens): | |
| super(Casual_Encoder, self).__init__() | |
| self._num_hiddens = num_hiddens | |
| self._num_residual_layers = num_residual_layers | |
| self._num_residual_hiddens = num_residual_hiddens | |
| self.project = nn.Conv1d(in_dim, self._num_hiddens // 4, 1, 1) | |
| self._enc_1 = Res_CNR_Stack(self._num_hiddens // 4, self._num_residual_layers, leaky=True, casual=True) | |
| self._down_1 = CasualConv(self._num_hiddens // 4, self._num_hiddens // 2, leaky=True, downsample=True) | |
| self._enc_2 = Res_CNR_Stack(self._num_hiddens // 2, self._num_residual_layers, leaky=True, casual=True) | |
| self._down_2 = CasualConv(self._num_hiddens // 2, self._num_hiddens, leaky=True, downsample=True) | |
| self._enc_3 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True, casual=True) | |
| # self.pre_vq_conv = nn.Conv1d(self._num_hiddens, embedding_dim, 1, 1) | |
| def forward(self, x): | |
| h = self.project(x) | |
| h, _ = self._enc_1(h) | |
| h = self._down_1(h) | |
| h, _ = self._enc_2(h) | |
| h = self._down_2(h) | |
| h, _ = self._enc_3(h) | |
| # h = self.pre_vq_conv(h) | |
| return h | |
| class Casual_Decoder(nn.Module): | |
| def __init__(self, out_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens): | |
| super(Casual_Decoder, self).__init__() | |
| self._num_hiddens = num_hiddens | |
| self._num_residual_layers = num_residual_layers | |
| self._num_residual_hiddens = num_residual_hiddens | |
| # self.aft_vq_conv = nn.Conv1d(embedding_dim, self._num_hiddens, 1, 1) | |
| self._dec_1 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True, casual=True) | |
| self._up_2 = CasualCT(self._num_hiddens, self._num_hiddens // 2, leaky=True) | |
| self._dec_2 = Res_CNR_Stack(self._num_hiddens // 2, self._num_residual_layers, leaky=True, casual=True) | |
| self._up_3 = CasualCT(self._num_hiddens // 2, self._num_hiddens // 4, leaky=True) | |
| self._dec_3 = Res_CNR_Stack(self._num_hiddens // 4, self._num_residual_layers, leaky=True, casual=True) | |
| self.project = nn.Conv1d(self._num_hiddens//4, out_dim, 1, 1) | |
| def forward(self, h, pre_state=None): | |
| cur_state = [] | |
| # h = self.aft_vq_conv(x) | |
| h, s = self._dec_1(h, pre_state[0] if pre_state is not None else None) | |
| cur_state.append(s) | |
| h = self._up_2(h) | |
| h, s = self._dec_2(h, pre_state[1] if pre_state is not None else None) | |
| cur_state.append(s) | |
| h = self._up_3(h) | |
| h, s = self._dec_3(h, pre_state[2] if pre_state is not None else None) | |
| cur_state.append(s) | |
| recon = self.project(h) | |
| return recon, cur_state |