File size: 5,052 Bytes
481b13c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 | """Quick test that model creation, dataset loading, and loss work correctly."""
import sys
import os
import numpy as np
import torch
print("=" * 60)
print("TESTING: Model + Dataset + Loss Setup")
print("=" * 60)
# Test 1: Model creation
print("\n1. Creating U-Net++ + EfficientNet-B4...")
import segmentation_models_pytorch as smp
model = smp.UnetPlusPlus(
encoder_name="efficientnet-b4",
encoder_weights=None, # Don't download in test
in_channels=1,
classes=5,
decoder_attention_type="scse",
)
params = sum(p.numel() for p in model.parameters())
print(f" β
Model created: {params:,} parameters")
# Test 2: Forward pass
print("\n2. Forward pass test...")
x = torch.randn(1, 1, 512, 512)
with torch.no_grad():
y = model(x)
assert y.shape == (1, 5, 512, 512), f"Wrong shape: {y.shape}"
print(f" β
Input {x.shape} β Output {y.shape}")
# Test 3: Loss functions
print("\n3. Loss functions...")
dice_loss = smp.losses.DiceLoss(mode="multiclass", from_logits=True, smooth=1.0)
focal_loss = smp.losses.FocalLoss(mode="multiclass", gamma=2.0)
pred = torch.randn(2, 5, 64, 64, requires_grad=True)
target = torch.randint(0, 5, (2, 64, 64))
loss = 0.5 * dice_loss(pred, target) + 0.5 * focal_loss(pred, target)
loss.backward()
print(f" β
Dice+Focal loss: {loss.item():.4f}, backward OK")
# Test 4: Label remap
print("\n4. Label remapping (30β5 classes)...")
from train import LABEL_REMAP
assert LABEL_REMAP[0] == 0 # background
assert LABEL_REMAP[9] == 1 # busbar
assert LABEL_REMAP[14] == 2 # crack
assert LABEL_REMAP[10] == 2 # crack_rbn_edge β crack
assert LABEL_REMAP[11] == 3 # inactive β dark
assert LABEL_REMAP[17] == 3 # dead_cell β dark
assert LABEL_REMAP[20] == 3 # edge_dark β dark
assert LABEL_REMAP[12] == 4 # rings β other
print(f" β
Remap table verified: {np.unique(LABEL_REMAP)}")
# Test 5: Augmentation pipeline
print("\n5. Augmentation pipeline...")
from train import get_train_transforms, get_val_transforms
train_tf = get_train_transforms(512)
val_tf = get_val_transforms(512)
# Simulate a 512x512 grayscale image + mask
fake_img = np.random.randint(0, 255, (512, 512), dtype=np.uint8).astype(np.float32)
fake_mask = np.random.randint(0, 5, (512, 512), dtype=np.uint8)
aug = train_tf(image=fake_img, mask=fake_mask)
assert aug["image"].shape == torch.Size([1, 512, 512]), f"Wrong aug img shape: {aug['image'].shape}"
assert aug["mask"].shape == torch.Size([512, 512]), f"Wrong aug mask shape: {aug['mask'].shape}"
print(f" β
Train transform: img {aug['image'].shape}, mask {aug['mask'].shape}")
aug = val_tf(image=fake_img, mask=fake_mask)
print(f" β
Val transform: img {aug['image'].shape}, mask {aug['mask'].shape}")
# Test 6: Dataset class
print("\n6. Dataset class (with fake data)...")
import tempfile
from PIL import Image as PILImage
with tempfile.TemporaryDirectory() as tmpdir:
img_dir = os.path.join(tmpdir, "images")
mask_dir = os.path.join(tmpdir, "masks")
os.makedirs(img_dir)
os.makedirs(mask_dir)
# Create fake RGBA image and L mask
for i in range(3):
img = PILImage.fromarray(np.random.randint(0, 255, (512, 512, 4), dtype=np.uint8), mode="RGBA")
img.save(os.path.join(img_dir, f"test_{i}.png"))
# Mask with valid E-SCDD labels
mask_arr = np.zeros((512, 512), dtype=np.uint8)
mask_arr[100:200, 100:200] = 14 # crack
mask_arr[200:300, 200:300] = 11 # inactive
mask_arr[50:60, :] = 9 # busbar
PILImage.fromarray(mask_arr, mode="L").save(os.path.join(mask_dir, f"test_{i}.png"))
from train import ESCDDDataset
ds = ESCDDDataset(img_dir, mask_dir, transform=get_train_transforms(512))
img, mask = ds[0]
assert img.shape == (1, 512, 512), f"Wrong dataset img shape: {img.shape}"
assert mask.shape == (512, 512), f"Wrong dataset mask shape: {mask.shape}"
unique_classes = torch.unique(mask).tolist()
print(f" β
Dataset OK: img {img.shape}, mask {mask.shape}, classes: {unique_classes}")
# Verify remapping
assert 0 in unique_classes # background
assert 2 in unique_classes # crack (was 14)
assert 3 in unique_classes # dark (was 11)
assert 1 in unique_classes # busbar (was 9)
print(f" β
Class remapping verified in dataset output")
# Test 7: VRAM estimate
print("\n7. VRAM estimate for RTX 4060...")
# Model params: ~20.9M Γ 4 bytes = 83.6 MB
# Optimizer states: Γ2 = 167.2 MB
# Activations for UNet++ at 512x512, bs=4 with AMP: ~2.5-3.5 GB
# Total: ~3.0-4.0 GB β fits 8GB easily
model_bytes = params * 4 / 1e6
print(f" Model weights: {model_bytes:.0f} MB")
print(f" Optimizer states (AdamW): ~{model_bytes * 2:.0f} MB")
print(f" Estimated activations (bs=4, AMP): ~3000 MB")
print(f" Estimated total: ~{model_bytes * 3 + 3000:.0f} MB")
print(f" RTX 4060 VRAM: 8192 MB")
print(f" β
Fits with ~{8192 - model_bytes * 3 - 3000:.0f} MB headroom")
print("\n" + "=" * 60)
print("ALL TESTS PASSED β
")
print("=" * 60)
print("\nYou can now run: python train.py")
|