""" Image dataset for Electrical Outlets. FINAL v5: Direct folder_to_class mapping — no pattern matching, no ambiguity. """ from pathlib import Path import json import logging from collections import defaultdict from typing import Optional, Callable, List, Tuple import torch from torch.utils.data import Dataset from PIL import Image logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO, format="%(message)s") class ElectricalOutletsImageDataset(Dataset): def __init__( self, root: Path, label_mapping_path: Path, split: str = "train", train_ratio: float = 0.7, val_ratio: float = 0.15, seed: int = 42, transform: Optional[Callable] = None, extensions: Tuple[str, ...] = (".jpg", ".jpeg", ".png"), ): self.root = Path(root) self.transform = transform self.extensions = extensions self.split = split with open(label_mapping_path) as f: lm = json.load(f) self.folder_to_class = lm["image"]["folder_to_class"] self.class_to_idx = lm["image"]["class_to_idx"] self.idx_to_issue_type = lm["image"]["idx_to_issue_type"] self.idx_to_severity = lm["image"]["idx_to_severity"] self.num_classes = len(self.class_to_idx) # Build samples list self.samples: List[Tuple[Path, int]] = [] class_counts = defaultdict(int) matched_folders = [] unmatched_folders = [] for folder in sorted(self.root.iterdir()): if not folder.is_dir(): continue # Direct lookup by exact folder name class_key = self.folder_to_class.get(folder.name) if class_key is None: unmatched_folders.append(folder.name) continue cls_idx = self.class_to_idx[class_key] count = 0 for f in folder.iterdir(): if f.suffix.lower() in self.extensions: self.samples.append((f, cls_idx)) count += 1 class_counts[cls_idx] += count matched_folders.append(f" ✓ {folder.name} → {class_key} (idx={cls_idx}): {count} images") # Log results logger.info(f"\n{'='*60}") logger.info(f"Dataset loading from: {self.root}") logger.info(f"{'='*60}") for line in matched_folders: logger.info(line) for uf in unmatched_folders: logger.warning(f" ✗ SKIPPED: '{uf}' (not in folder_to_class)") logger.info(f"\nClass distribution:") for idx in sorted(class_counts.keys()): name = [k for k, v in self.class_to_idx.items() if v == idx][0] logger.info(f" Class {idx} ({name}): {class_counts[idx]} images") logger.info(f"Total: {len(self.samples)} images in {self.num_classes} classes") if len(self.samples) == 0: logger.error("NO SAMPLES FOUND! Check that data_root points to the folder containing your class subfolders.") raise ValueError(f"No images found in {self.root}. Check folder names match label_mapping.json folder_to_class keys.") # Stratified split by_class = defaultdict(list) for i, (_, cls) in enumerate(self.samples): by_class[cls].append(i) train_idx, val_idx, test_idx = [], [], [] for cls in sorted(by_class.keys()): indices = by_class[cls] g = torch.Generator().manual_seed(seed) perm = torch.randperm(len(indices), generator=g).tolist() n_cls = len(indices) n_tr = int(n_cls * train_ratio) n_va = int(n_cls * val_ratio) train_idx.extend([indices[p] for p in perm[:n_tr]]) val_idx.extend([indices[p] for p in perm[n_tr:n_tr + n_va]]) test_idx.extend([indices[p] for p in perm[n_tr + n_va:]]) if split == "train": self.indices = train_idx elif split == "val": self.indices = val_idx else: self.indices = test_idx logger.info(f"Split '{split}': {len(self.indices)} samples\n") def __len__(self): return len(self.indices) def __getitem__(self, idx): i = self.indices[idx] path, cls = self.samples[i] img = Image.open(path).convert("RGB") if self.transform: img = self.transform(img) return img, cls def get_issue_type(self, class_idx: int) -> str: return self.idx_to_issue_type[class_idx] def get_severity(self, class_idx: int) -> str: return self.idx_to_severity[class_idx] def get_image_class_weights(label_mapping_path: Path, root: Path) -> torch.Tensor: """Compute inverse frequency weights for class-weighted loss.""" with open(label_mapping_path) as f: lm = json.load(f) folder_to_class = lm["image"]["folder_to_class"] class_to_idx = lm["image"]["class_to_idx"] num_classes = len(class_to_idx) counts = [0] * num_classes root = Path(root) for folder in root.iterdir(): if not folder.is_dir(): continue class_key = folder_to_class.get(folder.name) if class_key is None: continue cls_idx = class_to_idx[class_key] n = sum(1 for f in folder.iterdir() if f.suffix.lower() in (".jpg", ".jpeg", ".png")) counts[cls_idx] += n total = sum(counts) if total == 0: return torch.ones(num_classes) weights = [total / (num_classes * c) if c else 1.0 for c in counts] return torch.tensor(weights, dtype=torch.float32)