""" Audio dataset for Electrical Outlets. Uses README/file naming and config/label_mapping.json. PATCHED: rglob for subfolders, torchaudio import at module level, stratified splits. """ from pathlib import Path import json import logging from collections import defaultdict from typing import Optional, Callable, List, Tuple import torch import torchaudio from torch.utils.data import Dataset logger = logging.getLogger(__name__) def _label_from_filename(filename: str, file_pattern_to_label: dict) -> str: for pattern, label in file_pattern_to_label.items(): if filename.startswith(pattern) or pattern in filename: return label return "normal" class ElectricalOutletsAudioDataset(Dataset): """Audio dataset from electrical_outlets_sounds_100 WAVs.""" def __init__( self, root: Path, label_mapping_path: Path, split: str = "train", train_ratio: float = 0.7, val_ratio: float = 0.15, seed: int = 42, transform: Optional[Callable] = None, target_length_sec: float = 5.0, sample_rate: int = 22050, ): self.root = Path(root) self.transform = transform self.target_length_sec = target_length_sec self.sample_rate = sample_rate with open(label_mapping_path) as f: lm = json.load(f) self.file_pattern_to_label = lm["audio"]["file_pattern_to_label"] self.class_to_idx = lm["audio"]["class_to_idx"] self.idx_to_label = lm["audio"]["idx_to_label"] self.label_to_severity = lm["audio"]["label_to_severity"] self.label_to_issue_type = lm["audio"]["label_to_issue_type"] self.num_classes = len(self.class_to_idx) self.samples: List[Tuple[Path, int]] = [] # rglob to search subfolders for wav in self.root.rglob("*.wav"): label = _label_from_filename(wav.stem, self.file_pattern_to_label) if label not in self.class_to_idx: logger.warning(f"Unmatched audio file: {wav.name} → label '{label}' not in class_to_idx") continue self.samples.append((wav, self.class_to_idx[label])) # Stratified split by_class = defaultdict(list) for i, (_, cls) in enumerate(self.samples): by_class[cls].append(i) train_idx, val_idx, test_idx = [], [], [] for cls in sorted(by_class.keys()): indices = by_class[cls] g = torch.Generator().manual_seed(seed) perm = torch.randperm(len(indices), generator=g).tolist() n_cls = len(indices) n_tr = int(n_cls * train_ratio) n_va = int(n_cls * val_ratio) train_idx.extend([indices[p] for p in perm[:n_tr]]) val_idx.extend([indices[p] for p in perm[n_tr:n_tr + n_va]]) test_idx.extend([indices[p] for p in perm[n_tr + n_va:]]) if split == "train": self.indices = train_idx elif split == "val": self.indices = val_idx else: self.indices = test_idx def __len__(self) -> int: return len(self.indices) def __getitem__(self, idx: int): i = self.indices[idx] path, cls = self.samples[i] waveform, sr = torchaudio.load(str(path)) if sr != self.sample_rate: waveform = torchaudio.functional.resample(waveform, sr, self.sample_rate) if waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) target_len = int(self.target_length_sec * self.sample_rate) if waveform.shape[1] >= target_len: start = (waveform.shape[1] - target_len) // 2 waveform = waveform[:, start : start + target_len] else: waveform = torch.nn.functional.pad(waveform, (0, target_len - waveform.shape[1])) if self.transform: waveform = self.transform(waveform) return waveform, cls