File size: 497 Bytes
16405f2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 | import torch
from train.losses import permutation_invariant_role_loss
def test_role_loss_permutation():
logits = torch.tensor(
[
[[0.1, 4.0, 0.0, -1.0], [0.1, 0.0, 4.0, -1.0]],
[[0.1, 0.0, 4.0, -1.0], [0.1, 4.0, 0.0, -1.0]],
],
dtype=torch.float32,
)
loss_original = permutation_invariant_role_loss(logits)
loss_swapped = permutation_invariant_role_loss(logits.flip(1))
torch.testing.assert_close(loss_original, loss_swapped)
|