|
|
import numpy as np |
|
|
import torch |
|
|
from torch import nn |
|
|
from torch.nn import functional as F |
|
|
|
|
|
from .layers import MLPLayers |
|
|
from .rq import ResidualVectorQuantizer |
|
|
|
|
|
|
|
|
class RQVAE(nn.Module): |
|
|
def __init__(self, |
|
|
in_dim=768, |
|
|
|
|
|
num_emb_list=None, |
|
|
e_dim=64, |
|
|
|
|
|
layers=None, |
|
|
dropout_prob=0.0, |
|
|
bn=False, |
|
|
loss_type="mse", |
|
|
quant_loss_weight=1.0, |
|
|
kmeans_init=False, |
|
|
kmeans_iters=100, |
|
|
|
|
|
sk_epsilons=None, |
|
|
sk_iters=100, |
|
|
): |
|
|
super(RQVAE, self).__init__() |
|
|
|
|
|
self.in_dim = in_dim |
|
|
self.num_emb_list = num_emb_list |
|
|
self.e_dim = e_dim |
|
|
|
|
|
self.layers = layers |
|
|
self.dropout_prob = dropout_prob |
|
|
self.bn = bn |
|
|
self.loss_type = loss_type |
|
|
self.quant_loss_weight=quant_loss_weight |
|
|
self.kmeans_init = kmeans_init |
|
|
self.kmeans_iters = kmeans_iters |
|
|
self.sk_epsilons = sk_epsilons |
|
|
self.sk_iters = sk_iters |
|
|
|
|
|
self.encode_layer_dims = [self.in_dim] + self.layers + [self.e_dim] |
|
|
self.encoder = MLPLayers(layers=self.encode_layer_dims, |
|
|
dropout=self.dropout_prob,bn=self.bn) |
|
|
|
|
|
self.rq = ResidualVectorQuantizer(num_emb_list, e_dim, |
|
|
kmeans_init = self.kmeans_init, |
|
|
kmeans_iters = self.kmeans_iters, |
|
|
sk_epsilons=self.sk_epsilons, |
|
|
sk_iters=self.sk_iters,) |
|
|
|
|
|
self.decode_layer_dims = self.encode_layer_dims[::-1] |
|
|
self.decoder = MLPLayers(layers=self.decode_layer_dims, |
|
|
dropout=self.dropout_prob,bn=self.bn) |
|
|
|
|
|
def forward(self, x, use_sk=True): |
|
|
|
|
|
x = self.encoder(x) |
|
|
x_q, rq_loss, indices = self.rq(x,use_sk=use_sk) |
|
|
out = self.decoder(x_q) |
|
|
|
|
|
|
|
|
return out, rq_loss, indices |
|
|
|
|
|
@torch.no_grad() |
|
|
def get_indices(self, xs, use_sk=False): |
|
|
x_e = self.encoder(xs) |
|
|
_, _, indices = self.rq(x_e, use_sk=use_sk) |
|
|
return indices |
|
|
|
|
|
def compute_loss(self, out, quant_loss, xs=None): |
|
|
|
|
|
if self.loss_type == 'mse': |
|
|
loss_recon = F.mse_loss(out, xs, reduction='mean') |
|
|
elif self.loss_type == 'l1': |
|
|
loss_recon = F.l1_loss(out, xs, reduction='mean') |
|
|
else: |
|
|
raise ValueError('incompatible loss type') |
|
|
|
|
|
loss_total = loss_recon + self.quant_loss_weight * quant_loss |
|
|
|
|
|
return loss_total, loss_recon |