import torch from train.losses import permutation_invariant_role_loss def test_role_loss_permutation(): logits = torch.tensor( [ [[0.1, 4.0, 0.0, -1.0], [0.1, 0.0, 4.0, -1.0]], [[0.1, 0.0, 4.0, -1.0], [0.1, 4.0, 0.0, -1.0]], ], dtype=torch.float32, ) loss_original = permutation_invariant_role_loss(logits) loss_swapped = permutation_invariant_role_loss(logits.flip(1)) torch.testing.assert_close(loss_original, loss_swapped)