import torch import torch.nn as nn def mae_loss(pred, target, mask): # pred/target: (B, N, P), mask: (B, N) with 1=masked B, N, P = pred.shape mask = mask.unsqueeze(-1).float() # (B, N, 1) loss = (pred - target) ** 2 loss = (loss * mask).sum() / mask.sum().clamp_min(1.0) return loss