""" Dataset class for E-SCDD solar cell EL images. Design decisions: - Loads raw PNG files (not HF datasets API, which has loading issues with this dataset) - Remaps 30 E-SCDD classes → 5 target classes using ESCDD_LABEL_MAP - Supports both augmented (synthetic) and original images - Handles variable image sizes via resize to consistent input size - Grayscale only: EL images are single-channel The dataset has matched image/mask pairs: el_images_train/ARTS_00001_r4_c2.png ↔ el_masks_train/ARTS_00001_r4_c2.png """ import os import numpy as np from pathlib import Path from PIL import Image from typing import Optional, Tuple, List, Callable import torch from torch.utils.data import Dataset class ESCDDDataset(Dataset): """ E-SCDD dataset loader. Loads grayscale EL images and indexed PNG segmentation masks. Remaps 30 original classes → 5 target classes. """ # E-SCDD → target class mapping # 0=background, 1=dark, 2=crack, 3=cross, 4=busbar LABEL_MAP = { 0: 0, # bckgnd → background 1: 0, # sp multi → background 2: 0, # sp mono → background 3: 0, # sp dogbone → background 4: 0, # ribbons → background (feature, not defect) 5: 0, # border → background 6: 0, # text → background 7: 0, # padding → background 8: 0, # clamp → background 9: 4, # busbars → busbar 10: 3, # crack rbn edge → cross 11: 1, # inactive → dark 12: 0, # rings → background (too rare) 13: 0, # material → background 14: 2, # crack → crack 15: 0, # gridline → background 16: 0, # splice → background 17: 1, # dead cell → dark (similar appearance) 18: 0, # corrosion rbn → background 19: 0, # belt mark → background 20: 1, # edge dark → dark (similar appearance) 21: 0, # frame edge → background 22: 0, # jbox → background 23: 0, # meas artifact → background 24: 0, # sp mono halfcut → background 25: 0, # scuff → background 26: 0, # corrosion cell → background 27: 0, # brightening → background 28: 0, # star → background 29: 0, # sp multi halfcut → background } def __init__( self, image_dir: str, mask_dir: str, transform: Optional[Callable] = None, img_size: int = 512, ): """ Args: image_dir: Path to el_images_train/ or el_images_test/ mask_dir: Path to el_masks_train/ or el_masks_test/ transform: albumentations transform (applied to both image and mask) img_size: Target size for resize (square) """ self.image_dir = Path(image_dir) self.mask_dir = Path(mask_dir) self.transform = transform self.img_size = img_size # Find matching image-mask pairs self.image_files = sorted(self.image_dir.glob("*.png")) # Build mask lookup for fast matching mask_files = {f.name: f for f in self.mask_dir.glob("*.png")} # Only keep images that have matching masks self.pairs = [] for img_file in self.image_files: if img_file.name in mask_files: self.pairs.append((img_file, mask_files[img_file.name])) if len(self.pairs) == 0: raise ValueError( f"No matching image-mask pairs found in " f"{image_dir} and {mask_dir}" ) print(f"ESCDDDataset: Found {len(self.pairs)} image-mask pairs") def remap_mask(self, mask: np.ndarray) -> np.ndarray: """Remap E-SCDD 30-class mask to 5-class target mask.""" output = np.zeros_like(mask, dtype=np.uint8) for src_label, dst_label in self.LABEL_MAP.items(): output[mask == src_label] = dst_label return output def __len__(self) -> int: return len(self.pairs) def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: img_path, mask_path = self.pairs[idx] # Load grayscale image img = np.array(Image.open(img_path).convert("L"), dtype=np.float32) # Load indexed mask (palette or grayscale) mask_pil = Image.open(mask_path) if mask_pil.mode == "P": # Palette image: convert to array using palette indices mask = np.array(mask_pil, dtype=np.uint8) elif mask_pil.mode == "RGB": # RGB mask: need to map colors to labels mask = self._rgb_to_label(np.array(mask_pil)) else: mask = np.array(mask_pil.convert("L"), dtype=np.uint8) # Remap to 5 classes mask = self.remap_mask(mask) # Apply augmentations if self.transform is not None: augmented = self.transform(image=img, mask=mask) img = augmented["image"] mask = augmented["mask"] else: # Default: resize and normalize from PIL import Image as PILImage img_pil = PILImage.fromarray(img.astype(np.uint8)) img_pil = img_pil.resize((self.img_size, self.img_size), PILImage.BILINEAR) img = np.array(img_pil, dtype=np.float32) / 255.0 mask_pil = PILImage.fromarray(mask) mask_pil = mask_pil.resize((self.img_size, self.img_size), PILImage.NEAREST) mask = np.array(mask_pil, dtype=np.int64) img = torch.from_numpy(img).unsqueeze(0) # (1, H, W) mask = torch.from_numpy(mask).long() # Ensure correct types if isinstance(img, np.ndarray): img = torch.from_numpy(img).float() if img.ndim == 2: img = img.unsqueeze(0) if isinstance(mask, np.ndarray): mask = torch.from_numpy(mask).long() return img, mask def _rgb_to_label(self, rgb_mask: np.ndarray) -> np.ndarray: """Convert RGB color mask to label indices using E-SCDD color codes.""" # E-SCDD color → label mapping color_to_label = { (0, 0, 0): 0, (128, 128, 128): 1, (80, 80, 80): 2, (0, 0, 255): 3, (0, 255, 0): 4, (100, 50, 50): 5, (225, 0, 100): 6, (128, 128, 0): 7, (255, 215, 0): 8, (50, 50, 255): 9, (0, 255, 255): 10, (255, 0, 0): 11, (255, 0, 255): 12, (255, 255, 0): 13, (255, 255, 255): 14, (255, 165, 0): 15, (75, 0, 130): 16, (32, 32, 32): 17, (0, 150, 0): 18, (218, 165, 32): 19, (184, 134, 11): 20, (127, 255, 215): 21, (45, 45, 255): 22, (50, 50, 50): 23, (100, 100, 100): 24, (200, 200, 0): 25, (0, 100, 0): 26, (192, 192, 192): 27, (200, 0, 200): 28, (135, 206, 235): 29, } label_mask = np.zeros(rgb_mask.shape[:2], dtype=np.uint8) for color, label in color_to_label.items(): match = np.all(rgb_mask == np.array(color), axis=-1) label_mask[match] = label return label_mask def get_train_transforms(img_size: int = 512): """ Training augmentations for EL images. Key augmentations: - RandomRotate90 + Flips: EL defects can appear at any orientation - RandomBrightnessContrast: simulate exposure variation - GaussNoise: simulate sensor noise - ElasticTransform: simulate crack deformation patterns - CLAHE applied separately in preprocessing, not here """ try: import albumentations as A from albumentations.pytorch import ToTensorV2 except ImportError: raise ImportError("albumentations required: pip install albumentations") return A.Compose([ A.Resize(img_size, img_size), A.RandomRotate90(p=0.5), A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.5), A.RandomBrightnessContrast( brightness_limit=0.2, contrast_limit=0.2, p=0.5 ), A.GaussNoise(p=0.3), A.ElasticTransform(alpha=30, sigma=5, p=0.3), # Normalize to [0, 1] for single channel A.Normalize(mean=[0.5], std=[0.5]), ToTensorV2(), ]) def get_val_transforms(img_size: int = 512): """Validation transforms: resize + normalize only.""" try: import albumentations as A from albumentations.pytorch import ToTensorV2 except ImportError: raise ImportError("albumentations required: pip install albumentations") return A.Compose([ A.Resize(img_size, img_size), A.Normalize(mean=[0.5], std=[0.5]), ToTensorV2(), ])