Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .base_architecture import BaseArchitecture | |
| from ..builder import ( | |
| ARCHITECTURES, | |
| build_architecture, | |
| build_submodule, | |
| build_loss | |
| ) | |
| class PoseVAE(BaseArchitecture): | |
| def __init__(self, | |
| encoder=None, | |
| decoder=None, | |
| loss_recon=None, | |
| kl_div_loss_weight=None, | |
| init_cfg=None, | |
| **kwargs): | |
| super().__init__(init_cfg=init_cfg, **kwargs) | |
| self.encoder = build_submodule(encoder) | |
| self.decoder = build_submodule(decoder) | |
| self.loss_recon = build_loss(loss_recon) | |
| self.kl_div_loss_weight = kl_div_loss_weight | |
| def reparameterize(self, mu, logvar): | |
| std = torch.exp(logvar / 2) | |
| eps = std.data.new(std.size()).normal_() | |
| latent_code = eps.mul(std).add_(mu) | |
| return latent_code | |
| def encode(self, pose): | |
| mu, logvar = self.encoder(pose) | |
| return mu | |
| def forward(self, **kwargs): | |
| motion = kwargs['motion'].float() | |
| B, T = motion.shape[:2] | |
| pose = motion.reshape(B * T, -1) | |
| pose = pose[:, :-4] | |
| mu, logvar = self.encoder(pose) | |
| z = self.reparameterize(mu, logvar) | |
| pred = self.decoder(z) | |
| loss = dict() | |
| recon_loss = self.loss_recon(pred, pose, reduction_override='none') | |
| loss['recon_loss'] = recon_loss | |
| if self.kl_div_loss_weight is not None: | |
| loss_kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) | |
| loss['kl_div_loss'] = (loss_kl * self.kl_div_loss_weight) | |
| return loss | |
| class MotionVAE(BaseArchitecture): | |
| def __init__(self, | |
| encoder=None, | |
| decoder=None, | |
| loss_recon=None, | |
| kl_div_loss_weight=None, | |
| init_cfg=None, | |
| **kwargs): | |
| super().__init__(init_cfg=init_cfg, **kwargs) | |
| self.encoder = build_submodule(encoder) | |
| self.decoder = build_submodule(decoder) | |
| self.loss_recon = build_loss(loss_recon) | |
| self.kl_div_loss_weight = kl_div_loss_weight | |
| def sample(self, std=1, latent_code=None): | |
| if latent_code is not None: | |
| z = latent_code | |
| else: | |
| z = torch.randn(1, 7, self.decoder.latent_dim).cuda() * std | |
| output = self.decoder(z) | |
| if self.use_normalization: | |
| output = output * self.motion_std | |
| output = output + self.motion_mean | |
| return output | |
| def reparameterize(self, mu, logvar): | |
| std = torch.exp(logvar / 2) | |
| eps = std.data.new(std.size()).normal_() | |
| latent_code = eps.mul(std).add_(mu) | |
| return latent_code | |
| def encode(self, motion, motion_mask): | |
| mu, logvar = self.encoder(motion, motion_mask) | |
| return self.reparameterize(mu, logvar) | |
| def decode(self, z, motion_mask): | |
| return self.decoder(z, motion_mask) | |
| def forward(self, **kwargs): | |
| motion, motion_mask = kwargs['motion'].float(), kwargs['motion_mask'] | |
| B, T = motion.shape[:2] | |
| mu, logvar = self.encoder(motion, motion_mask) | |
| z = self.reparameterize(mu, logvar) | |
| pred = self.decoder(z, motion_mask) | |
| loss = dict() | |
| recon_loss = self.loss_recon(pred, motion, reduction_override='none') | |
| recon_loss = (recon_loss.mean(dim=-1) * motion_mask).sum() / motion_mask.sum() | |
| loss['recon_loss'] = recon_loss | |
| if self.kl_div_loss_weight is not None: | |
| loss_kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) | |
| loss['kl_div_loss'] = (loss_kl * self.kl_div_loss_weight) | |
| return loss |