| import torch | |
| import torch.nn as nn | |
| def mae_loss(pred, target, mask): | |
| # pred/target: (B, N, P), mask: (B, N) with 1=masked | |
| B, N, P = pred.shape | |
| mask = mask.unsqueeze(-1).float() # (B, N, 1) | |
| loss = (pred - target) ** 2 | |
| loss = (loss * mask).sum() / mask.sum().clamp_min(1.0) | |
| return loss |