File size: 2,388 Bytes
5666923
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""
Image classifier for Electrical Outlets. EfficientNet-B0 backbone + MLP head.
FINAL v5: 5 classes (no GFCI).
"""
from pathlib import Path
from typing import Dict, Any, Optional
import json
import torch
import torch.nn as nn
from torchvision import models


class ElectricalOutletsImageModel(nn.Module):

    def __init__(
        self,
        num_classes: int = 5,
        label_mapping_path: Optional[Path] = None,
        pretrained: bool = True,
        head_hidden: int = 256,
        head_dropout: float = 0.4,
    ):
        super().__init__()
        self.num_classes = num_classes
        self.backbone = models.efficientnet_b0(
            weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1 if pretrained else None
        )
        in_features = self.backbone.classifier[1].in_features  # 1280
        self.backbone.classifier = nn.Identity()

        self.head = nn.Sequential(
            nn.Dropout(head_dropout),
            nn.Linear(in_features, head_hidden),
            nn.ReLU(),
            nn.Dropout(head_dropout * 0.5),
            nn.Linear(head_hidden, num_classes),
        )

        self.idx_to_issue_type = None
        self.idx_to_severity = None
        if label_mapping_path and Path(label_mapping_path).exists():
            with open(label_mapping_path) as f:
                lm = json.load(f)
            self.idx_to_issue_type = lm["image"]["idx_to_issue_type"]
            self.idx_to_severity = lm["image"]["idx_to_severity"]

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        features = self.backbone(x)
        return self.head(features)

    def predict_to_schema(self, logits: torch.Tensor) -> Dict[str, Any]:
        probs = torch.softmax(logits, dim=-1)
        if logits.dim() == 1:
            probs = probs.unsqueeze(0)
        conf, pred = probs.max(dim=-1)
        pred = pred.item() if pred.numel() == 1 else pred
        conf = conf.item() if conf.numel() == 1 else conf
        issue_type = (self.idx_to_issue_type or ["unknown"] * self.num_classes)[pred]
        severity = (self.idx_to_severity or ["medium"] * self.num_classes)[pred]
        result = "normal" if issue_type == "normal" else "issue_detected"
        return {
            "result": result,
            "issue_type": issue_type,
            "severity": severity,
            "confidence": float(conf),
            "class_idx": int(pred),
        }