File size: 1,767 Bytes
9c2e807
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
"""Tests for physics-aware augmentations."""

from __future__ import annotations

import numpy as np
import pytest

from dipauglib.transforms.physics import (
    CastShadow,
    ColourFade,
    ColourTempShift,
    DefocusBlur,
    DustOverlay,
    IlluminationGradient,
    MotionBlur,
    SensorNoise,
)


TRANSFORMS = [
    IlluminationGradient,
    CastShadow,
    MotionBlur,
    DefocusBlur,
    ColourTempShift,
    ColourFade,
    DustOverlay,
    SensorNoise,
]


@pytest.mark.parametrize("transform_cls", TRANSFORMS)
def test_transform_preserves_shape_and_mask(transform_cls):
    image = np.full((32, 32, 3), 128, dtype=np.uint8)
    mask = np.zeros((32, 32), dtype=np.uint8)
    mask[8:20, 10:22] = 1
    transform = transform_cls(intensity=0.5, p=1.0)
    output = transform(image=image, mask=mask)
    assert output["image"].shape == image.shape
    assert output["mask"].shape == mask.shape
    assert np.array_equal(output["mask"], mask)


@pytest.mark.parametrize("transform_cls", TRANSFORMS)
def test_transform_handles_all_black_image(transform_cls):
    image = np.zeros((24, 24, 3), dtype=np.uint8)
    mask = np.zeros((24, 24), dtype=np.uint8)
    transform = transform_cls(intensity=1.0, p=1.0)
    output = transform(image=image, mask=mask)
    assert output["image"].shape == image.shape
    assert output["mask"].shape == mask.shape


@pytest.mark.parametrize("transform_cls", TRANSFORMS)
def test_transform_handles_single_pixel_image(transform_cls):
    image = np.array([[[255, 0, 0]]], dtype=np.uint8)
    mask = np.array([[1]], dtype=np.uint8)
    transform = transform_cls(intensity=0.2, p=1.0)
    output = transform(image=image, mask=mask)
    assert output["image"].shape == image.shape
    assert output["mask"].shape == mask.shape