File size: 5,052 Bytes
481b13c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
"""Quick test that model creation, dataset loading, and loss work correctly."""
import sys
import os
import numpy as np
import torch

print("=" * 60)
print("TESTING: Model + Dataset + Loss Setup")
print("=" * 60)

# Test 1: Model creation
print("\n1. Creating U-Net++ + EfficientNet-B4...")
import segmentation_models_pytorch as smp

model = smp.UnetPlusPlus(
    encoder_name="efficientnet-b4",
    encoder_weights=None,  # Don't download in test
    in_channels=1,
    classes=5,
    decoder_attention_type="scse",
)
params = sum(p.numel() for p in model.parameters())
print(f"   βœ… Model created: {params:,} parameters")

# Test 2: Forward pass
print("\n2. Forward pass test...")
x = torch.randn(1, 1, 512, 512)
with torch.no_grad():
    y = model(x)
assert y.shape == (1, 5, 512, 512), f"Wrong shape: {y.shape}"
print(f"   βœ… Input {x.shape} β†’ Output {y.shape}")

# Test 3: Loss functions
print("\n3. Loss functions...")
dice_loss = smp.losses.DiceLoss(mode="multiclass", from_logits=True, smooth=1.0)
focal_loss = smp.losses.FocalLoss(mode="multiclass", gamma=2.0)

pred = torch.randn(2, 5, 64, 64, requires_grad=True)
target = torch.randint(0, 5, (2, 64, 64))

loss = 0.5 * dice_loss(pred, target) + 0.5 * focal_loss(pred, target)
loss.backward()
print(f"   βœ… Dice+Focal loss: {loss.item():.4f}, backward OK")

# Test 4: Label remap
print("\n4. Label remapping (30β†’5 classes)...")
from train import LABEL_REMAP
assert LABEL_REMAP[0] == 0   # background
assert LABEL_REMAP[9] == 1   # busbar
assert LABEL_REMAP[14] == 2  # crack
assert LABEL_REMAP[10] == 2  # crack_rbn_edge β†’ crack
assert LABEL_REMAP[11] == 3  # inactive β†’ dark
assert LABEL_REMAP[17] == 3  # dead_cell β†’ dark
assert LABEL_REMAP[20] == 3  # edge_dark β†’ dark
assert LABEL_REMAP[12] == 4  # rings β†’ other
print(f"   βœ… Remap table verified: {np.unique(LABEL_REMAP)}")

# Test 5: Augmentation pipeline
print("\n5. Augmentation pipeline...")
from train import get_train_transforms, get_val_transforms

train_tf = get_train_transforms(512)
val_tf = get_val_transforms(512)

# Simulate a 512x512 grayscale image + mask
fake_img = np.random.randint(0, 255, (512, 512), dtype=np.uint8).astype(np.float32)
fake_mask = np.random.randint(0, 5, (512, 512), dtype=np.uint8)

aug = train_tf(image=fake_img, mask=fake_mask)
assert aug["image"].shape == torch.Size([1, 512, 512]), f"Wrong aug img shape: {aug['image'].shape}"
assert aug["mask"].shape == torch.Size([512, 512]), f"Wrong aug mask shape: {aug['mask'].shape}"
print(f"   βœ… Train transform: img {aug['image'].shape}, mask {aug['mask'].shape}")

aug = val_tf(image=fake_img, mask=fake_mask)
print(f"   βœ… Val transform: img {aug['image'].shape}, mask {aug['mask'].shape}")

# Test 6: Dataset class
print("\n6. Dataset class (with fake data)...")
import tempfile
from PIL import Image as PILImage

with tempfile.TemporaryDirectory() as tmpdir:
    img_dir = os.path.join(tmpdir, "images")
    mask_dir = os.path.join(tmpdir, "masks")
    os.makedirs(img_dir)
    os.makedirs(mask_dir)

    # Create fake RGBA image and L mask
    for i in range(3):
        img = PILImage.fromarray(np.random.randint(0, 255, (512, 512, 4), dtype=np.uint8), mode="RGBA")
        img.save(os.path.join(img_dir, f"test_{i}.png"))

        # Mask with valid E-SCDD labels
        mask_arr = np.zeros((512, 512), dtype=np.uint8)
        mask_arr[100:200, 100:200] = 14  # crack
        mask_arr[200:300, 200:300] = 11  # inactive
        mask_arr[50:60, :] = 9           # busbar
        PILImage.fromarray(mask_arr, mode="L").save(os.path.join(mask_dir, f"test_{i}.png"))

    from train import ESCDDDataset
    ds = ESCDDDataset(img_dir, mask_dir, transform=get_train_transforms(512))

    img, mask = ds[0]
    assert img.shape == (1, 512, 512), f"Wrong dataset img shape: {img.shape}"
    assert mask.shape == (512, 512), f"Wrong dataset mask shape: {mask.shape}"
    unique_classes = torch.unique(mask).tolist()
    print(f"   βœ… Dataset OK: img {img.shape}, mask {mask.shape}, classes: {unique_classes}")

    # Verify remapping
    assert 0 in unique_classes  # background
    assert 2 in unique_classes  # crack (was 14)
    assert 3 in unique_classes  # dark (was 11)
    assert 1 in unique_classes  # busbar (was 9)
    print(f"   βœ… Class remapping verified in dataset output")

# Test 7: VRAM estimate
print("\n7. VRAM estimate for RTX 4060...")
# Model params: ~20.9M Γ— 4 bytes = 83.6 MB
# Optimizer states: Γ—2 = 167.2 MB
# Activations for UNet++ at 512x512, bs=4 with AMP: ~2.5-3.5 GB
# Total: ~3.0-4.0 GB β†’ fits 8GB easily
model_bytes = params * 4 / 1e6
print(f"   Model weights: {model_bytes:.0f} MB")
print(f"   Optimizer states (AdamW): ~{model_bytes * 2:.0f} MB")
print(f"   Estimated activations (bs=4, AMP): ~3000 MB")
print(f"   Estimated total: ~{model_bytes * 3 + 3000:.0f} MB")
print(f"   RTX 4060 VRAM: 8192 MB")
print(f"   βœ… Fits with ~{8192 - model_bytes * 3 - 3000:.0f} MB headroom")

print("\n" + "=" * 60)
print("ALL TESTS PASSED βœ…")
print("=" * 60)
print("\nYou can now run: python train.py")