NMR / src /models /tokenizers /vqvae.py
RayZhao's picture
initial push
4cc0d6c
import torch
from mmengine.model import BaseModel
from mmengine.registry import MODELS
from .encdoc.encdoc_attn import length_to_mask
from torch.nn.utils.rnn import pad_sequence
@MODELS.register_module()
class VQVAE(BaseModel):
def __init__(self, encoder_cfg, decoder_cfg, quantizer_cfg, loss_cfg=None, **kwargs):
super().__init__(**kwargs)
self.encoder = MODELS.build(encoder_cfg)
self.decoder = MODELS.build(decoder_cfg)
self.quantizer = MODELS.build(quantizer_cfg)
if loss_cfg:
self.recons_loss = MODELS.build(loss_cfg)
def preprocess(self, x):
# (bs, T, Jx3) -> (bs, Jx3, T)
x = x.permute(0, 2, 1).float()
return x
def postprocess(self, x):
# (bs, Jx3, T) -> (bs, T, Jx3)
x = x.permute(0, 2, 1)
return x
def forward_loss(self, motion, motion_length, **kwargs):
# encoder
x_in = self.preprocess(motion)
x_encoder = self.encoder(x_in, motion_length=motion_length)
motion_length = motion_length // (x_in.shape[2] // x_encoder.shape[2])
# quantization
x_quantized, commit_loss, perplexity, activate, code_indices = self.quantizer(
x_encoder, motion_length=motion_length) # B, C, T'
mask = length_to_mask(motion_length, max_length=x_quantized.shape[-1]).unsqueeze(1) # (B, 1, T')
x_quantized = x_quantized * mask # (B, C, T')
# decoder
x_decoder = self.decoder(x_quantized, motion_length=motion_length)
pred_motion = self.postprocess(x_decoder) # (B, T, C)
losses = self.recons_loss(pred_motion, motion, motion_length=motion_length, commit_loss=commit_loss)
return losses
def forward_predict(self, motion, motion_length, **kwargs):
## encoder
x_in = self.preprocess(motion)
x_encoder = self.encoder(x_in, motion_length=motion_length)
motion_length = motion_length // (x_in.shape[2] // x_encoder.shape[2])
x_quantized, commit_loss, perplexity, activate, code_indices = self.quantizer(
x_encoder, motion_length=motion_length) # B, C, T'
mask = length_to_mask(motion_length, max_length=x_quantized.shape[-1]).unsqueeze(1) # (B, 1, T')
x_quantized = x_quantized * mask # (B, C, T')
## decoder
x_decoder = self.decoder(x_quantized, motion_length=motion_length)
pred_motion = self.postprocess(x_decoder) # (B, T, C)
return pred_motion, code_indices
def forward(self, motion: torch.Tensor, mode: str='tensor', **kwargs): # type: ignore
if mode == 'loss':
return self.forward_loss(motion, **kwargs)
elif mode == 'predict':
return self.forward_predict(motion, **kwargs)
else:
raise NotImplementedError
def encode(self, motion, motion_length):
if isinstance(motion, list):
motion = pad_sequence(motion, batch_first=True)
x_in = self.preprocess(motion)
x_encoder = self.encoder(x_in, motion_length=motion_length)
code_indices_length = motion_length // (x_in.shape[2] // x_encoder.shape[2])
code_indices = self.quantizer(x_encoder, motion_length=motion_length)[-1] # B, C, T'
new_code_indices = []
for code_indice, code_indice_length in zip(code_indices, code_indices_length):
new_code_indices.append(code_indice[:code_indice_length])
return new_code_indices
def decode(self, x, motion_length=None):
assert x.dim() == 2 # B, T
# import ipdb; ipdb.set_trace()
B, T = x.shape
x_d = self.quantizer.dequantize(x) # B, T
x_d = x_d.permute(0, 2, 1) # (B, T, C) -> (B, C, T)
if motion_length is None:
motion_length = torch.tensor(T).unsqueeze(0).to(x_d.device).repeat(B)
x_out = self.decoder(x_d, motion_length)
return self.postprocess(x_out)