Spaces:
Build error
Build error
| from diffusion.nn import sum_flat | |
| def diff_l2(a, b): | |
| return (a - b) ** 2 | |
| def masked_l2(a, b, mask, loss_fn=diff_l2, epsilon=1e-8, entries_norm=True): | |
| # assuming a.shape == b.shape == bs, J, Jdim, seqlen | |
| # assuming mask.shape == bs, 1, 1, seqlen | |
| loss = loss_fn(a, b) | |
| loss = sum_flat( | |
| loss * mask.float() | |
| ) # gives \sigma_euclidean over unmasked elements | |
| n_entries = a.shape[1] | |
| if len(a.shape) > 3: | |
| n_entries *= a.shape[2] | |
| non_zero_elements = sum_flat(mask) | |
| if entries_norm: | |
| # In cases the mask is per frame, and not specifying the number of entries per frame, this normalization is needed, | |
| # Otherwise set it to False | |
| non_zero_elements *= n_entries | |
| # print('mask', mask.shape) | |
| # print('non_zero_elements', non_zero_elements) | |
| # print('loss', loss) | |
| mse_loss_val = loss / ( | |
| non_zero_elements + epsilon | |
| ) # Add epsilon to avoid division by zero | |
| # print('mse_loss_val', mse_loss_val) | |
| return mse_loss_val | |