electrical-outlets-diagnostic / src /data /audio_dataset.py
Asadrizvi64's picture
Electrical Outlets diagnostic pipeline v1.0
5666923
"""
Audio dataset for Electrical Outlets. Uses README/file naming and config/label_mapping.json.
PATCHED: rglob for subfolders, torchaudio import at module level, stratified splits.
"""
from pathlib import Path
import json
import logging
from collections import defaultdict
from typing import Optional, Callable, List, Tuple
import torch
import torchaudio
from torch.utils.data import Dataset
logger = logging.getLogger(__name__)
def _label_from_filename(filename: str, file_pattern_to_label: dict) -> str:
for pattern, label in file_pattern_to_label.items():
if filename.startswith(pattern) or pattern in filename:
return label
return "normal"
class ElectricalOutletsAudioDataset(Dataset):
"""Audio dataset from electrical_outlets_sounds_100 WAVs."""
def __init__(
self,
root: Path,
label_mapping_path: Path,
split: str = "train",
train_ratio: float = 0.7,
val_ratio: float = 0.15,
seed: int = 42,
transform: Optional[Callable] = None,
target_length_sec: float = 5.0,
sample_rate: int = 22050,
):
self.root = Path(root)
self.transform = transform
self.target_length_sec = target_length_sec
self.sample_rate = sample_rate
with open(label_mapping_path) as f:
lm = json.load(f)
self.file_pattern_to_label = lm["audio"]["file_pattern_to_label"]
self.class_to_idx = lm["audio"]["class_to_idx"]
self.idx_to_label = lm["audio"]["idx_to_label"]
self.label_to_severity = lm["audio"]["label_to_severity"]
self.label_to_issue_type = lm["audio"]["label_to_issue_type"]
self.num_classes = len(self.class_to_idx)
self.samples: List[Tuple[Path, int]] = []
# rglob to search subfolders
for wav in self.root.rglob("*.wav"):
label = _label_from_filename(wav.stem, self.file_pattern_to_label)
if label not in self.class_to_idx:
logger.warning(f"Unmatched audio file: {wav.name} → label '{label}' not in class_to_idx")
continue
self.samples.append((wav, self.class_to_idx[label]))
# Stratified split
by_class = defaultdict(list)
for i, (_, cls) in enumerate(self.samples):
by_class[cls].append(i)
train_idx, val_idx, test_idx = [], [], []
for cls in sorted(by_class.keys()):
indices = by_class[cls]
g = torch.Generator().manual_seed(seed)
perm = torch.randperm(len(indices), generator=g).tolist()
n_cls = len(indices)
n_tr = int(n_cls * train_ratio)
n_va = int(n_cls * val_ratio)
train_idx.extend([indices[p] for p in perm[:n_tr]])
val_idx.extend([indices[p] for p in perm[n_tr:n_tr + n_va]])
test_idx.extend([indices[p] for p in perm[n_tr + n_va:]])
if split == "train":
self.indices = train_idx
elif split == "val":
self.indices = val_idx
else:
self.indices = test_idx
def __len__(self) -> int:
return len(self.indices)
def __getitem__(self, idx: int):
i = self.indices[idx]
path, cls = self.samples[i]
waveform, sr = torchaudio.load(str(path))
if sr != self.sample_rate:
waveform = torchaudio.functional.resample(waveform, sr, self.sample_rate)
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
target_len = int(self.target_length_sec * self.sample_rate)
if waveform.shape[1] >= target_len:
start = (waveform.shape[1] - target_len) // 2
waveform = waveform[:, start : start + target_len]
else:
waveform = torch.nn.functional.pad(waveform, (0, target_len - waveform.shape[1]))
if self.transform:
waveform = self.transform(waveform)
return waveform, cls