| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | class QuantizeEMAReset(nn.Module): |
| | def __init__(self, nb_code, code_dim, args): |
| | super().__init__() |
| | self.nb_code = nb_code |
| | self.code_dim = code_dim |
| | self.mu = args.mu |
| | self.reset_codebook() |
| | self.reset_count = 0 |
| | self.usage = torch.zeros((self.nb_code, 1)) |
| | |
| | def reset_codebook(self): |
| | self.init = False |
| | self.code_sum = None |
| | self.code_count = None |
| | self.register_buffer('codebook', torch.zeros(self.nb_code, self.code_dim)) |
| |
|
| | def _tile(self, x): |
| | nb_code_x, code_dim = x.shape |
| | if nb_code_x < self.nb_code: |
| | n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x |
| | std = 0.01 / np.sqrt(code_dim) |
| | out = x.repeat(n_repeats, 1) |
| | out = out + torch.randn_like(out) * std |
| | else : |
| | out = x |
| | return out |
| |
|
| | def init_codebook(self, x): |
| | out = self._tile(x) |
| | self.codebook = out[:self.nb_code] |
| | self.code_sum = self.codebook.clone() |
| | self.code_count = torch.ones(self.nb_code, device=self.codebook.device) |
| | self.init = True |
| | |
| | @torch.no_grad() |
| | def compute_perplexity(self, code_idx) : |
| | |
| | code_onehot = torch.zeros(self.nb_code, code_idx.shape[0], device=code_idx.device) |
| | code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1) |
| |
|
| | code_count = code_onehot.sum(dim=-1) |
| | prob = code_count / torch.sum(code_count) |
| | perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) |
| | return perplexity |
| | |
| | @torch.no_grad() |
| | def update_codebook(self, x, code_idx): |
| | |
| | code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) |
| | code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1) |
| |
|
| | code_sum = torch.matmul(code_onehot, x) |
| | code_count = code_onehot.sum(dim=-1) |
| |
|
| | out = self._tile(x) |
| | code_rand = out[torch.randperm(out.shape[0])[:self.nb_code]] |
| |
|
| | |
| | self.code_sum = self.mu * self.code_sum + (1. - self.mu) * code_sum |
| | self.code_count = self.mu * self.code_count + (1. - self.mu) * code_count |
| |
|
| | usage = (self.code_count.view(self.nb_code, 1) >= 1.0).float() |
| | self.usage = self.usage.to(usage.device) |
| | if self.reset_count >= 20: |
| | self.reset_count = 0 |
| | usage = (usage + self.usage >= 1.0).float() |
| | else: |
| | self.reset_count += 1 |
| | self.usage = (usage + self.usage >= 1.0).float() |
| | usage = torch.ones_like(self.usage, device=x.device) |
| | code_update = self.code_sum.view(self.nb_code, self.code_dim) / self.code_count.view(self.nb_code, 1) |
| |
|
| | self.codebook = usage * code_update + (1 - usage) * code_rand |
| | prob = code_count / torch.sum(code_count) |
| | perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) |
| |
|
| | |
| | return perplexity |
| |
|
| | def preprocess(self, x): |
| | |
| | x = x.permute(0, 2, 1).contiguous() |
| | x = x.view(-1, x.shape[-1]) |
| | return x |
| |
|
| | def quantize(self, x): |
| | |
| | k_w = self.codebook.t() |
| | distance = torch.sum(x ** 2, dim=-1, keepdim=True) - 2 * torch.matmul(x, k_w) + torch.sum(k_w ** 2, dim=0, |
| | keepdim=True) |
| | _, code_idx = torch.min(distance, dim=-1) |
| | return code_idx |
| |
|
| | def dequantize(self, code_idx): |
| | x = F.embedding(code_idx, self.codebook) |
| | return x |
| |
|
| | |
| | def forward(self, x): |
| | N, width, T = x.shape |
| |
|
| | |
| | x = self.preprocess(x) |
| |
|
| | |
| | if self.training and not self.init: |
| | self.init_codebook(x) |
| |
|
| | |
| | code_idx = self.quantize(x) |
| | x_d = self.dequantize(code_idx) |
| |
|
| | |
| | if self.training: |
| | perplexity = self.update_codebook(x, code_idx) |
| | else : |
| | perplexity = self.compute_perplexity(code_idx) |
| | |
| | |
| | commit_loss = F.mse_loss(x, x_d.detach()) |
| |
|
| | |
| | x_d = x + (x_d - x).detach() |
| |
|
| | |
| | x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous() |
| | |
| | return x_d, commit_loss, perplexity |
| |
|
| |
|
| |
|
| | class Quantizer(nn.Module): |
| | def __init__(self, n_e, e_dim, beta): |
| | super(Quantizer, self).__init__() |
| |
|
| | self.e_dim = e_dim |
| | self.n_e = n_e |
| | self.beta = beta |
| |
|
| | self.embedding = nn.Embedding(self.n_e, self.e_dim) |
| | self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) |
| |
|
| | def forward(self, z): |
| | |
| | N, width, T = z.shape |
| | z = self.preprocess(z) |
| | assert z.shape[-1] == self.e_dim |
| | z_flattened = z.contiguous().view(-1, self.e_dim) |
| |
|
| | |
| | d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ |
| | torch.sum(self.embedding.weight**2, dim=1) - 2 * \ |
| | torch.matmul(z_flattened, self.embedding.weight.t()) |
| | |
| | min_encoding_indices = torch.argmin(d, dim=1) |
| | z_q = self.embedding(min_encoding_indices).view(z.shape) |
| |
|
| | |
| | loss = torch.mean((z_q - z.detach())**2) + self.beta * \ |
| | torch.mean((z_q.detach() - z)**2) |
| |
|
| | |
| | z_q = z + (z_q - z).detach() |
| | z_q = z_q.view(N, T, -1).permute(0, 2, 1).contiguous() |
| |
|
| | min_encodings = F.one_hot(min_encoding_indices, self.n_e).type(z.dtype) |
| | e_mean = torch.mean(min_encodings, dim=0) |
| | perplexity = torch.exp(-torch.sum(e_mean*torch.log(e_mean + 1e-10))) |
| | return z_q, loss, perplexity |
| |
|
| | def quantize(self, z): |
| |
|
| | assert z.shape[-1] == self.e_dim |
| |
|
| | |
| | d = torch.sum(z ** 2, dim=1, keepdim=True) + \ |
| | torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \ |
| | torch.matmul(z, self.embedding.weight.t()) |
| | |
| | min_encoding_indices = torch.argmin(d, dim=1) |
| | return min_encoding_indices |
| |
|
| | def dequantize(self, indices): |
| |
|
| | index_flattened = indices.view(-1) |
| | z_q = self.embedding(index_flattened) |
| | z_q = z_q.view(indices.shape + (self.e_dim, )).contiguous() |
| | return z_q |
| |
|
| | def preprocess(self, x): |
| | |
| | x = x.permute(0, 2, 1).contiguous() |
| | x = x.view(-1, x.shape[-1]) |
| | return x |
| |
|
| |
|
| |
|
| | class QuantizeReset(nn.Module): |
| | def __init__(self, nb_code, code_dim, args): |
| | super().__init__() |
| | self.nb_code = nb_code |
| | self.code_dim = code_dim |
| | self.reset_codebook() |
| | self.codebook = nn.Parameter(torch.randn(nb_code, code_dim)) |
| | |
| | def reset_codebook(self): |
| | self.init = False |
| | self.code_count = None |
| |
|
| | def _tile(self, x): |
| | nb_code_x, code_dim = x.shape |
| | if nb_code_x < self.nb_code: |
| | n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x |
| | std = 0.01 / np.sqrt(code_dim) |
| | out = x.repeat(n_repeats, 1) |
| | out = out + torch.randn_like(out) * std |
| | else : |
| | out = x |
| | return out |
| |
|
| | def init_codebook(self, x): |
| | out = self._tile(x) |
| | self.codebook = nn.Parameter(out[:self.nb_code]) |
| | self.code_count = torch.ones(self.nb_code, device=self.codebook.device) |
| | self.init = True |
| | |
| | @torch.no_grad() |
| | def compute_perplexity(self, code_idx) : |
| | |
| | code_onehot = torch.zeros(self.nb_code, code_idx.shape[0], device=code_idx.device) |
| | code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1) |
| |
|
| | code_count = code_onehot.sum(dim=-1) |
| | prob = code_count / torch.sum(code_count) |
| | perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) |
| | return perplexity |
| | |
| | def update_codebook(self, x, code_idx): |
| | |
| | code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) |
| | code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1) |
| |
|
| | code_count = code_onehot.sum(dim=-1) |
| |
|
| | out = self._tile(x) |
| | code_rand = out[:self.nb_code] |
| |
|
| | |
| | self.code_count = code_count |
| | usage = (self.code_count.view(self.nb_code, 1) >= 1.0).float() |
| |
|
| | self.codebook.data = usage * self.codebook.data + (1 - usage) * code_rand |
| | prob = code_count / torch.sum(code_count) |
| | perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) |
| |
|
| | |
| | return perplexity |
| |
|
| | def preprocess(self, x): |
| | |
| | x = x.permute(0, 2, 1).contiguous() |
| | x = x.view(-1, x.shape[-1]) |
| | return x |
| |
|
| | def quantize(self, x): |
| | |
| | k_w = self.codebook.t() |
| | distance = torch.sum(x ** 2, dim=-1, keepdim=True) - 2 * torch.matmul(x, k_w) + torch.sum(k_w ** 2, dim=0, |
| | keepdim=True) |
| | _, code_idx = torch.min(distance, dim=-1) |
| | return code_idx |
| |
|
| | def dequantize(self, code_idx): |
| | x = F.embedding(code_idx, self.codebook) |
| | return x |
| |
|
| | |
| | def forward(self, x): |
| | N, width, T = x.shape |
| | |
| | x = self.preprocess(x) |
| | |
| | if self.training and not self.init: |
| | self.init_codebook(x) |
| | |
| | code_idx = self.quantize(x) |
| | x_d = self.dequantize(code_idx) |
| | |
| | if self.training: |
| | perplexity = self.update_codebook(x, code_idx) |
| | else : |
| | perplexity = self.compute_perplexity(code_idx) |
| | |
| | |
| | commit_loss = F.mse_loss(x, x_d.detach()) |
| |
|
| | |
| | x_d = x + (x_d - x).detach() |
| |
|
| | |
| | x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous() |
| | |
| | return x_d, commit_loss, perplexity |
| |
|
| | |
| | class QuantizeEMA(nn.Module): |
| | def __init__(self, nb_code, code_dim, args): |
| | super().__init__() |
| | self.nb_code = nb_code |
| | self.code_dim = code_dim |
| | self.mu = 0.99 |
| | self.reset_codebook() |
| | |
| | def reset_codebook(self): |
| | self.init = False |
| | self.code_sum = None |
| | self.code_count = None |
| | self.register_buffer('codebook', torch.zeros(self.nb_code, self.code_dim).cuda()) |
| |
|
| | def _tile(self, x): |
| | nb_code_x, code_dim = x.shape |
| | if nb_code_x < self.nb_code: |
| | n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x |
| | std = 0.01 / np.sqrt(code_dim) |
| | out = x.repeat(n_repeats, 1) |
| | out = out + torch.randn_like(out) * std |
| | else : |
| | out = x |
| | return out |
| |
|
| | def init_codebook(self, x): |
| | out = self._tile(x) |
| | self.codebook = out[:self.nb_code] |
| | self.code_sum = self.codebook.clone() |
| | self.code_count = torch.ones(self.nb_code, device=self.codebook.device) |
| | self.init = True |
| | |
| | @torch.no_grad() |
| | def compute_perplexity(self, code_idx) : |
| | |
| | code_onehot = torch.zeros(self.nb_code, code_idx.shape[0], device=code_idx.device) |
| | code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1) |
| |
|
| | code_count = code_onehot.sum(dim=-1) |
| | prob = code_count / torch.sum(code_count) |
| | perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) |
| | return perplexity |
| | |
| | @torch.no_grad() |
| | def update_codebook(self, x, code_idx): |
| | |
| | code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) |
| | code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1) |
| |
|
| | code_sum = torch.matmul(code_onehot, x) |
| | code_count = code_onehot.sum(dim=-1) |
| |
|
| | |
| | self.code_sum = self.mu * self.code_sum + (1. - self.mu) * code_sum |
| | self.code_count = self.mu * self.code_count + (1. - self.mu) * code_count |
| |
|
| | code_update = self.code_sum.view(self.nb_code, self.code_dim) / self.code_count.view(self.nb_code, 1) |
| |
|
| | self.codebook = code_update |
| | prob = code_count / torch.sum(code_count) |
| | perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) |
| | |
| | return perplexity |
| |
|
| | def preprocess(self, x): |
| | |
| | x = x.permute(0, 2, 1).contiguous() |
| | x = x.view(-1, x.shape[-1]) |
| | return x |
| |
|
| | def quantize(self, x): |
| | |
| | k_w = self.codebook.t() |
| | distance = torch.sum(x ** 2, dim=-1, keepdim=True) - 2 * torch.matmul(x, k_w) + torch.sum(k_w ** 2, dim=0, |
| | keepdim=True) |
| | _, code_idx = torch.min(distance, dim=-1) |
| | return code_idx |
| |
|
| | def dequantize(self, code_idx): |
| | x = F.embedding(code_idx, self.codebook) |
| | return x |
| |
|
| | |
| | def forward(self, x): |
| | N, width, T = x.shape |
| |
|
| | |
| | x = self.preprocess(x) |
| |
|
| | |
| | if self.training and not self.init: |
| | self.init_codebook(x) |
| |
|
| | |
| | code_idx = self.quantize(x) |
| | x_d = self.dequantize(code_idx) |
| |
|
| | |
| | if self.training: |
| | perplexity = self.update_codebook(x, code_idx) |
| | else : |
| | perplexity = self.compute_perplexity(code_idx) |
| | |
| | |
| | commit_loss = F.mse_loss(x, x_d.detach()) |
| |
|
| | |
| | x_d = x + (x_d - x).detach() |
| |
|
| | |
| | x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous() |
| | |
| | return x_d, commit_loss, perplexity |
| |
|