DamageLensAI / test /test_dataset.py
junaid17's picture
Upload 43 files
eef8873 verified
Raw
History Blame Contribute Delete
1.73 kB
import logging
from src.data.dataset import (
create_resnet_dataloaders,
create_fusion_dataloaders
)
logger = logging.getLogger(__name__)
def test_dataset():
logger.info("Testing dataset loaders...")
# ---------------- ResNet ----------------
resnet_loader, _ = create_resnet_dataloaders()
images, labels = next(iter(resnet_loader))
assert images.shape[1:] == (3, 128, 128), \
f"Unexpected ResNet image shape: {images.shape}"
assert len(labels.shape) == 1, \
f"Unexpected ResNet labels shape: {labels.shape}"
logger.info("ResNet dataloader test passed.")
# ---------------- Fusion ----------------
fusion_loader, _ = create_fusion_dataloaders()
batch = next(iter(fusion_loader))
assert "pixel_values_eff" in batch, "Missing EfficientNet input"
assert "pixel_values_cnx" in batch, "Missing ConvNeXt input"
assert "labels" in batch, "Missing labels"
assert batch["pixel_values_eff"].shape[1:] == (3, 260, 260), \
f"Unexpected Fusion EfficientNet shape: {batch['pixel_values_eff'].shape}"
assert batch["pixel_values_cnx"].shape[1:] == (3, 224, 224), \
f"Unexpected Fusion ConvNeXt shape: {batch['pixel_values_cnx'].shape}"
assert len(batch["labels"].shape) == 1, \
f"Unexpected Fusion labels shape: {batch['labels'].shape}"
logger.info("Fusion dataloader test passed.")
logger.info("Dataset test passed successfully.")
if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s"
)
test_dataset()
print("Dataset test completed successfully.")