File size: 328 Bytes
5ffe2e2 |
1 2 3 4 5 6 7 8 9 10 |
import torch
import torch.nn as nn
def mae_loss(pred, target, mask):
# pred/target: (B, N, P), mask: (B, N) with 1=masked
B, N, P = pred.shape
mask = mask.unsqueeze(-1).float() # (B, N, 1)
loss = (pred - target) ** 2
loss = (loss * mask).sum() / mask.sum().clamp_min(1.0)
return loss |