File size: 497 Bytes
16405f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch

from train.losses import permutation_invariant_role_loss


def test_role_loss_permutation():
    logits = torch.tensor(
        [
            [[0.1, 4.0, 0.0, -1.0], [0.1, 0.0, 4.0, -1.0]],
            [[0.1, 0.0, 4.0, -1.0], [0.1, 4.0, 0.0, -1.0]],
        ],
        dtype=torch.float32,
    )
    loss_original = permutation_invariant_role_loss(logits)
    loss_swapped = permutation_invariant_role_loss(logits.flip(1))
    torch.testing.assert_close(loss_original, loss_swapped)