| """Quick test that model creation, dataset loading, and loss work correctly.""" |
| import sys |
| import os |
| import numpy as np |
| import torch |
|
|
| print("=" * 60) |
| print("TESTING: Model + Dataset + Loss Setup") |
| print("=" * 60) |
|
|
| |
| print("\n1. Creating U-Net++ + EfficientNet-B4...") |
| import segmentation_models_pytorch as smp |
|
|
| model = smp.UnetPlusPlus( |
| encoder_name="efficientnet-b4", |
| encoder_weights=None, |
| in_channels=1, |
| classes=5, |
| decoder_attention_type="scse", |
| ) |
| params = sum(p.numel() for p in model.parameters()) |
| print(f" β
Model created: {params:,} parameters") |
|
|
| |
| print("\n2. Forward pass test...") |
| x = torch.randn(1, 1, 512, 512) |
| with torch.no_grad(): |
| y = model(x) |
| assert y.shape == (1, 5, 512, 512), f"Wrong shape: {y.shape}" |
| print(f" β
Input {x.shape} β Output {y.shape}") |
|
|
| |
| print("\n3. Loss functions...") |
| dice_loss = smp.losses.DiceLoss(mode="multiclass", from_logits=True, smooth=1.0) |
| focal_loss = smp.losses.FocalLoss(mode="multiclass", gamma=2.0) |
|
|
| pred = torch.randn(2, 5, 64, 64, requires_grad=True) |
| target = torch.randint(0, 5, (2, 64, 64)) |
|
|
| loss = 0.5 * dice_loss(pred, target) + 0.5 * focal_loss(pred, target) |
| loss.backward() |
| print(f" β
Dice+Focal loss: {loss.item():.4f}, backward OK") |
|
|
| |
| print("\n4. Label remapping (30β5 classes)...") |
| from train import LABEL_REMAP |
| assert LABEL_REMAP[0] == 0 |
| assert LABEL_REMAP[9] == 1 |
| assert LABEL_REMAP[14] == 2 |
| assert LABEL_REMAP[10] == 2 |
| assert LABEL_REMAP[11] == 3 |
| assert LABEL_REMAP[17] == 3 |
| assert LABEL_REMAP[20] == 3 |
| assert LABEL_REMAP[12] == 4 |
| print(f" β
Remap table verified: {np.unique(LABEL_REMAP)}") |
|
|
| |
| print("\n5. Augmentation pipeline...") |
| from train import get_train_transforms, get_val_transforms |
|
|
| train_tf = get_train_transforms(512) |
| val_tf = get_val_transforms(512) |
|
|
| |
| fake_img = np.random.randint(0, 255, (512, 512), dtype=np.uint8).astype(np.float32) |
| fake_mask = np.random.randint(0, 5, (512, 512), dtype=np.uint8) |
|
|
| aug = train_tf(image=fake_img, mask=fake_mask) |
| assert aug["image"].shape == torch.Size([1, 512, 512]), f"Wrong aug img shape: {aug['image'].shape}" |
| assert aug["mask"].shape == torch.Size([512, 512]), f"Wrong aug mask shape: {aug['mask'].shape}" |
| print(f" β
Train transform: img {aug['image'].shape}, mask {aug['mask'].shape}") |
|
|
| aug = val_tf(image=fake_img, mask=fake_mask) |
| print(f" β
Val transform: img {aug['image'].shape}, mask {aug['mask'].shape}") |
|
|
| |
| print("\n6. Dataset class (with fake data)...") |
| import tempfile |
| from PIL import Image as PILImage |
|
|
| with tempfile.TemporaryDirectory() as tmpdir: |
| img_dir = os.path.join(tmpdir, "images") |
| mask_dir = os.path.join(tmpdir, "masks") |
| os.makedirs(img_dir) |
| os.makedirs(mask_dir) |
|
|
| |
| for i in range(3): |
| img = PILImage.fromarray(np.random.randint(0, 255, (512, 512, 4), dtype=np.uint8), mode="RGBA") |
| img.save(os.path.join(img_dir, f"test_{i}.png")) |
|
|
| |
| mask_arr = np.zeros((512, 512), dtype=np.uint8) |
| mask_arr[100:200, 100:200] = 14 |
| mask_arr[200:300, 200:300] = 11 |
| mask_arr[50:60, :] = 9 |
| PILImage.fromarray(mask_arr, mode="L").save(os.path.join(mask_dir, f"test_{i}.png")) |
|
|
| from train import ESCDDDataset |
| ds = ESCDDDataset(img_dir, mask_dir, transform=get_train_transforms(512)) |
|
|
| img, mask = ds[0] |
| assert img.shape == (1, 512, 512), f"Wrong dataset img shape: {img.shape}" |
| assert mask.shape == (512, 512), f"Wrong dataset mask shape: {mask.shape}" |
| unique_classes = torch.unique(mask).tolist() |
| print(f" β
Dataset OK: img {img.shape}, mask {mask.shape}, classes: {unique_classes}") |
|
|
| |
| assert 0 in unique_classes |
| assert 2 in unique_classes |
| assert 3 in unique_classes |
| assert 1 in unique_classes |
| print(f" β
Class remapping verified in dataset output") |
|
|
| |
| print("\n7. VRAM estimate for RTX 4060...") |
| |
| |
| |
| |
| model_bytes = params * 4 / 1e6 |
| print(f" Model weights: {model_bytes:.0f} MB") |
| print(f" Optimizer states (AdamW): ~{model_bytes * 2:.0f} MB") |
| print(f" Estimated activations (bs=4, AMP): ~3000 MB") |
| print(f" Estimated total: ~{model_bytes * 3 + 3000:.0f} MB") |
| print(f" RTX 4060 VRAM: 8192 MB") |
| print(f" β
Fits with ~{8192 - model_bytes * 3 - 3000:.0f} MB headroom") |
|
|
| print("\n" + "=" * 60) |
| print("ALL TESTS PASSED β
") |
| print("=" * 60) |
| print("\nYou can now run: python train.py") |
|
|