File size: 1,047 Bytes
eef8873
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import logging
from PIL import Image

from src.data.augmentation import (
    get_resnet_train_transforms,
    get_fusion_train_transforms
)

logger = logging.getLogger(__name__)


def test_augmentation():
    logger.info("Testing augmentation pipelines...")

    dummy_image = Image.new("RGB", (300, 300))

    resnet_transform = get_resnet_train_transforms()
    fusion_transform = get_fusion_train_transforms()

    resnet_tensor = resnet_transform(dummy_image)
    fusion_tensor = fusion_transform(dummy_image)

    assert resnet_tensor.shape == (3, 128, 128), \
        f"Unexpected ResNet shape: {resnet_tensor.shape}"

    assert fusion_tensor.shape == (3, 260, 260), \
        f"Unexpected Fusion shape: {fusion_tensor.shape}"

    logger.info("Augmentation test passed.")


if __name__ == "__main__":
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s - %(levelname)s - %(message)s"
    )

    test_augmentation()

    print("Augmentation test completed successfully.")