File size: 3,973 Bytes
5666923
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""
Audio dataset for Electrical Outlets. Uses README/file naming and config/label_mapping.json.
PATCHED: rglob for subfolders, torchaudio import at module level, stratified splits.
"""
from pathlib import Path
import json
import logging
from collections import defaultdict
from typing import Optional, Callable, List, Tuple

import torch
import torchaudio
from torch.utils.data import Dataset

logger = logging.getLogger(__name__)


def _label_from_filename(filename: str, file_pattern_to_label: dict) -> str:
    for pattern, label in file_pattern_to_label.items():
        if filename.startswith(pattern) or pattern in filename:
            return label
    return "normal"


class ElectricalOutletsAudioDataset(Dataset):
    """Audio dataset from electrical_outlets_sounds_100 WAVs."""

    def __init__(
        self,
        root: Path,
        label_mapping_path: Path,
        split: str = "train",
        train_ratio: float = 0.7,
        val_ratio: float = 0.15,
        seed: int = 42,
        transform: Optional[Callable] = None,
        target_length_sec: float = 5.0,
        sample_rate: int = 22050,
    ):
        self.root = Path(root)
        self.transform = transform
        self.target_length_sec = target_length_sec
        self.sample_rate = sample_rate
        with open(label_mapping_path) as f:
            lm = json.load(f)
        self.file_pattern_to_label = lm["audio"]["file_pattern_to_label"]
        self.class_to_idx = lm["audio"]["class_to_idx"]
        self.idx_to_label = lm["audio"]["idx_to_label"]
        self.label_to_severity = lm["audio"]["label_to_severity"]
        self.label_to_issue_type = lm["audio"]["label_to_issue_type"]
        self.num_classes = len(self.class_to_idx)

        self.samples: List[Tuple[Path, int]] = []
        # rglob to search subfolders
        for wav in self.root.rglob("*.wav"):
            label = _label_from_filename(wav.stem, self.file_pattern_to_label)
            if label not in self.class_to_idx:
                logger.warning(f"Unmatched audio file: {wav.name} → label '{label}' not in class_to_idx")
                continue
            self.samples.append((wav, self.class_to_idx[label]))

        # Stratified split
        by_class = defaultdict(list)
        for i, (_, cls) in enumerate(self.samples):
            by_class[cls].append(i)

        train_idx, val_idx, test_idx = [], [], []
        for cls in sorted(by_class.keys()):
            indices = by_class[cls]
            g = torch.Generator().manual_seed(seed)
            perm = torch.randperm(len(indices), generator=g).tolist()
            n_cls = len(indices)
            n_tr = int(n_cls * train_ratio)
            n_va = int(n_cls * val_ratio)
            train_idx.extend([indices[p] for p in perm[:n_tr]])
            val_idx.extend([indices[p] for p in perm[n_tr:n_tr + n_va]])
            test_idx.extend([indices[p] for p in perm[n_tr + n_va:]])

        if split == "train":
            self.indices = train_idx
        elif split == "val":
            self.indices = val_idx
        else:
            self.indices = test_idx

    def __len__(self) -> int:
        return len(self.indices)

    def __getitem__(self, idx: int):
        i = self.indices[idx]
        path, cls = self.samples[i]
        waveform, sr = torchaudio.load(str(path))
        if sr != self.sample_rate:
            waveform = torchaudio.functional.resample(waveform, sr, self.sample_rate)
        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)
        target_len = int(self.target_length_sec * self.sample_rate)
        if waveform.shape[1] >= target_len:
            start = (waveform.shape[1] - target_len) // 2
            waveform = waveform[:, start : start + target_len]
        else:
            waveform = torch.nn.functional.pad(waveform, (0, target_len - waveform.shape[1]))
        if self.transform:
            waveform = self.transform(waveform)
        return waveform, cls