el-defect-training / train.py
nithishbasireddy's picture
Fix: mask dtype β†’ LongTensor for DiceLoss one_hot compatibility
d3f1e7d verified
"""
EL Defect Detection β€” Training Script for RTX 4060 (8GB VRAM)
Model: U-Net++ with EfficientNet-B4 encoder + scSE attention
Dataset: E-SCDD (snt-ubix/e-scdd) β€” 903 images, 512x512
Loss: 0.5 * Dice + 0.5 * Focal (handles severe class imbalance)
Classes: 0=background, 1=busbar, 2=crack, 3=dark/inactive, 4=other_defects
Usage:
pip install torch torchvision segmentation-models-pytorch albumentations \
huggingface-hub scikit-image scipy opencv-python-headless pillow
python train.py
"""
import os
import sys
import json
import time
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from pathlib import Path
from PIL import Image
import segmentation_models_pytorch as smp
import albumentations as A
from albumentations.pytorch import ToTensorV2
# ═══════════════════════════════════════════════════════════════
# CONFIGURATION
# ═══════════════════════════════════════════════════════════════
class Config:
# Data
DATA_DIR = "./data" # Will download here
OUTPUT_DIR = "./output"
# Model β€” U-Net++ with EfficientNet-B4 is SOTA for thin-crack segmentation
# Dense skip connections preserve fine details that plain U-Net misses
ARCHITECTURE = "UnetPlusPlus" # UnetPlusPlus > Unet for thin structures
ENCODER = "efficientnet-b4" # Best accuracy/size ratio, 20.9M params
ENCODER_WEIGHTS = "imagenet"
IN_CHANNELS = 1 # EL images are grayscale
NUM_CLASSES = 5 # bg, busbar, crack, dark, other_defects
# Training β€” tuned for RTX 4060 (8GB VRAM)
IMG_SIZE = 512 # E-SCDD native resolution
BATCH_SIZE = 4 # Safe for 8GB with AMP
NUM_EPOCHS = 100
ENCODER_LR = 1e-4 # Lower LR for pretrained encoder
DECODER_LR = 5e-4 # Higher LR for random decoder
WEIGHT_DECAY = 1e-4
USE_AMP = True # Mixed precision β€” halves VRAM usage
NUM_WORKERS = 0 # 0 for Windows compatibility
GRADIENT_CLIP = 1.0
# Loss
DICE_WEIGHT = 0.5
FOCAL_WEIGHT = 0.5
FOCAL_GAMMA = 2.0
# Hub
HUB_MODEL_ID = None # Set to "username/model-name" to push
PUSH_TO_HUB = False
# Class names
CLASS_NAMES = ["background", "busbar", "crack", "dark", "other_defect"]
# ═══════════════════════════════════════════════════════════════
# CLASS MAPPING: E-SCDD 30 classes β†’ 5 classes
# ═══════════════════════════════════════════════════════════════
LABEL_REMAP = np.zeros(30, dtype=np.uint8)
# Busbar
LABEL_REMAP[9] = 1 # busbars β†’ busbar
# Crack (HIGH IMPORTANCE)
LABEL_REMAP[10] = 2 # crack_rbn_edge β†’ crack
LABEL_REMAP[14] = 2 # crack β†’ crack
# Dark/Inactive (HIGH IMPORTANCE)
LABEL_REMAP[11] = 3 # inactive β†’ dark
LABEL_REMAP[17] = 3 # dead_cell β†’ dark
LABEL_REMAP[20] = 3 # edge_dark β†’ dark
# Other defects (MEDIUM IMPORTANCE)
LABEL_REMAP[12] = 4 # rings
LABEL_REMAP[13] = 4 # material
LABEL_REMAP[15] = 4 # gridline defect
LABEL_REMAP[16] = 4 # splice
LABEL_REMAP[18] = 4 # corrosion_rbn
LABEL_REMAP[19] = 4 # belt_mark
LABEL_REMAP[25] = 4 # scuff
LABEL_REMAP[26] = 4 # corrosion_cell
LABEL_REMAP[27] = 4 # brightening
LABEL_REMAP[28] = 4 # star
# ═══════════════════════════════════════════════════════════════
# DATASET
# ═══════════════════════════════════════════════════════════════
class ESCDDDataset(Dataset):
"""
E-SCDD dataset: 512x512 EL images (RGBA) + grayscale masks (L, values 0-29).
"""
def __init__(self, img_dir, mask_dir, transform=None):
self.img_dir = Path(img_dir)
self.mask_dir = Path(mask_dir)
self.transform = transform
# Match images to masks by filename
img_files = {f.stem: f for f in sorted(self.img_dir.glob("*.png"))}
mask_files = {f.stem: f for f in sorted(self.mask_dir.glob("*.png"))}
self.pairs = []
for stem in img_files:
if stem in mask_files:
self.pairs.append((img_files[stem], mask_files[stem]))
print(f" {img_dir}: {len(self.pairs)} image-mask pairs")
def __len__(self):
return len(self.pairs)
def __getitem__(self, idx):
img_path, mask_path = self.pairs[idx]
# Load image β€” RGBA, convert to grayscale
img = np.array(Image.open(img_path).convert("L"), dtype=np.float32)
# Load mask β€” grayscale, pixel value = class label (0-29)
mask = np.array(Image.open(mask_path), dtype=np.uint8)
# Remap 30 β†’ 5 classes using lookup table
mask = LABEL_REMAP[np.clip(mask, 0, 29)]
# Apply augmentations
if self.transform:
augmented = self.transform(image=img, mask=mask)
img = augmented["image"] # (1, H, W) float tensor
mask = augmented["mask"].long() # (H, W) LongTensor β€” required by DiceLoss one_hot
else:
img = torch.from_numpy(img).unsqueeze(0) / 255.0
mask = torch.from_numpy(mask).long()
return img, mask
def get_train_transforms(img_size=512):
return A.Compose([
A.RandomCrop(img_size, img_size, p=1.0),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
A.RandomRotate90(p=0.5),
A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
A.GaussNoise(std_range=(0.02, 0.1), p=0.3),
A.GridDistortion(num_steps=5, distort_limit=0.3, p=0.3),
A.Normalize(mean=[0.0], std=[1.0], max_pixel_value=255.0),
ToTensorV2(),
])
def get_val_transforms(img_size=512):
return A.Compose([
A.CenterCrop(img_size, img_size, p=1.0),
A.Normalize(mean=[0.0], std=[1.0], max_pixel_value=255.0),
ToTensorV2(),
])
# ═══════════════════════════════════════════════════════════════
# DOWNLOAD DATASET
# ═══════════════════════════════════════════════════════════════
def download_dataset(data_dir):
"""Download E-SCDD from HuggingFace Hub."""
train_img = os.path.join(data_dir, "el_images_train")
if os.path.exists(train_img) and len(os.listdir(train_img)) > 100:
print("Dataset already downloaded.")
return
print("Downloading E-SCDD dataset from HuggingFace Hub...")
from huggingface_hub import snapshot_download
snapshot_download(
repo_id="snt-ubix/e-scdd",
repo_type="dataset",
local_dir=data_dir,
)
print(f"Downloaded to {data_dir}")
# ═══════════════════════════════════════════════════════════════
# METRICS
# ═══════════════════════════════════════════════════════════════
def compute_metrics(pred_logits, target, num_classes=5):
"""Compute per-class IoU and Dice."""
pred = torch.argmax(pred_logits, dim=1) # (B, H, W)
ious, dices = [], []
for c in range(num_classes):
pred_c = (pred == c)
target_c = (target == c)
intersection = (pred_c & target_c).float().sum()
union = (pred_c | target_c).float().sum()
iou = (intersection + 1e-6) / (union + 1e-6)
dice = (2 * intersection + 1e-6) / (pred_c.float().sum() + target_c.float().sum() + 1e-6)
ious.append(iou.item())
dices.append(dice.item())
return {
"mean_iou": np.mean(ious),
"mean_dice": np.mean(dices),
"per_class_iou": dict(zip(Config.CLASS_NAMES, ious)),
"per_class_dice": dict(zip(Config.CLASS_NAMES, dices)),
}
# ═══════════════════════════════════════════════════════════════
# TRAINING
# ═══════════════════════════════════════════════════════════════
def train():
cfg = Config()
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
if device.type == "cuda":
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
# ── Download data ────────────────────────────────────────
download_dataset(cfg.DATA_DIR)
# ── Create datasets ──────────────────────────────────────
print("\nLoading datasets...")
train_ds = ESCDDDataset(
os.path.join(cfg.DATA_DIR, "el_images_train"),
os.path.join(cfg.DATA_DIR, "el_masks_train"),
transform=get_train_transforms(cfg.IMG_SIZE),
)
val_ds = ESCDDDataset(
os.path.join(cfg.DATA_DIR, "el_images_val"),
os.path.join(cfg.DATA_DIR, "el_masks_val"),
transform=get_val_transforms(cfg.IMG_SIZE),
)
train_loader = DataLoader(
train_ds, batch_size=cfg.BATCH_SIZE, shuffle=True,
num_workers=cfg.NUM_WORKERS, pin_memory=True, drop_last=True,
)
val_loader = DataLoader(
val_ds, batch_size=cfg.BATCH_SIZE, shuffle=False,
num_workers=cfg.NUM_WORKERS, pin_memory=True,
)
# ── Compute class weights from training data ─────────────
print("\nComputing class distribution...")
class_pixels = np.zeros(cfg.NUM_CLASSES, dtype=np.float64)
for i in range(min(len(train_ds), 200)): # Sample 200 images
_, mask = train_ds[i]
if isinstance(mask, torch.Tensor):
mask = mask.numpy()
for c in range(cfg.NUM_CLASSES):
class_pixels[c] += (mask == c).sum()
total = class_pixels.sum()
class_freq = class_pixels / total
print("Class distribution:")
for i, name in enumerate(cfg.CLASS_NAMES):
print(f" {name}: {class_freq[i]*100:.2f}% ({int(class_pixels[i]):,} px)")
# ── Create model ─────────────────────────────────────────
print(f"\nCreating {cfg.ARCHITECTURE} + {cfg.ENCODER}...")
ModelClass = getattr(smp, cfg.ARCHITECTURE)
model = ModelClass(
encoder_name=cfg.ENCODER,
encoder_weights=cfg.ENCODER_WEIGHTS,
in_channels=cfg.IN_CHANNELS,
classes=cfg.NUM_CLASSES,
decoder_attention_type="scse",
)
model = model.to(device)
total_params = sum(p.numel() for p in model.parameters())
print(f"Parameters: {total_params:,}")
# ── Loss: Dice + Focal (handles class imbalance) ─────────
dice_loss = smp.losses.DiceLoss(mode="multiclass", from_logits=True, smooth=1.0)
focal_loss = smp.losses.FocalLoss(mode="multiclass", gamma=cfg.FOCAL_GAMMA)
def criterion(pred, target):
return cfg.DICE_WEIGHT * dice_loss(pred, target) + cfg.FOCAL_WEIGHT * focal_loss(pred, target)
# ── Optimizer with differential LR ───────────────────────
optimizer = AdamW([
{"params": model.encoder.parameters(), "lr": cfg.ENCODER_LR},
{"params": model.decoder.parameters(), "lr": cfg.DECODER_LR},
{"params": model.segmentation_head.parameters(), "lr": cfg.DECODER_LR},
], weight_decay=cfg.WEIGHT_DECAY)
scheduler = CosineAnnealingLR(optimizer, T_max=cfg.NUM_EPOCHS, eta_min=1e-6)
scaler = torch.amp.GradScaler(enabled=cfg.USE_AMP)
# ── Training loop ────────────────────────────────────────
best_val_dice = 0.0
history = {"train_loss": [], "val_loss": [], "val_dice": [], "val_iou": []}
print(f"\n{'='*60}")
print(f"Starting training: {cfg.NUM_EPOCHS} epochs")
print(f"{'='*60}\n")
for epoch in range(cfg.NUM_EPOCHS):
t_start = time.time()
# ── Train ────────────────────────────────────────────
model.train()
train_loss = 0.0
for batch_idx, (images, masks) in enumerate(train_loader):
images = images.to(device)
masks = masks.to(device).long() # Ensure LongTensor on GPU
optimizer.zero_grad()
with torch.amp.autocast(device_type="cuda", enabled=cfg.USE_AMP):
logits = model(images)
loss = criterion(logits, masks)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.GRADIENT_CLIP)
scaler.step(optimizer)
scaler.update()
train_loss += loss.item()
train_loss /= len(train_loader)
scheduler.step()
# ── Validate ─────────────────────────────────────────
model.eval()
val_loss = 0.0
all_ious, all_dices = [], []
with torch.no_grad():
for images, masks in val_loader:
images = images.to(device)
masks = masks.to(device).long() # Ensure LongTensor on GPU
with torch.amp.autocast(device_type="cuda", enabled=cfg.USE_AMP):
logits = model(images)
loss = criterion(logits, masks)
val_loss += loss.item()
metrics = compute_metrics(logits, masks, cfg.NUM_CLASSES)
all_ious.append(metrics["mean_iou"])
all_dices.append(metrics["mean_dice"])
val_loss /= len(val_loader)
val_dice = np.mean(all_dices)
val_iou = np.mean(all_ious)
t_elapsed = time.time() - t_start
lr_enc = optimizer.param_groups[0]["lr"]
lr_dec = optimizer.param_groups[1]["lr"]
print(f"Epoch {epoch+1:3d}/{cfg.NUM_EPOCHS} | "
f"train_loss={train_loss:.4f} | val_loss={val_loss:.4f} | "
f"val_dice={val_dice:.4f} | val_iou={val_iou:.4f} | "
f"lr_enc={lr_enc:.6f} | {t_elapsed:.1f}s")
# Per-class dice every 10 epochs
if (epoch + 1) % 10 == 0:
all_per_class = {name: [] for name in cfg.CLASS_NAMES}
with torch.no_grad():
for images, masks in val_loader:
images, masks = images.to(device), masks.to(device).long()
with torch.amp.autocast(device_type="cuda", enabled=cfg.USE_AMP):
logits = model(images)
m = compute_metrics(logits, masks, cfg.NUM_CLASSES)
for name in cfg.CLASS_NAMES:
all_per_class[name].append(m["per_class_dice"][name])
print(" Per-class Dice:")
for name in cfg.CLASS_NAMES:
print(f" {name:20s}: {np.mean(all_per_class[name]):.4f}")
history["train_loss"].append(train_loss)
history["val_loss"].append(val_loss)
history["val_dice"].append(val_dice)
history["val_iou"].append(val_iou)
# ── Save best model ──────────────────────────────────
if val_dice > best_val_dice:
best_val_dice = val_dice
save_path = os.path.join(cfg.OUTPUT_DIR, "best_model.pth")
torch.save({
"epoch": epoch + 1,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"val_dice": val_dice,
"val_iou": val_iou,
"architecture": cfg.ARCHITECTURE,
"encoder": cfg.ENCODER,
"num_classes": cfg.NUM_CLASSES,
"img_size": cfg.IMG_SIZE,
"class_names": cfg.CLASS_NAMES,
"label_remap": LABEL_REMAP.tolist(),
}, save_path)
print(f" β†’ Best model saved (dice={val_dice:.4f})")
# Periodic checkpoint every 25 epochs
if (epoch + 1) % 25 == 0:
ckpt_path = os.path.join(cfg.OUTPUT_DIR, f"checkpoint_ep{epoch+1}.pth")
torch.save({"epoch": epoch+1, "model_state_dict": model.state_dict()}, ckpt_path)
# ── Save final model + history ───────────────────────────
final_path = os.path.join(cfg.OUTPUT_DIR, "final_model.pth")
torch.save({
"epoch": cfg.NUM_EPOCHS,
"model_state_dict": model.state_dict(),
"val_dice": history["val_dice"][-1],
"val_iou": history["val_iou"][-1],
"architecture": cfg.ARCHITECTURE,
"encoder": cfg.ENCODER,
"num_classes": cfg.NUM_CLASSES,
"img_size": cfg.IMG_SIZE,
"class_names": cfg.CLASS_NAMES,
"label_remap": LABEL_REMAP.tolist(),
"history": history,
}, final_path)
with open(os.path.join(cfg.OUTPUT_DIR, "history.json"), "w") as f:
json.dump(history, f, indent=2)
print(f"\n{'='*60}")
print(f"Training complete! Best val dice: {best_val_dice:.4f}")
print(f"Models saved to {cfg.OUTPUT_DIR}/")
print(f"{'='*60}")
# ── Push to Hub ──────────────────────────────────────────
if cfg.PUSH_TO_HUB and cfg.HUB_MODEL_ID:
try:
from huggingface_hub import HfApi
api = HfApi()
api.create_repo(cfg.HUB_MODEL_ID, exist_ok=True)
api.upload_folder(
folder_path=cfg.OUTPUT_DIR,
repo_id=cfg.HUB_MODEL_ID,
commit_message=f"Trained model (dice={best_val_dice:.4f})",
)
print(f"Pushed to hub: {cfg.HUB_MODEL_ID}")
except Exception as e:
print(f"Hub push failed: {e}")
if __name__ == "__main__":
train()