"""Quick test that model creation, dataset loading, and loss work correctly.""" import sys import os import numpy as np import torch print("=" * 60) print("TESTING: Model + Dataset + Loss Setup") print("=" * 60) # Test 1: Model creation print("\n1. Creating U-Net++ + EfficientNet-B4...") import segmentation_models_pytorch as smp model = smp.UnetPlusPlus( encoder_name="efficientnet-b4", encoder_weights=None, # Don't download in test in_channels=1, classes=5, decoder_attention_type="scse", ) params = sum(p.numel() for p in model.parameters()) print(f" ✅ Model created: {params:,} parameters") # Test 2: Forward pass print("\n2. Forward pass test...") x = torch.randn(1, 1, 512, 512) with torch.no_grad(): y = model(x) assert y.shape == (1, 5, 512, 512), f"Wrong shape: {y.shape}" print(f" ✅ Input {x.shape} → Output {y.shape}") # Test 3: Loss functions print("\n3. Loss functions...") dice_loss = smp.losses.DiceLoss(mode="multiclass", from_logits=True, smooth=1.0) focal_loss = smp.losses.FocalLoss(mode="multiclass", gamma=2.0) pred = torch.randn(2, 5, 64, 64, requires_grad=True) target = torch.randint(0, 5, (2, 64, 64)) loss = 0.5 * dice_loss(pred, target) + 0.5 * focal_loss(pred, target) loss.backward() print(f" ✅ Dice+Focal loss: {loss.item():.4f}, backward OK") # Test 4: Label remap print("\n4. Label remapping (30→5 classes)...") from train import LABEL_REMAP assert LABEL_REMAP[0] == 0 # background assert LABEL_REMAP[9] == 1 # busbar assert LABEL_REMAP[14] == 2 # crack assert LABEL_REMAP[10] == 2 # crack_rbn_edge → crack assert LABEL_REMAP[11] == 3 # inactive → dark assert LABEL_REMAP[17] == 3 # dead_cell → dark assert LABEL_REMAP[20] == 3 # edge_dark → dark assert LABEL_REMAP[12] == 4 # rings → other print(f" ✅ Remap table verified: {np.unique(LABEL_REMAP)}") # Test 5: Augmentation pipeline print("\n5. Augmentation pipeline...") from train import get_train_transforms, get_val_transforms train_tf = get_train_transforms(512) val_tf = get_val_transforms(512) # Simulate a 512x512 grayscale image + mask fake_img = np.random.randint(0, 255, (512, 512), dtype=np.uint8).astype(np.float32) fake_mask = np.random.randint(0, 5, (512, 512), dtype=np.uint8) aug = train_tf(image=fake_img, mask=fake_mask) assert aug["image"].shape == torch.Size([1, 512, 512]), f"Wrong aug img shape: {aug['image'].shape}" assert aug["mask"].shape == torch.Size([512, 512]), f"Wrong aug mask shape: {aug['mask'].shape}" print(f" ✅ Train transform: img {aug['image'].shape}, mask {aug['mask'].shape}") aug = val_tf(image=fake_img, mask=fake_mask) print(f" ✅ Val transform: img {aug['image'].shape}, mask {aug['mask'].shape}") # Test 6: Dataset class print("\n6. Dataset class (with fake data)...") import tempfile from PIL import Image as PILImage with tempfile.TemporaryDirectory() as tmpdir: img_dir = os.path.join(tmpdir, "images") mask_dir = os.path.join(tmpdir, "masks") os.makedirs(img_dir) os.makedirs(mask_dir) # Create fake RGBA image and L mask for i in range(3): img = PILImage.fromarray(np.random.randint(0, 255, (512, 512, 4), dtype=np.uint8), mode="RGBA") img.save(os.path.join(img_dir, f"test_{i}.png")) # Mask with valid E-SCDD labels mask_arr = np.zeros((512, 512), dtype=np.uint8) mask_arr[100:200, 100:200] = 14 # crack mask_arr[200:300, 200:300] = 11 # inactive mask_arr[50:60, :] = 9 # busbar PILImage.fromarray(mask_arr, mode="L").save(os.path.join(mask_dir, f"test_{i}.png")) from train import ESCDDDataset ds = ESCDDDataset(img_dir, mask_dir, transform=get_train_transforms(512)) img, mask = ds[0] assert img.shape == (1, 512, 512), f"Wrong dataset img shape: {img.shape}" assert mask.shape == (512, 512), f"Wrong dataset mask shape: {mask.shape}" unique_classes = torch.unique(mask).tolist() print(f" ✅ Dataset OK: img {img.shape}, mask {mask.shape}, classes: {unique_classes}") # Verify remapping assert 0 in unique_classes # background assert 2 in unique_classes # crack (was 14) assert 3 in unique_classes # dark (was 11) assert 1 in unique_classes # busbar (was 9) print(f" ✅ Class remapping verified in dataset output") # Test 7: VRAM estimate print("\n7. VRAM estimate for RTX 4060...") # Model params: ~20.9M × 4 bytes = 83.6 MB # Optimizer states: ×2 = 167.2 MB # Activations for UNet++ at 512x512, bs=4 with AMP: ~2.5-3.5 GB # Total: ~3.0-4.0 GB → fits 8GB easily model_bytes = params * 4 / 1e6 print(f" Model weights: {model_bytes:.0f} MB") print(f" Optimizer states (AdamW): ~{model_bytes * 2:.0f} MB") print(f" Estimated activations (bs=4, AMP): ~3000 MB") print(f" Estimated total: ~{model_bytes * 3 + 3000:.0f} MB") print(f" RTX 4060 VRAM: 8192 MB") print(f" ✅ Fits with ~{8192 - model_bytes * 3 - 3000:.0f} MB headroom") print("\n" + "=" * 60) print("ALL TESTS PASSED ✅") print("=" * 60) print("\nYou can now run: python train.py")