import torch import torch.nn.functional as F import random from data_loaders.humanml.networks.modules import * from torch.utils.data import DataLoader import torch.optim as optim from torch.nn.utils import clip_grad_norm_ # import tensorflow as tf from collections import OrderedDict from data_loaders.humanml.utils.utils import * from os.path import join as pjoin from data_loaders.humanml.data.dataset import collate_fn import codecs as cs class Logger(object): def __init__(self, log_dir): self.writer = tf.summary.create_file_writer(log_dir) def scalar_summary(self, tag, value, step): with self.writer.as_default(): tf.summary.scalar(tag, value, step=step) self.writer.flush() class DecompTrainerV3(object): def __init__(self, args, movement_enc, movement_dec): self.opt = args self.movement_enc = movement_enc self.movement_dec = movement_dec self.device = args.device if args.is_train: self.logger = Logger(args.log_dir) self.sml1_criterion = torch.nn.SmoothL1Loss() self.l1_criterion = torch.nn.L1Loss() self.mse_criterion = torch.nn.MSELoss() @staticmethod def zero_grad(opt_list): for opt in opt_list: opt.zero_grad() @staticmethod def clip_norm(network_list): for network in network_list: clip_grad_norm_(network.parameters(), 0.5) @staticmethod def step(opt_list): for opt in opt_list: opt.step() def forward(self, batch_data): motions = batch_data self.motions = motions.detach().to(self.device).float() self.latents = self.movement_enc(self.motions[..., :-4]) self.recon_motions = self.movement_dec(self.latents) def backward(self): self.loss_rec = self.l1_criterion(self.recon_motions, self.motions) # self.sml1_criterion(self.recon_motions[:, 1:] - self.recon_motions[:, :-1], # self.motions[:, 1:] - self.recon_motions[:, :-1]) self.loss_sparsity = torch.mean(torch.abs(self.latents)) self.loss_smooth = self.l1_criterion(self.latents[:, 1:], self.latents[:, :-1]) self.loss = self.loss_rec + self.loss_sparsity * self.opt.lambda_sparsity +\ self.loss_smooth*self.opt.lambda_smooth def update(self): # time0 = time.time() self.zero_grad([self.opt_movement_enc, self.opt_movement_dec]) # time1 = time.time() # print('\t Zero_grad Time: %.5f s' % (time1 - time0)) self.backward() # time2 = time.time() # print('\t Backward Time: %.5f s' % (time2 - time1)) self.loss.backward() # time3 = time.time() # print('\t Loss backward Time: %.5f s' % (time3 - time2)) # self.clip_norm([self.movement_enc, self.movement_dec]) # time4 = time.time() # print('\t Clip_norm Time: %.5f s' % (time4 - time3)) self.step([self.opt_movement_enc, self.opt_movement_dec]) # time5 = time.time() # print('\t Step Time: %.5f s' % (time5 - time4)) loss_logs = OrderedDict({}) loss_logs['loss'] = self.loss_rec.item() loss_logs['loss_rec'] = self.loss_rec.item() loss_logs['loss_sparsity'] = self.loss_sparsity.item() loss_logs['loss_smooth'] = self.loss_smooth.item() return loss_logs def save(self, file_name, ep, total_it): state = { 'movement_enc': self.movement_enc.state_dict(), 'movement_dec': self.movement_dec.state_dict(), 'opt_movement_enc': self.opt_movement_enc.state_dict(), 'opt_movement_dec': self.opt_movement_dec.state_dict(), 'ep': ep, 'total_it': total_it, } torch.save(state, file_name) return def resume(self, model_dir): checkpoint = torch.load(model_dir, map_location=self.device) self.movement_dec.load_state_dict(checkpoint['movement_dec']) self.movement_enc.load_state_dict(checkpoint['movement_enc']) self.opt_movement_enc.load_state_dict(checkpoint['opt_movement_enc']) self.opt_movement_dec.load_state_dict(checkpoint['opt_movement_dec']) return checkpoint['ep'], checkpoint['total_it'] def train(self, train_dataloader, val_dataloader, plot_eval): self.movement_enc.to(self.device) self.movement_dec.to(self.device) self.opt_movement_enc = optim.Adam(self.movement_enc.parameters(), lr=self.opt.lr) self.opt_movement_dec = optim.Adam(self.movement_dec.parameters(), lr=self.opt.lr) epoch = 0 it = 0 if self.opt.is_continue: model_dir = pjoin(self.opt.model_dir, 'latest.tar') epoch, it = self.resume(model_dir) start_time = time.time() total_iters = self.opt.max_epoch * len(train_dataloader) print('Iters Per Epoch, Training: %04d, Validation: %03d' % (len(train_dataloader), len(val_dataloader))) val_loss = 0 logs = OrderedDict() while epoch < self.opt.max_epoch: # time0 = time.time() for i, batch_data in enumerate(train_dataloader): self.movement_dec.train() self.movement_enc.train() # time1 = time.time() # print('DataLoader Time: %.5f s'%(time1-time0) ) self.forward(batch_data) # time2 = time.time() # print('Forward Time: %.5f s'%(time2-time1)) log_dict = self.update() # time3 = time.time() # print('Update Time: %.5f s' % (time3 - time2)) # time0 = time3 for k, v in log_dict.items(): if k not in logs: logs[k] = v else: logs[k] += v it += 1 if it % self.opt.log_every == 0: mean_loss = OrderedDict({'val_loss': val_loss}) self.logger.scalar_summary('val_loss', val_loss, it) for tag, value in logs.items(): self.logger.scalar_summary(tag, value / self.opt.log_every, it) mean_loss[tag] = value / self.opt.log_every logs = OrderedDict() print_current_loss_decomp(start_time, it, total_iters, mean_loss, epoch, i) if it % self.opt.save_latest == 0: self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it) self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it) epoch += 1 if epoch % self.opt.save_every_e == 0: self.save(pjoin(self.opt.model_dir, 'E%04d.tar' % (epoch)), epoch, total_it=it) print('Validation time:') val_loss = 0 val_rec_loss = 0 val_sparcity_loss = 0 val_smooth_loss = 0 with torch.no_grad(): for i, batch_data in enumerate(val_dataloader): self.forward(batch_data) self.backward() val_rec_loss += self.loss_rec.item() val_smooth_loss += self.loss.item() val_sparcity_loss += self.loss_sparsity.item() val_smooth_loss += self.loss_smooth.item() val_loss += self.loss.item() val_loss = val_loss / (len(val_dataloader) + 1) val_rec_loss = val_rec_loss / (len(val_dataloader) + 1) val_sparcity_loss = val_sparcity_loss / (len(val_dataloader) + 1) val_smooth_loss = val_smooth_loss / (len(val_dataloader) + 1) print('Validation Loss: %.5f Reconstruction Loss: %.5f ' 'Sparsity Loss: %.5f Smooth Loss: %.5f' % (val_loss, val_rec_loss, val_sparcity_loss, \ val_smooth_loss)) if epoch % self.opt.eval_every_e == 0: data = torch.cat([self.recon_motions[:4], self.motions[:4]], dim=0).detach().cpu().numpy() save_dir = pjoin(self.opt.eval_dir, 'E%04d' % (epoch)) os.makedirs(save_dir, exist_ok=True) plot_eval(data, save_dir) # VAE Sequence Decoder/Prior/Posterior latent by latent class CompTrainerV6(object): def __init__(self, args, text_enc, seq_pri, seq_dec, att_layer, mov_dec, mov_enc=None, seq_post=None): self.opt = args self.text_enc = text_enc self.seq_pri = seq_pri self.att_layer = att_layer self.device = args.device self.seq_dec = seq_dec self.mov_dec = mov_dec self.mov_enc = mov_enc if args.is_train: self.seq_post = seq_post # self.motion_dis self.logger = Logger(args.log_dir) self.l1_criterion = torch.nn.SmoothL1Loss() self.gan_criterion = torch.nn.BCEWithLogitsLoss() self.mse_criterion = torch.nn.MSELoss() @staticmethod def reparametrize(mu, logvar): s_var = logvar.mul(0.5).exp_() eps = s_var.data.new(s_var.size()).normal_() return eps.mul(s_var).add_(mu) @staticmethod def ones_like(tensor, val=1.): return torch.FloatTensor(tensor.size()).fill_(val).to(tensor.device).requires_grad_(False) @staticmethod def zeros_like(tensor, val=0.): return torch.FloatTensor(tensor.size()).fill_(val).to(tensor.device).requires_grad_(False) @staticmethod def zero_grad(opt_list): for opt in opt_list: opt.zero_grad() @staticmethod def clip_norm(network_list): for network in network_list: clip_grad_norm_(network.parameters(), 0.5) @staticmethod def step(opt_list): for opt in opt_list: opt.step() @staticmethod def kl_criterion(mu1, logvar1, mu2, logvar2): # KL( N(mu1, sigma2_1) || N(mu_2, sigma2_2)) # loss = log(sigma2/sigma1) + (sigma1^2 + (mu1 - mu2)^2)/(2*sigma2^2) - 1/2 sigma1 = logvar1.mul(0.5).exp() sigma2 = logvar2.mul(0.5).exp() kld = torch.log(sigma2 / sigma1) + (torch.exp(logvar1) + (mu1 - mu2) ** 2) / ( 2 * torch.exp(logvar2)) - 1 / 2 return kld.sum() / mu1.shape[0] @staticmethod def kl_criterion_unit(mu, logvar): # KL( N(mu1, sigma2_1) || N(mu_2, sigma2_2)) # loss = log(sigma2/sigma1) + (sigma1^2 + (mu1 - mu2)^2)/(2*sigma2^2) - 1/2 kld = ((torch.exp(logvar) + mu ** 2) - logvar - 1) / 2 return kld.sum() / mu.shape[0] def forward(self, batch_data, tf_ratio, mov_len, eval_mode=False): word_emb, pos_ohot, caption, cap_lens, motions, m_lens = batch_data word_emb = word_emb.detach().to(self.device).float() pos_ohot = pos_ohot.detach().to(self.device).float() motions = motions.detach().to(self.device).float() self.cap_lens = cap_lens self.caption = caption # print(motions.shape) # (batch_size, motion_len, pose_dim) self.motions = motions '''Movement Encoding''' self.movements = self.mov_enc(self.motions[..., :-4]).detach() # Initially input a mean vector mov_in = self.mov_enc( torch.zeros((self.motions.shape[0], self.opt.unit_length, self.motions.shape[-1] - 4), device=self.device) ).squeeze(1).detach() assert self.movements.shape[1] == mov_len teacher_force = True if random.random() < tf_ratio else False '''Text Encoding''' # time0 = time.time() # text_input = torch.cat([word_emb, pos_ohot], dim=-1) word_hids, hidden = self.text_enc(word_emb, pos_ohot, cap_lens) # print(word_hids.shape, hidden.shape) if self.opt.text_enc_mod == 'bigru': hidden_pos = self.seq_post.get_init_hidden(hidden) hidden_pri = self.seq_pri.get_init_hidden(hidden) hidden_dec = self.seq_dec.get_init_hidden(hidden) elif self.opt.text_enc_mod == 'transformer': hidden_pos = self.seq_post.get_init_hidden(hidden.detach()) hidden_pri = self.seq_pri.get_init_hidden(hidden.detach()) hidden_dec = self.seq_dec.get_init_hidden(hidden) mus_pri = [] logvars_pri = [] mus_post = [] logvars_post = [] fake_mov_batch = [] query_input = [] # time1 = time.time() # print("\t Text Encoder Cost:%5f" % (time1 - time0)) # print(self.movements.shape) for i in range(mov_len): # print("\t Sequence Measure") # print(mov_in.shape) mov_tgt = self.movements[:, i] '''Local Attention Vector''' att_vec, _ = self.att_layer(hidden_dec[-1], word_hids) query_input.append(hidden_dec[-1]) tta = m_lens // self.opt.unit_length - i if self.opt.text_enc_mod == 'bigru': pos_in = torch.cat([mov_in, mov_tgt, att_vec], dim=-1) pri_in = torch.cat([mov_in, att_vec], dim=-1) elif self.opt.text_enc_mod == 'transformer': pos_in = torch.cat([mov_in, mov_tgt, att_vec.detach()], dim=-1) pri_in = torch.cat([mov_in, att_vec.detach()], dim=-1) '''Posterior''' z_pos, mu_pos, logvar_pos, hidden_pos = self.seq_post(pos_in, hidden_pos, tta) '''Prior''' z_pri, mu_pri, logvar_pri, hidden_pri = self.seq_pri(pri_in, hidden_pri, tta) '''Decoder''' if eval_mode: dec_in = torch.cat([mov_in, att_vec, z_pri], dim=-1) else: dec_in = torch.cat([mov_in, att_vec, z_pos], dim=-1) fake_mov, hidden_dec = self.seq_dec(dec_in, mov_in, hidden_dec, tta) # print(fake_mov.shape) mus_post.append(mu_pos) logvars_post.append(logvar_pos) mus_pri.append(mu_pri) logvars_pri.append(logvar_pri) fake_mov_batch.append(fake_mov.unsqueeze(1)) if teacher_force: mov_in = self.movements[:, i].detach() else: mov_in = fake_mov.detach() self.fake_movements = torch.cat(fake_mov_batch, dim=1) # print(self.fake_movements.shape) self.fake_motions = self.mov_dec(self.fake_movements) self.mus_post = torch.cat(mus_post, dim=0) self.mus_pri = torch.cat(mus_pri, dim=0) self.logvars_post = torch.cat(logvars_post, dim=0) self.logvars_pri = torch.cat(logvars_pri, dim=0) def generate(self, word_emb, pos_ohot, cap_lens, m_lens, mov_len, dim_pose): word_emb = word_emb.detach().to(self.device).float() pos_ohot = pos_ohot.detach().to(self.device).float() self.cap_lens = cap_lens # print(motions.shape) # (batch_size, motion_len, pose_dim) '''Movement Encoding''' # Initially input a mean vector mov_in = self.mov_enc( torch.zeros((word_emb.shape[0], self.opt.unit_length, dim_pose - 4), device=self.device) ).squeeze(1).detach() '''Text Encoding''' # time0 = time.time() # text_input = torch.cat([word_emb, pos_ohot], dim=-1) word_hids, hidden = self.text_enc(word_emb, pos_ohot, cap_lens) # print(word_hids.shape, hidden.shape) hidden_pri = self.seq_pri.get_init_hidden(hidden) hidden_dec = self.seq_dec.get_init_hidden(hidden) mus_pri = [] logvars_pri = [] fake_mov_batch = [] att_wgt = [] # time1 = time.time() # print("\t Text Encoder Cost:%5f" % (time1 - time0)) # print(self.movements.shape) for i in range(mov_len): # print("\t Sequence Measure") # print(mov_in.shape) '''Local Attention Vector''' att_vec, co_weights = self.att_layer(hidden_dec[-1], word_hids) tta = m_lens // self.opt.unit_length - i # tta = m_lens - i '''Prior''' pri_in = torch.cat([mov_in, att_vec], dim=-1) z_pri, mu_pri, logvar_pri, hidden_pri = self.seq_pri(pri_in, hidden_pri, tta) '''Decoder''' dec_in = torch.cat([mov_in, att_vec, z_pri], dim=-1) fake_mov, hidden_dec = self.seq_dec(dec_in, mov_in, hidden_dec, tta) # print(fake_mov.shape) mus_pri.append(mu_pri) logvars_pri.append(logvar_pri) fake_mov_batch.append(fake_mov.unsqueeze(1)) att_wgt.append(co_weights) mov_in = fake_mov.detach() fake_movements = torch.cat(fake_mov_batch, dim=1) att_wgts = torch.cat(att_wgt, dim=-1) # print(self.fake_movements.shape) fake_motions = self.mov_dec(fake_movements) mus_pri = torch.cat(mus_pri, dim=0) logvars_pri = torch.cat(logvars_pri, dim=0) return fake_motions, mus_pri, att_wgts def backward_G(self): self.loss_mot_rec = self.l1_criterion(self.fake_motions, self.motions) self.loss_mov_rec = self.l1_criterion(self.fake_movements, self.movements) self.loss_kld = self.kl_criterion(self.mus_post, self.logvars_post, self.mus_pri, self.logvars_pri) self.loss_gen = self.loss_mot_rec * self.opt.lambda_rec_mov + self.loss_mov_rec * self.opt.lambda_rec_mot + \ self.loss_kld * self.opt.lambda_kld loss_logs = OrderedDict({}) loss_logs['loss_gen'] = self.loss_gen.item() loss_logs['loss_mot_rec'] = self.loss_mot_rec.item() loss_logs['loss_mov_rec'] = self.loss_mov_rec.item() loss_logs['loss_kld'] = self.loss_kld.item() return loss_logs # self.loss_gen = self.loss_rec_mov # self.loss_gen = self.loss_rec_mov * self.opt.lambda_rec_mov + self.loss_rec_mot + \ # self.loss_kld * self.opt.lambda_kld + \ # self.loss_mtgan_G * self.opt.lambda_gan_mt + self.loss_mvgan_G * self.opt.lambda_gan_mv def update(self): self.zero_grad([self.opt_text_enc, self.opt_seq_dec, self.opt_seq_post, self.opt_seq_pri, self.opt_att_layer, self.opt_mov_dec]) # time2_0 = time.time() # print("\t\t Zero Grad:%5f" % (time2_0 - time1)) loss_logs = self.backward_G() # time2_1 = time.time() # print("\t\t Backward_G :%5f" % (time2_1 - time2_0)) self.loss_gen.backward() # time2_2 = time.time() # print("\t\t Backward :%5f" % (time2_2 - time2_1)) self.clip_norm([self.text_enc, self.seq_dec, self.seq_post, self.seq_pri, self.att_layer, self.mov_dec]) # time2_3 = time.time() # print("\t\t Clip Norm :%5f" % (time2_3 - time2_2)) self.step([self.opt_text_enc, self.opt_seq_dec, self.opt_seq_post, self.opt_seq_pri, self.opt_att_layer, self.opt_mov_dec]) # time2_4 = time.time() # print("\t\t Step :%5f" % (time2_4 - time2_3)) # time2 = time.time() # print("\t Update Generator Cost:%5f" % (time2 - time1)) # self.zero_grad([self.opt_att_layer]) # self.backward_Att() # self.loss_lgan_G_.backward() # self.clip_norm([self.att_layer]) # self.step([self.opt_att_layer]) # # time3 = time.time() # # print("\t Update Att Cost:%5f" % (time3 - time2)) # self.loss_gen += self.loss_lgan_G_ return loss_logs def to(self, device): if self.opt.is_train: self.gan_criterion.to(device) self.mse_criterion.to(device) self.l1_criterion.to(device) self.seq_post.to(device) self.mov_enc.to(device) self.text_enc.to(device) self.mov_dec.to(device) self.seq_pri.to(device) self.att_layer.to(device) self.seq_dec.to(device) def train_mode(self): if self.opt.is_train: self.seq_post.train() self.mov_enc.eval() # self.motion_dis.train() # self.movement_dis.train() self.mov_dec.train() self.text_enc.train() self.seq_pri.train() self.att_layer.train() self.seq_dec.train() def eval_mode(self): if self.opt.is_train: self.seq_post.eval() self.mov_enc.eval() # self.motion_dis.train() # self.movement_dis.train() self.mov_dec.eval() self.text_enc.eval() self.seq_pri.eval() self.att_layer.eval() self.seq_dec.eval() def save(self, file_name, ep, total_it, sub_ep, sl_len): state = { # 'latent_dis': self.latent_dis.state_dict(), # 'motion_dis': self.motion_dis.state_dict(), 'text_enc': self.text_enc.state_dict(), 'seq_post': self.seq_post.state_dict(), 'att_layer': self.att_layer.state_dict(), 'seq_dec': self.seq_dec.state_dict(), 'seq_pri': self.seq_pri.state_dict(), 'mov_enc': self.mov_enc.state_dict(), 'mov_dec': self.mov_dec.state_dict(), # 'opt_motion_dis': self.opt_motion_dis.state_dict(), 'opt_mov_dec': self.opt_mov_dec.state_dict(), 'opt_text_enc': self.opt_text_enc.state_dict(), 'opt_seq_pri': self.opt_seq_pri.state_dict(), 'opt_att_layer': self.opt_att_layer.state_dict(), 'opt_seq_post': self.opt_seq_post.state_dict(), 'opt_seq_dec': self.opt_seq_dec.state_dict(), # 'opt_movement_dis': self.opt_movement_dis.state_dict(), 'ep': ep, 'total_it': total_it, 'sub_ep': sub_ep, 'sl_len': sl_len } torch.save(state, file_name) return def load(self, model_dir): checkpoint = torch.load(model_dir, map_location=self.device) if self.opt.is_train: self.seq_post.load_state_dict(checkpoint['seq_post']) # self.opt_latent_dis.load_state_dict(checkpoint['opt_latent_dis']) self.opt_text_enc.load_state_dict(checkpoint['opt_text_enc']) self.opt_seq_post.load_state_dict(checkpoint['opt_seq_post']) self.opt_att_layer.load_state_dict(checkpoint['opt_att_layer']) self.opt_seq_pri.load_state_dict(checkpoint['opt_seq_pri']) self.opt_seq_dec.load_state_dict(checkpoint['opt_seq_dec']) self.opt_mov_dec.load_state_dict(checkpoint['opt_mov_dec']) self.text_enc.load_state_dict(checkpoint['text_enc']) self.mov_dec.load_state_dict(checkpoint['mov_dec']) self.seq_pri.load_state_dict(checkpoint['seq_pri']) self.att_layer.load_state_dict(checkpoint['att_layer']) self.seq_dec.load_state_dict(checkpoint['seq_dec']) self.mov_enc.load_state_dict(checkpoint['mov_enc']) return checkpoint['ep'], checkpoint['total_it'], checkpoint['sub_ep'], checkpoint['sl_len'] def train(self, train_dataset, val_dataset, plot_eval): self.to(self.device) self.opt_text_enc = optim.Adam(self.text_enc.parameters(), lr=self.opt.lr) self.opt_seq_post = optim.Adam(self.seq_post.parameters(), lr=self.opt.lr) self.opt_seq_pri = optim.Adam(self.seq_pri.parameters(), lr=self.opt.lr) self.opt_att_layer = optim.Adam(self.att_layer.parameters(), lr=self.opt.lr) self.opt_seq_dec = optim.Adam(self.seq_dec.parameters(), lr=self.opt.lr) self.opt_mov_dec = optim.Adam(self.mov_dec.parameters(), lr=self.opt.lr*0.1) epoch = 0 it = 0 if self.opt.dataset_name == 't2m': schedule_len = 10 elif self.opt.dataset_name == 'kit': schedule_len = 6 sub_ep = 0 if self.opt.is_continue: model_dir = pjoin(self.opt.model_dir, 'latest.tar') epoch, it, sub_ep, schedule_len = self.load(model_dir) invalid = True start_time = time.time() val_loss = 0 is_continue_and_first = self.opt.is_continue while invalid: train_dataset.reset_max_len(schedule_len * self.opt.unit_length) val_dataset.reset_max_len(schedule_len * self.opt.unit_length) train_loader = DataLoader(train_dataset, batch_size=self.opt.batch_size, drop_last=True, num_workers=4, shuffle=True, collate_fn=collate_fn, pin_memory=True) val_loader = DataLoader(val_dataset, batch_size=self.opt.batch_size, drop_last=True, num_workers=4, shuffle=True, collate_fn=collate_fn, pin_memory=True) print("Max_Length:%03d Training Split:%05d Validation Split:%04d" % (schedule_len, len(train_loader), len(val_loader))) min_val_loss = np.inf stop_cnt = 0 logs = OrderedDict() for sub_epoch in range(sub_ep, self.opt.max_sub_epoch): self.train_mode() if is_continue_and_first: sub_ep = 0 is_continue_and_first = False tf_ratio = self.opt.tf_ratio time1 = time.time() for i, batch_data in enumerate(train_loader): time2 = time.time() self.forward(batch_data, tf_ratio, schedule_len) time3 = time.time() log_dict = self.update() for k, v in log_dict.items(): if k not in logs: logs[k] = v else: logs[k] += v time4 = time.time() it += 1 if it % self.opt.log_every == 0: mean_loss = OrderedDict({'val_loss': val_loss}) self.logger.scalar_summary('val_loss', val_loss, it) self.logger.scalar_summary('scheduled_length', schedule_len, it) for tag, value in logs.items(): self.logger.scalar_summary(tag, value/self.opt.log_every, it) mean_loss[tag] = value / self.opt.log_every logs = OrderedDict() print_current_loss(start_time, it, mean_loss, epoch, sub_epoch=sub_epoch, inner_iter=i, tf_ratio=tf_ratio, sl_steps=schedule_len) if it % self.opt.save_latest == 0: self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it, sub_epoch, schedule_len) time5 = time.time() # print("Data Loader Time: %5f s" % ((time2 - time1))) # print("Forward Time: %5f s" % ((time3 - time2))) # print("Update Time: %5f s" % ((time4 - time3))) # print('Per Iteration: %5f s' % ((time5 - time1))) time1 = time5 self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it, sub_epoch, schedule_len) epoch += 1 if epoch % self.opt.save_every_e == 0: self.save(pjoin(self.opt.model_dir, 'E%03d_SE%02d_SL%02d.tar'%(epoch, sub_epoch, schedule_len)), epoch, total_it=it, sub_ep=sub_epoch, sl_len=schedule_len) print('Validation time:') loss_mot_rec = 0 loss_mov_rec = 0 loss_kld = 0 val_loss = 0 with torch.no_grad(): for i, batch_data in enumerate(val_loader): self.forward(batch_data, 0, schedule_len) self.backward_G() loss_mot_rec += self.loss_mot_rec.item() loss_mov_rec += self.loss_mov_rec.item() loss_kld += self.loss_kld.item() val_loss += self.loss_gen.item() loss_mot_rec /= len(val_loader) + 1 loss_mov_rec /= len(val_loader) + 1 loss_kld /= len(val_loader) + 1 val_loss /= len(val_loader) + 1 print('Validation Loss: %.5f Movement Recon Loss: %.5f Motion Recon Loss: %.5f KLD Loss: %.5f:' % (val_loss, loss_mov_rec, loss_mot_rec, loss_kld)) if epoch % self.opt.eval_every_e == 0: reco_data = self.fake_motions[:4] with torch.no_grad(): self.forward(batch_data, 0, schedule_len, eval_mode=True) fake_data = self.fake_motions[:4] gt_data = self.motions[:4] data = torch.cat([fake_data, reco_data, gt_data], dim=0).cpu().numpy() captions = self.caption[:4] * 3 save_dir = pjoin(self.opt.eval_dir, 'E%03d_SE%02d_SL%02d'%(epoch, sub_epoch, schedule_len)) os.makedirs(save_dir, exist_ok=True) plot_eval(data, save_dir, captions) # if cl_ratio == 1: if val_loss < min_val_loss: min_val_loss = val_loss stop_cnt = 0 elif stop_cnt < self.opt.early_stop_count: stop_cnt += 1 elif stop_cnt >= self.opt.early_stop_count: break if val_loss - min_val_loss >= 0.1: break schedule_len += 1 if schedule_len > 49: invalid = False class LengthEstTrainer(object): def __init__(self, args, estimator): self.opt = args self.estimator = estimator self.device = args.device if args.is_train: # self.motion_dis self.logger = Logger(args.log_dir) self.mul_cls_criterion = torch.nn.CrossEntropyLoss() def resume(self, model_dir): checkpoints = torch.load(model_dir, map_location=self.device) self.estimator.load_state_dict(checkpoints['estimator']) self.opt_estimator.load_state_dict(checkpoints['opt_estimator']) return checkpoints['epoch'], checkpoints['iter'] def save(self, model_dir, epoch, niter): state = { 'estimator': self.estimator.state_dict(), 'opt_estimator': self.opt_estimator.state_dict(), 'epoch': epoch, 'niter': niter, } torch.save(state, model_dir) @staticmethod def zero_grad(opt_list): for opt in opt_list: opt.zero_grad() @staticmethod def clip_norm(network_list): for network in network_list: clip_grad_norm_(network.parameters(), 0.5) @staticmethod def step(opt_list): for opt in opt_list: opt.step() def train(self, train_dataloader, val_dataloader): self.estimator.to(self.device) self.opt_estimator = optim.Adam(self.estimator.parameters(), lr=self.opt.lr) epoch = 0 it = 0 if self.opt.is_continue: model_dir = pjoin(self.opt.model_dir, 'latest.tar') epoch, it = self.resume(model_dir) start_time = time.time() total_iters = self.opt.max_epoch * len(train_dataloader) print('Iters Per Epoch, Training: %04d, Validation: %03d' % (len(train_dataloader), len(val_dataloader))) val_loss = 0 min_val_loss = np.inf logs = OrderedDict({'loss': 0}) while epoch < self.opt.max_epoch: # time0 = time.time() for i, batch_data in enumerate(train_dataloader): self.estimator.train() word_emb, pos_ohot, _, cap_lens, _, m_lens = batch_data word_emb = word_emb.detach().to(self.device).float() pos_ohot = pos_ohot.detach().to(self.device).float() pred_dis = self.estimator(word_emb, pos_ohot, cap_lens) self.zero_grad([self.opt_estimator]) gt_labels = m_lens // self.opt.unit_length gt_labels = gt_labels.long().to(self.device) # print(gt_labels) # print(pred_dis) loss = self.mul_cls_criterion(pred_dis, gt_labels) loss.backward() self.clip_norm([self.estimator]) self.step([self.opt_estimator]) logs['loss'] += loss.item() it += 1 if it % self.opt.log_every == 0: mean_loss = OrderedDict({'val_loss': val_loss}) self.logger.scalar_summary('val_loss', val_loss, it) for tag, value in logs.items(): self.logger.scalar_summary(tag, value / self.opt.log_every, it) mean_loss[tag] = value / self.opt.log_every logs = OrderedDict({'loss': 0}) print_current_loss_decomp(start_time, it, total_iters, mean_loss, epoch, i) if it % self.opt.save_latest == 0: self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it) self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it) epoch += 1 if epoch % self.opt.save_every_e == 0: self.save(pjoin(self.opt.model_dir, 'E%04d.tar' % (epoch)), epoch, it) print('Validation time:') val_loss = 0 with torch.no_grad(): for i, batch_data in enumerate(val_dataloader): word_emb, pos_ohot, _, cap_lens, _, m_lens = batch_data word_emb = word_emb.detach().to(self.device).float() pos_ohot = pos_ohot.detach().to(self.device).float() pred_dis = self.estimator(word_emb, pos_ohot, cap_lens) gt_labels = m_lens // self.opt.unit_length gt_labels = gt_labels.long().to(self.device) loss = self.mul_cls_criterion(pred_dis, gt_labels) val_loss += loss.item() val_loss = val_loss / (len(val_dataloader) + 1) print('Validation Loss: %.5f' % (val_loss)) if val_loss < min_val_loss: self.save(pjoin(self.opt.model_dir, 'finest.tar'), epoch, it) min_val_loss = val_loss class TextMotionMatchTrainer(object): def __init__(self, args, text_encoder, motion_encoder, movement_encoder): self.opt = args self.text_encoder = text_encoder self.motion_encoder = motion_encoder self.movement_encoder = movement_encoder self.device = args.device if args.is_train: # self.motion_dis self.logger = Logger(args.log_dir) self.contrastive_loss = ContrastiveLoss(self.opt.negative_margin) def resume(self, model_dir): checkpoints = torch.load(model_dir, map_location=self.device) self.text_encoder.load_state_dict(checkpoints['text_encoder']) self.motion_encoder.load_state_dict(checkpoints['motion_encoder']) self.movement_encoder.load_state_dict(checkpoints['movement_encoder']) self.opt_text_encoder.load_state_dict(checkpoints['opt_text_encoder']) self.opt_motion_encoder.load_state_dict(checkpoints['opt_motion_encoder']) return checkpoints['epoch'], checkpoints['iter'] def save(self, model_dir, epoch, niter): state = { 'text_encoder': self.text_encoder.state_dict(), 'motion_encoder': self.motion_encoder.state_dict(), 'movement_encoder': self.movement_encoder.state_dict(), 'opt_text_encoder': self.opt_text_encoder.state_dict(), 'opt_motion_encoder': self.opt_motion_encoder.state_dict(), 'epoch': epoch, 'iter': niter, } torch.save(state, model_dir) @staticmethod def zero_grad(opt_list): for opt in opt_list: opt.zero_grad() @staticmethod def clip_norm(network_list): for network in network_list: clip_grad_norm_(network.parameters(), 0.5) @staticmethod def step(opt_list): for opt in opt_list: opt.step() def to(self, device): self.text_encoder.to(device) self.motion_encoder.to(device) self.movement_encoder.to(device) def train_mode(self): self.text_encoder.train() self.motion_encoder.train() self.movement_encoder.eval() def forward(self, batch_data): word_emb, pos_ohot, caption, cap_lens, motions, m_lens, _ = batch_data word_emb = word_emb.detach().to(self.device).float() pos_ohot = pos_ohot.detach().to(self.device).float() motions = motions.detach().to(self.device).float() # Sort the length of motions in descending order, (length of text has been sorted) self.align_idx = np.argsort(m_lens.data.tolist())[::-1].copy() # print(self.align_idx) # print(m_lens[self.align_idx]) motions = motions[self.align_idx] m_lens = m_lens[self.align_idx] '''Movement Encoding''' movements = self.movement_encoder(motions[..., :-4]).detach() m_lens = m_lens // self.opt.unit_length self.motion_embedding = self.motion_encoder(movements, m_lens) '''Text Encoding''' # time0 = time.time() # text_input = torch.cat([word_emb, pos_ohot], dim=-1) self.text_embedding = self.text_encoder(word_emb, pos_ohot, cap_lens) self.text_embedding = self.text_embedding.clone()[self.align_idx] def backward(self): batch_size = self.text_embedding.shape[0] '''Positive pairs''' pos_labels = torch.zeros(batch_size).to(self.text_embedding.device) self.loss_pos = self.contrastive_loss(self.text_embedding, self.motion_embedding, pos_labels) '''Negative Pairs, shifting index''' neg_labels = torch.ones(batch_size).to(self.text_embedding.device) shift = np.random.randint(0, batch_size-1) new_idx = np.arange(shift, batch_size + shift) % batch_size self.mis_motion_embedding = self.motion_embedding.clone()[new_idx] self.loss_neg = self.contrastive_loss(self.text_embedding, self.mis_motion_embedding, neg_labels) self.loss = self.loss_pos + self.loss_neg loss_logs = OrderedDict({}) loss_logs['loss'] = self.loss.item() loss_logs['loss_pos'] = self.loss_pos.item() loss_logs['loss_neg'] = self.loss_neg.item() return loss_logs def update(self): self.zero_grad([self.opt_motion_encoder, self.opt_text_encoder]) loss_logs = self.backward() self.loss.backward() self.clip_norm([self.text_encoder, self.motion_encoder]) self.step([self.opt_text_encoder, self.opt_motion_encoder]) return loss_logs def train(self, train_dataloader, val_dataloader): self.to(self.device) self.opt_motion_encoder = optim.Adam(self.motion_encoder.parameters(), lr=self.opt.lr) self.opt_text_encoder = optim.Adam(self.text_encoder.parameters(), lr=self.opt.lr) epoch = 0 it = 0 if self.opt.is_continue: model_dir = pjoin(self.opt.model_dir, 'latest.tar') epoch, it = self.resume(model_dir) start_time = time.time() total_iters = self.opt.max_epoch * len(train_dataloader) print('Iters Per Epoch, Training: %04d, Validation: %03d' % (len(train_dataloader), len(val_dataloader))) val_loss = 0 logs = OrderedDict() min_val_loss = np.inf while epoch < self.opt.max_epoch: # time0 = time.time() for i, batch_data in enumerate(train_dataloader): self.train_mode() self.forward(batch_data) # time3 = time.time() log_dict = self.update() for k, v in log_dict.items(): if k not in logs: logs[k] = v else: logs[k] += v it += 1 if it % self.opt.log_every == 0: mean_loss = OrderedDict({'val_loss': val_loss}) self.logger.scalar_summary('val_loss', val_loss, it) for tag, value in logs.items(): self.logger.scalar_summary(tag, value / self.opt.log_every, it) mean_loss[tag] = value / self.opt.log_every logs = OrderedDict() print_current_loss_decomp(start_time, it, total_iters, mean_loss, epoch, i) if it % self.opt.save_latest == 0: self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it) self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it) epoch += 1 if epoch % self.opt.save_every_e == 0: self.save(pjoin(self.opt.model_dir, 'E%04d.tar' % (epoch)), epoch, it) print('Validation time:') loss_pos_pair = 0 loss_neg_pair = 0 val_loss = 0 with torch.no_grad(): for i, batch_data in enumerate(val_dataloader): self.forward(batch_data) self.backward() loss_pos_pair += self.loss_pos.item() loss_neg_pair += self.loss_neg.item() val_loss += self.loss.item() loss_pos_pair /= len(val_dataloader) + 1 loss_neg_pair /= len(val_dataloader) + 1 val_loss /= len(val_dataloader) + 1 print('Validation Loss: %.5f Positive Loss: %.5f Negative Loss: %.5f' % (val_loss, loss_pos_pair, loss_neg_pair)) if val_loss < min_val_loss: self.save(pjoin(self.opt.model_dir, 'finest.tar'), epoch, it) min_val_loss = val_loss if epoch % self.opt.eval_every_e == 0: pos_dist = F.pairwise_distance(self.text_embedding, self.motion_embedding) neg_dist = F.pairwise_distance(self.text_embedding, self.mis_motion_embedding) pos_str = ' '.join(['%.3f' % (pos_dist[i]) for i in range(pos_dist.shape[0])]) neg_str = ' '.join(['%.3f' % (neg_dist[i]) for i in range(neg_dist.shape[0])]) save_path = pjoin(self.opt.eval_dir, 'E%03d.txt' % (epoch)) with cs.open(save_path, 'w') as f: f.write('Positive Pairs Distance\n') f.write(pos_str + '\n') f.write('Negative Pairs Distance\n') f.write(neg_str + '\n')