| """Unit tests for loss functions.""" |
|
|
| import pytest |
| import torch |
|
|
| from src.loss import cornernet_focal_loss, offset_loss, total_loss |
|
|
|
|
| class TestCornerNetFocalLoss: |
| def test_perfect_prediction_zero_loss(self): |
| """Perfect predictions should produce near-zero loss.""" |
| gt = torch.zeros(1, 2, 64, 64) |
| gt[0, 0, 32, 32] = 1.0 |
|
|
| |
| pred = torch.zeros(1, 2, 64, 64) + 1e-6 |
| pred[0, 0, 32, 32] = 1.0 - 1e-6 |
|
|
| loss = cornernet_focal_loss(pred, gt) |
| assert loss.item() < 0.1 |
|
|
| def test_all_zeros_prediction_nonzero_loss(self): |
| """Predicting all zeros when particles exist should give positive loss.""" |
| gt = torch.zeros(1, 2, 64, 64) |
| gt[0, 0, 32, 32] = 1.0 |
|
|
| pred = torch.zeros(1, 2, 64, 64) + 1e-6 |
| loss = cornernet_focal_loss(pred, gt) |
| assert loss.item() > 0 |
|
|
| def test_high_false_positive_penalized(self): |
| """Predicting high confidence where GT is zero should be penalized.""" |
| gt = torch.zeros(1, 2, 64, 64) |
| pred_low_fp = torch.zeros(1, 2, 64, 64) + 0.01 |
| pred_high_fp = torch.zeros(1, 2, 64, 64) + 0.9 |
|
|
| loss_low = cornernet_focal_loss(pred_low_fp, gt) |
| loss_high = cornernet_focal_loss(pred_high_fp, gt) |
|
|
| assert loss_high.item() > loss_low.item() |
|
|
| def test_near_peak_reduced_penalty(self): |
| """Pixels near GT peaks should have reduced negative penalty via beta term.""" |
| gt = torch.zeros(1, 2, 64, 64) |
| gt[0, 0, 32, 32] = 1.0 |
| gt[0, 0, 31, 32] = 0.8 |
|
|
| |
| pred = torch.zeros(1, 2, 64, 64) + 0.01 |
| pred[0, 0, 31, 32] = 0.5 |
|
|
| loss = cornernet_focal_loss(pred, gt) |
| |
| assert loss.item() < 10 |
|
|
| def test_confidence_weighting(self): |
| """Confidence weights should scale the loss.""" |
| gt = torch.zeros(1, 2, 64, 64) |
| gt[0, 0, 32, 32] = 1.0 |
| pred = torch.zeros(1, 2, 64, 64) + 0.5 |
|
|
| weights_full = torch.ones(1, 2, 64, 64) |
| weights_half = torch.ones(1, 2, 64, 64) * 0.5 |
|
|
| loss_full = cornernet_focal_loss(pred, gt, conf_weights=weights_full) |
| loss_half = cornernet_focal_loss(pred, gt, conf_weights=weights_half) |
|
|
| |
| assert loss_half.item() < loss_full.item() |
|
|
|
|
| class TestOffsetLoss: |
| def test_zero_when_no_particles(self): |
| """Offset loss should be zero when mask is empty.""" |
| pred = torch.randn(1, 2, 64, 64) |
| gt = torch.zeros(1, 2, 64, 64) |
| mask = torch.zeros(1, 64, 64, dtype=torch.bool) |
|
|
| loss = offset_loss(pred, gt, mask) |
| assert loss.item() == 0.0 |
|
|
| def test_nonzero_with_particles(self): |
| """Offset loss should be nonzero when predictions differ from GT.""" |
| pred = torch.randn(1, 2, 64, 64) |
| gt = torch.zeros(1, 2, 64, 64) |
| mask = torch.zeros(1, 64, 64, dtype=torch.bool) |
| mask[0, 32, 32] = True |
|
|
| loss = offset_loss(pred, gt, mask) |
| assert loss.item() > 0 |
|
|
|
|
| class TestTotalLoss: |
| def test_returns_three_values(self): |
| """total_loss should return (total, hm_loss, off_loss).""" |
| hm_pred = torch.sigmoid(torch.randn(1, 2, 64, 64)) |
| hm_gt = torch.zeros(1, 2, 64, 64) |
| off_pred = torch.randn(1, 2, 64, 64) |
| off_gt = torch.zeros(1, 2, 64, 64) |
| mask = torch.zeros(1, 64, 64, dtype=torch.bool) |
|
|
| total, hm_val, off_val = total_loss( |
| hm_pred, hm_gt, off_pred, off_gt, mask, |
| ) |
|
|
| assert isinstance(total, torch.Tensor) |
| assert isinstance(hm_val, float) |
| assert isinstance(off_val, float) |
| assert total.requires_grad |
|
|