| """ |
| EL Defect Detection β Training Script for RTX 4060 (8GB VRAM) |
| |
| Model: U-Net++ with EfficientNet-B4 encoder + scSE attention |
| Dataset: E-SCDD (snt-ubix/e-scdd) β 903 images, 512x512 |
| Loss: 0.5 * Dice + 0.5 * Focal (handles severe class imbalance) |
| Classes: 0=background, 1=busbar, 2=crack, 3=dark/inactive, 4=other_defects |
| |
| Usage: |
| pip install torch torchvision segmentation-models-pytorch albumentations \ |
| huggingface-hub scikit-image scipy opencv-python-headless pillow |
| python train.py |
| """ |
|
|
| import os |
| import sys |
| import json |
| import time |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from torch.utils.data import Dataset, DataLoader |
| from torch.optim import AdamW |
| from torch.optim.lr_scheduler import CosineAnnealingLR |
| from pathlib import Path |
| from PIL import Image |
|
|
| import segmentation_models_pytorch as smp |
| import albumentations as A |
| from albumentations.pytorch import ToTensorV2 |
|
|
|
|
| |
| |
| |
|
|
| class Config: |
| |
| DATA_DIR = "./data" |
| OUTPUT_DIR = "./output" |
|
|
| |
| |
| ARCHITECTURE = "UnetPlusPlus" |
| ENCODER = "efficientnet-b4" |
| ENCODER_WEIGHTS = "imagenet" |
| IN_CHANNELS = 1 |
| NUM_CLASSES = 5 |
|
|
| |
| IMG_SIZE = 512 |
| BATCH_SIZE = 4 |
| NUM_EPOCHS = 100 |
| ENCODER_LR = 1e-4 |
| DECODER_LR = 5e-4 |
| WEIGHT_DECAY = 1e-4 |
| USE_AMP = True |
| NUM_WORKERS = 4 |
| GRADIENT_CLIP = 1.0 |
|
|
| |
| DICE_WEIGHT = 0.5 |
| FOCAL_WEIGHT = 0.5 |
| FOCAL_GAMMA = 2.0 |
|
|
| |
| HUB_MODEL_ID = None |
| PUSH_TO_HUB = False |
|
|
| |
| CLASS_NAMES = ["background", "busbar", "crack", "dark", "other_defect"] |
|
|
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| LABEL_REMAP = np.zeros(30, dtype=np.uint8) |
|
|
| |
| |
|
|
| |
| LABEL_REMAP[9] = 1 |
|
|
| |
| LABEL_REMAP[10] = 2 |
| LABEL_REMAP[14] = 2 |
|
|
| |
| LABEL_REMAP[11] = 3 |
| LABEL_REMAP[17] = 3 |
| LABEL_REMAP[20] = 3 |
|
|
| |
| LABEL_REMAP[12] = 4 |
| LABEL_REMAP[13] = 4 |
| LABEL_REMAP[15] = 4 |
| LABEL_REMAP[16] = 4 |
| LABEL_REMAP[18] = 4 |
| LABEL_REMAP[19] = 4 |
| LABEL_REMAP[25] = 4 |
| LABEL_REMAP[26] = 4 |
| LABEL_REMAP[27] = 4 |
| LABEL_REMAP[28] = 4 |
|
|
|
|
| |
| |
| |
|
|
| class ESCDDDataset(Dataset): |
| """ |
| E-SCDD dataset: 512x512 EL images (RGBA) + grayscale masks (L, values 0-29). |
| """ |
|
|
| def __init__(self, img_dir, mask_dir, transform=None): |
| self.img_dir = Path(img_dir) |
| self.mask_dir = Path(mask_dir) |
| self.transform = transform |
|
|
| |
| img_files = {f.stem: f for f in sorted(self.img_dir.glob("*.png"))} |
| mask_files = {f.stem: f for f in sorted(self.mask_dir.glob("*.png"))} |
|
|
| self.pairs = [] |
| for stem in img_files: |
| if stem in mask_files: |
| self.pairs.append((img_files[stem], mask_files[stem])) |
|
|
| print(f" {img_dir}: {len(self.pairs)} image-mask pairs") |
|
|
| def __len__(self): |
| return len(self.pairs) |
|
|
| def __getitem__(self, idx): |
| img_path, mask_path = self.pairs[idx] |
|
|
| |
| img = np.array(Image.open(img_path).convert("L"), dtype=np.float32) |
|
|
| |
| mask = np.array(Image.open(mask_path), dtype=np.uint8) |
|
|
| |
| mask = LABEL_REMAP[np.clip(mask, 0, 29)] |
|
|
| |
| if self.transform: |
| augmented = self.transform(image=img, mask=mask) |
| img = augmented["image"] |
| mask = augmented["mask"] |
| else: |
| img = torch.from_numpy(img).unsqueeze(0) / 255.0 |
| mask = torch.from_numpy(mask).long() |
|
|
| return img, mask |
|
|
|
|
| def get_train_transforms(img_size=512): |
| return A.Compose([ |
| A.RandomCrop(img_size, img_size, p=1.0), |
| A.HorizontalFlip(p=0.5), |
| A.VerticalFlip(p=0.5), |
| A.RandomRotate90(p=0.5), |
| A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5), |
| A.GaussNoise(std_range=(0.02, 0.1), p=0.3), |
| A.GridDistortion(num_steps=5, distort_limit=0.3, p=0.3), |
| A.Normalize(mean=[0.0], std=[1.0], max_pixel_value=255.0), |
| ToTensorV2(), |
| ]) |
|
|
|
|
| def get_val_transforms(img_size=512): |
| return A.Compose([ |
| A.CenterCrop(img_size, img_size, p=1.0), |
| A.Normalize(mean=[0.0], std=[1.0], max_pixel_value=255.0), |
| ToTensorV2(), |
| ]) |
|
|
|
|
| |
| |
| |
|
|
| def download_dataset(data_dir): |
| """Download E-SCDD from HuggingFace Hub.""" |
| train_img = os.path.join(data_dir, "el_images_train") |
| if os.path.exists(train_img) and len(os.listdir(train_img)) > 100: |
| print("Dataset already downloaded.") |
| return |
|
|
| print("Downloading E-SCDD dataset from HuggingFace Hub...") |
| from huggingface_hub import snapshot_download |
| snapshot_download( |
| repo_id="snt-ubix/e-scdd", |
| repo_type="dataset", |
| local_dir=data_dir, |
| ) |
| print(f"Downloaded to {data_dir}") |
|
|
|
|
| |
| |
| |
|
|
| def compute_metrics(pred_logits, target, num_classes=5): |
| """Compute per-class IoU and Dice.""" |
| pred = torch.argmax(pred_logits, dim=1) |
|
|
| ious, dices = [], [] |
| for c in range(num_classes): |
| pred_c = (pred == c) |
| target_c = (target == c) |
|
|
| intersection = (pred_c & target_c).float().sum() |
| union = (pred_c | target_c).float().sum() |
|
|
| iou = (intersection + 1e-6) / (union + 1e-6) |
| dice = (2 * intersection + 1e-6) / (pred_c.float().sum() + target_c.float().sum() + 1e-6) |
|
|
| ious.append(iou.item()) |
| dices.append(dice.item()) |
|
|
| return { |
| "mean_iou": np.mean(ious), |
| "mean_dice": np.mean(dices), |
| "per_class_iou": dict(zip(Config.CLASS_NAMES, ious)), |
| "per_class_dice": dict(zip(Config.CLASS_NAMES, dices)), |
| } |
|
|
|
|
| |
| |
| |
|
|
| def train(): |
| cfg = Config() |
| os.makedirs(cfg.OUTPUT_DIR, exist_ok=True) |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Device: {device}") |
| if device.type == "cuda": |
| print(f"GPU: {torch.cuda.get_device_name(0)}") |
| print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB") |
|
|
| |
| download_dataset(cfg.DATA_DIR) |
|
|
| |
| print("\nLoading datasets...") |
| train_ds = ESCDDDataset( |
| os.path.join(cfg.DATA_DIR, "el_images_train"), |
| os.path.join(cfg.DATA_DIR, "el_masks_train"), |
| transform=get_train_transforms(cfg.IMG_SIZE), |
| ) |
| val_ds = ESCDDDataset( |
| os.path.join(cfg.DATA_DIR, "el_images_val"), |
| os.path.join(cfg.DATA_DIR, "el_masks_val"), |
| transform=get_val_transforms(cfg.IMG_SIZE), |
| ) |
|
|
| train_loader = DataLoader( |
| train_ds, batch_size=cfg.BATCH_SIZE, shuffle=True, |
| num_workers=cfg.NUM_WORKERS, pin_memory=True, drop_last=True, |
| ) |
| val_loader = DataLoader( |
| val_ds, batch_size=cfg.BATCH_SIZE, shuffle=False, |
| num_workers=cfg.NUM_WORKERS, pin_memory=True, |
| ) |
|
|
| |
| print("\nComputing class distribution...") |
| class_pixels = np.zeros(cfg.NUM_CLASSES, dtype=np.float64) |
| for i in range(min(len(train_ds), 200)): |
| _, mask = train_ds[i] |
| if isinstance(mask, torch.Tensor): |
| mask = mask.numpy() |
| for c in range(cfg.NUM_CLASSES): |
| class_pixels[c] += (mask == c).sum() |
|
|
| total = class_pixels.sum() |
| class_freq = class_pixels / total |
| print("Class distribution:") |
| for i, name in enumerate(cfg.CLASS_NAMES): |
| print(f" {name}: {class_freq[i]*100:.2f}% ({int(class_pixels[i]):,} px)") |
|
|
| |
| print(f"\nCreating {cfg.ARCHITECTURE} + {cfg.ENCODER}...") |
| ModelClass = getattr(smp, cfg.ARCHITECTURE) |
| model = ModelClass( |
| encoder_name=cfg.ENCODER, |
| encoder_weights=cfg.ENCODER_WEIGHTS, |
| in_channels=cfg.IN_CHANNELS, |
| classes=cfg.NUM_CLASSES, |
| decoder_attention_type="scse", |
| ) |
| model = model.to(device) |
|
|
| total_params = sum(p.numel() for p in model.parameters()) |
| print(f"Parameters: {total_params:,}") |
|
|
| |
| dice_loss = smp.losses.DiceLoss(mode="multiclass", from_logits=True, smooth=1.0) |
| focal_loss = smp.losses.FocalLoss(mode="multiclass", gamma=cfg.FOCAL_GAMMA) |
|
|
| def criterion(pred, target): |
| return cfg.DICE_WEIGHT * dice_loss(pred, target) + cfg.FOCAL_WEIGHT * focal_loss(pred, target) |
|
|
| |
| optimizer = AdamW([ |
| {"params": model.encoder.parameters(), "lr": cfg.ENCODER_LR}, |
| {"params": model.decoder.parameters(), "lr": cfg.DECODER_LR}, |
| {"params": model.segmentation_head.parameters(), "lr": cfg.DECODER_LR}, |
| ], weight_decay=cfg.WEIGHT_DECAY) |
|
|
| scheduler = CosineAnnealingLR(optimizer, T_max=cfg.NUM_EPOCHS, eta_min=1e-6) |
| scaler = torch.amp.GradScaler(enabled=cfg.USE_AMP) |
|
|
| |
| best_val_dice = 0.0 |
| history = {"train_loss": [], "val_loss": [], "val_dice": [], "val_iou": []} |
|
|
| print(f"\n{'='*60}") |
| print(f"Starting training: {cfg.NUM_EPOCHS} epochs") |
| print(f"{'='*60}\n") |
|
|
| for epoch in range(cfg.NUM_EPOCHS): |
| t_start = time.time() |
|
|
| |
| model.train() |
| train_loss = 0.0 |
|
|
| for batch_idx, (images, masks) in enumerate(train_loader): |
| images = images.to(device) |
| masks = masks.to(device) |
|
|
| optimizer.zero_grad() |
|
|
| with torch.amp.autocast(device_type="cuda", enabled=cfg.USE_AMP): |
| logits = model(images) |
| loss = criterion(logits, masks) |
|
|
| scaler.scale(loss).backward() |
| scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.GRADIENT_CLIP) |
| scaler.step(optimizer) |
| scaler.update() |
|
|
| train_loss += loss.item() |
|
|
| train_loss /= len(train_loader) |
| scheduler.step() |
|
|
| |
| model.eval() |
| val_loss = 0.0 |
| all_ious, all_dices = [], [] |
|
|
| with torch.no_grad(): |
| for images, masks in val_loader: |
| images = images.to(device) |
| masks = masks.to(device) |
|
|
| with torch.amp.autocast(device_type="cuda", enabled=cfg.USE_AMP): |
| logits = model(images) |
| loss = criterion(logits, masks) |
|
|
| val_loss += loss.item() |
| metrics = compute_metrics(logits, masks, cfg.NUM_CLASSES) |
| all_ious.append(metrics["mean_iou"]) |
| all_dices.append(metrics["mean_dice"]) |
|
|
| val_loss /= len(val_loader) |
| val_dice = np.mean(all_dices) |
| val_iou = np.mean(all_ious) |
|
|
| t_elapsed = time.time() - t_start |
| lr_enc = optimizer.param_groups[0]["lr"] |
| lr_dec = optimizer.param_groups[1]["lr"] |
|
|
| print(f"Epoch {epoch+1:3d}/{cfg.NUM_EPOCHS} | " |
| f"train_loss={train_loss:.4f} | val_loss={val_loss:.4f} | " |
| f"val_dice={val_dice:.4f} | val_iou={val_iou:.4f} | " |
| f"lr_enc={lr_enc:.6f} | {t_elapsed:.1f}s") |
|
|
| |
| if (epoch + 1) % 10 == 0: |
| |
| all_per_class = {name: [] for name in cfg.CLASS_NAMES} |
| with torch.no_grad(): |
| for images, masks in val_loader: |
| images, masks = images.to(device), masks.to(device) |
| with torch.amp.autocast(device_type="cuda", enabled=cfg.USE_AMP): |
| logits = model(images) |
| m = compute_metrics(logits, masks, cfg.NUM_CLASSES) |
| for name in cfg.CLASS_NAMES: |
| all_per_class[name].append(m["per_class_dice"][name]) |
| print(" Per-class Dice:") |
| for name in cfg.CLASS_NAMES: |
| print(f" {name:20s}: {np.mean(all_per_class[name]):.4f}") |
|
|
| history["train_loss"].append(train_loss) |
| history["val_loss"].append(val_loss) |
| history["val_dice"].append(val_dice) |
| history["val_iou"].append(val_iou) |
|
|
| |
| if val_dice > best_val_dice: |
| best_val_dice = val_dice |
| save_path = os.path.join(cfg.OUTPUT_DIR, "best_model.pth") |
| torch.save({ |
| "epoch": epoch + 1, |
| "model_state_dict": model.state_dict(), |
| "optimizer_state_dict": optimizer.state_dict(), |
| "val_dice": val_dice, |
| "val_iou": val_iou, |
| "architecture": cfg.ARCHITECTURE, |
| "encoder": cfg.ENCODER, |
| "num_classes": cfg.NUM_CLASSES, |
| "img_size": cfg.IMG_SIZE, |
| "class_names": cfg.CLASS_NAMES, |
| "label_remap": LABEL_REMAP.tolist(), |
| }, save_path) |
| print(f" β Best model saved (dice={val_dice:.4f})") |
|
|
| |
| if (epoch + 1) % 25 == 0: |
| ckpt_path = os.path.join(cfg.OUTPUT_DIR, f"checkpoint_ep{epoch+1}.pth") |
| torch.save({"epoch": epoch+1, "model_state_dict": model.state_dict()}, ckpt_path) |
|
|
| |
| final_path = os.path.join(cfg.OUTPUT_DIR, "final_model.pth") |
| torch.save({ |
| "epoch": cfg.NUM_EPOCHS, |
| "model_state_dict": model.state_dict(), |
| "val_dice": history["val_dice"][-1], |
| "val_iou": history["val_iou"][-1], |
| "architecture": cfg.ARCHITECTURE, |
| "encoder": cfg.ENCODER, |
| "num_classes": cfg.NUM_CLASSES, |
| "img_size": cfg.IMG_SIZE, |
| "class_names": cfg.CLASS_NAMES, |
| "label_remap": LABEL_REMAP.tolist(), |
| "history": history, |
| }, final_path) |
|
|
| with open(os.path.join(cfg.OUTPUT_DIR, "history.json"), "w") as f: |
| json.dump(history, f, indent=2) |
|
|
| print(f"\n{'='*60}") |
| print(f"Training complete! Best val dice: {best_val_dice:.4f}") |
| print(f"Models saved to {cfg.OUTPUT_DIR}/") |
| print(f"{'='*60}") |
|
|
| |
| if cfg.PUSH_TO_HUB and cfg.HUB_MODEL_ID: |
| try: |
| from huggingface_hub import HfApi |
| api = HfApi() |
| api.create_repo(cfg.HUB_MODEL_ID, exist_ok=True) |
| api.upload_folder( |
| folder_path=cfg.OUTPUT_DIR, |
| repo_id=cfg.HUB_MODEL_ID, |
| commit_message=f"Trained model (dice={best_val_dice:.4f})", |
| ) |
| print(f"Pushed to hub: {cfg.HUB_MODEL_ID}") |
| except Exception as e: |
| print(f"Hub push failed: {e}") |
|
|
|
|
| if __name__ == "__main__": |
| train() |
|
|