| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
|
|
| |
| class VQEmbeddingEMA(nn.Module): |
| def __init__(self, nband, num_code, code_dim, decay=0.99, layer=0): |
| super(VQEmbeddingEMA, self).__init__() |
|
|
| self.nband = nband |
| self.num_code = num_code |
| self.code_dim = code_dim |
| self.decay = decay |
| self.layer = layer |
| self.stale_tolerance = 50 |
| self.eps = torch.finfo(torch.float32).eps |
|
|
| if layer == 0: |
| embedding = torch.empty(nband, num_code, code_dim).normal_() |
| embedding = embedding / (embedding.pow(2).sum(-1) + self.eps).sqrt().unsqueeze(-1) |
| else: |
| embedding = torch.empty(nband, num_code, code_dim).normal_() / code_dim |
| embedding[:,0] = embedding[:,0] * 0 |
| self.register_buffer("embedding", embedding) |
| self.register_buffer("ema_weight", self.embedding.clone()) |
| self.register_buffer("ema_count", torch.zeros(self.nband, self.num_code)) |
| self.register_buffer("stale_counter", torch.zeros(nband, self.num_code)) |
|
|
| def forward(self, input): |
| num_valid_bands = 1 |
| B, C, N, T = input.shape |
| assert N == self.code_dim |
| assert C == num_valid_bands |
|
|
| input_detach = input.detach().permute(0,3,1,2).contiguous().view(B*T, num_valid_bands, self.code_dim) |
| embedding = self.embedding[:num_valid_bands,:,:].contiguous() |
| |
| eu_dis = input_detach.pow(2).sum(2).unsqueeze(2) + embedding.pow(2).sum(2).unsqueeze(0) |
| eu_dis = eu_dis - 2 * torch.stack([input_detach[:,i].mm(embedding[i].T) for i in range(num_valid_bands)], 1) |
|
|
| |
| indices = torch.argmin(eu_dis, dim=-1) |
| quantized = [] |
| for i in range(num_valid_bands): |
| quantized.append(torch.gather(embedding[i], 0, indices[:,i].unsqueeze(-1).expand(-1, self.code_dim))) |
| quantized = torch.stack(quantized, 1) |
| quantized = quantized.view(B, T, C, N).permute(0,2,3,1).contiguous() |
|
|
| |
| encodings = F.one_hot(indices, self.num_code).float() |
| avg_probs = encodings.mean(0) |
| perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + self.eps), -1)).mean() |
|
|
| if self.training: |
| |
| |
| self.ema_count[:num_valid_bands] = self.decay * self.ema_count[:num_valid_bands] + (1 - self.decay) * torch.sum(encodings, dim=0) |
|
|
| update_direction = encodings.permute(1,2,0).bmm(input_detach.permute(1,0,2)) |
| self.ema_weight[:num_valid_bands] = self.decay * self.ema_weight[:num_valid_bands] + (1 - self.decay) * update_direction |
|
|
| |
| |
| n = torch.sum(self.ema_count[:num_valid_bands], dim=-1, keepdim=True) |
| self.ema_count[:num_valid_bands] = (self.ema_count[:num_valid_bands] + self.eps) / (n + self.num_code * self.eps) * n |
|
|
| self.embedding[:num_valid_bands] = self.ema_weight[:num_valid_bands] / self.ema_count[:num_valid_bands].unsqueeze(-1) |
|
|
| |
| stale_codes = (encodings.sum(0) == 0).float() |
| self.stale_counter[:num_valid_bands] = self.stale_counter[:num_valid_bands] * stale_codes + stale_codes |
| print("Lyaer {}, Ratio of unused vector : {}, {:.1f}, {:.1f}".format(self.layer, encodings.sum(), stale_codes.sum()/torch.numel(stale_codes)*100., (self.stale_counter > self.stale_tolerance //2).sum() /torch.numel(self.stale_counter)*100.)) |
|
|
| |
| replace_code = (self.stale_counter[:num_valid_bands] == self.stale_tolerance).float() |
| if replace_code.sum(-1).max() > 0: |
| random_input_idx = torch.randperm(input_detach.shape[0]) |
| random_input = input_detach[random_input_idx].view(input_detach.shape) |
| if random_input.shape[0] < self.num_code: |
| random_input = torch.cat([random_input]*(self.num_code // random_input.shape[0] + 1), 0) |
| random_input = random_input[:self.num_code,:].contiguous().transpose(0,1) |
|
|
| self.embedding[:num_valid_bands] = self.embedding[:num_valid_bands] * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1) |
| self.ema_weight[:num_valid_bands] = self.ema_weight[:num_valid_bands] * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1) |
| self.ema_count[:num_valid_bands] = self.ema_count[:num_valid_bands] * (1 - replace_code) |
| self.stale_counter[:num_valid_bands] = self.stale_counter[:num_valid_bands] * (1 - replace_code) |
|
|
| |
| |
| if self.layer == 0: |
| self.embedding[:num_valid_bands] = self.embedding[:num_valid_bands] / (self.embedding[:num_valid_bands].pow(2).sum(-1) + self.eps).sqrt().unsqueeze(-1) |
| |
| |
| |
| |
|
|
| return quantized, indices.reshape(B, T, -1), perplexity |
|
|
| class RVQEmbedding(nn.Module): |
| def __init__(self, nband, code_dim, decay=0.99, num_codes=[1024, 1024]): |
| super(RVQEmbedding, self).__init__() |
|
|
| self.nband = nband |
| self.code_dim = code_dim |
| self.decay = decay |
| self.eps = torch.finfo(torch.float32).eps |
| self.min_max = [10000, -10000] |
| self.bins = [256+i*8 for i in range(32)] |
|
|
| self.VQEmbedding = nn.ModuleList([]) |
| for i in range(len(num_codes)): |
| self.VQEmbedding.append(VQEmbeddingEMA(nband, num_codes[i], code_dim, decay, layer=i)) |
|
|
| def forward(self, input): |
| norm_value = torch.norm(input, p=2, dim=-2) |
| if(norm_value.min()<self.min_max[0]):self.min_max[0]=norm_value.min().cpu().item() |
| if(norm_value.max()>self.min_max[-1]):self.min_max[-1]=norm_value.max().cpu().item() |
| print("Min-max : {}".format(self.min_max)) |
| norm_value = (((norm_value - 256) / 20).clamp(min=0, max=7).int() * 20 + 256 + 10).float() |
| print("Min-max : {}, {}".format(norm_value.min(), norm_value.max())) |
| input = torch.nn.functional.normalize(input, p = 2, dim = -2) |
|
|
| quantized_list = [] |
| perplexity_list = [] |
| indices_list = [] |
| c = [] |
|
|
| residual_input = input |
| for i in range(len(self.VQEmbedding)): |
| this_quantized, this_indices, this_perplexity = self.VQEmbedding[i](residual_input) |
| perplexity_list.append(this_perplexity) |
| indices_list.append(this_indices) |
| residual_input = residual_input - this_quantized |
| if i == 0: |
| quantized_list.append(this_quantized) |
| else: |
| quantized_list.append(quantized_list[-1] + this_quantized) |
| |
| quantized_list = torch.stack(quantized_list, -1) |
| perplexity_list = torch.stack(perplexity_list, -1) |
| indices_list = torch.stack(indices_list, -1) |
| latent_loss = 0 |
| for i in range(quantized_list.shape[-1]): |
| latent_loss = latent_loss + F.mse_loss(input, quantized_list.detach()[:,:,:,:,i]) |
| |
| quantized_list = quantized_list / (quantized_list.pow(2).sum(2) + self.eps).sqrt().unsqueeze(2) |
|
|
| return quantized_list, norm_value, indices_list, latent_loss |
|
|