VLAarchtests / tests /test_role_loss_permutation.py
lsnu's picture
Add files using upload-large-folder tool
16405f2 verified
import torch
from train.losses import permutation_invariant_role_loss
def test_role_loss_permutation():
logits = torch.tensor(
[
[[0.1, 4.0, 0.0, -1.0], [0.1, 0.0, 4.0, -1.0]],
[[0.1, 0.0, 4.0, -1.0], [0.1, 4.0, 0.0, -1.0]],
],
dtype=torch.float32,
)
loss_original = permutation_invariant_role_loss(logits)
loss_swapped = permutation_invariant_role_loss(logits.flip(1))
torch.testing.assert_close(loss_original, loss_swapped)