el-defect-detection / src /datasets /el_dataset.py
nithishbasireddy's picture
Upload src/datasets/el_dataset.py with huggingface_hub
ec6198e verified
"""
Dataset class for E-SCDD solar cell EL images.
Design decisions:
- Loads raw PNG files (not HF datasets API, which has loading issues with this dataset)
- Remaps 30 E-SCDD classes → 5 target classes using ESCDD_LABEL_MAP
- Supports both augmented (synthetic) and original images
- Handles variable image sizes via resize to consistent input size
- Grayscale only: EL images are single-channel
The dataset has matched image/mask pairs:
el_images_train/ARTS_00001_r4_c2.png ↔ el_masks_train/ARTS_00001_r4_c2.png
"""
import os
import numpy as np
from pathlib import Path
from PIL import Image
from typing import Optional, Tuple, List, Callable
import torch
from torch.utils.data import Dataset
class ESCDDDataset(Dataset):
"""
E-SCDD dataset loader.
Loads grayscale EL images and indexed PNG segmentation masks.
Remaps 30 original classes → 5 target classes.
"""
# E-SCDD → target class mapping
# 0=background, 1=dark, 2=crack, 3=cross, 4=busbar
LABEL_MAP = {
0: 0, # bckgnd → background
1: 0, # sp multi → background
2: 0, # sp mono → background
3: 0, # sp dogbone → background
4: 0, # ribbons → background (feature, not defect)
5: 0, # border → background
6: 0, # text → background
7: 0, # padding → background
8: 0, # clamp → background
9: 4, # busbars → busbar
10: 3, # crack rbn edge → cross
11: 1, # inactive → dark
12: 0, # rings → background (too rare)
13: 0, # material → background
14: 2, # crack → crack
15: 0, # gridline → background
16: 0, # splice → background
17: 1, # dead cell → dark (similar appearance)
18: 0, # corrosion rbn → background
19: 0, # belt mark → background
20: 1, # edge dark → dark (similar appearance)
21: 0, # frame edge → background
22: 0, # jbox → background
23: 0, # meas artifact → background
24: 0, # sp mono halfcut → background
25: 0, # scuff → background
26: 0, # corrosion cell → background
27: 0, # brightening → background
28: 0, # star → background
29: 0, # sp multi halfcut → background
}
def __init__(
self,
image_dir: str,
mask_dir: str,
transform: Optional[Callable] = None,
img_size: int = 512,
):
"""
Args:
image_dir: Path to el_images_train/ or el_images_test/
mask_dir: Path to el_masks_train/ or el_masks_test/
transform: albumentations transform (applied to both image and mask)
img_size: Target size for resize (square)
"""
self.image_dir = Path(image_dir)
self.mask_dir = Path(mask_dir)
self.transform = transform
self.img_size = img_size
# Find matching image-mask pairs
self.image_files = sorted(self.image_dir.glob("*.png"))
# Build mask lookup for fast matching
mask_files = {f.name: f for f in self.mask_dir.glob("*.png")}
# Only keep images that have matching masks
self.pairs = []
for img_file in self.image_files:
if img_file.name in mask_files:
self.pairs.append((img_file, mask_files[img_file.name]))
if len(self.pairs) == 0:
raise ValueError(
f"No matching image-mask pairs found in "
f"{image_dir} and {mask_dir}"
)
print(f"ESCDDDataset: Found {len(self.pairs)} image-mask pairs")
def remap_mask(self, mask: np.ndarray) -> np.ndarray:
"""Remap E-SCDD 30-class mask to 5-class target mask."""
output = np.zeros_like(mask, dtype=np.uint8)
for src_label, dst_label in self.LABEL_MAP.items():
output[mask == src_label] = dst_label
return output
def __len__(self) -> int:
return len(self.pairs)
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
img_path, mask_path = self.pairs[idx]
# Load grayscale image
img = np.array(Image.open(img_path).convert("L"), dtype=np.float32)
# Load indexed mask (palette or grayscale)
mask_pil = Image.open(mask_path)
if mask_pil.mode == "P":
# Palette image: convert to array using palette indices
mask = np.array(mask_pil, dtype=np.uint8)
elif mask_pil.mode == "RGB":
# RGB mask: need to map colors to labels
mask = self._rgb_to_label(np.array(mask_pil))
else:
mask = np.array(mask_pil.convert("L"), dtype=np.uint8)
# Remap to 5 classes
mask = self.remap_mask(mask)
# Apply augmentations
if self.transform is not None:
augmented = self.transform(image=img, mask=mask)
img = augmented["image"]
mask = augmented["mask"]
else:
# Default: resize and normalize
from PIL import Image as PILImage
img_pil = PILImage.fromarray(img.astype(np.uint8))
img_pil = img_pil.resize((self.img_size, self.img_size), PILImage.BILINEAR)
img = np.array(img_pil, dtype=np.float32) / 255.0
mask_pil = PILImage.fromarray(mask)
mask_pil = mask_pil.resize((self.img_size, self.img_size), PILImage.NEAREST)
mask = np.array(mask_pil, dtype=np.int64)
img = torch.from_numpy(img).unsqueeze(0) # (1, H, W)
mask = torch.from_numpy(mask).long()
# Ensure correct types
if isinstance(img, np.ndarray):
img = torch.from_numpy(img).float()
if img.ndim == 2:
img = img.unsqueeze(0)
if isinstance(mask, np.ndarray):
mask = torch.from_numpy(mask).long()
return img, mask
def _rgb_to_label(self, rgb_mask: np.ndarray) -> np.ndarray:
"""Convert RGB color mask to label indices using E-SCDD color codes."""
# E-SCDD color → label mapping
color_to_label = {
(0, 0, 0): 0,
(128, 128, 128): 1,
(80, 80, 80): 2,
(0, 0, 255): 3,
(0, 255, 0): 4,
(100, 50, 50): 5,
(225, 0, 100): 6,
(128, 128, 0): 7,
(255, 215, 0): 8,
(50, 50, 255): 9,
(0, 255, 255): 10,
(255, 0, 0): 11,
(255, 0, 255): 12,
(255, 255, 0): 13,
(255, 255, 255): 14,
(255, 165, 0): 15,
(75, 0, 130): 16,
(32, 32, 32): 17,
(0, 150, 0): 18,
(218, 165, 32): 19,
(184, 134, 11): 20,
(127, 255, 215): 21,
(45, 45, 255): 22,
(50, 50, 50): 23,
(100, 100, 100): 24,
(200, 200, 0): 25,
(0, 100, 0): 26,
(192, 192, 192): 27,
(200, 0, 200): 28,
(135, 206, 235): 29,
}
label_mask = np.zeros(rgb_mask.shape[:2], dtype=np.uint8)
for color, label in color_to_label.items():
match = np.all(rgb_mask == np.array(color), axis=-1)
label_mask[match] = label
return label_mask
def get_train_transforms(img_size: int = 512):
"""
Training augmentations for EL images.
Key augmentations:
- RandomRotate90 + Flips: EL defects can appear at any orientation
- RandomBrightnessContrast: simulate exposure variation
- GaussNoise: simulate sensor noise
- ElasticTransform: simulate crack deformation patterns
- CLAHE applied separately in preprocessing, not here
"""
try:
import albumentations as A
from albumentations.pytorch import ToTensorV2
except ImportError:
raise ImportError("albumentations required: pip install albumentations")
return A.Compose([
A.Resize(img_size, img_size),
A.RandomRotate90(p=0.5),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
A.RandomBrightnessContrast(
brightness_limit=0.2, contrast_limit=0.2, p=0.5
),
A.GaussNoise(p=0.3),
A.ElasticTransform(alpha=30, sigma=5, p=0.3),
# Normalize to [0, 1] for single channel
A.Normalize(mean=[0.5], std=[0.5]),
ToTensorV2(),
])
def get_val_transforms(img_size: int = 512):
"""Validation transforms: resize + normalize only."""
try:
import albumentations as A
from albumentations.pytorch import ToTensorV2
except ImportError:
raise ImportError("albumentations required: pip install albumentations")
return A.Compose([
A.Resize(img_size, img_size),
A.Normalize(mean=[0.5], std=[0.5]),
ToTensorV2(),
])