| """ |
| Dataset class for E-SCDD solar cell EL images. |
| |
| Design decisions: |
| - Loads raw PNG files (not HF datasets API, which has loading issues with this dataset) |
| - Remaps 30 E-SCDD classes → 5 target classes using ESCDD_LABEL_MAP |
| - Supports both augmented (synthetic) and original images |
| - Handles variable image sizes via resize to consistent input size |
| - Grayscale only: EL images are single-channel |
| |
| The dataset has matched image/mask pairs: |
| el_images_train/ARTS_00001_r4_c2.png ↔ el_masks_train/ARTS_00001_r4_c2.png |
| """ |
|
|
| import os |
| import numpy as np |
| from pathlib import Path |
| from PIL import Image |
| from typing import Optional, Tuple, List, Callable |
| import torch |
| from torch.utils.data import Dataset |
|
|
|
|
| class ESCDDDataset(Dataset): |
| """ |
| E-SCDD dataset loader. |
| |
| Loads grayscale EL images and indexed PNG segmentation masks. |
| Remaps 30 original classes → 5 target classes. |
| """ |
| |
| |
| |
| LABEL_MAP = { |
| 0: 0, |
| 1: 0, |
| 2: 0, |
| 3: 0, |
| 4: 0, |
| 5: 0, |
| 6: 0, |
| 7: 0, |
| 8: 0, |
| 9: 4, |
| 10: 3, |
| 11: 1, |
| 12: 0, |
| 13: 0, |
| 14: 2, |
| 15: 0, |
| 16: 0, |
| 17: 1, |
| 18: 0, |
| 19: 0, |
| 20: 1, |
| 21: 0, |
| 22: 0, |
| 23: 0, |
| 24: 0, |
| 25: 0, |
| 26: 0, |
| 27: 0, |
| 28: 0, |
| 29: 0, |
| } |
| |
| def __init__( |
| self, |
| image_dir: str, |
| mask_dir: str, |
| transform: Optional[Callable] = None, |
| img_size: int = 512, |
| ): |
| """ |
| Args: |
| image_dir: Path to el_images_train/ or el_images_test/ |
| mask_dir: Path to el_masks_train/ or el_masks_test/ |
| transform: albumentations transform (applied to both image and mask) |
| img_size: Target size for resize (square) |
| """ |
| self.image_dir = Path(image_dir) |
| self.mask_dir = Path(mask_dir) |
| self.transform = transform |
| self.img_size = img_size |
| |
| |
| self.image_files = sorted(self.image_dir.glob("*.png")) |
| |
| |
| mask_files = {f.name: f for f in self.mask_dir.glob("*.png")} |
| |
| |
| self.pairs = [] |
| for img_file in self.image_files: |
| if img_file.name in mask_files: |
| self.pairs.append((img_file, mask_files[img_file.name])) |
| |
| if len(self.pairs) == 0: |
| raise ValueError( |
| f"No matching image-mask pairs found in " |
| f"{image_dir} and {mask_dir}" |
| ) |
| |
| print(f"ESCDDDataset: Found {len(self.pairs)} image-mask pairs") |
| |
| def remap_mask(self, mask: np.ndarray) -> np.ndarray: |
| """Remap E-SCDD 30-class mask to 5-class target mask.""" |
| output = np.zeros_like(mask, dtype=np.uint8) |
| for src_label, dst_label in self.LABEL_MAP.items(): |
| output[mask == src_label] = dst_label |
| return output |
| |
| def __len__(self) -> int: |
| return len(self.pairs) |
| |
| def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: |
| img_path, mask_path = self.pairs[idx] |
| |
| |
| img = np.array(Image.open(img_path).convert("L"), dtype=np.float32) |
| |
| |
| mask_pil = Image.open(mask_path) |
| if mask_pil.mode == "P": |
| |
| mask = np.array(mask_pil, dtype=np.uint8) |
| elif mask_pil.mode == "RGB": |
| |
| mask = self._rgb_to_label(np.array(mask_pil)) |
| else: |
| mask = np.array(mask_pil.convert("L"), dtype=np.uint8) |
| |
| |
| mask = self.remap_mask(mask) |
| |
| |
| if self.transform is not None: |
| augmented = self.transform(image=img, mask=mask) |
| img = augmented["image"] |
| mask = augmented["mask"] |
| else: |
| |
| from PIL import Image as PILImage |
| img_pil = PILImage.fromarray(img.astype(np.uint8)) |
| img_pil = img_pil.resize((self.img_size, self.img_size), PILImage.BILINEAR) |
| img = np.array(img_pil, dtype=np.float32) / 255.0 |
| |
| mask_pil = PILImage.fromarray(mask) |
| mask_pil = mask_pil.resize((self.img_size, self.img_size), PILImage.NEAREST) |
| mask = np.array(mask_pil, dtype=np.int64) |
| |
| img = torch.from_numpy(img).unsqueeze(0) |
| mask = torch.from_numpy(mask).long() |
| |
| |
| if isinstance(img, np.ndarray): |
| img = torch.from_numpy(img).float() |
| if img.ndim == 2: |
| img = img.unsqueeze(0) |
| if isinstance(mask, np.ndarray): |
| mask = torch.from_numpy(mask).long() |
| |
| return img, mask |
| |
| def _rgb_to_label(self, rgb_mask: np.ndarray) -> np.ndarray: |
| """Convert RGB color mask to label indices using E-SCDD color codes.""" |
| |
| color_to_label = { |
| (0, 0, 0): 0, |
| (128, 128, 128): 1, |
| (80, 80, 80): 2, |
| (0, 0, 255): 3, |
| (0, 255, 0): 4, |
| (100, 50, 50): 5, |
| (225, 0, 100): 6, |
| (128, 128, 0): 7, |
| (255, 215, 0): 8, |
| (50, 50, 255): 9, |
| (0, 255, 255): 10, |
| (255, 0, 0): 11, |
| (255, 0, 255): 12, |
| (255, 255, 0): 13, |
| (255, 255, 255): 14, |
| (255, 165, 0): 15, |
| (75, 0, 130): 16, |
| (32, 32, 32): 17, |
| (0, 150, 0): 18, |
| (218, 165, 32): 19, |
| (184, 134, 11): 20, |
| (127, 255, 215): 21, |
| (45, 45, 255): 22, |
| (50, 50, 50): 23, |
| (100, 100, 100): 24, |
| (200, 200, 0): 25, |
| (0, 100, 0): 26, |
| (192, 192, 192): 27, |
| (200, 0, 200): 28, |
| (135, 206, 235): 29, |
| } |
| |
| label_mask = np.zeros(rgb_mask.shape[:2], dtype=np.uint8) |
| for color, label in color_to_label.items(): |
| match = np.all(rgb_mask == np.array(color), axis=-1) |
| label_mask[match] = label |
| |
| return label_mask |
|
|
|
|
| def get_train_transforms(img_size: int = 512): |
| """ |
| Training augmentations for EL images. |
| |
| Key augmentations: |
| - RandomRotate90 + Flips: EL defects can appear at any orientation |
| - RandomBrightnessContrast: simulate exposure variation |
| - GaussNoise: simulate sensor noise |
| - ElasticTransform: simulate crack deformation patterns |
| - CLAHE applied separately in preprocessing, not here |
| """ |
| try: |
| import albumentations as A |
| from albumentations.pytorch import ToTensorV2 |
| except ImportError: |
| raise ImportError("albumentations required: pip install albumentations") |
| |
| return A.Compose([ |
| A.Resize(img_size, img_size), |
| A.RandomRotate90(p=0.5), |
| A.HorizontalFlip(p=0.5), |
| A.VerticalFlip(p=0.5), |
| A.RandomBrightnessContrast( |
| brightness_limit=0.2, contrast_limit=0.2, p=0.5 |
| ), |
| A.GaussNoise(p=0.3), |
| A.ElasticTransform(alpha=30, sigma=5, p=0.3), |
| |
| A.Normalize(mean=[0.5], std=[0.5]), |
| ToTensorV2(), |
| ]) |
|
|
|
|
| def get_val_transforms(img_size: int = 512): |
| """Validation transforms: resize + normalize only.""" |
| try: |
| import albumentations as A |
| from albumentations.pytorch import ToTensorV2 |
| except ImportError: |
| raise ImportError("albumentations required: pip install albumentations") |
| |
| return A.Compose([ |
| A.Resize(img_size, img_size), |
| A.Normalize(mean=[0.5], std=[0.5]), |
| ToTensorV2(), |
| ]) |
|
|