File size: 1,725 Bytes
eef8873
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import logging

from src.data.dataset import (
    create_resnet_dataloaders,
    create_fusion_dataloaders
)

logger = logging.getLogger(__name__)


def test_dataset():
    logger.info("Testing dataset loaders...")

    # ---------------- ResNet ----------------
    resnet_loader, _ = create_resnet_dataloaders()

    images, labels = next(iter(resnet_loader))

    assert images.shape[1:] == (3, 128, 128), \
        f"Unexpected ResNet image shape: {images.shape}"

    assert len(labels.shape) == 1, \
        f"Unexpected ResNet labels shape: {labels.shape}"

    logger.info("ResNet dataloader test passed.")

    # ---------------- Fusion ----------------
    fusion_loader, _ = create_fusion_dataloaders()

    batch = next(iter(fusion_loader))

    assert "pixel_values_eff" in batch, "Missing EfficientNet input"
    assert "pixel_values_cnx" in batch, "Missing ConvNeXt input"
    assert "labels" in batch, "Missing labels"

    assert batch["pixel_values_eff"].shape[1:] == (3, 260, 260), \
        f"Unexpected Fusion EfficientNet shape: {batch['pixel_values_eff'].shape}"

    assert batch["pixel_values_cnx"].shape[1:] == (3, 224, 224), \
        f"Unexpected Fusion ConvNeXt shape: {batch['pixel_values_cnx'].shape}"

    assert len(batch["labels"].shape) == 1, \
        f"Unexpected Fusion labels shape: {batch['labels'].shape}"

    logger.info("Fusion dataloader test passed.")
    logger.info("Dataset test passed successfully.")


if __name__ == "__main__":
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s - %(levelname)s - %(message)s"
    )

    test_dataset()

    print("Dataset test completed successfully.")