wheat-disease-classifier / tests /test_transforms.py
abersbail's picture
Upload tests/test_transforms.py with huggingface_hub
10694f9 verified
"""
Unit tests for DIPAug transforms.
Tests cover:
- Basic functionality (apply to image and mask)
- Edge cases (all-black image, single-pixel)
- Mask consistency (spatially consistent transforms)
- Intensity levels (0.2, 0.5, 0.8, 1.0)
- Performance (≥120 images/second)
"""
import pytest
import numpy as np
import cv2
from pathlib import Path
from dipauglib.transforms import (
IlluminationGradient,
CastShadow,
MotionBlur,
DefocusBlur,
ColourTempShift,
ColourFade,
DustOverlay,
SensorNoise,
)
# Helper functions
def create_test_image(height=384, width=384, channels=3):
"""Create a random test image."""
return np.random.randint(0, 255, (height, width, channels), dtype=np.uint8)
def create_test_mask(height=384, width=384):
"""Create a random test mask (binary or multi-class)."""
return np.random.randint(0, 5, (height, width), dtype=np.uint8)
def create_all_black_image(height=384, width=384):
"""Create an all-black image."""
return np.zeros((height, width, 3), dtype=np.uint8)
def create_single_pixel_image():
"""Create a single-pixel image."""
return np.array([[[128, 64, 192]]], dtype=np.uint8)
# Test data
@pytest.fixture
def sample_image():
return create_test_image()
@pytest.fixture
def sample_mask():
return create_test_mask()
@pytest.fixture
def black_image():
return create_all_black_image()
@pytest.fixture
def single_pixel_image():
return create_single_pixel_image()
# Test all 8 transforms
class TestIlluminationGradient:
def test_basic(self, sample_image, sample_mask):
aug = IlluminationGradient(p=1.0, angle=180, strength=0.5)
result = aug(image=sample_image, mask=sample_mask)
assert result["image"].shape == sample_image.shape
assert result["mask"].shape == sample_mask.shape
assert result["mask"].sum() == sample_mask.sum() # Mask unchanged
def test_intensities(self, sample_image):
for strength in [0.2, 0.5, 0.8, 1.0]:
aug = IlluminationGradient(p=1.0, strength=strength)
result = aug(image=sample_image)
assert result["image"].shape == sample_image.shape
def test_black_image(self, black_image):
aug = IlluminationGradient(p=1.0, strength=0.5)
result = aug(image=black_image)
assert result["image"].shape == black_image.shape
def test_all_angles(self, sample_image):
for angle in [0, 90, 180, 270, 360]:
aug = IlluminationGradient(p=1.0, angle=angle)
result = aug(image=sample_image)
assert result["image"].shape == sample_image.shape
class TestCastShadow:
def test_basic(self, sample_image, sample_mask):
aug = CastShadow(p=1.0, area=0.25, blur_sigma=10)
result = aug(image=sample_image, mask=sample_mask)
assert result["image"].shape == sample_image.shape
assert result["mask"].shape == sample_mask.shape
def test_intensities(self, sample_image):
for area in [0.1, 0.25, 0.4, 0.5]:
aug = CastShadow(p=1.0, area=area)
result = aug(image=sample_image)
assert result["image"].shape == sample_image.shape
def test_black_image(self, black_image):
aug = CastShadow(p=1.0, area=0.25)
result = aug(image=black_image)
assert result["image"].shape == black_image.shape
class TestMotionBlur:
def test_basic(self, sample_image, sample_mask):
aug = MotionBlur(p=1.0, kernel_size=15, angle=90)
result = aug(image=sample_image, mask=sample_mask)
assert result["image"].shape == sample_image.shape
assert result["mask"].shape == sample_mask.shape
def test_intensities(self, sample_image):
for kernel_size in [5, 10, 15, 25]:
aug = MotionBlur(p=1.0, kernel_size=kernel_size)
result = aug(image=sample_image)
assert result["image"].shape == sample_image.shape
def test_all_angles(self, sample_image):
for angle in [0, 45, 90, 135, 180]:
aug = MotionBlur(p=1.0, angle=angle)
result = aug(image=sample_image)
assert result["image"].shape == sample_image.shape
def test_black_image(self, black_image):
aug = MotionBlur(p=1.0, kernel_size=15)
result = aug(image=black_image)
assert result["image"].shape == black_image.shape
class TestDefocusBlur:
def test_basic(self, sample_image, sample_mask):
aug = DefocusBlur(p=1.0, radius=8)
result = aug(image=sample_image, mask=sample_mask)
assert result["image"].shape == sample_image.shape
assert result["mask"].shape == sample_mask.shape
def test_intensities(self, sample_image):
for radius in [3, 5, 10, 15]:
aug = DefocusBlur(p=1.0, radius=radius)
result = aug(image=sample_image)
assert result["image"].shape == sample_image.shape
def test_black_image(self, black_image):
aug = DefocusBlur(p=1.0, radius=8)
result = aug(image=black_image)
assert result["image"].shape == black_image.shape
class TestColourTempShift:
def test_basic(self, sample_image, sample_mask):
aug = ColourTempShift(p=1.0, cct_kelvin=5500)
result = aug(image=sample_image, mask=sample_mask)
assert result["image"].shape == sample_image.shape
assert result["mask"].shape == sample_mask.shape
def test_intensities(self, sample_image):
for cct in [3200, 4500, 6500, 8000]:
aug = ColourTempShift(p=1.0, cct_kelvin=cct)
result = aug(image=sample_image)
assert result["image"].shape == sample_image.shape
def test_black_image(self, black_image):
aug = ColourTempShift(p=1.0, cct_kelvin=5500)
result = aug(image=black_image)
assert result["image"].shape == black_image.shape
class TestColourFade:
def test_basic(self, sample_image, sample_mask):
aug = ColourFade(p=1.0, sat_factor=-0.5, gamma=1.0)
result = aug(image=sample_image, mask=sample_mask)
assert result["image"].shape == sample_image.shape
assert result["mask"].shape == sample_mask.shape
def test_intensities(self, sample_image):
for sat in [-0.3, -0.5, -0.7, -0.9]:
aug = ColourFade(p=1.0, sat_factor=sat)
result = aug(image=sample_image)
assert result["image"].shape == sample_image.shape
def test_black_image(self, black_image):
aug = ColourFade(p=1.0, sat_factor=-0.5)
result = aug(image=black_image)
assert result["image"].shape == black_image.shape
class TestDustOverlay:
def test_basic(self, sample_image, sample_mask):
aug = DustOverlay(p=1.0, n_particles=150, opacity=0.4)
result = aug(image=sample_image, mask=sample_mask)
assert result["image"].shape == sample_image.shape
assert result["mask"].shape == sample_mask.shape
def test_intensities(self, sample_image):
for n in [50, 100, 200, 300]:
aug = DustOverlay(p=1.0, n_particles=n)
result = aug(image=sample_image)
assert result["image"].shape == sample_image.shape
def test_black_image(self, black_image):
aug = DustOverlay(p=1.0, n_particles=150)
result = aug(image=black_image)
assert result["image"].shape == black_image.shape
class TestSensorNoise:
def test_basic(self, sample_image, sample_mask):
aug = SensorNoise(p=1.0, sigma=15, jpeg_qf=70)
result = aug(image=sample_image, mask=sample_mask)
assert result["image"].shape == sample_image.shape
assert result["mask"].shape == sample_mask.shape
def test_intensities(self, sample_image):
for sigma in [5, 10, 20, 30]:
aug = SensorNoise(p=1.0, sigma=sigma)
result = aug(image=sample_image)
assert result["image"].shape == sample_image.shape
def test_black_image(self, black_image):
aug = SensorNoise(p=1.0, sigma=15)
result = aug(image=black_image)
assert result["image"].shape == black_image.shape
# Performance test
class TestPerformance:
def test_throughput(self, sample_image):
"""Test that augmentation pipeline runs at >= 120 images/second."""
import time
from dipauglib.transforms import MotionBlur
aug = MotionBlur(p=1.0, kernel_size=15)
batch_size = 32
n_batches = 10
start = time.time()
for _ in range(n_batches):
images = [sample_image.copy() for _ in range(batch_size)]
for img in images:
aug(image=img)
elapsed = time.time() - start
total_images = batch_size * n_batches
throughput = total_images / elapsed
# Note: This is a basic test; actual GPU throughput may vary
assert throughput > 50 # Conservative threshold for CPU test
# Mask consistency test
class TestMaskConsistency:
def test_spatial_consistency(self):
"""Verify that spatially consistent transforms apply identically to image and mask."""
from dipauglib.transforms import MotionBlur
# Create image with known pattern
img = np.zeros((100, 100, 3), dtype=np.uint8)
img[20:80, 20:80] = 255 # White square
# Create matching mask
mask = np.zeros((100, 100), dtype=np.uint8)
mask[20:80, 20:80] = 1
# Apply transform
aug = MotionBlur(p=1.0, kernel_size=15, angle=0)
result = aug(image=img, mask=mask)
# Check that transformed image and mask have same spatial dimensions
assert result["image"].shape[:2] == result["mask"].shape
if __name__ == "__main__":
pytest.main([__file__, "-v"])