Spaces:
Sleeping
Sleeping
| from diffusion.nn import mean_flat, sum_flat | |
| import torch | |
| import numpy as np | |
| def angle_l2(angle1, angle2): | |
| a = angle1 - angle2 | |
| a = (a + (torch.pi/2)) % torch.pi - (torch.pi/2) | |
| return a ** 2 | |
| def diff_l2(a, b): | |
| return (a - b) ** 2 | |
| def masked_l2(a, b, mask, loss_fn=diff_l2, epsilon=1e-8, entries_norm=True): | |
| # assuming a.shape == b.shape == bs, J, Jdim, seqlen | |
| # assuming mask.shape == bs, 1, 1, seqlen | |
| loss = loss_fn(a, b) | |
| loss = sum_flat(loss * mask.float()) # gives \sigma_euclidean over unmasked elements | |
| n_entries = a.shape[1] | |
| if len(a.shape) > 3: | |
| n_entries *= a.shape[2] | |
| non_zero_elements = sum_flat(mask) | |
| if entries_norm: | |
| # In cases the mask is per frame, and not specifying the number of entries per frame, this normalization is needed, | |
| # Otherwise set it to False | |
| non_zero_elements *= n_entries | |
| # print('mask', mask.shape) | |
| # print('non_zero_elements', non_zero_elements) | |
| # print('loss', loss) | |
| mse_loss_val = loss / (non_zero_elements + epsilon) # Add epsilon to avoid division by zero | |
| # print('mse_loss_val', mse_loss_val) | |
| return mse_loss_val | |
| def masked_goal_l2(pred_goal, ref_goal, cond, all_goal_joint_names): | |
| all_goal_joint_names_w_traj = np.append(all_goal_joint_names, 'traj') | |
| target_joint_idx = [[np.where(all_goal_joint_names_w_traj == j)[0][0] for j in sample_joints] for sample_joints in cond['target_joint_names']] | |
| loc_mask = torch.zeros_like(pred_goal[:,:-1], dtype=torch.bool) | |
| for sample_idx in range(loc_mask.shape[0]): | |
| loc_mask[sample_idx, target_joint_idx[sample_idx]] = True | |
| loc_mask[:, -1, 1] = False # vertical joint of 'traj' is always masked out | |
| loc_loss = masked_l2(pred_goal[:,:-1], ref_goal[:,:-1], loc_mask, entries_norm=False) | |
| heading_loss = masked_l2(pred_goal[:,-1:, :1], ref_goal[:,-1:, :1], cond['is_heading'].unsqueeze(1).unsqueeze(1), loss_fn=angle_l2, entries_norm=False) | |
| loss = loc_loss + heading_loss | |
| return loss | |