File size: 3,253 Bytes
5666923
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""
Audio classifier for Electrical Outlets. Expects spectrogram or waveform; outputs class logits.
Severity from label_mapping. Small CNN for 100-sample regime.
"""
from pathlib import Path
from typing import Dict, Any, Optional
import json
import torch
import torch.nn as nn


class SpectrogramCNN(nn.Module):
    """Lightweight CNN on mel spectrogram (n_mels x time)."""

    def __init__(self, n_mels: int = 64, time_steps: int = 128, num_classes: int = 4):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
        )
        self.fc = nn.Linear(128, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.dim() == 2:
            x = x.unsqueeze(0).unsqueeze(0)
        elif x.dim() == 3:
            x = x.unsqueeze(1)
        x = self.conv(x)
        x = x.flatten(1)
        return self.fc(x)


class ElectricalOutletsAudioModel(nn.Module):
    """Wrapper: optional mel transform then SpectrogramCNN. Severity from mapping."""

    def __init__(
        self,
        num_classes: int = 4,
        label_mapping_path: Optional[Path] = None,
        n_mels: int = 64,
        time_steps: int = 128,
    ):
        super().__init__()
        self.num_classes = num_classes
        self.n_mels = n_mels
        self.time_steps = time_steps
        self.backbone = SpectrogramCNN(n_mels=n_mels, time_steps=time_steps, num_classes=num_classes)
        self.idx_to_label = None
        self.idx_to_issue_type = None
        self.idx_to_severity = None
        if label_mapping_path and Path(label_mapping_path).exists():
            with open(label_mapping_path) as f:
                lm = json.load(f)
            self.idx_to_label = lm["audio"]["idx_to_label"]
            self.idx_to_issue_type = [lm["audio"]["label_to_issue_type"].get(lbl, "normal") for lbl in lm["audio"]["idx_to_label"]]
            self.idx_to_severity = [lm["audio"]["label_to_severity"].get(lm["audio"]["idx_to_label"][i], "medium") for i in range(num_classes)]

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.backbone(x)

    def predict_to_schema(self, logits: torch.Tensor) -> Dict[str, Any]:
        probs = torch.softmax(logits, dim=-1)
        if logits.dim() == 1:
            probs = probs.unsqueeze(0)
        conf, pred = probs.max(dim=-1)
        pred = pred.item() if pred.numel() == 1 else pred
        conf = conf.item() if conf.numel() == 1 else conf
        issue_type = (self.idx_to_issue_type or ["normal"] * self.num_classes)[pred]
        severity = (self.idx_to_severity or ["medium"] * self.num_classes)[pred]
        result = "normal" if issue_type == "normal" else "issue_detected"
        return {
            "result": result,
            "issue_type": issue_type,
            "severity": severity,
            "confidence": float(conf),
            "class_idx": int(pred),
        }