Spaces:
Sleeping
Sleeping
File size: 1,725 Bytes
eef8873 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 | import logging
from src.data.dataset import (
create_resnet_dataloaders,
create_fusion_dataloaders
)
logger = logging.getLogger(__name__)
def test_dataset():
logger.info("Testing dataset loaders...")
# ---------------- ResNet ----------------
resnet_loader, _ = create_resnet_dataloaders()
images, labels = next(iter(resnet_loader))
assert images.shape[1:] == (3, 128, 128), \
f"Unexpected ResNet image shape: {images.shape}"
assert len(labels.shape) == 1, \
f"Unexpected ResNet labels shape: {labels.shape}"
logger.info("ResNet dataloader test passed.")
# ---------------- Fusion ----------------
fusion_loader, _ = create_fusion_dataloaders()
batch = next(iter(fusion_loader))
assert "pixel_values_eff" in batch, "Missing EfficientNet input"
assert "pixel_values_cnx" in batch, "Missing ConvNeXt input"
assert "labels" in batch, "Missing labels"
assert batch["pixel_values_eff"].shape[1:] == (3, 260, 260), \
f"Unexpected Fusion EfficientNet shape: {batch['pixel_values_eff'].shape}"
assert batch["pixel_values_cnx"].shape[1:] == (3, 224, 224), \
f"Unexpected Fusion ConvNeXt shape: {batch['pixel_values_cnx'].shape}"
assert len(batch["labels"].shape) == 1, \
f"Unexpected Fusion labels shape: {batch['labels'].shape}"
logger.info("Fusion dataloader test passed.")
logger.info("Dataset test passed successfully.")
if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s"
)
test_dataset()
print("Dataset test completed successfully.") |