mae / loss /mae_loss.py
adelelsayed1991's picture
Upload folder using huggingface_hub
5ffe2e2 verified
import torch
import torch.nn as nn
def mae_loss(pred, target, mask):
# pred/target: (B, N, P), mask: (B, N) with 1=masked
B, N, P = pred.shape
mask = mask.unsqueeze(-1).float() # (B, N, 1)
loss = (pred - target) ** 2
loss = (loss * mask).sum() / mask.sum().clamp_min(1.0)
return loss