""" Audio classifier for Electrical Outlets. Expects spectrogram or waveform; outputs class logits. Severity from label_mapping. Small CNN for 100-sample regime. """ from pathlib import Path from typing import Dict, Any, Optional import json import torch import torch.nn as nn class SpectrogramCNN(nn.Module): """Lightweight CNN on mel spectrogram (n_mels x time).""" def __init__(self, n_mels: int = 64, time_steps: int = 128, num_classes: int = 4): super().__init__() self.conv = nn.Sequential( nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.AdaptiveAvgPool2d(1), ) self.fc = nn.Linear(128, num_classes) def forward(self, x: torch.Tensor) -> torch.Tensor: if x.dim() == 2: x = x.unsqueeze(0).unsqueeze(0) elif x.dim() == 3: x = x.unsqueeze(1) x = self.conv(x) x = x.flatten(1) return self.fc(x) class ElectricalOutletsAudioModel(nn.Module): """Wrapper: optional mel transform then SpectrogramCNN. Severity from mapping.""" def __init__( self, num_classes: int = 4, label_mapping_path: Optional[Path] = None, n_mels: int = 64, time_steps: int = 128, ): super().__init__() self.num_classes = num_classes self.n_mels = n_mels self.time_steps = time_steps self.backbone = SpectrogramCNN(n_mels=n_mels, time_steps=time_steps, num_classes=num_classes) self.idx_to_label = None self.idx_to_issue_type = None self.idx_to_severity = None if label_mapping_path and Path(label_mapping_path).exists(): with open(label_mapping_path) as f: lm = json.load(f) self.idx_to_label = lm["audio"]["idx_to_label"] self.idx_to_issue_type = [lm["audio"]["label_to_issue_type"].get(lbl, "normal") for lbl in lm["audio"]["idx_to_label"]] self.idx_to_severity = [lm["audio"]["label_to_severity"].get(lm["audio"]["idx_to_label"][i], "medium") for i in range(num_classes)] def forward(self, x: torch.Tensor) -> torch.Tensor: return self.backbone(x) def predict_to_schema(self, logits: torch.Tensor) -> Dict[str, Any]: probs = torch.softmax(logits, dim=-1) if logits.dim() == 1: probs = probs.unsqueeze(0) conf, pred = probs.max(dim=-1) pred = pred.item() if pred.numel() == 1 else pred conf = conf.item() if conf.numel() == 1 else conf issue_type = (self.idx_to_issue_type or ["normal"] * self.num_classes)[pred] severity = (self.idx_to_severity or ["medium"] * self.num_classes)[pred] result = "normal" if issue_type == "normal" else "issue_detected" return { "result": result, "issue_type": issue_type, "severity": severity, "confidence": float(conf), "class_idx": int(pred), }