File size: 2,910 Bytes
090a270
 
f2a237f
090a270
f2a237f
090a270
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# tests/test_preprocessing/test_augment.py
import torch

from app.preprocessing import rotation
from app.preprocessing.augment import get_augmentation_functions


class TestGetAugmentationFunctions:
    """Tests for get_augmentation_functions."""

    def test_returns_list(self):
        funcs = get_augmentation_functions()
        assert isinstance(funcs, list)
        assert len(funcs) == 4

    def test_all_callable(self):
        funcs = get_augmentation_functions()
        for f in funcs:
            assert callable(f)


class TestAugmentationEffect:
    """Test that augmentation functions actually transform the image."""

    def test_identity_preserves_image(self):
        """identity() should return the exact same tensor."""
        img = torch.randn(3, 64, 64)
        result = rotation.identity(img)
        assert torch.equal(result, img)

    def test_horizontal_flip_changes_image(self):
        """horizontal_flip should differ from identity for non-symmetric images."""
        img = torch.arange(64 * 64, dtype=torch.float).reshape(1, 64, 64).expand(3, -1, -1)
        flipped = rotation.horizontal_flip(img)
        assert not torch.equal(flipped, img)

    def test_horizontal_flip_is_left_right(self):
        """horizontal_flip should mirror along width dimension."""
        # Create image with distinct left/right halves
        img = torch.zeros(3, 4, 4)
        img[:, :, 0] = 1.0  # left column = 1
        flipped = rotation.horizontal_flip(img)
        assert torch.equal(flipped[:, :, -1], torch.ones(3, 4))  # now right column = 1
        assert torch.equal(flipped[:, :, 0], torch.zeros(3, 4))  # left column = 0

    def test_vertical_flip_changes_image(self):
        """vertical_flip should differ from identity for non-symmetric images."""
        img = torch.arange(64 * 64, dtype=torch.float).reshape(1, 64, 64).expand(3, -1, -1)
        flipped = rotation.vertical_flip(img)
        assert not torch.equal(flipped, img)

    def test_vertical_flip_is_top_bottom(self):
        """vertical_flip should mirror along height dimension."""
        img = torch.zeros(3, 4, 4)
        img[:, 0, :] = 1.0  # top row = 1
        flipped = rotation.vertical_flip(img)
        assert torch.equal(flipped[:, -1, :], torch.ones(3, 4))  # now bottom row = 1
        assert torch.equal(flipped[:, 0, :], torch.zeros(3, 4))  # top row = 0

    def test_double_flip_returns_original(self):
        """Double flip (h+v) restored by applying hflip+vflip returns original image."""
        img = torch.randn(3, 64, 64)
        double_flipped = rotation.horizontal_and_vertical_flip(img)
        # Double flip on a non-symmetric image is different from original
        # But double flip = hflip(vflip(img)). Applying hflip again + vflip again = original
        restored = rotation.horizontal_flip(rotation.vertical_flip(double_flipped))
        assert torch.allclose(restored, img)