Spaces:
Sleeping
Sleeping
| import pytest | |
| import torch | |
| import segmentation_models_pytorch as smp | |
| import segmentation_models_pytorch.losses._functional as F | |
| from segmentation_models_pytorch.losses import ( | |
| DiceLoss, | |
| JaccardLoss, | |
| SoftBCEWithLogitsLoss, | |
| SoftCrossEntropyLoss, | |
| ) | |
| def test_focal_loss_with_logits(): | |
| input_good = torch.tensor([10, -10, 10]).float() | |
| input_bad = torch.tensor([-1, 2, 0]).float() | |
| target = torch.tensor([1, 0, 1]) | |
| loss_good = F.focal_loss_with_logits(input_good, target) | |
| loss_bad = F.focal_loss_with_logits(input_bad, target) | |
| assert loss_good < loss_bad | |
| def test_softmax_focal_loss_with_logits(): | |
| input_good = torch.tensor([[0, 10, 0], [10, 0, 0], [0, 0, 10]]).float() | |
| input_bad = torch.tensor([[0, -10, 0], [0, 10, 0], [0, 0, 10]]).float() | |
| target = torch.tensor([1, 0, 2]).long() | |
| loss_good = F.softmax_focal_loss_with_logits(input_good, target) | |
| loss_bad = F.softmax_focal_loss_with_logits(input_bad, target) | |
| assert loss_good < loss_bad | |
| def test_soft_jaccard_score(y_true, y_pred, expected, eps): | |
| y_true = torch.tensor(y_true, dtype=torch.float32) | |
| y_pred = torch.tensor(y_pred, dtype=torch.float32) | |
| actual = F.soft_jaccard_score(y_pred, y_true, eps=eps) | |
| assert float(actual) == pytest.approx(expected, eps) | |
| def test_soft_jaccard_score_2(y_true, y_pred, expected, eps): | |
| y_true = torch.tensor(y_true, dtype=torch.float32) | |
| y_pred = torch.tensor(y_pred, dtype=torch.float32) | |
| actual = F.soft_jaccard_score(y_pred, y_true, dims=[1], eps=eps) | |
| actual = actual.mean() | |
| assert float(actual) == pytest.approx(expected, eps) | |
| def test_soft_dice_score(y_true, y_pred, expected, eps): | |
| y_true = torch.tensor(y_true, dtype=torch.float32) | |
| y_pred = torch.tensor(y_pred, dtype=torch.float32) | |
| actual = F.soft_dice_score(y_pred, y_true, eps=eps) | |
| assert float(actual) == pytest.approx(expected, eps) | |
| def test_dice_loss_binary(): | |
| eps = 1e-5 | |
| criterion = DiceLoss(mode=smp.losses.BINARY_MODE, from_logits=False) | |
| # Ideal case | |
| y_pred = torch.tensor([1.0, 1.0, 1.0]).view(1, 1, 1, -1) | |
| y_true = torch.tensor(([1, 1, 1])).view(1, 1, 1, -1) | |
| loss = criterion(y_pred, y_true) | |
| assert float(loss) == pytest.approx(0.0, abs=eps) | |
| y_pred = torch.tensor([1.0, 0.0, 1.0]).view(1, 1, 1, -1) | |
| y_true = torch.tensor(([1, 0, 1])).view(1, 1, 1, -1) | |
| loss = criterion(y_pred, y_true) | |
| assert float(loss) == pytest.approx(0.0, abs=eps) | |
| y_pred = torch.tensor([0.0, 0.0, 0.0]).view(1, 1, 1, -1) | |
| y_true = torch.tensor(([0, 0, 0])).view(1, 1, 1, -1) | |
| loss = criterion(y_pred, y_true) | |
| assert float(loss) == pytest.approx(0.0, abs=eps) | |
| # Worst case | |
| y_pred = torch.tensor([1.0, 1.0, 1.0]).view(1, 1, -1) | |
| y_true = torch.tensor([0, 0, 0]).view(1, 1, 1, -1) | |
| loss = criterion(y_pred, y_true) | |
| assert float(loss) == pytest.approx(0.0, abs=eps) | |
| y_pred = torch.tensor([1.0, 0.0, 1.0]).view(1, 1, -1) | |
| y_true = torch.tensor([0, 1, 0]).view(1, 1, 1, -1) | |
| loss = criterion(y_pred, y_true) | |
| assert float(loss) == pytest.approx(1.0, abs=eps) | |
| y_pred = torch.tensor([0.0, 0.0, 0.0]).view(1, 1, -1) | |
| y_true = torch.tensor([1, 1, 1]).view(1, 1, 1, -1) | |
| loss = criterion(y_pred, y_true) | |
| assert float(loss) == pytest.approx(1.0, abs=eps) | |
| def test_binary_jaccard_loss(): | |
| eps = 1e-5 | |
| criterion = JaccardLoss(mode=smp.losses.BINARY_MODE, from_logits=False) | |
| # Ideal case | |
| y_pred = torch.tensor([1.0]).view(1, 1, 1, 1) | |
| y_true = torch.tensor(([1])).view(1, 1, 1, 1) | |
| loss = criterion(y_pred, y_true) | |
| assert float(loss) == pytest.approx(0.0, abs=eps) | |
| y_pred = torch.tensor([1.0, 0.0, 1.0]).view(1, 1, 1, -1) | |
| y_true = torch.tensor(([1, 0, 1])).view(1, 1, 1, -1) | |
| loss = criterion(y_pred, y_true) | |
| assert float(loss) == pytest.approx(0.0, abs=eps) | |
| y_pred = torch.tensor([0.0, 0.0, 0.0]).view(1, 1, 1, -1) | |
| y_true = torch.tensor(([0, 0, 0])).view(1, 1, 1, -1) | |
| loss = criterion(y_pred, y_true) | |
| assert float(loss) == pytest.approx(0.0, abs=eps) | |
| # Worst case | |
| y_pred = torch.tensor([1.0, 1.0, 1.0]).view(1, 1, -1) | |
| y_true = torch.tensor([0, 0, 0]).view(1, 1, 1, -1) | |
| loss = criterion(y_pred, y_true) | |
| assert float(loss) == pytest.approx(0.0, abs=eps) | |
| y_pred = torch.tensor([1.0, 0.0, 1.0]).view(1, 1, -1) | |
| y_true = torch.tensor([0, 1, 0]).view(1, 1, 1, -1) | |
| loss = criterion(y_pred, y_true) | |
| assert float(loss) == pytest.approx(1.0, eps) | |
| y_pred = torch.tensor([0.0, 0.0, 0.0]).view(1, 1, -1) | |
| y_true = torch.tensor([1, 1, 1]).view(1, 1, 1, -1) | |
| loss = criterion(y_pred, y_true) | |
| assert float(loss) == pytest.approx(1.0, eps) | |
| def test_multiclass_jaccard_loss(): | |
| eps = 1e-5 | |
| criterion = JaccardLoss(mode=smp.losses.MULTICLASS_MODE, from_logits=False) | |
| # Ideal case | |
| y_pred = torch.tensor([[[1.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 1.0]]]) | |
| y_true = torch.tensor([[0, 0, 1, 1]]) | |
| loss = criterion(y_pred, y_true) | |
| assert float(loss) == pytest.approx(0.0, abs=eps) | |
| # Worst case | |
| y_pred = torch.tensor([[[1.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 1.0]]]) | |
| y_true = torch.tensor([[1, 1, 0, 0]]) | |
| loss = criterion(y_pred, y_true) | |
| assert float(loss) == pytest.approx(1.0, abs=eps) | |
| # 1 - 1/3 case | |
| y_pred = torch.tensor([[[1.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 1.0]]]) | |
| y_true = torch.tensor([[1, 1, 0, 0]]) | |
| loss = criterion(y_pred, y_true) | |
| assert float(loss) == pytest.approx(1.0 - 1.0 / 3.0, abs=eps) | |
| def test_multilabel_jaccard_loss(): | |
| eps = 1e-5 | |
| criterion = JaccardLoss(mode=smp.losses.MULTILABEL_MODE, from_logits=False) | |
| # Ideal case | |
| y_pred = torch.tensor([[[1.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 1.0]]]) | |
| y_true = torch.tensor([[[1.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 1.0]]]) | |
| loss = criterion(y_pred, y_true) | |
| assert float(loss) == pytest.approx(0.0, abs=eps) | |
| # Worst case | |
| y_pred = torch.tensor([[[1.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 1.0]]]) | |
| y_true = 1 - y_pred | |
| loss = criterion(y_pred, y_true) | |
| assert float(loss) == pytest.approx(1.0, abs=eps) | |
| # 1 - 1/3 case | |
| y_pred = torch.tensor([[[0.0, 1.0, 1.0, 0.0], [0.0, 1.0, 1.0, 0.0]]]) | |
| y_true = torch.tensor([[[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]]]) | |
| loss = criterion(y_pred, y_true) | |
| assert float(loss) == pytest.approx(1.0 - 1.0 / 3.0, abs=eps) | |
| def test_soft_ce_loss(): | |
| criterion = SoftCrossEntropyLoss(smooth_factor=0.1, ignore_index=-100) | |
| # Ideal case | |
| y_pred = torch.tensor( | |
| [[+9, -9, -9, -9], [-9, +9, -9, -9], [-9, -9, +9, -9], [-9, -9, -9, +9]] | |
| ).float() | |
| y_true = torch.tensor([0, 1, -100, 3]).long() | |
| loss = criterion(y_pred, y_true) | |
| print(loss) | |
| def test_soft_bce_loss(): | |
| criterion = SoftBCEWithLogitsLoss(smooth_factor=0.1, ignore_index=-100) | |
| # Ideal case | |
| y_pred = torch.tensor([-9, 9, 1, 9, -9]).float() | |
| y_true = torch.tensor([0, 1, -100, 1, 0]).long() | |
| loss = criterion(y_pred, y_true) | |
| print(loss) | |