| | """ |
| | Image classifier for Electrical Outlets. EfficientNet-B0 backbone + MLP head. |
| | FINAL v5: 5 classes (no GFCI). |
| | """ |
| | from pathlib import Path |
| | from typing import Dict, Any, Optional |
| | import json |
| | import torch |
| | import torch.nn as nn |
| | from torchvision import models |
| |
|
| |
|
| | class ElectricalOutletsImageModel(nn.Module): |
| |
|
| | def __init__( |
| | self, |
| | num_classes: int = 5, |
| | label_mapping_path: Optional[Path] = None, |
| | pretrained: bool = True, |
| | head_hidden: int = 256, |
| | head_dropout: float = 0.4, |
| | ): |
| | super().__init__() |
| | self.num_classes = num_classes |
| | self.backbone = models.efficientnet_b0( |
| | weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1 if pretrained else None |
| | ) |
| | in_features = self.backbone.classifier[1].in_features |
| | self.backbone.classifier = nn.Identity() |
| |
|
| | self.head = nn.Sequential( |
| | nn.Dropout(head_dropout), |
| | nn.Linear(in_features, head_hidden), |
| | nn.ReLU(), |
| | nn.Dropout(head_dropout * 0.5), |
| | nn.Linear(head_hidden, num_classes), |
| | ) |
| |
|
| | self.idx_to_issue_type = None |
| | self.idx_to_severity = None |
| | if label_mapping_path and Path(label_mapping_path).exists(): |
| | with open(label_mapping_path) as f: |
| | lm = json.load(f) |
| | self.idx_to_issue_type = lm["image"]["idx_to_issue_type"] |
| | self.idx_to_severity = lm["image"]["idx_to_severity"] |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | features = self.backbone(x) |
| | return self.head(features) |
| |
|
| | def predict_to_schema(self, logits: torch.Tensor) -> Dict[str, Any]: |
| | probs = torch.softmax(logits, dim=-1) |
| | if logits.dim() == 1: |
| | probs = probs.unsqueeze(0) |
| | conf, pred = probs.max(dim=-1) |
| | pred = pred.item() if pred.numel() == 1 else pred |
| | conf = conf.item() if conf.numel() == 1 else conf |
| | issue_type = (self.idx_to_issue_type or ["unknown"] * self.num_classes)[pred] |
| | severity = (self.idx_to_severity or ["medium"] * self.num_classes)[pred] |
| | result = "normal" if issue_type == "normal" else "issue_detected" |
| | return { |
| | "result": result, |
| | "issue_type": issue_type, |
| | "severity": severity, |
| | "confidence": float(conf), |
| | "class_idx": int(pred), |
| | } |
| |
|