Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| import torch.nn.functional as F | |
| from torch.optim.lr_scheduler import StepLR | |
| sys.path.append(os.getcwd()) | |
| from nets.layers import * | |
| from nets.base import TrainWrapperBaseClass | |
| from nets.spg.gated_pixelcnn_v2 import GatedPixelCNN as pixelcnn | |
| from nets.spg.vqvae_1d import VQVAE as s2g_body, Wav2VecEncoder, AudioEncoder | |
| from nets.utils import parse_audio, denormalize | |
| from data_utils import get_mfcc, get_melspec, get_mfcc_old, get_mfcc_psf, get_mfcc_psf_min, get_mfcc_ta | |
| from data_utils.lower_body import c_index, c_index_3d, c_index_6d | |
| from data_utils.utils import smooth_geom, get_mfcc_sepa | |
| import numpy as np | |
| from sklearn.preprocessing import normalize | |
| class TrainWrapper(TrainWrapperBaseClass): | |
| ''' | |
| a wrapper receiving a batch from data_utils and calculate loss | |
| ''' | |
| def __init__(self, args, config): | |
| self.args = args | |
| self.config = config | |
| self.global_step = 0 | |
| # Force CPU device | |
| self.device = torch.device('cpu') | |
| self.convert_to_6d = self.config.Data.pose.convert_to_6d | |
| self.expression = self.config.Data.pose.expression | |
| self.epoch = 0 | |
| self.init_params() | |
| self.num_classes = 4 | |
| self.audio = True | |
| self.composition = self.config.Model.composition | |
| self.bh_model = self.config.Model.bh_model | |
| if self.audio: | |
| self.audioencoder = AudioEncoder( | |
| in_dim=64, | |
| num_hiddens=256, | |
| num_residual_layers=2, | |
| num_residual_hiddens=256 | |
| ).to(self.device) | |
| else: | |
| self.audioencoder = None | |
| if self.convert_to_6d: | |
| dim, layer = 512, 10 | |
| else: | |
| dim, layer = 256, 15 | |
| self.generator = pixelcnn(2048, dim, layer, self.num_classes, self.audio, self.bh_model).to(self.device) | |
| self.g_body = s2g_body(self.each_dim[1], embedding_dim=64, num_embeddings=config.Model.code_num, num_hiddens=1024, | |
| num_residual_layers=2, num_residual_hiddens=512).to(self.device) | |
| self.g_hand = s2g_body(self.each_dim[2], embedding_dim=64, num_embeddings=config.Model.code_num, num_hiddens=1024, | |
| num_residual_layers=2, num_residual_hiddens=512).to(self.device) | |
| model_path = self.config.Model.vq_path | |
| model_ckpt = torch.load(model_path, map_location=torch.device('cpu')) | |
| self.g_body.load_state_dict(model_ckpt['generator']['g_body']) | |
| self.g_hand.load_state_dict(model_ckpt['generator']['g_hand']) | |
| self.discriminator = None | |
| if self.convert_to_6d: | |
| self.c_index = c_index_6d | |
| else: | |
| self.c_index = c_index_3d | |
| super().__init__(args, config) | |
| def init_optimizer(self): | |
| print('using Adam') | |
| self.generator_optimizer = optim.Adam( | |
| self.generator.parameters(), | |
| lr=self.config.Train.learning_rate.generator_learning_rate, | |
| betas=[0.9, 0.999] | |
| ) | |
| if self.audioencoder is not None: | |
| opt = self.config.Model.AudioOpt | |
| if opt == 'Adam': | |
| self.audioencoder_optimizer = optim.Adam( | |
| self.audioencoder.parameters(), | |
| lr=self.config.Train.learning_rate.generator_learning_rate, | |
| betas=[0.9, 0.999] | |
| ) | |
| else: | |
| print('using SGD') | |
| self.audioencoder_optimizer = optim.SGD( | |
| filter(lambda p: p.requires_grad, self.audioencoder.parameters()), | |
| lr=self.config.Train.learning_rate.generator_learning_rate * 10, | |
| momentum=0.9, | |
| nesterov=False | |
| ) | |
| def state_dict(self): | |
| return { | |
| 'generator': self.generator.state_dict(), | |
| 'generator_optim': self.generator_optimizer.state_dict(), | |
| 'audioencoder': self.audioencoder.state_dict() if self.audio else None, | |
| 'audioencoder_optim': self.audioencoder_optimizer.state_dict() if self.audio else None, | |
| 'discriminator': self.discriminator.state_dict() if self.discriminator else None, | |
| 'discriminator_optim': self.discriminator_optimizer.state_dict() if self.discriminator else None | |
| } | |
| def load_state_dict(self, state_dict): | |
| from collections import OrderedDict | |
| new_state_dict = OrderedDict() | |
| for k, v in state_dict.items(): | |
| sub_dict = OrderedDict() | |
| if v is not None: | |
| for k1, v1 in v.items(): | |
| name = k1.replace('module.', '') | |
| sub_dict[name] = v1 | |
| new_state_dict[k] = sub_dict | |
| state_dict = new_state_dict | |
| if 'generator' in state_dict: | |
| self.generator.load_state_dict(state_dict['generator']) | |
| else: | |
| self.generator.load_state_dict(state_dict) | |
| if 'generator_optim' in state_dict and self.generator_optimizer is not None: | |
| self.generator_optimizer.load_state_dict(state_dict['generator_optim']) | |
| if self.discriminator is not None: | |
| self.discriminator.load_state_dict(state_dict['discriminator']) | |
| if 'discriminator_optim' in state_dict and self.discriminator_optimizer is not None: | |
| self.discriminator_optimizer.load_state_dict(state_dict['discriminator_optim']) | |
| if 'audioencoder' in state_dict and self.audioencoder is not None: | |
| self.audioencoder.load_state_dict(state_dict['audioencoder']) | |
| def init_params(self): | |
| if self.config.Data.pose.convert_to_6d: | |
| scale = 2 | |
| else: | |
| scale = 1 | |
| global_orient = round(0 * scale) | |
| leye_pose = reye_pose = round(0 * scale) | |
| jaw_pose = round(0 * scale) | |
| body_pose = round((63 - 24) * scale) | |
| left_hand_pose = right_hand_pose = round(45 * scale) | |
| if self.expression: | |
| expression = 100 | |
| else: | |
| expression = 0 | |
| b_j = 0 | |
| jaw_dim = jaw_pose | |
| b_e = b_j + jaw_dim | |
| eye_dim = leye_pose + reye_pose | |
| b_b = b_e + eye_dim | |
| body_dim = global_orient + body_pose | |
| b_h = b_b + body_dim | |
| hand_dim = left_hand_pose + right_hand_pose | |
| b_f = b_h + hand_dim | |
| face_dim = expression | |
| self.dim_list = [b_j, b_e, b_b, b_h, b_f] | |
| self.full_dim = jaw_dim + eye_dim + body_dim + hand_dim | |
| self.pose = int(self.full_dim / round(3 * scale)) | |
| self.each_dim = [jaw_dim, eye_dim + body_dim, hand_dim, face_dim] | |
| def __call__(self, bat): | |
| self.global_step += 1 | |
| total_loss = None | |
| loss_dict = {} | |
| aud, poses = bat['aud_feat'].to(self.device).float(), bat['poses'].to(self.device).float() | |
| id = bat['speaker'].to(self.device) - 20 | |
| poses = poses[:, self.c_index, :] | |
| aud = aud.permute(0, 2, 1) | |
| gt_poses = poses.permute(0, 2, 1) | |
| with torch.no_grad(): | |
| self.g_body.eval() | |
| self.g_hand.eval() | |
| _, body_latents = self.g_body.encode(gt_poses=gt_poses[..., :self.each_dim[1]], id=id) | |
| _, hand_latents = self.g_hand.encode(gt_poses=gt_poses[..., self.each_dim[1]:], id=id) | |
| latents = torch.cat([body_latents.unsqueeze(-1), hand_latents.unsqueeze(-1)], dim=-1).detach() | |
| if self.audio: | |
| audio = self.audioencoder(aud.transpose(1, 2), frame_num=latents.shape[1]*4).unsqueeze(-1).repeat(1, 1, 1, 2) | |
| logits = self.generator(latents, id, audio) | |
| else: | |
| logits = self.generator(latents, id) | |
| logits = logits.permute(0, 2, 3, 1).contiguous() | |
| self.generator_optimizer.zero_grad() | |
| if self.audio: | |
| self.audioencoder_optimizer.zero_grad() | |
| loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), latents.view(-1)) | |
| loss.backward() | |
| grad = torch.nn.utils.clip_grad_norm(self.generator.parameters(), self.config.Train.max_gradient_norm) | |
| loss_dict['grad'] = grad.item() | |
| loss_dict['ce_loss'] = loss.item() | |
| self.generator_optimizer.step() | |
| if self.audio: | |
| self.audioencoder_optimizer.step() | |
| return total_loss, loss_dict | |
| # ---------------------------------------- | |
| # 🚀 NEW SIMPLE WRAPPER CLASS for inference | |
| # ---------------------------------------- | |
| class s2g_body_pixel(nn.Module): | |
| def __init__(self, args, config): | |
| super().__init__() | |
| self.wrapper = TrainWrapper(args, config) | |
| def infer_on_audio(self, *args, **kwargs): | |
| return self.wrapper.infer_on_audio(*args, **kwargs) | |
| def forward(self, *args, **kwargs): | |
| return self.wrapper(*args, **kwargs) | |
| def load_state_dict(self, *args, **kwargs): | |
| return self.wrapper.load_state_dict(*args, **kwargs) | |