from functools import reduce from pathlib import Path import torch import torch.nn.functional as F class NullableArgs: def __init__(self, namespace): for key, value in namespace.__dict__.items(): setattr(self, key, value) def __getattr__(self, key): # when an attribute lookup has not found the attribute if key == 'align_mask_width': if 'use_alignment_mask' in self.__dict__: return 1 if self.use_alignment_mask else 0 else: return 0 if key == 'no_head_pose': return not self.predict_head_pose if key == 'no_use_learnable_pe': return not self.use_learnable_pe return None def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) def get_option_text(args, parser): message = '' for k, v in sorted(vars(args).items()): comment = '' default = parser.get_default(k) if v != default: comment = f'\t[default: {str(default)}]' message += f'{str(k):>30}: {str(v):<30}{comment}\n' return message def get_model_path(exp_name, iteration, model_type='DPT'): exp_root_dir = Path(__file__).parent.parent / 'experiments' / model_type exp_dir = exp_root_dir / exp_name if not exp_dir.exists(): exp_dir = next(exp_root_dir.glob(f'{exp_name}*')) model_path = exp_dir / f'checkpoints/iter_{iteration:07}.pt' return model_path, exp_dir.relative_to(exp_root_dir) def get_pose_input(coef_dict, rot_repr, with_global_pose): if rot_repr == 'aa': pose_input = coef_dict['pose'] if with_global_pose else coef_dict['pose'][..., 63:70] # Remove mouth rotation round y, z axis # pose_input = pose_input[..., :-2] else: raise ValueError(f'Unknown rotation representation: {rot_repr}') return pose_input def get_motion_coef(coef_dict, rot_repr, with_global_pose=False, norm_stats=None): if norm_stats is not None: if rot_repr == 'aa': keys = ['exp', 'pose'] coef_dict = {k: (coef_dict[k] - norm_stats[f'{k}_mean']) / norm_stats[f'{k}_std'] for k in keys} pose_coef = get_pose_input(coef_dict, rot_repr, with_global_pose) return torch.cat([coef_dict['exp'], pose_coef], dim=-1) elif rot_repr == 'emo': print(f"coef_dict.keys(): {coef_dict.keys()}") keys = ['exp', 'pose', 'emotion'] return torch.cat([coef_dict['exp'], coef_dict["pose"], coef_dict["emotion"]], dim=-1) # return torch.cat([[coef_dict[key]] for key in keys], dim=-1) else: raise ValueError(f'Unknown rotation representation {rot_repr}!') else: if rot_repr == 'aa': keys = ['exp', 'pose'] pose_coef = get_pose_input(coef_dict, rot_repr, with_global_pose) return torch.cat([coef_dict['exp'], pose_coef], dim=-1) elif rot_repr == 'emo': # print(f"coef_dict.keys(): {coef_dict.keys()}") keys = ['exp', 'pose', 'emotion'] return torch.cat([coef_dict['exp'], coef_dict["pose"], coef_dict["emotion"]], dim=-1) # return torch.cat([[coef_dict[key]] for key in keys], dim=-1) else: raise ValueError(f'Unknown rotation representation {rot_repr}!') def get_coef_dict(motion_coef, shape_coef=None, denorm_stats=None, with_global_pose=False, rot_repr='aa'): coef_dict = { 'exp': motion_coef[..., :63] } if rot_repr == 'aa': if with_global_pose: coef_dict['pose'] = motion_coef[..., 63:] else: placeholder = torch.zeros_like(motion_coef[..., :3]) coef_dict['pose'] = torch.cat([placeholder, motion_coef[..., -1:]], dim=-1) # Add back rotation around y, z axis # coef_dict['pose'] = torch.cat([coef_dict['pose'], torch.zeros_like(motion_coef[..., :2])], dim=-1) else: raise ValueError(f'Unknown rotation representation {rot_repr}!') if denorm_stats is not None: coef_dict = {k: coef_dict[k] * denorm_stats[f'{k}_std'] + denorm_stats[f'{k}_mean'] for k in coef_dict} if not with_global_pose: if rot_repr == 'aa': coef_dict['pose'][..., :3] = 0 else: raise ValueError(f'Unknown rotation representation {rot_repr}!') return coef_dict def coef_dict_to_vertices(coef_dict, flame, rot_repr='aa', ignore_global_rot=False, flame_batch_size=512): shape = coef_dict['exp'].shape[:-1] coef_dict = {k: v.view(-1, v.shape[-1]) for k, v in coef_dict.items()} n_samples = reduce(lambda x, y: x * y, shape, 1) # Convert to vertices vert_list = [] for i in range(0, n_samples, flame_batch_size): batch_coef_dict = {k: v[i:i + flame_batch_size] for k, v in coef_dict.items()} if rot_repr == 'aa': vert, _, _ = flame( batch_coef_dict['shape'], batch_coef_dict['exp'], batch_coef_dict['pose'], pose2rot=True, ignore_global_rot=ignore_global_rot, return_lm2d=False, return_lm3d=False) else: raise ValueError(f'Unknown rot_repr: {rot_repr}') vert_list.append(vert) vert_list = torch.cat(vert_list, dim=0) # (n_samples, 5023, 3) vert_list = vert_list.view(*shape, -1, 3) # (..., 5023, 3) return vert_list def _truncate_audio(audio, end_idx, pad_mode='zero'): batch_size = audio.shape[0] audio_trunc = audio.clone() if pad_mode == 'replicate': for i in range(batch_size): audio_trunc[i, end_idx[i]:] = audio_trunc[i, end_idx[i] - 1] elif pad_mode == 'zero': for i in range(batch_size): audio_trunc[i, end_idx[i]:] = 0 else: raise ValueError(f'Unknown pad mode {pad_mode}!') return audio_trunc def _truncate_coef_dict(coef_dict, end_idx, pad_mode='zero'): batch_size = coef_dict['exp'].shape[0] coef_dict_trunc = {k: v.clone() for k, v in coef_dict.items()} if pad_mode == 'replicate': for i in range(batch_size): for k in coef_dict_trunc: coef_dict_trunc[k][i, end_idx[i]:] = coef_dict_trunc[k][i, end_idx[i] - 1] elif pad_mode == 'zero': for i in range(batch_size): for k in coef_dict: coef_dict_trunc[k][i, end_idx[i]:] = 0 else: raise ValueError(f'Unknown pad mode: {pad_mode}!') return coef_dict_trunc def truncate_coef_dict_and_audio(audio, coef_dict, n_motions, audio_unit=640, pad_mode='zero'): batch_size = audio.shape[0] end_idx = torch.randint(1, n_motions, (batch_size,), device=audio.device) audio_end_idx = (end_idx * audio_unit).long() # mask = torch.arange(n_motions, device=audio.device).expand(batch_size, -1) < end_idx.unsqueeze(1) # truncate audio audio_trunc = _truncate_audio(audio, audio_end_idx, pad_mode=pad_mode) # truncate coef dict coef_dict_trunc = _truncate_coef_dict(coef_dict, end_idx, pad_mode=pad_mode) return audio_trunc, coef_dict_trunc, end_idx def truncate_motion_coef_and_audio(audio, motion_coef, n_motions, audio_unit=640, pad_mode='zero'): batch_size = audio.shape[0] end_idx = torch.randint(1, n_motions, (batch_size,), device=audio.device) audio_end_idx = (end_idx * audio_unit).long() # mask = torch.arange(n_motions, device=audio.device).expand(batch_size, -1) < end_idx.unsqueeze(1) # truncate audio audio_trunc = _truncate_audio(audio, audio_end_idx, pad_mode=pad_mode) # prepare coef dict and stats coef_dict = {'exp': motion_coef[..., :63], 'pose_any': motion_coef[..., 63:]} # truncate coef dict coef_dict_trunc = _truncate_coef_dict(coef_dict, end_idx, pad_mode=pad_mode) motion_coef_trunc = torch.cat([coef_dict_trunc['exp'], coef_dict_trunc['pose_any']], dim=-1) return audio_trunc, motion_coef_trunc, end_idx def nt_xent_loss(feature_a, feature_b, temperature): """ Normalized temperature-scaled cross entropy loss. (Adapted from https://github.com/sthalles/SimCLR/blob/master/simclr.py) Args: feature_a (torch.Tensor): shape (batch_size, feature_dim) feature_b (torch.Tensor): shape (batch_size, feature_dim) temperature (float): temperature scaling factor Returns: torch.Tensor: scalar """ batch_size = feature_a.shape[0] device = feature_a.device features = torch.cat([feature_a, feature_b], dim=0) labels = torch.cat([torch.arange(batch_size), torch.arange(batch_size)], dim=0) labels = (labels.unsqueeze(0) == labels.unsqueeze(1)) labels = labels.to(device) features = F.normalize(features, dim=1) similarity_matrix = torch.matmul(features, features.T) # discard the main diagonal from both: labels and similarities matrix mask = torch.eye(labels.shape[0], dtype=torch.bool).to(device) labels = labels[~mask].view(labels.shape[0], -1) similarity_matrix = similarity_matrix[~mask].view(labels.shape[0], -1) # select the positives and negatives positives = similarity_matrix[labels].view(labels.shape[0], -1) negatives = similarity_matrix[~labels].view(labels.shape[0], -1) logits = torch.cat([positives, negatives], dim=1) logits = logits / temperature labels = torch.zeros(labels.shape[0], dtype=torch.long).to(device) loss = F.cross_entropy(logits, labels) return loss def compute_loss_new(args, is_starting_sample, motion_coef_gt, noise, target, prev_motion_coef, end_idx=None): if args.criterion.lower() == 'l2': criterion_func = F.mse_loss elif args.criterion.lower() == 'l1': criterion_func = F.l1_loss else: raise NotImplementedError(f'Criterion {args.criterion} not implemented.') # 表情类损失 loss_exp = None loss_exp_vel = None loss_exp_smooth = None # 头部运动类 loss_head_angle = None loss_head_vel = None loss_head_smooth = None # Trans类 loss_head_trans_vel = None loss_head_trans_accel = None loss_head_trans = None if args.target == 'noise': loss_noise = criterion_func(noise, target[:, args.n_prev_motions:], reduction='none') elif args.target == 'sample': if is_starting_sample: target = target[:, args.n_prev_motions:] else: motion_coef_gt = torch.cat([prev_motion_coef, motion_coef_gt], dim=1) if args.no_constrain_prev: target = torch.cat([prev_motion_coef, target[:, args.n_prev_motions:]], dim=1) # print(f"loss, motion_coef_gt: {motion_coef_gt.shape}, target: {target.shape}") loss_noise = criterion_func(motion_coef_gt, target, reduction='none') # print("loss_noise: ", loss_noise) # 表情相关损失 if args.rot_repr == "aa": exp_gt = motion_coef_gt[:, :, :63] exp_pred = target[:, :, :63] elif args.rot_repr == "emo": exp_gt = torch.cat([motion_coef_gt[:, :, :63], motion_coef_gt[:, :, -3:]], -1) exp_pred = torch.cat([target[:, :, :63], target[:, :, -3:]], -1) else: raise ValueError(f'Unknown rotation representation {args.rot_repr}!') loss_exp = criterion_func(exp_gt, exp_pred, reduction='none') if args.l_exp_vel > 0: vel_exp_gt = exp_gt[:, 1:] - exp_gt[:, :-1] vel_exp_pred = exp_pred[:, 1:] - exp_pred[:, :-1] loss_exp_vel = criterion_func(vel_exp_gt, vel_exp_pred, reduction='none') if args.l_exp_smooth > 0: vel_exp_pred = exp_pred[:, 1:] - exp_pred[:, :-1] loss_exp_smooth = criterion_func(vel_exp_pred[:, 1:], vel_exp_pred[:, :-1], reduction='none') # 头部运动相关损失 if not args.no_head_pose: if args.rot_repr == 'aa': # 旋转表征,aa,单独计算exp和(R, t, delta) head_pose_gt = motion_coef_gt[:, :, 63:] head_pose_pred = target[:, :, 63:] elif args.rot_repr == 'emo': # 旋转表征,aa,单独计算exp和(R, t, delta) head_pose_gt = motion_coef_gt[:, :, 63:70] head_pose_pred = target[:, :, 63:70] else: raise ValueError(f'Unknown rotation representation {args.rot_repr}!') # angle, gt_pose和pred_pose之间的损失 if args.l_head_angle > 0: loss_head_angle = criterion_func(head_pose_gt, head_pose_pred, reduction='none') if args.l_head_vel > 0: # print("head_pose_gt: ", head_pose_gt.shape, head_pose_pred.shape) head_vel_gt = head_pose_gt[:, 1:] - head_pose_gt[:, :-1] head_vel_pred = head_pose_pred[:, 1:] - head_pose_pred[:, :-1] loss_head_vel = criterion_func(head_vel_gt, head_vel_pred, reduction='none') if args.l_head_smooth > 0: head_vel_pred = head_pose_pred[:, 1:] - head_pose_pred[:, :-1] loss_head_smooth = criterion_func(head_vel_pred[:, 1:], head_vel_pred[:, :-1], reduction='none') if not is_starting_sample and args.l_head_trans > 0: # # version 1: constrain both the predicted previous and current motions (x_{-3} ~ x_{2}) # head_pose_trans = head_pose_pred[:, args.n_prev_motions - 3:args.n_prev_motions + 3] # head_vel_pred = head_pose_trans[:, 1:] - head_pose_trans[:, :-1] # head_accel_pred = head_vel_pred[:, 1:] - head_vel_pred[:, :-1] # version 2: constrain only the predicted current motions (x_{0} ~ x_{2}) head_pose_trans = torch.cat([head_pose_gt[:, args.n_prev_motions - 3:args.n_prev_motions], head_pose_pred[:, args.n_prev_motions:args.n_prev_motions + 3]], dim=1) head_vel_pred = head_pose_trans[:, 1:] - head_pose_trans[:, :-1] head_accel_pred = head_vel_pred[:, 1:] - head_vel_pred[:, :-1] # will constrain x_{-2|0} ~ x_{1} loss_head_trans_vel = criterion_func(head_vel_pred[:, 2:4], head_vel_pred[:, 1:3], reduction='none') # will constrain x_{-3|0} ~ x_{2} loss_head_trans_accel = criterion_func(head_accel_pred[:, 1:], head_accel_pred[:, :-1], reduction='none') else: raise ValueError(f'Unknown diffusion target: {args.target}') if end_idx is None: mask = torch.ones((target.shape[0], args.n_motions), dtype=torch.bool, device=target.device) else: mask = torch.arange(args.n_motions, device=target.device).expand(target.shape[0], -1) < end_idx.unsqueeze(1) if args.target == 'sample' and not is_starting_sample: if args.no_constrain_prev: # Warning: this option will be deprecated in the future mask = torch.cat([torch.zeros_like(mask[:, :args.n_prev_motions]), mask], dim=1) else: mask = torch.cat([torch.ones_like(mask[:, :args.n_prev_motions]), mask], dim=1) loss_noise = loss_noise[mask].mean() if loss_exp is not None: loss_exp = loss_exp[mask].mean() if loss_exp_vel is not None: loss_exp_vel = loss_exp_vel[mask[:, 1:]].mean() if loss_exp_smooth is not None: loss_exp_smooth = loss_exp_smooth[mask[:, 2:]].mean() if loss_head_angle is not None: loss_head_angle = loss_head_angle[mask].mean() if loss_head_vel is not None: loss_head_vel = loss_head_vel[mask[:, 1:]] loss_head_vel = loss_head_vel.mean() if torch.numel(loss_head_vel) > 0 else None if loss_head_smooth is not None: loss_head_smooth = loss_head_smooth[mask[:, 2:]] loss_head_smooth = loss_head_smooth.mean() if torch.numel(loss_head_smooth) > 0 else None if loss_head_trans_vel is not None: vel_mask = mask[:, args.n_prev_motions:args.n_prev_motions + 2] accel_mask = mask[:, args.n_prev_motions:args.n_prev_motions + 3] loss_head_trans_vel = loss_head_trans_vel[vel_mask].mean() loss_head_trans_accel = loss_head_trans_accel[accel_mask].mean() loss_head_trans = loss_head_trans_vel + loss_head_trans_accel return loss_noise, loss_exp, loss_exp_vel, loss_exp_smooth, loss_head_angle, loss_head_vel, loss_head_smooth, loss_head_trans