| """ |
| Audio classifier for Electrical Outlets. Expects spectrogram or waveform; outputs class logits. |
| Severity from label_mapping. Small CNN for 100-sample regime. |
| """ |
| from pathlib import Path |
| from typing import Dict, Any, Optional |
| import json |
| import torch |
| import torch.nn as nn |
|
|
|
|
| class SpectrogramCNN(nn.Module): |
| """Lightweight CNN on mel spectrogram (n_mels x time).""" |
|
|
| def __init__(self, n_mels: int = 64, time_steps: int = 128, num_classes: int = 4): |
| super().__init__() |
| self.conv = nn.Sequential( |
| nn.Conv2d(1, 32, 3, padding=1), |
| nn.BatchNorm2d(32), |
| nn.ReLU(), |
| nn.MaxPool2d(2), |
| nn.Conv2d(32, 64, 3, padding=1), |
| nn.BatchNorm2d(64), |
| nn.ReLU(), |
| nn.MaxPool2d(2), |
| nn.Conv2d(64, 128, 3, padding=1), |
| nn.BatchNorm2d(128), |
| nn.ReLU(), |
| nn.AdaptiveAvgPool2d(1), |
| ) |
| self.fc = nn.Linear(128, num_classes) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| if x.dim() == 2: |
| x = x.unsqueeze(0).unsqueeze(0) |
| elif x.dim() == 3: |
| x = x.unsqueeze(1) |
| x = self.conv(x) |
| x = x.flatten(1) |
| return self.fc(x) |
|
|
|
|
| class ElectricalOutletsAudioModel(nn.Module): |
| """Wrapper: optional mel transform then SpectrogramCNN. Severity from mapping.""" |
|
|
| def __init__( |
| self, |
| num_classes: int = 4, |
| label_mapping_path: Optional[Path] = None, |
| n_mels: int = 64, |
| time_steps: int = 128, |
| ): |
| super().__init__() |
| self.num_classes = num_classes |
| self.n_mels = n_mels |
| self.time_steps = time_steps |
| self.backbone = SpectrogramCNN(n_mels=n_mels, time_steps=time_steps, num_classes=num_classes) |
| self.idx_to_label = None |
| self.idx_to_issue_type = None |
| self.idx_to_severity = None |
| if label_mapping_path and Path(label_mapping_path).exists(): |
| with open(label_mapping_path) as f: |
| lm = json.load(f) |
| self.idx_to_label = lm["audio"]["idx_to_label"] |
| self.idx_to_issue_type = [lm["audio"]["label_to_issue_type"].get(lbl, "normal") for lbl in lm["audio"]["idx_to_label"]] |
| self.idx_to_severity = [lm["audio"]["label_to_severity"].get(lm["audio"]["idx_to_label"][i], "medium") for i in range(num_classes)] |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.backbone(x) |
|
|
| def predict_to_schema(self, logits: torch.Tensor) -> Dict[str, Any]: |
| probs = torch.softmax(logits, dim=-1) |
| if logits.dim() == 1: |
| probs = probs.unsqueeze(0) |
| conf, pred = probs.max(dim=-1) |
| pred = pred.item() if pred.numel() == 1 else pred |
| conf = conf.item() if conf.numel() == 1 else conf |
| issue_type = (self.idx_to_issue_type or ["normal"] * self.num_classes)[pred] |
| severity = (self.idx_to_severity or ["medium"] * self.num_classes)[pred] |
| result = "normal" if issue_type == "normal" else "issue_detected" |
| return { |
| "result": result, |
| "issue_type": issue_type, |
| "severity": severity, |
| "confidence": float(conf), |
| "class_idx": int(pred), |
| } |
|
|