File size: 328 Bytes
5ffe2e2
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
import torch
import torch.nn as nn

def mae_loss(pred, target, mask):
    # pred/target: (B, N, P), mask: (B, N) with 1=masked
    B, N, P = pred.shape
    mask = mask.unsqueeze(-1).float()            # (B, N, 1)
    loss = (pred - target) ** 2
    loss = (loss * mask).sum() / mask.sum().clamp_min(1.0)
    return loss