File size: 5,644 Bytes
5666923
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""
Image dataset for Electrical Outlets.
FINAL v5: Direct folder_to_class mapping — no pattern matching, no ambiguity.
"""
from pathlib import Path
import json
import logging
from collections import defaultdict
from typing import Optional, Callable, List, Tuple

import torch
from torch.utils.data import Dataset
from PIL import Image

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format="%(message)s")


class ElectricalOutletsImageDataset(Dataset):

    def __init__(
        self,
        root: Path,
        label_mapping_path: Path,
        split: str = "train",
        train_ratio: float = 0.7,
        val_ratio: float = 0.15,
        seed: int = 42,
        transform: Optional[Callable] = None,
        extensions: Tuple[str, ...] = (".jpg", ".jpeg", ".png"),
    ):
        self.root = Path(root)
        self.transform = transform
        self.extensions = extensions
        self.split = split

        with open(label_mapping_path) as f:
            lm = json.load(f)

        self.folder_to_class = lm["image"]["folder_to_class"]
        self.class_to_idx = lm["image"]["class_to_idx"]
        self.idx_to_issue_type = lm["image"]["idx_to_issue_type"]
        self.idx_to_severity = lm["image"]["idx_to_severity"]
        self.num_classes = len(self.class_to_idx)

        # Build samples list
        self.samples: List[Tuple[Path, int]] = []
        class_counts = defaultdict(int)
        matched_folders = []
        unmatched_folders = []

        for folder in sorted(self.root.iterdir()):
            if not folder.is_dir():
                continue
            # Direct lookup by exact folder name
            class_key = self.folder_to_class.get(folder.name)
            if class_key is None:
                unmatched_folders.append(folder.name)
                continue
            cls_idx = self.class_to_idx[class_key]
            count = 0
            for f in folder.iterdir():
                if f.suffix.lower() in self.extensions:
                    self.samples.append((f, cls_idx))
                    count += 1
            class_counts[cls_idx] += count
            matched_folders.append(f"  ✓ {folder.name}{class_key} (idx={cls_idx}): {count} images")

        # Log results
        logger.info(f"\n{'='*60}")
        logger.info(f"Dataset loading from: {self.root}")
        logger.info(f"{'='*60}")
        for line in matched_folders:
            logger.info(line)
        for uf in unmatched_folders:
            logger.warning(f"  ✗ SKIPPED: '{uf}' (not in folder_to_class)")
        logger.info(f"\nClass distribution:")
        for idx in sorted(class_counts.keys()):
            name = [k for k, v in self.class_to_idx.items() if v == idx][0]
            logger.info(f"  Class {idx} ({name}): {class_counts[idx]} images")
        logger.info(f"Total: {len(self.samples)} images in {self.num_classes} classes")

        if len(self.samples) == 0:
            logger.error("NO SAMPLES FOUND! Check that data_root points to the folder containing your class subfolders.")
            raise ValueError(f"No images found in {self.root}. Check folder names match label_mapping.json folder_to_class keys.")

        # Stratified split
        by_class = defaultdict(list)
        for i, (_, cls) in enumerate(self.samples):
            by_class[cls].append(i)

        train_idx, val_idx, test_idx = [], [], []
        for cls in sorted(by_class.keys()):
            indices = by_class[cls]
            g = torch.Generator().manual_seed(seed)
            perm = torch.randperm(len(indices), generator=g).tolist()
            n_cls = len(indices)
            n_tr = int(n_cls * train_ratio)
            n_va = int(n_cls * val_ratio)
            train_idx.extend([indices[p] for p in perm[:n_tr]])
            val_idx.extend([indices[p] for p in perm[n_tr:n_tr + n_va]])
            test_idx.extend([indices[p] for p in perm[n_tr + n_va:]])

        if split == "train":
            self.indices = train_idx
        elif split == "val":
            self.indices = val_idx
        else:
            self.indices = test_idx

        logger.info(f"Split '{split}': {len(self.indices)} samples\n")

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        i = self.indices[idx]
        path, cls = self.samples[i]
        img = Image.open(path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, cls

    def get_issue_type(self, class_idx: int) -> str:
        return self.idx_to_issue_type[class_idx]

    def get_severity(self, class_idx: int) -> str:
        return self.idx_to_severity[class_idx]


def get_image_class_weights(label_mapping_path: Path, root: Path) -> torch.Tensor:
    """Compute inverse frequency weights for class-weighted loss."""
    with open(label_mapping_path) as f:
        lm = json.load(f)
    folder_to_class = lm["image"]["folder_to_class"]
    class_to_idx = lm["image"]["class_to_idx"]
    num_classes = len(class_to_idx)
    counts = [0] * num_classes

    root = Path(root)
    for folder in root.iterdir():
        if not folder.is_dir():
            continue
        class_key = folder_to_class.get(folder.name)
        if class_key is None:
            continue
        cls_idx = class_to_idx[class_key]
        n = sum(1 for f in folder.iterdir() if f.suffix.lower() in (".jpg", ".jpeg", ".png"))
        counts[cls_idx] += n

    total = sum(counts)
    if total == 0:
        return torch.ones(num_classes)
    weights = [total / (num_classes * c) if c else 1.0 for c in counts]
    return torch.tensor(weights, dtype=torch.float32)