import numpy as np import torch from torch import nn from torch.nn import functional as F from .layers import MLPLayers from .rq import ResidualVectorQuantizer class RQVAE(nn.Module): def __init__(self, in_dim=768, # num_emb_list=[256,256,256,256], num_emb_list=None, e_dim=64, # layers=[512,256,128], layers=None, dropout_prob=0.0, bn=False, loss_type="mse", quant_loss_weight=1.0, kmeans_init=False, kmeans_iters=100, # sk_epsilons=[0,0,0.003,0.01]], sk_epsilons=None, sk_iters=100, ): super(RQVAE, self).__init__() self.in_dim = in_dim self.num_emb_list = num_emb_list self.e_dim = e_dim self.layers = layers self.dropout_prob = dropout_prob self.bn = bn self.loss_type = loss_type self.quant_loss_weight=quant_loss_weight self.kmeans_init = kmeans_init self.kmeans_iters = kmeans_iters self.sk_epsilons = sk_epsilons self.sk_iters = sk_iters self.encode_layer_dims = [self.in_dim] + self.layers + [self.e_dim] self.encoder = MLPLayers(layers=self.encode_layer_dims, dropout=self.dropout_prob,bn=self.bn) self.rq = ResidualVectorQuantizer(num_emb_list, e_dim, kmeans_init = self.kmeans_init, kmeans_iters = self.kmeans_iters, sk_epsilons=self.sk_epsilons, sk_iters=self.sk_iters,) self.decode_layer_dims = self.encode_layer_dims[::-1] self.decoder = MLPLayers(layers=self.decode_layer_dims, dropout=self.dropout_prob,bn=self.bn) def forward(self, x, use_sk=True): # print('x.shape:',x.shape) x = self.encoder(x) x_q, rq_loss, indices = self.rq(x,use_sk=use_sk) out = self.decoder(x_q) # print('out.shape:',out.shape) return out, rq_loss, indices @torch.no_grad() def get_indices(self, xs, use_sk=False): x_e = self.encoder(xs) _, _, indices = self.rq(x_e, use_sk=use_sk) return indices def compute_loss(self, out, quant_loss, xs=None): if self.loss_type == 'mse': loss_recon = F.mse_loss(out, xs, reduction='mean') elif self.loss_type == 'l1': loss_recon = F.l1_loss(out, xs, reduction='mean') else: raise ValueError('incompatible loss type') loss_total = loss_recon + self.quant_loss_weight * quant_loss return loss_total, loss_recon