mdm / utils /loss_util.py
hassanjbara's picture
update model
5007d4b
from diffusion.nn import sum_flat
def diff_l2(a, b):
return (a - b) ** 2
def masked_l2(a, b, mask, loss_fn=diff_l2, epsilon=1e-8, entries_norm=True):
# assuming a.shape == b.shape == bs, J, Jdim, seqlen
# assuming mask.shape == bs, 1, 1, seqlen
loss = loss_fn(a, b)
loss = sum_flat(
loss * mask.float()
) # gives \sigma_euclidean over unmasked elements
n_entries = a.shape[1]
if len(a.shape) > 3:
n_entries *= a.shape[2]
non_zero_elements = sum_flat(mask)
if entries_norm:
# In cases the mask is per frame, and not specifying the number of entries per frame, this normalization is needed,
# Otherwise set it to False
non_zero_elements *= n_entries
# print('mask', mask.shape)
# print('non_zero_elements', non_zero_elements)
# print('loss', loss)
mse_loss_val = loss / (
non_zero_elements + epsilon
) # Add epsilon to avoid division by zero
# print('mse_loss_val', mse_loss_val)
return mse_loss_val