| import torch | |
| from train.losses import proposal_diversity_loss | |
| def test_proposal_diversity(): | |
| collapsed = torch.zeros(2, 4, 2, 14) | |
| diverse = torch.randn(2, 4, 2, 14) | |
| assert proposal_diversity_loss(collapsed) > proposal_diversity_loss(diverse) | |
| import torch | |
| from train.losses import proposal_diversity_loss | |
| def test_proposal_diversity(): | |
| collapsed = torch.zeros(2, 4, 2, 14) | |
| diverse = torch.randn(2, 4, 2, 14) | |
| assert proposal_diversity_loss(collapsed) > proposal_diversity_loss(diverse) | |