| import torch | |
| from train.losses import permutation_invariant_role_loss | |
| def test_role_loss_permutation(): | |
| logits = torch.tensor( | |
| [ | |
| [[0.1, 4.0, 0.0, -1.0], [0.1, 0.0, 4.0, -1.0]], | |
| [[0.1, 0.0, 4.0, -1.0], [0.1, 4.0, 0.0, -1.0]], | |
| ], | |
| dtype=torch.float32, | |
| ) | |
| loss_original = permutation_invariant_role_loss(logits) | |
| loss_swapped = permutation_invariant_role_loss(logits.flip(1)) | |
| torch.testing.assert_close(loss_original, loss_swapped) | |