File size: 257 Bytes
16405f2 | 1 2 3 4 5 6 7 8 9 10 | import torch
from train.losses import proposal_diversity_loss
def test_proposal_diversity():
collapsed = torch.zeros(2, 4, 2, 14)
diverse = torch.randn(2, 4, 2, 14)
assert proposal_diversity_loss(collapsed) > proposal_diversity_loss(diverse)
|