""" EL Defect Detection — Training Script for RTX 4060 (8GB VRAM) Model: U-Net++ with EfficientNet-B4 encoder + scSE attention Dataset: E-SCDD (snt-ubix/e-scdd) — 903 images, 512x512 Loss: 0.5 * Dice + 0.5 * Focal (handles severe class imbalance) Classes: 0=background, 1=busbar, 2=crack, 3=dark/inactive, 4=other_defects Usage: pip install torch torchvision segmentation-models-pytorch albumentations \ huggingface-hub scikit-image scipy opencv-python-headless pillow python train.py """ import os import sys import json import time import numpy as np import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingLR from pathlib import Path from PIL import Image import segmentation_models_pytorch as smp import albumentations as A from albumentations.pytorch import ToTensorV2 # ═══════════════════════════════════════════════════════════════ # CONFIGURATION # ═══════════════════════════════════════════════════════════════ class Config: # Data DATA_DIR = "./data" # Will download here OUTPUT_DIR = "./output" # Model — U-Net++ with EfficientNet-B4 is SOTA for thin-crack segmentation # Dense skip connections preserve fine details that plain U-Net misses ARCHITECTURE = "UnetPlusPlus" # UnetPlusPlus > Unet for thin structures ENCODER = "efficientnet-b4" # Best accuracy/size ratio, 20.9M params ENCODER_WEIGHTS = "imagenet" IN_CHANNELS = 1 # EL images are grayscale NUM_CLASSES = 5 # bg, busbar, crack, dark, other_defects # Training — tuned for RTX 4060 (8GB VRAM) IMG_SIZE = 512 # E-SCDD native resolution BATCH_SIZE = 4 # Safe for 8GB with AMP NUM_EPOCHS = 100 ENCODER_LR = 1e-4 # Lower LR for pretrained encoder DECODER_LR = 5e-4 # Higher LR for random decoder WEIGHT_DECAY = 1e-4 USE_AMP = True # Mixed precision — halves VRAM usage NUM_WORKERS = 4 GRADIENT_CLIP = 1.0 # Loss DICE_WEIGHT = 0.5 FOCAL_WEIGHT = 0.5 FOCAL_GAMMA = 2.0 # Hub HUB_MODEL_ID = None # Set to "username/model-name" to push PUSH_TO_HUB = False # Class names CLASS_NAMES = ["background", "busbar", "crack", "dark", "other_defect"] # ═══════════════════════════════════════════════════════════════ # CLASS MAPPING: E-SCDD 30 classes → 5 classes # ═══════════════════════════════════════════════════════════════ # Mask pixel values in E-SCDD are integers 0-29 (Label column in CSV) # We remap to 5 meaningful classes: # 0 = background (all spacing, borders, padding, text, clamp, frame, jbox) # 1 = busbar (label 9) # 2 = crack (label 14=crack, label 10=crack_rbn_edge) # 3 = dark/inactive (label 11=inactive, label 17=dead_cell, label 20=edge_dark) # 4 = other_defect (rings, material, gridline, splice, corrosion, belt_mark, etc.) LABEL_REMAP = np.zeros(30, dtype=np.uint8) # default: everything → 0 (background) # Background features (labels 0-8, 21-24, 29) # Already 0 by default # Busbar LABEL_REMAP[9] = 1 # busbars → busbar # Crack (HIGH IMPORTANCE) LABEL_REMAP[10] = 2 # crack_rbn_edge → crack LABEL_REMAP[14] = 2 # crack → crack # Dark/Inactive (HIGH IMPORTANCE) LABEL_REMAP[11] = 3 # inactive → dark LABEL_REMAP[17] = 3 # dead_cell → dark LABEL_REMAP[20] = 3 # edge_dark → dark # Other defects (MEDIUM IMPORTANCE) LABEL_REMAP[12] = 4 # rings LABEL_REMAP[13] = 4 # material LABEL_REMAP[15] = 4 # gridline defect LABEL_REMAP[16] = 4 # splice LABEL_REMAP[18] = 4 # corrosion_rbn LABEL_REMAP[19] = 4 # belt_mark LABEL_REMAP[25] = 4 # scuff LABEL_REMAP[26] = 4 # corrosion_cell LABEL_REMAP[27] = 4 # brightening LABEL_REMAP[28] = 4 # star # ═══════════════════════════════════════════════════════════════ # DATASET # ═══════════════════════════════════════════════════════════════ class ESCDDDataset(Dataset): """ E-SCDD dataset: 512x512 EL images (RGBA) + grayscale masks (L, values 0-29). """ def __init__(self, img_dir, mask_dir, transform=None): self.img_dir = Path(img_dir) self.mask_dir = Path(mask_dir) self.transform = transform # Match images to masks by filename img_files = {f.stem: f for f in sorted(self.img_dir.glob("*.png"))} mask_files = {f.stem: f for f in sorted(self.mask_dir.glob("*.png"))} self.pairs = [] for stem in img_files: if stem in mask_files: self.pairs.append((img_files[stem], mask_files[stem])) print(f" {img_dir}: {len(self.pairs)} image-mask pairs") def __len__(self): return len(self.pairs) def __getitem__(self, idx): img_path, mask_path = self.pairs[idx] # Load image — RGBA, convert to grayscale img = np.array(Image.open(img_path).convert("L"), dtype=np.float32) # Load mask — grayscale, pixel value = class label (0-29) mask = np.array(Image.open(mask_path), dtype=np.uint8) # Remap 30 → 5 classes using lookup table mask = LABEL_REMAP[np.clip(mask, 0, 29)] # Apply augmentations if self.transform: augmented = self.transform(image=img, mask=mask) img = augmented["image"] # (1, H, W) float tensor mask = augmented["mask"] # (H, W) long tensor else: img = torch.from_numpy(img).unsqueeze(0) / 255.0 mask = torch.from_numpy(mask).long() return img, mask def get_train_transforms(img_size=512): return A.Compose([ A.RandomCrop(img_size, img_size, p=1.0), A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.5), A.RandomRotate90(p=0.5), A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5), A.GaussNoise(std_range=(0.02, 0.1), p=0.3), A.GridDistortion(num_steps=5, distort_limit=0.3, p=0.3), A.Normalize(mean=[0.0], std=[1.0], max_pixel_value=255.0), ToTensorV2(), ]) def get_val_transforms(img_size=512): return A.Compose([ A.CenterCrop(img_size, img_size, p=1.0), A.Normalize(mean=[0.0], std=[1.0], max_pixel_value=255.0), ToTensorV2(), ]) # ═══════════════════════════════════════════════════════════════ # DOWNLOAD DATASET # ═══════════════════════════════════════════════════════════════ def download_dataset(data_dir): """Download E-SCDD from HuggingFace Hub.""" train_img = os.path.join(data_dir, "el_images_train") if os.path.exists(train_img) and len(os.listdir(train_img)) > 100: print("Dataset already downloaded.") return print("Downloading E-SCDD dataset from HuggingFace Hub...") from huggingface_hub import snapshot_download snapshot_download( repo_id="snt-ubix/e-scdd", repo_type="dataset", local_dir=data_dir, ) print(f"Downloaded to {data_dir}") # ═══════════════════════════════════════════════════════════════ # METRICS # ═══════════════════════════════════════════════════════════════ def compute_metrics(pred_logits, target, num_classes=5): """Compute per-class IoU and Dice.""" pred = torch.argmax(pred_logits, dim=1) # (B, H, W) ious, dices = [], [] for c in range(num_classes): pred_c = (pred == c) target_c = (target == c) intersection = (pred_c & target_c).float().sum() union = (pred_c | target_c).float().sum() iou = (intersection + 1e-6) / (union + 1e-6) dice = (2 * intersection + 1e-6) / (pred_c.float().sum() + target_c.float().sum() + 1e-6) ious.append(iou.item()) dices.append(dice.item()) return { "mean_iou": np.mean(ious), "mean_dice": np.mean(dices), "per_class_iou": dict(zip(Config.CLASS_NAMES, ious)), "per_class_dice": dict(zip(Config.CLASS_NAMES, dices)), } # ═══════════════════════════════════════════════════════════════ # TRAINING # ═══════════════════════════════════════════════════════════════ def train(): cfg = Config() os.makedirs(cfg.OUTPUT_DIR, exist_ok=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Device: {device}") if device.type == "cuda": print(f"GPU: {torch.cuda.get_device_name(0)}") print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB") # ── Download data ──────────────────────────────────────── download_dataset(cfg.DATA_DIR) # ── Create datasets ────────────────────────────────────── print("\nLoading datasets...") train_ds = ESCDDDataset( os.path.join(cfg.DATA_DIR, "el_images_train"), os.path.join(cfg.DATA_DIR, "el_masks_train"), transform=get_train_transforms(cfg.IMG_SIZE), ) val_ds = ESCDDDataset( os.path.join(cfg.DATA_DIR, "el_images_val"), os.path.join(cfg.DATA_DIR, "el_masks_val"), transform=get_val_transforms(cfg.IMG_SIZE), ) train_loader = DataLoader( train_ds, batch_size=cfg.BATCH_SIZE, shuffle=True, num_workers=cfg.NUM_WORKERS, pin_memory=True, drop_last=True, ) val_loader = DataLoader( val_ds, batch_size=cfg.BATCH_SIZE, shuffle=False, num_workers=cfg.NUM_WORKERS, pin_memory=True, ) # ── Compute class weights from training data ───────────── print("\nComputing class distribution...") class_pixels = np.zeros(cfg.NUM_CLASSES, dtype=np.float64) for i in range(min(len(train_ds), 200)): # Sample 200 images _, mask = train_ds[i] if isinstance(mask, torch.Tensor): mask = mask.numpy() for c in range(cfg.NUM_CLASSES): class_pixels[c] += (mask == c).sum() total = class_pixels.sum() class_freq = class_pixels / total print("Class distribution:") for i, name in enumerate(cfg.CLASS_NAMES): print(f" {name}: {class_freq[i]*100:.2f}% ({int(class_pixels[i]):,} px)") # ── Create model ───────────────────────────────────────── print(f"\nCreating {cfg.ARCHITECTURE} + {cfg.ENCODER}...") ModelClass = getattr(smp, cfg.ARCHITECTURE) model = ModelClass( encoder_name=cfg.ENCODER, encoder_weights=cfg.ENCODER_WEIGHTS, in_channels=cfg.IN_CHANNELS, classes=cfg.NUM_CLASSES, decoder_attention_type="scse", ) model = model.to(device) total_params = sum(p.numel() for p in model.parameters()) print(f"Parameters: {total_params:,}") # ── Loss: Dice + Focal (handles class imbalance) ───────── dice_loss = smp.losses.DiceLoss(mode="multiclass", from_logits=True, smooth=1.0) focal_loss = smp.losses.FocalLoss(mode="multiclass", gamma=cfg.FOCAL_GAMMA) def criterion(pred, target): return cfg.DICE_WEIGHT * dice_loss(pred, target) + cfg.FOCAL_WEIGHT * focal_loss(pred, target) # ── Optimizer with differential LR ─────────────────────── optimizer = AdamW([ {"params": model.encoder.parameters(), "lr": cfg.ENCODER_LR}, {"params": model.decoder.parameters(), "lr": cfg.DECODER_LR}, {"params": model.segmentation_head.parameters(), "lr": cfg.DECODER_LR}, ], weight_decay=cfg.WEIGHT_DECAY) scheduler = CosineAnnealingLR(optimizer, T_max=cfg.NUM_EPOCHS, eta_min=1e-6) scaler = torch.amp.GradScaler(enabled=cfg.USE_AMP) # ── Training loop ──────────────────────────────────────── best_val_dice = 0.0 history = {"train_loss": [], "val_loss": [], "val_dice": [], "val_iou": []} print(f"\n{'='*60}") print(f"Starting training: {cfg.NUM_EPOCHS} epochs") print(f"{'='*60}\n") for epoch in range(cfg.NUM_EPOCHS): t_start = time.time() # ── Train ──────────────────────────────────────────── model.train() train_loss = 0.0 for batch_idx, (images, masks) in enumerate(train_loader): images = images.to(device) masks = masks.to(device) optimizer.zero_grad() with torch.amp.autocast(device_type="cuda", enabled=cfg.USE_AMP): logits = model(images) loss = criterion(logits, masks) scaler.scale(loss).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.GRADIENT_CLIP) scaler.step(optimizer) scaler.update() train_loss += loss.item() train_loss /= len(train_loader) scheduler.step() # ── Validate ───────────────────────────────────────── model.eval() val_loss = 0.0 all_ious, all_dices = [], [] with torch.no_grad(): for images, masks in val_loader: images = images.to(device) masks = masks.to(device) with torch.amp.autocast(device_type="cuda", enabled=cfg.USE_AMP): logits = model(images) loss = criterion(logits, masks) val_loss += loss.item() metrics = compute_metrics(logits, masks, cfg.NUM_CLASSES) all_ious.append(metrics["mean_iou"]) all_dices.append(metrics["mean_dice"]) val_loss /= len(val_loader) val_dice = np.mean(all_dices) val_iou = np.mean(all_ious) t_elapsed = time.time() - t_start lr_enc = optimizer.param_groups[0]["lr"] lr_dec = optimizer.param_groups[1]["lr"] print(f"Epoch {epoch+1:3d}/{cfg.NUM_EPOCHS} | " f"train_loss={train_loss:.4f} | val_loss={val_loss:.4f} | " f"val_dice={val_dice:.4f} | val_iou={val_iou:.4f} | " f"lr_enc={lr_enc:.6f} | {t_elapsed:.1f}s") # Per-class dice every 10 epochs if (epoch + 1) % 10 == 0: # Run full validation for per-class metrics all_per_class = {name: [] for name in cfg.CLASS_NAMES} with torch.no_grad(): for images, masks in val_loader: images, masks = images.to(device), masks.to(device) with torch.amp.autocast(device_type="cuda", enabled=cfg.USE_AMP): logits = model(images) m = compute_metrics(logits, masks, cfg.NUM_CLASSES) for name in cfg.CLASS_NAMES: all_per_class[name].append(m["per_class_dice"][name]) print(" Per-class Dice:") for name in cfg.CLASS_NAMES: print(f" {name:20s}: {np.mean(all_per_class[name]):.4f}") history["train_loss"].append(train_loss) history["val_loss"].append(val_loss) history["val_dice"].append(val_dice) history["val_iou"].append(val_iou) # ── Save best model ────────────────────────────────── if val_dice > best_val_dice: best_val_dice = val_dice save_path = os.path.join(cfg.OUTPUT_DIR, "best_model.pth") torch.save({ "epoch": epoch + 1, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "val_dice": val_dice, "val_iou": val_iou, "architecture": cfg.ARCHITECTURE, "encoder": cfg.ENCODER, "num_classes": cfg.NUM_CLASSES, "img_size": cfg.IMG_SIZE, "class_names": cfg.CLASS_NAMES, "label_remap": LABEL_REMAP.tolist(), }, save_path) print(f" → Best model saved (dice={val_dice:.4f})") # Periodic checkpoint every 25 epochs if (epoch + 1) % 25 == 0: ckpt_path = os.path.join(cfg.OUTPUT_DIR, f"checkpoint_ep{epoch+1}.pth") torch.save({"epoch": epoch+1, "model_state_dict": model.state_dict()}, ckpt_path) # ── Save final model + history ─────────────────────────── final_path = os.path.join(cfg.OUTPUT_DIR, "final_model.pth") torch.save({ "epoch": cfg.NUM_EPOCHS, "model_state_dict": model.state_dict(), "val_dice": history["val_dice"][-1], "val_iou": history["val_iou"][-1], "architecture": cfg.ARCHITECTURE, "encoder": cfg.ENCODER, "num_classes": cfg.NUM_CLASSES, "img_size": cfg.IMG_SIZE, "class_names": cfg.CLASS_NAMES, "label_remap": LABEL_REMAP.tolist(), "history": history, }, final_path) with open(os.path.join(cfg.OUTPUT_DIR, "history.json"), "w") as f: json.dump(history, f, indent=2) print(f"\n{'='*60}") print(f"Training complete! Best val dice: {best_val_dice:.4f}") print(f"Models saved to {cfg.OUTPUT_DIR}/") print(f"{'='*60}") # ── Push to Hub ────────────────────────────────────────── if cfg.PUSH_TO_HUB and cfg.HUB_MODEL_ID: try: from huggingface_hub import HfApi api = HfApi() api.create_repo(cfg.HUB_MODEL_ID, exist_ok=True) api.upload_folder( folder_path=cfg.OUTPUT_DIR, repo_id=cfg.HUB_MODEL_ID, commit_message=f"Trained model (dice={best_val_dice:.4f})", ) print(f"Pushed to hub: {cfg.HUB_MODEL_ID}") except Exception as e: print(f"Hub push failed: {e}") if __name__ == "__main__": train()