| """ |
| Audio dataset for Electrical Outlets. Uses README/file naming and config/label_mapping.json. |
| PATCHED: rglob for subfolders, torchaudio import at module level, stratified splits. |
| """ |
| from pathlib import Path |
| import json |
| import logging |
| from collections import defaultdict |
| from typing import Optional, Callable, List, Tuple |
|
|
| import torch |
| import torchaudio |
| from torch.utils.data import Dataset |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def _label_from_filename(filename: str, file_pattern_to_label: dict) -> str: |
| for pattern, label in file_pattern_to_label.items(): |
| if filename.startswith(pattern) or pattern in filename: |
| return label |
| return "normal" |
|
|
|
|
| class ElectricalOutletsAudioDataset(Dataset): |
| """Audio dataset from electrical_outlets_sounds_100 WAVs.""" |
|
|
| def __init__( |
| self, |
| root: Path, |
| label_mapping_path: Path, |
| split: str = "train", |
| train_ratio: float = 0.7, |
| val_ratio: float = 0.15, |
| seed: int = 42, |
| transform: Optional[Callable] = None, |
| target_length_sec: float = 5.0, |
| sample_rate: int = 22050, |
| ): |
| self.root = Path(root) |
| self.transform = transform |
| self.target_length_sec = target_length_sec |
| self.sample_rate = sample_rate |
| with open(label_mapping_path) as f: |
| lm = json.load(f) |
| self.file_pattern_to_label = lm["audio"]["file_pattern_to_label"] |
| self.class_to_idx = lm["audio"]["class_to_idx"] |
| self.idx_to_label = lm["audio"]["idx_to_label"] |
| self.label_to_severity = lm["audio"]["label_to_severity"] |
| self.label_to_issue_type = lm["audio"]["label_to_issue_type"] |
| self.num_classes = len(self.class_to_idx) |
|
|
| self.samples: List[Tuple[Path, int]] = [] |
| |
| for wav in self.root.rglob("*.wav"): |
| label = _label_from_filename(wav.stem, self.file_pattern_to_label) |
| if label not in self.class_to_idx: |
| logger.warning(f"Unmatched audio file: {wav.name} → label '{label}' not in class_to_idx") |
| continue |
| self.samples.append((wav, self.class_to_idx[label])) |
|
|
| |
| by_class = defaultdict(list) |
| for i, (_, cls) in enumerate(self.samples): |
| by_class[cls].append(i) |
|
|
| train_idx, val_idx, test_idx = [], [], [] |
| for cls in sorted(by_class.keys()): |
| indices = by_class[cls] |
| g = torch.Generator().manual_seed(seed) |
| perm = torch.randperm(len(indices), generator=g).tolist() |
| n_cls = len(indices) |
| n_tr = int(n_cls * train_ratio) |
| n_va = int(n_cls * val_ratio) |
| train_idx.extend([indices[p] for p in perm[:n_tr]]) |
| val_idx.extend([indices[p] for p in perm[n_tr:n_tr + n_va]]) |
| test_idx.extend([indices[p] for p in perm[n_tr + n_va:]]) |
|
|
| if split == "train": |
| self.indices = train_idx |
| elif split == "val": |
| self.indices = val_idx |
| else: |
| self.indices = test_idx |
|
|
| def __len__(self) -> int: |
| return len(self.indices) |
|
|
| def __getitem__(self, idx: int): |
| i = self.indices[idx] |
| path, cls = self.samples[i] |
| waveform, sr = torchaudio.load(str(path)) |
| if sr != self.sample_rate: |
| waveform = torchaudio.functional.resample(waveform, sr, self.sample_rate) |
| if waveform.shape[0] > 1: |
| waveform = waveform.mean(dim=0, keepdim=True) |
| target_len = int(self.target_length_sec * self.sample_rate) |
| if waveform.shape[1] >= target_len: |
| start = (waveform.shape[1] - target_len) // 2 |
| waveform = waveform[:, start : start + target_len] |
| else: |
| waveform = torch.nn.functional.pad(waveform, (0, target_len - waveform.shape[1])) |
| if self.transform: |
| waveform = self.transform(waveform) |
| return waveform, cls |
|
|