el-defect-training / test_setup.py
nithishbasireddy's picture
Upload test_setup.py with huggingface_hub
481b13c verified
"""Quick test that model creation, dataset loading, and loss work correctly."""
import sys
import os
import numpy as np
import torch
print("=" * 60)
print("TESTING: Model + Dataset + Loss Setup")
print("=" * 60)
# Test 1: Model creation
print("\n1. Creating U-Net++ + EfficientNet-B4...")
import segmentation_models_pytorch as smp
model = smp.UnetPlusPlus(
encoder_name="efficientnet-b4",
encoder_weights=None, # Don't download in test
in_channels=1,
classes=5,
decoder_attention_type="scse",
)
params = sum(p.numel() for p in model.parameters())
print(f" βœ… Model created: {params:,} parameters")
# Test 2: Forward pass
print("\n2. Forward pass test...")
x = torch.randn(1, 1, 512, 512)
with torch.no_grad():
y = model(x)
assert y.shape == (1, 5, 512, 512), f"Wrong shape: {y.shape}"
print(f" βœ… Input {x.shape} β†’ Output {y.shape}")
# Test 3: Loss functions
print("\n3. Loss functions...")
dice_loss = smp.losses.DiceLoss(mode="multiclass", from_logits=True, smooth=1.0)
focal_loss = smp.losses.FocalLoss(mode="multiclass", gamma=2.0)
pred = torch.randn(2, 5, 64, 64, requires_grad=True)
target = torch.randint(0, 5, (2, 64, 64))
loss = 0.5 * dice_loss(pred, target) + 0.5 * focal_loss(pred, target)
loss.backward()
print(f" βœ… Dice+Focal loss: {loss.item():.4f}, backward OK")
# Test 4: Label remap
print("\n4. Label remapping (30β†’5 classes)...")
from train import LABEL_REMAP
assert LABEL_REMAP[0] == 0 # background
assert LABEL_REMAP[9] == 1 # busbar
assert LABEL_REMAP[14] == 2 # crack
assert LABEL_REMAP[10] == 2 # crack_rbn_edge β†’ crack
assert LABEL_REMAP[11] == 3 # inactive β†’ dark
assert LABEL_REMAP[17] == 3 # dead_cell β†’ dark
assert LABEL_REMAP[20] == 3 # edge_dark β†’ dark
assert LABEL_REMAP[12] == 4 # rings β†’ other
print(f" βœ… Remap table verified: {np.unique(LABEL_REMAP)}")
# Test 5: Augmentation pipeline
print("\n5. Augmentation pipeline...")
from train import get_train_transforms, get_val_transforms
train_tf = get_train_transforms(512)
val_tf = get_val_transforms(512)
# Simulate a 512x512 grayscale image + mask
fake_img = np.random.randint(0, 255, (512, 512), dtype=np.uint8).astype(np.float32)
fake_mask = np.random.randint(0, 5, (512, 512), dtype=np.uint8)
aug = train_tf(image=fake_img, mask=fake_mask)
assert aug["image"].shape == torch.Size([1, 512, 512]), f"Wrong aug img shape: {aug['image'].shape}"
assert aug["mask"].shape == torch.Size([512, 512]), f"Wrong aug mask shape: {aug['mask'].shape}"
print(f" βœ… Train transform: img {aug['image'].shape}, mask {aug['mask'].shape}")
aug = val_tf(image=fake_img, mask=fake_mask)
print(f" βœ… Val transform: img {aug['image'].shape}, mask {aug['mask'].shape}")
# Test 6: Dataset class
print("\n6. Dataset class (with fake data)...")
import tempfile
from PIL import Image as PILImage
with tempfile.TemporaryDirectory() as tmpdir:
img_dir = os.path.join(tmpdir, "images")
mask_dir = os.path.join(tmpdir, "masks")
os.makedirs(img_dir)
os.makedirs(mask_dir)
# Create fake RGBA image and L mask
for i in range(3):
img = PILImage.fromarray(np.random.randint(0, 255, (512, 512, 4), dtype=np.uint8), mode="RGBA")
img.save(os.path.join(img_dir, f"test_{i}.png"))
# Mask with valid E-SCDD labels
mask_arr = np.zeros((512, 512), dtype=np.uint8)
mask_arr[100:200, 100:200] = 14 # crack
mask_arr[200:300, 200:300] = 11 # inactive
mask_arr[50:60, :] = 9 # busbar
PILImage.fromarray(mask_arr, mode="L").save(os.path.join(mask_dir, f"test_{i}.png"))
from train import ESCDDDataset
ds = ESCDDDataset(img_dir, mask_dir, transform=get_train_transforms(512))
img, mask = ds[0]
assert img.shape == (1, 512, 512), f"Wrong dataset img shape: {img.shape}"
assert mask.shape == (512, 512), f"Wrong dataset mask shape: {mask.shape}"
unique_classes = torch.unique(mask).tolist()
print(f" βœ… Dataset OK: img {img.shape}, mask {mask.shape}, classes: {unique_classes}")
# Verify remapping
assert 0 in unique_classes # background
assert 2 in unique_classes # crack (was 14)
assert 3 in unique_classes # dark (was 11)
assert 1 in unique_classes # busbar (was 9)
print(f" βœ… Class remapping verified in dataset output")
# Test 7: VRAM estimate
print("\n7. VRAM estimate for RTX 4060...")
# Model params: ~20.9M Γ— 4 bytes = 83.6 MB
# Optimizer states: Γ—2 = 167.2 MB
# Activations for UNet++ at 512x512, bs=4 with AMP: ~2.5-3.5 GB
# Total: ~3.0-4.0 GB β†’ fits 8GB easily
model_bytes = params * 4 / 1e6
print(f" Model weights: {model_bytes:.0f} MB")
print(f" Optimizer states (AdamW): ~{model_bytes * 2:.0f} MB")
print(f" Estimated activations (bs=4, AMP): ~3000 MB")
print(f" Estimated total: ~{model_bytes * 3 + 3000:.0f} MB")
print(f" RTX 4060 VRAM: 8192 MB")
print(f" βœ… Fits with ~{8192 - model_bytes * 3 - 3000:.0f} MB headroom")
print("\n" + "=" * 60)
print("ALL TESTS PASSED βœ…")
print("=" * 60)
print("\nYou can now run: python train.py")