Asadrizvi64 commited on
Commit
5666923
Β·
1 Parent(s): 352967a

Electrical Outlets diagnostic pipeline v1.0

Browse files
.gitignore CHANGED
@@ -1,10 +1,38 @@
1
- venv/
2
- __pycache__/
3
- *.pt
4
- *.pdf
5
- notebooks/
6
- ELECTRICAL*
7
- electrical_outlets_sounds_100/
8
- *.wav
9
- *.jpg
10
- *.png
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ELECTRICAL OUTLETS-20260106T153508Z-3-001/
2
+ electrical_outlets_sounds_100/
3
+ 111/
4
+
5
+ # Model weights (upload separately via LFS)
6
+ weights/*.pt
7
+
8
+ # Binary files
9
+ *.pdf
10
+ *.jpg
11
+ *.jpeg
12
+ *.png
13
+ *.wav
14
+ *.mp3
15
+
16
+ # Python
17
+ __pycache__/
18
+ *.py[cod]
19
+ *.egg-info/
20
+ venv/
21
+ .venv/
22
+ env/
23
+
24
+ # IDE
25
+ .vscode/
26
+ .idea/
27
+
28
+ # OS
29
+ .DS_Store
30
+ Thumbs.db
31
+
32
+ # Notebooks
33
+ notebooks/
34
+
35
+ # Misc
36
+ tmp/
37
+ *.log
38
+ wandb/
README.md ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Electrical Outlets & Switches Diagnostic Pipeline
2
+
3
+ Non-intrusive AI diagnostic system for electrical outlets and switches using **image classification** and **audio analysis** with decision-level fusion.
4
+
5
+ ## Overview
6
+
7
+ This pipeline analyzes photos and/or audio recordings of electrical outlets to detect potential safety issues without requiring physical inspection. It uses two independent models fused at the decision level for robust predictions.
8
+
9
+ ### Image Model
10
+ - **Architecture:** EfficientNet-B0 (frozen backbone) + MLP head (512 β†’ 5 classes)
11
+ - **Classes:** burn/overheating, cracked faceplate, loose outlet, normal, water exposed
12
+ - **Performance:** 77.3% accuracy, 66.7% minimum per-class recall
13
+ - **Training data:** 1,299 images across 10 source categories merged into 5 classes
14
+
15
+ ### Audio Model
16
+ - **Architecture:** 3-layer Spectrogram CNN (32β†’64β†’128 channels + adaptive pooling)
17
+ - **Classes:** normal, buzzing, crackling/arcing, arcing pop
18
+ - **Performance:** 100% macro recall on validation
19
+ - **Training data:** 100 WAV files (22050 Hz, mel spectrograms with SpecAugment)
20
+
21
+ ### Fusion
22
+ - Decision-level fusion combining both modalities
23
+ - Safety-first: prefers "uncertain" over "normal" when in doubt
24
+ - Severity = max(image_severity, audio_severity)
25
+ - Configurable confidence thresholds in `config/thresholds.yaml`
26
+
27
+ ## Project Structure
28
+
29
+ ```
30
+ CV/
31
+ β”œβ”€β”€ config/
32
+ β”‚ β”œβ”€β”€ label_mapping.json # Class definitions & folderβ†’class mapping
33
+ β”‚ β”œβ”€β”€ image_train_config.yaml # Image training hyperparameters
34
+ β”‚ β”œβ”€β”€ audio_train_config.yaml # Audio training hyperparameters
35
+ β”‚ β”œβ”€β”€ thresholds.yaml # Fusion confidence thresholds
36
+ β”‚ └── schema.yaml # API output schema
37
+ β”œβ”€β”€ src/
38
+ β”‚ β”œβ”€β”€ data/
39
+ β”‚ β”‚ β”œβ”€β”€ image_dataset.py # Image dataset with stratified splits
40
+ β”‚ β”‚ └── audio_dataset.py # Audio dataset with stratified splits
41
+ β”‚ β”œβ”€β”€ models/
42
+ β”‚ β”‚ β”œβ”€β”€ image_model.py # EfficientNet-B0 + MLP classifier
43
+ β”‚ β”‚ └── audio_model.py # Spectrogram CNN classifier
44
+ β”‚ β”œβ”€β”€ fusion/
45
+ β”‚ β”‚ └── fusion_logic.py # Decision-level fusion
46
+ β”‚ └── inference/
47
+ β”‚ └── wrapper.py # End-to-end inference pipeline
48
+ β”œβ”€β”€ training/
49
+ β”‚ β”œβ”€β”€ train_image.py # Image model training (2-stage)
50
+ β”‚ └── train_audio.py # Audio model training
51
+ β”œβ”€β”€ api/
52
+ β”‚ └── main.py # FastAPI endpoint
53
+ β”œβ”€β”€ weights/
54
+ β”‚ β”œβ”€β”€ electrical_outlets_image_best.pt # Trained image model
55
+ β”‚ └── electrical_outlets_audio_best.pt # Trained audio model
56
+ β”œβ”€β”€ tests/
57
+ β”‚ └── test_fusion.py # Fusion logic tests
58
+ β”œβ”€β”€ test_single_image.py # Quick single-image testing
59
+ β”œβ”€β”€ requirements.txt
60
+ └── README.md
61
+ ```
62
+
63
+ ## Setup
64
+
65
+ ### Requirements
66
+
67
+ - Python 3.10+
68
+ - NVIDIA GPU with CUDA (recommended: RTX 3090 or better)
69
+
70
+ ### Installation
71
+
72
+ ```bash
73
+ git clone https://huggingface.co/<your-repo>/electrical-outlets-diagnostic
74
+ cd electrical-outlets-diagnostic
75
+
76
+ pip install -r requirements.txt
77
+
78
+ # If GPU: install CUDA-enabled PyTorch
79
+ pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
80
+ # Also needed on Windows:
81
+ pip install soundfile
82
+ ```
83
+
84
+ ### Download Weights
85
+
86
+ Download the model weights from the HuggingFace repository and place them in `weights/`:
87
+
88
+ ```
89
+ weights/
90
+ β”œβ”€β”€ electrical_outlets_image_best.pt (~ 17 MB)
91
+ └── electrical_outlets_audio_best.pt (~ 2 MB)
92
+ ```
93
+
94
+ ## Usage
95
+
96
+ ### Test a Single Image
97
+
98
+ ```bash
99
+ python test_single_image.py --image path/to/outlet_photo.jpg
100
+ ```
101
+
102
+ Output:
103
+ ```
104
+ ==================================================
105
+ burned_outlet.jpg
106
+ ==================================================
107
+ β†’ burn_overheating (high severity)
108
+ β†’ 87.3% confidence
109
+ β†’ issue_detected
110
+
111
+ burn_overheating 87.3% β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ β—„
112
+ cracked_faceplate 5.2% β–ˆ
113
+ loose_outlet 3.1% β–Š
114
+ normal 2.8% β–Š
115
+ water_exposed 1.6% ▍
116
+ ```
117
+
118
+ ### API Server
119
+
120
+ ```bash
121
+ uvicorn api.main:app --host 0.0.0.0 --port 8000
122
+ ```
123
+
124
+ #### Endpoints
125
+
126
+ **POST** `/v1/diagnose/electrical_outlets`
127
+
128
+ Upload image and/or audio for diagnosis:
129
+ ```bash
130
+ # Image only
131
+ curl -X POST http://localhost:8000/v1/diagnose/electrical_outlets \
132
+ -F "image=@outlet_photo.jpg"
133
+
134
+ # Image + Audio
135
+ curl -X POST http://localhost:8000/v1/diagnose/electrical_outlets \
136
+ -F "image=@outlet_photo.jpg" \
137
+ -F "audio=@outlet_recording.wav"
138
+ ```
139
+
140
+ Response:
141
+ ```json
142
+ {
143
+ "diagnostic_element": "electrical_outlets",
144
+ "result": "issue_detected",
145
+ "issue_type": "burn_overheating",
146
+ "severity": "high",
147
+ "confidence": 0.873,
148
+ "modality_contributions": null,
149
+ "primary_issue": "burn_overheating",
150
+ "secondary_issue": null
151
+ }
152
+ ```
153
+
154
+ **GET** `/health` β€” Check model availability
155
+
156
+ ### Python API
157
+
158
+ ```python
159
+ from src.inference.wrapper import run_electrical_outlets_inference
160
+
161
+ result = run_electrical_outlets_inference(
162
+ image_path="path/to/photo.jpg",
163
+ audio_path="path/to/recording.wav", # optional
164
+ )
165
+ print(result)
166
+ ```
167
+
168
+ ## Training
169
+
170
+ ### Image Model
171
+
172
+ ```bash
173
+ python training/train_image.py --device cuda
174
+ ```
175
+
176
+ Two-stage training:
177
+ 1. **Stage 1:** Frozen EfficientNet-B0 backbone, train MLP head only (80-100 epochs)
178
+ 2. **Stage 2:** Unfreeze last 2 backbone blocks, fine-tune with low LR (25 epochs)
179
+
180
+ ### Audio Model
181
+
182
+ ```bash
183
+ python training/train_audio.py --device cuda
184
+ ```
185
+
186
+ Single-stage with SpecAugment, class-weighted loss, cosine LR schedule.
187
+
188
+ ## Class Mapping
189
+
190
+ ### Image Classes (5)
191
+
192
+ | Class | Issue Type | Severity | Source Folders |
193
+ |-------|-----------|----------|----------------|
194
+ | 0 | burn_overheating | high | Burn marks (250), Discoloration (100), Sparking damage (150) |
195
+ | 1 | cracked_faceplate | medium | Cracked faceplate (150), Damaged switches (50) |
196
+ | 2 | loose_outlet | medium | Loose outlet (200), Exposed wiring (150) |
197
+ | 3 | normal | low | Normal outlets (50), Normal switches (50) |
198
+ | 4 | water_exposed | high | Water intrusion (150) |
199
+
200
+ ### Audio Classes (4)
201
+
202
+ | Class | Issue Type | Severity |
203
+ |-------|-----------|----------|
204
+ | 0 | normal | low |
205
+ | 1 | buzzing | high |
206
+ | 2 | crackling_arcing | high |
207
+ | 3 | arcing_pop | critical |
208
+
209
+ ## Severity Levels
210
+
211
+ | Level | Action Required |
212
+ |-------|----------------|
213
+ | **low** | Monitor β€” no immediate action |
214
+ | **medium** | Schedule repair |
215
+ | **high** | Shut off circuit immediately |
216
+ | **critical** | Shut off main breaker immediately |
217
+
218
+ ## Fusion Logic
219
+
220
+ The fusion layer combines image and audio predictions:
221
+
222
+ - If **both agree** on issue β†’ `issue_detected` with max severity
223
+ - If **both agree** on normal with high confidence β†’ `normal`
224
+ - If **they disagree** β†’ `uncertain` (unless one has >92% confidence)
225
+ - **Safety-first:** defaults to `uncertain` over `normal` when confidence is low
226
+
227
+ ## Limitations
228
+
229
+ - Image model trained on web-sourced images (some watermarked/AI-generated)
230
+ - Audio model trained on 100 synthetic clips β€” use as supporting evidence only
231
+ - Water damage and cracked faceplate classes have lower recall (64-67%)
232
+ - No GFCI failure detection (no training data available)
233
+ - Real-world accuracy will be lower than validation metrics
234
+
235
+ ## License
236
+
237
+ Proprietary β€” for use in the Electrical Outlets diagnostic pipeline only.
app.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Electrical Outlets Diagnostic β€” Gradio Demo
3
+ Install: pip install gradio
4
+ Run: python app.py
5
+ """
6
+ from pathlib import Path
7
+ import sys
8
+ import json
9
+
10
+ import torch
11
+ import numpy as np
12
+ from torchvision import transforms
13
+ from PIL import Image
14
+
15
+ ROOT = Path(__file__).resolve().parent
16
+ sys.path.insert(0, str(ROOT))
17
+
18
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
+ IMAGE_MODEL = None
20
+ IMAGE_TEMP = 1.0
21
+ AUDIO_MODEL = None
22
+ AUDIO_TEMP = 1.0
23
+ AUDIO_CFG = {}
24
+
25
+
26
+ def load_models():
27
+ global IMAGE_MODEL, IMAGE_TEMP, AUDIO_MODEL, AUDIO_TEMP, AUDIO_CFG
28
+
29
+ img_weights = ROOT / "weights" / "electrical_outlets_image_best.pt"
30
+ mapping = ROOT / "config" / "label_mapping.json"
31
+
32
+ if img_weights.exists():
33
+ from src.models.image_model import ElectricalOutletsImageModel
34
+ ckpt = torch.load(img_weights, map_location=DEVICE, weights_only=False)
35
+ head_hidden = ckpt["model_state_dict"]["head.1.weight"].shape[0]
36
+ IMAGE_MODEL = ElectricalOutletsImageModel(
37
+ num_classes=ckpt["num_classes"], label_mapping_path=mapping,
38
+ pretrained=False, head_hidden=head_hidden,
39
+ )
40
+ IMAGE_MODEL.load_state_dict(ckpt["model_state_dict"])
41
+ IMAGE_MODEL.idx_to_issue_type = ckpt.get("idx_to_issue_type")
42
+ IMAGE_MODEL.idx_to_severity = ckpt.get("idx_to_severity")
43
+ IMAGE_MODEL.eval().to(DEVICE)
44
+ T = ckpt.get("temperature", 1.0)
45
+ IMAGE_TEMP = T if 0 < T < 10 else 1.0
46
+ print(f" Image model loaded ({ckpt['num_classes']} classes, head={head_hidden})")
47
+
48
+ audio_weights = ROOT / "weights" / "electrical_outlets_audio_best.pt"
49
+ if audio_weights.exists():
50
+ from src.models.audio_model import ElectricalOutletsAudioModel
51
+ import yaml
52
+ ckpt = torch.load(audio_weights, map_location=DEVICE, weights_only=False)
53
+ audio_cfg_path = ROOT / "config" / "audio_train_config.yaml"
54
+ n_mels, time_steps = 128, 128
55
+ if audio_cfg_path.exists():
56
+ with open(audio_cfg_path) as f:
57
+ AUDIO_CFG = yaml.safe_load(f)
58
+ n_mels = AUDIO_CFG.get("model", {}).get("n_mels", 128)
59
+ time_steps = AUDIO_CFG.get("model", {}).get("time_steps", 128)
60
+ AUDIO_MODEL = ElectricalOutletsAudioModel(
61
+ num_classes=ckpt["num_classes"], label_mapping_path=mapping,
62
+ n_mels=n_mels, time_steps=time_steps,
63
+ )
64
+ AUDIO_MODEL.load_state_dict(ckpt["model_state_dict"])
65
+ AUDIO_MODEL.idx_to_label = ckpt.get("idx_to_label")
66
+ AUDIO_MODEL.idx_to_issue_type = ckpt.get("idx_to_issue_type")
67
+ AUDIO_MODEL.idx_to_severity = ckpt.get("idx_to_severity")
68
+ AUDIO_MODEL.eval().to(DEVICE)
69
+ T = ckpt.get("temperature", 1.0)
70
+ AUDIO_TEMP = T if 0 < T < 10 else 1.0
71
+ print(f" Audio model loaded ({ckpt['num_classes']} classes)")
72
+
73
+
74
+ SEV_COLORS = {"low": "#22c55e", "medium": "#f59e0b", "high": "#ef4444", "critical": "#dc2626"}
75
+ SEV_ICONS = {"low": "βœ…", "medium": "⚠️", "high": "πŸ”΄", "critical": "🚨"}
76
+
77
+
78
+ def make_bar_html(probs_dict, highlight=None):
79
+ rows = ""
80
+ for name, prob in sorted(probs_dict.items(), key=lambda x: -x[1]):
81
+ pct = prob * 100
82
+ color = "#60a5fa" if name != highlight else "#f59e0b"
83
+ rows += f"""
84
+ <div style="display:flex;align-items:center;gap:8px;margin:3px 0;">
85
+ <div style="width:140px;font-size:13px;text-align:right;color:#ccc;">{name.replace('_',' ')}</div>
86
+ <div style="flex:1;background:#2a2a3e;border-radius:4px;height:20px;overflow:hidden;">
87
+ <div style="width:{pct}%;background:{color};height:100%;border-radius:4px;"></div>
88
+ </div>
89
+ <div style="width:55px;font-size:13px;color:#eee;">{pct:.1f}%</div>
90
+ </div>"""
91
+ return f'<div style="padding:8px 0;">{rows}</div>'
92
+
93
+
94
+ def make_result_html(pred, title, probs_dict=None):
95
+ sev = pred.get("severity", "low")
96
+ color = SEV_COLORS.get(sev, "#666")
97
+ sev_icon = SEV_ICONS.get(sev, "")
98
+ conf = pred.get("confidence", 0)
99
+ issue = (pred.get("issue_type") or "uncertain").replace("_", " ").title()
100
+ result_text = pred.get("result", "").replace("_", " ").title()
101
+ bars = make_bar_html(probs_dict, pred.get("issue_type")) if probs_dict else ""
102
+
103
+ return f"""
104
+ <div style="background:#1a1a2e;border-radius:12px;padding:20px;margin:8px 0;
105
+ border-left:4px solid {color};color:#e0e0e0;font-family:system-ui;">
106
+ <div style="font-size:12px;color:#888;text-transform:uppercase;letter-spacing:1px;margin-bottom:10px;">{title}</div>
107
+ <div style="font-size:26px;font-weight:700;margin-bottom:6px;">{result_text}</div>
108
+ <div style="font-size:18px;color:{color};font-weight:600;margin-bottom:14px;">{issue}</div>
109
+ <div style="display:flex;gap:32px;">
110
+ <div><div style="font-size:11px;color:#888;text-transform:uppercase;">Severity</div>
111
+ <div style="font-size:15px;font-weight:600;color:{color};">{sev_icon} {sev.upper()}</div></div>
112
+ <div><div style="font-size:11px;color:#888;text-transform:uppercase;">Confidence</div>
113
+ <div style="font-size:15px;font-weight:600;">{conf:.1%}</div></div>
114
+ </div>
115
+ {bars}
116
+ </div>"""
117
+
118
+
119
+ def predict_image_fn(img):
120
+ if IMAGE_MODEL is None:
121
+ return None, None
122
+ tf = transforms.Compose([
123
+ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(),
124
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
125
+ ])
126
+ x = tf(img.convert("RGB")).unsqueeze(0).to(DEVICE)
127
+ with torch.no_grad():
128
+ logits = IMAGE_MODEL(x) / IMAGE_TEMP
129
+ probs = torch.softmax(logits, dim=-1)[0]
130
+ pred = IMAGE_MODEL.predict_to_schema(logits)
131
+ probs_dict = {IMAGE_MODEL.idx_to_issue_type[i]: p for i, p in enumerate(probs.tolist())}
132
+ return pred, probs_dict
133
+
134
+
135
+ def predict_audio_fn(audio_tuple):
136
+ if AUDIO_MODEL is None:
137
+ return None, None
138
+ import torchaudio
139
+ sr_in, audio_data = audio_tuple
140
+ if isinstance(audio_data, np.ndarray):
141
+ waveform = torch.from_numpy(audio_data.astype(np.float32))
142
+ if waveform.dim() == 1:
143
+ waveform = waveform.unsqueeze(0)
144
+ elif waveform.dim() == 2:
145
+ if waveform.shape[1] <= 2:
146
+ waveform = waveform.T
147
+ if waveform.shape[0] > 1:
148
+ waveform = waveform.mean(dim=0, keepdim=True)
149
+ mx = waveform.abs().max()
150
+ if mx > 0:
151
+ waveform = waveform / mx
152
+ else:
153
+ return None, None
154
+
155
+ sample_rate = AUDIO_CFG.get("data", {}).get("sample_rate", 22050)
156
+ if sr_in != sample_rate:
157
+ waveform = torchaudio.functional.resample(waveform, sr_in, sample_rate)
158
+ target_len = int(AUDIO_CFG.get("data", {}).get("target_length_sec", 5.0) * sample_rate)
159
+ if waveform.shape[1] >= target_len:
160
+ s = (waveform.shape[1] - target_len) // 2
161
+ waveform = waveform[:, s:s + target_len]
162
+ else:
163
+ waveform = torch.nn.functional.pad(waveform, (0, target_len - waveform.shape[1]))
164
+
165
+ sc = AUDIO_CFG.get("spectrogram", {})
166
+ mel = torchaudio.transforms.MelSpectrogram(
167
+ sample_rate=sample_rate, n_fft=sc.get("n_fft", 1024),
168
+ hop_length=sc.get("hop_length", 512), win_length=sc.get("win_length", 1024),
169
+ n_mels=sc.get("n_mels", 128),
170
+ )(waveform)
171
+ log_mel = torch.log(mel.clamp(min=1e-5)).unsqueeze(0).to(DEVICE)
172
+
173
+ with torch.no_grad():
174
+ logits = AUDIO_MODEL(log_mel) / AUDIO_TEMP
175
+ probs = torch.softmax(logits, dim=-1)[0]
176
+ pred = AUDIO_MODEL.predict_to_schema(logits)
177
+ labels = AUDIO_MODEL.idx_to_label or [f"class_{i}" for i in range(AUDIO_MODEL.num_classes)]
178
+ probs_dict = {labels[i]: p for i, p in enumerate(probs.tolist())}
179
+ return pred, probs_dict
180
+
181
+
182
+ def fuse_fn(image_pred, audio_pred):
183
+ from src.fusion.fusion_logic import fuse_modalities, ModalityOutput
184
+ import yaml
185
+ th_path = ROOT / "config" / "thresholds.yaml"
186
+ th = {}
187
+ if th_path.exists():
188
+ with open(th_path) as f:
189
+ th = yaml.safe_load(f) or {}
190
+ img_out = ModalityOutput(result=image_pred["result"], issue_type=image_pred.get("issue_type"),
191
+ severity=image_pred["severity"], confidence=image_pred["confidence"])
192
+ aud_out = ModalityOutput(result=audio_pred["result"], issue_type=audio_pred.get("issue_type"),
193
+ severity=audio_pred["severity"], confidence=audio_pred["confidence"])
194
+ return fuse_modalities(img_out, aud_out,
195
+ confidence_issue_min=th.get("confidence_issue_min", 0.6),
196
+ confidence_normal_min=th.get("confidence_normal_min", 0.75),
197
+ uncertain_if_disagree=th.get("uncertain_if_disagree", True),
198
+ high_confidence_override=th.get("high_confidence_override", 0.92))
199
+
200
+
201
+ def diagnose(image, audio):
202
+ if image is None and audio is None:
203
+ return '<div style="padding:40px;color:#888;text-align:center;font-style:italic;">Upload an image or audio to begin diagnosis...</div>'
204
+
205
+ img_pred, img_probs, aud_pred, aud_probs = None, None, None, None
206
+ try:
207
+ if image is not None:
208
+ img = Image.fromarray(image) if isinstance(image, np.ndarray) else image
209
+ img_pred, img_probs = predict_image_fn(img)
210
+ if audio is not None:
211
+ aud_pred, aud_probs = predict_audio_fn(audio)
212
+ except Exception as e:
213
+ return f'<div style="padding:20px;color:#f87171;">Error: {e}</div>'
214
+
215
+ html = ""
216
+ if img_pred and aud_pred:
217
+ fused = fuse_fn(img_pred, aud_pred)
218
+ html += make_result_html(fused, "⚑ Fused Diagnosis")
219
+ html += '<div style="display:flex;gap:12px;">'
220
+ html += f'<div style="flex:1;">{make_result_html(img_pred, "πŸ“· Image", img_probs)}</div>'
221
+ html += f'<div style="flex:1;">{make_result_html(aud_pred, "🎀 Audio", aud_probs)}</div>'
222
+ html += '</div>'
223
+ elif img_pred:
224
+ html += make_result_html(img_pred, "πŸ“· Image Diagnosis", img_probs)
225
+ elif aud_pred:
226
+ html += make_result_html(aud_pred, "🎀 Audio Diagnosis", aud_probs)
227
+ else:
228
+ html = '<div style="padding:20px;color:#f87171;">Could not process input.</div>'
229
+ return html
230
+
231
+
232
+ if __name__ == "__main__":
233
+ import gradio as gr
234
+
235
+ print("Loading models...")
236
+ load_models()
237
+ print(f"Device: {DEVICE}\n")
238
+
239
+ with gr.Blocks(
240
+ title="Electrical Outlets Diagnostic",
241
+ theme=gr.themes.Base(primary_hue="red", secondary_hue="amber", neutral_hue="slate",
242
+ font=gr.themes.GoogleFont("Inter")),
243
+ css=".gradio-container{max-width:960px!important} footer{display:none!important}"
244
+ ) as demo:
245
+
246
+ gr.Markdown("# ⚑ Electrical Outlets Diagnostic\nUpload a **photo** and/or **audio** to detect safety issues.")
247
+
248
+ with gr.Row():
249
+ with gr.Column(scale=1):
250
+ image_input = gr.Image(label="πŸ“· Outlet Photo", type="numpy", height=300)
251
+ audio_input = gr.Audio(label="🎀 Audio Recording", type="numpy")
252
+ btn = gr.Button("πŸ” Diagnose", variant="primary", size="lg")
253
+ with gr.Column(scale=1):
254
+ output = gr.HTML(value='<div style="padding:40px;color:#888;text-align:center;font-style:italic;">Upload an image or audio to begin...</div>')
255
+
256
+ btn.click(fn=diagnose, inputs=[image_input, audio_input], outputs=[output])
257
+ image_input.change(fn=diagnose, inputs=[image_input, audio_input], outputs=[output])
258
+ audio_input.change(fn=diagnose, inputs=[image_input, audio_input], outputs=[output])
259
+
260
+ gr.Markdown("---\n| Severity | Action |\n|--|--|\n| βœ… Low | Monitor |\n| ⚠️ Medium | Schedule repair |\n| πŸ”΄ High | Shut off circuit |\n| 🚨 Critical | Shut off main breaker |")
261
+
262
+ demo.launch(server_name="127.0.0.1", server_port=7860, share=False, show_error=True)
config/audio_train_config.yaml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Audio model training config - Electrical Outlets
2
+ # 100 samples: heavy augmentation, balanced batching, treat as preliminary
3
+
4
+ data:
5
+ root: "electrical_outlets_sounds_100"
6
+ label_mapping: "config/label_mapping.json"
7
+ train_ratio: 0.7
8
+ val_ratio: 0.15
9
+ seed: 42
10
+ batch_size: 16
11
+ num_workers: 0
12
+ target_length_sec: 5.0
13
+ sample_rate: 16000
14
+
15
+ spectrogram:
16
+ n_mels: 64
17
+ n_fft: 512
18
+ hop_length: 256
19
+ win_length: 512
20
+
21
+ model:
22
+ num_classes: 4
23
+ n_mels: 64
24
+ time_steps: 128
25
+
26
+ training:
27
+ epochs: 80
28
+ lr: 1.0e-3
29
+ weight_decay: 1.0e-4
30
+ use_class_weights: true
31
+ early_stopping_patience: 12
32
+ early_stopping_metric: "val_macro_recall"
33
+
34
+ calibration:
35
+ use_temperature_scaling: true
36
+ val_fraction_for_calibration: 0.5
37
+
38
+ output:
39
+ weights_dir: "weights"
40
+ best_name: "electrical_outlets_audio_best.pt"
41
+ report_name: "audio_model_report.md"
config/image_train_config.yaml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # v5.1 β€” Push past 63% min recall
2
+ # Changes: higher finetune LR, bigger head, fixed temp scaling
3
+
4
+ data:
5
+ root: "ELECTRICAL OUTLETS-20260106T153508Z-3-001"
6
+ label_mapping: "config/label_mapping.json"
7
+ train_ratio: 0.7
8
+ val_ratio: 0.15
9
+ seed: 42
10
+ batch_size: 64
11
+ num_workers: 4
12
+
13
+ augmentation:
14
+ resize: 256
15
+ crop: 224
16
+
17
+ model:
18
+ num_classes: 5
19
+ pretrained: true
20
+ head_hidden: 512 # was 256 β€” more capacity with 1300 images
21
+ head_dropout: 0.5 # was 0.4 β€” stronger regularization
22
+
23
+ training:
24
+ epochs: 100 # was 80 β€” give head more time
25
+ lr: 3.0e-3
26
+ weight_decay: 1.0e-3
27
+ use_class_weights: true
28
+ use_focal: true
29
+ focal_alpha: 0.25
30
+ focal_gamma: 2.0
31
+ early_stopping_patience: 25 # was 20
32
+ early_stopping_metric: "val_min_recall"
33
+ finetune_last_blocks: true
34
+ finetune_lr: 2.0e-4 # was 5e-5 β€” 4x higher, backbone needs to adapt more
35
+ finetune_epochs: 30 # was 25
36
+
37
+ calibration:
38
+ use_temperature_scaling: false # DISABLED β€” was producing negative T
39
+
40
+ output:
41
+ weights_dir: "weights"
42
+ best_name: "electrical_outlets_image_best.pt"
43
+ report_name: "image_model_report.md"
config/label_mapping.json ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "image": {
3
+ "classes": [
4
+ {
5
+ "folder_key": "burn_marks_overheating",
6
+ "issue_type": "burn_overheating",
7
+ "severity": "high",
8
+ "description": "Fire, overheating, sparking, discoloration"
9
+ },
10
+ {
11
+ "folder_key": "cracked_faceplates",
12
+ "issue_type": "cracked_faceplate",
13
+ "severity": "medium",
14
+ "description": "Cracked/broken faceplate, damaged switches"
15
+ },
16
+ {
17
+ "folder_key": "loose_outlets",
18
+ "issue_type": "loose_outlet",
19
+ "severity": "medium",
20
+ "description": "Loose outlet, pulled from wall, exposed wiring"
21
+ },
22
+ {
23
+ "folder_key": "normal_outlets",
24
+ "issue_type": "normal",
25
+ "severity": "low",
26
+ "description": "Normal outlet/switch condition"
27
+ },
28
+ {
29
+ "folder_key": "water_exposed",
30
+ "issue_type": "water_exposed",
31
+ "severity": "high",
32
+ "description": "Water intrusion near outlet"
33
+ }
34
+ ],
35
+ "folder_to_class": {
36
+ "Burn marks - overheating 250": "burn_marks_overheating",
37
+ "Discoloration (heat aging) 100": "burn_marks_overheating",
38
+ "Sparking damage evidence 150": "burn_marks_overheating",
39
+ "Cracked faceplate 150": "cracked_faceplates",
40
+ "Damaged switches 50": "cracked_faceplates",
41
+ "Loose outlet - pulled from wall 200": "loose_outlets",
42
+ "Exposed wiring 150": "loose_outlets",
43
+ "Normal outlets 50": "normal_outlets",
44
+ "Normal switches 50": "normal_outlets",
45
+ "Water intrusion near outlet 150": "water_exposed"
46
+ },
47
+ "class_to_idx": {
48
+ "burn_marks_overheating": 0,
49
+ "cracked_faceplates": 1,
50
+ "loose_outlets": 2,
51
+ "normal_outlets": 3,
52
+ "water_exposed": 4
53
+ },
54
+ "idx_to_issue_type": [
55
+ "burn_overheating",
56
+ "cracked_faceplate",
57
+ "loose_outlet",
58
+ "normal",
59
+ "water_exposed"
60
+ ],
61
+ "idx_to_severity": ["high", "medium", "medium", "low", "high"]
62
+ },
63
+ "audio": {
64
+ "file_pattern_to_label": {
65
+ "normal_near_silent": "normal",
66
+ "plug_insert_remove_clicks": "normal",
67
+ "load_switching": "normal",
68
+ "buzzing_outlet": "buzzing",
69
+ "loose_contact_crackle": "crackling_arcing",
70
+ "arcing_pop": "arcing_pop"
71
+ },
72
+ "label_to_severity": {
73
+ "normal": "low",
74
+ "buzzing": "high",
75
+ "crackling_arcing": "high",
76
+ "arcing_pop": "critical"
77
+ },
78
+ "label_to_issue_type": {
79
+ "normal": "normal",
80
+ "buzzing": "buzzing",
81
+ "crackling_arcing": "crackling_arcing",
82
+ "arcing_pop": "arcing_pop"
83
+ },
84
+ "class_to_idx": {
85
+ "normal": 0,
86
+ "buzzing": 1,
87
+ "crackling_arcing": 2,
88
+ "arcing_pop": 3
89
+ },
90
+ "idx_to_label": ["normal", "buzzing", "crackling_arcing", "arcing_pop"],
91
+ "num_classes": 4
92
+ }
93
+ }
config/schema.yaml ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Canonical output schema for Electrical Outlets diagnostic element
2
+ # Used by image model, audio model, fusion layer, and API
3
+ # Aligned to client PDF: Electrical outlet & switchs diagnostiocs
4
+
5
+ diagnostic_element: electrical_outlets
6
+
7
+ result:
8
+ type: string
9
+ enum:
10
+ - issue_detected
11
+ - normal
12
+ - uncertain
13
+ description: "Final outcome; uncertain triggers backend-guided adjustment or escalation"
14
+
15
+ issue_type:
16
+ type: string
17
+ nullable: true
18
+ enum:
19
+ # Image-derived (NOT OPEN, PDF diagnostics 1-38)
20
+ - burn_overheating
21
+ - cracked_faceplate
22
+ - gfci_failure
23
+ - loose_outlet
24
+ - water_exposed
25
+ # Audio-derived (PDF diagnostics 21-28)
26
+ - buzzing
27
+ - humming
28
+ - crackling_arcing
29
+ - arcing_pop
30
+ - sizzling
31
+ - clicking_idle
32
+ # Combined / generic
33
+ - normal
34
+ description: "Primary issue type when result is issue_detected; null for normal/uncertain when no single type"
35
+
36
+ severity:
37
+ type: string
38
+ enum:
39
+ - low
40
+ - medium
41
+ - high
42
+ - critical
43
+ description: "Per PDF: low=monitor, medium=repair, high=shut circuit, critical=shut main breaker"
44
+
45
+ confidence:
46
+ type: number
47
+ minimum: 0
48
+ maximum: 1
49
+ description: "Calibrated probability; drives uncertain path when below threshold"
50
+
51
+ modality_contributions:
52
+ type: object
53
+ nullable: true
54
+ properties:
55
+ image:
56
+ type: object
57
+ nullable: true
58
+ properties:
59
+ result: { type: string }
60
+ issue_type: { type: string, nullable: true }
61
+ severity: { type: string }
62
+ confidence: { type: number }
63
+ audio:
64
+ type: object
65
+ nullable: true
66
+ properties:
67
+ result: { type: string }
68
+ issue_type: { type: string, nullable: true }
69
+ severity: { type: string }
70
+ confidence: { type: number }
71
+ description: "Per-modality outputs for transparency; present when both image and audio provided"
72
+
73
+ # For fusion when both modalities detect different issues
74
+ primary_issue:
75
+ type: string
76
+ nullable: true
77
+ description: "Higher-severity issue when both modalities detect issues"
78
+ secondary_issue:
79
+ type: string
80
+ nullable: true
81
+ description: "Other issue when both modalities detect different issues"
config/thresholds.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Threshold and safety configuration - Electrical Outlets
2
+ # Prefer "uncertain" over "normal" when in doubt (minimize false negatives)
3
+
4
+ confidence_issue_min: 0.6 # below this -> result = uncertain when issue_detected
5
+ confidence_normal_min: 0.75 # both modalities must exceed this to return "normal"
6
+ uncertain_if_disagree: true # image defect + audio normal (or vice versa) -> uncertain unless one side very high
7
+ high_confidence_override: 0.92 # if one modality >= this and says issue_detected, can override disagree
8
+
9
+ severity_order:
10
+ - low
11
+ - medium
12
+ - high
13
+ - critical
releases.md ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Releases
2
+
3
+ ## v1.0 β€” February 2026
4
+ - 5-class image model (EfficientNet-B0 + MLP head): 77% accuracy, 67% min recall
5
+ - 4-class audio model (Spectrogram CNN): 100% recall
6
+ - Decision-level fusion with configurable thresholds
7
+ - Gradio demo app
8
+ - FastAPI endpoint
requirements.txt ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ aiohappyeyeballs==2.6.1
3
+ aiohttp==3.11.18
4
+ aiosignal==1.3.2
5
+ altair==5.5.0
6
+ annotated-doc==0.0.4
7
+ annotated-types==0.7.0
8
+ anyio==4.9.0
9
+ async-timeout==4.0.3
10
+ attrs==25.3.0
11
+ beautifulsoup4==4.13.4
12
+ Brotli @ file:///D:/bld/brotli-split_1725267609074/work
13
+ certifi @ file:///home/conda/feedstock_root/build_artifacts/certifi_1739515848642/work/certifi
14
+ cffi @ file:///D:/bld/cffi_1725560792189/work
15
+ charset-normalizer @ file:///home/conda/feedstock_root/build_artifacts/charset-normalizer_1746214863626/work
16
+ click==8.1.8
17
+ colorama==0.4.6
18
+ comtypes==1.4.10
19
+ contourpy==1.3.0
20
+ cycler==0.12.1
21
+ dataclasses-json==0.6.7
22
+ docopt==0.6.2
23
+ exceptiongroup==1.2.2
24
+ fastapi==0.95.2
25
+ ffmpy==1.0.0
26
+ filelock==3.18.0
27
+ fonttools==4.60.2
28
+ frozenlist==1.6.0
29
+ fsspec==2025.3.2
30
+ gradio==3.50.2
31
+ gradio_client==0.6.1
32
+ greenlet==3.2.1
33
+ h11==0.16.0
34
+ h2 @ file:///home/conda/feedstock_root/build_artifacts/h2_1738578511449/work
35
+ hpack @ file:///home/conda/feedstock_root/build_artifacts/hpack_1737618293087/work
36
+ httpcore==1.0.9
37
+ httpx==0.28.1
38
+ httpx-sse==0.4.0
39
+ huggingface-hub==0.31.1
40
+ hyperframe @ file:///home/conda/feedstock_root/build_artifacts/hyperframe_1737618333194/work
41
+ idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1733211830134/work
42
+ importlib_resources==6.5.2
43
+ Jinja2==3.1.6
44
+ joblib==1.5.3
45
+ Js2Py==0.74
46
+ jsonpatch==1.33
47
+ jsonpointer==3.0.0
48
+ jsonschema==4.25.1
49
+ jsonschema-specifications==2025.9.1
50
+ kiwisolver==1.4.7
51
+ langchain==0.3.25
52
+ langchain-community==0.3.23
53
+ langchain-core==0.3.58
54
+ langchain-ollama==0.3.2
55
+ langchain-text-splitters==0.3.8
56
+ langgraph==0.4.1
57
+ langgraph-checkpoint==2.0.25
58
+ langgraph-prebuilt==0.1.8
59
+ langgraph-sdk==0.1.66
60
+ langsmith==0.3.42
61
+ llvmlite==0.43.0
62
+ markdown-it-py==3.0.0
63
+ MarkupSafe==2.1.5
64
+ marshmallow==3.26.1
65
+ matplotlib==3.9.4
66
+ mdurl==0.1.2
67
+ more-itertools==10.7.0
68
+ mpmath==1.3.0
69
+ multidict==6.4.3
70
+ mypy_extensions==1.1.0
71
+ narwhals==2.17.0
72
+ networkx==3.2.1
73
+ numba==0.60.0
74
+ numpy==1.26.4
75
+ ollama==0.4.8
76
+ openai-whisper==20240930
77
+ orjson==3.10.18
78
+ ormsgpack==1.9.1
79
+ packaging==24.2
80
+ pandas==2.3.3
81
+ pillow==10.4.0
82
+ pipwin==0.5.2
83
+ propcache==0.3.1
84
+ PyAudio==0.2.14
85
+ pycparser @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_pycparser_1733195786/work
86
+ pydantic==1.10.13
87
+ pydantic-settings==2.9.1
88
+ pydantic_core==2.41.5
89
+ pydub==0.25.1
90
+ pygame==2.6.1
91
+ Pygments==2.19.2
92
+ pyjsparser==2.7.1
93
+ pyparsing==3.3.2
94
+ pypiwin32==223
95
+ PyPrind==2.11.3
96
+ pySmartDL==1.3.4
97
+ PySocks @ file:///D:/bld/pysocks_1733217287171/work
98
+ python-dateutil==2.9.0.post0
99
+ python-dotenv==1.1.0
100
+ python-multipart==0.0.20
101
+ pyttsx3==2.98
102
+ pytz==2025.2
103
+ pywin32==310
104
+ PyYAML==6.0.2
105
+ referencing==0.36.2
106
+ regex==2024.11.6
107
+ requests @ file:///home/conda/feedstock_root/build_artifacts/requests_1733217035951/work
108
+ requests-toolbelt==1.0.0
109
+ rich==14.3.3
110
+ rpds-py==0.27.1
111
+ ruff==0.15.2
112
+ scikit-learn==1.6.1
113
+ scipy==1.13.1
114
+ semantic-version==2.10.0
115
+ shellingham==1.5.4
116
+ six==1.17.0
117
+ sniffio==1.3.1
118
+ soundfile==0.13.1
119
+ soupsieve==2.7
120
+ SpeechRecognition @ file:///home/conda/feedstock_root/build_artifacts/speechrecognition_1742707644995/work
121
+ SQLAlchemy==2.0.40
122
+ starlette==0.27.0
123
+ sympy==1.13.1
124
+ tenacity==9.1.2
125
+ threadpoolctl==3.6.0
126
+ tiktoken==0.9.0
127
+ tomlkit==0.12.0
128
+ torch==2.6.0+cu124
129
+ torchaudio==2.6.0+cu124
130
+ torchvision==0.21.0+cu124
131
+ tqdm==4.67.1
132
+ typer==0.23.2
133
+ typing-inspect==0.9.0
134
+ typing-inspection==0.4.2
135
+ typing_extensions==4.15.0
136
+ tzdata==2025.2
137
+ tzlocal==5.3.1
138
+ urllib3 @ file:///home/conda/feedstock_root/build_artifacts/urllib3_1744323578849/work
139
+ uvicorn==0.39.0
140
+ websockets==11.0.3
141
+ win_inet_pton @ file:///D:/bld/win_inet_pton_1733130564612/work
142
+ xxhash==3.5.0
143
+ yarl==1.20.0
144
+ zipp==3.23.0
145
+ zstandard==0.23.0
src/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .image_model import ElectricalOutletsImageModel
2
+ from .audio_model import ElectricalOutletsAudioModel
3
+
4
+ __all__ = ["ElectricalOutletsImageModel", "ElectricalOutletsAudioModel"]
src/data/audio_dataset.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Audio dataset for Electrical Outlets. Uses README/file naming and config/label_mapping.json.
3
+ PATCHED: rglob for subfolders, torchaudio import at module level, stratified splits.
4
+ """
5
+ from pathlib import Path
6
+ import json
7
+ import logging
8
+ from collections import defaultdict
9
+ from typing import Optional, Callable, List, Tuple
10
+
11
+ import torch
12
+ import torchaudio
13
+ from torch.utils.data import Dataset
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ def _label_from_filename(filename: str, file_pattern_to_label: dict) -> str:
19
+ for pattern, label in file_pattern_to_label.items():
20
+ if filename.startswith(pattern) or pattern in filename:
21
+ return label
22
+ return "normal"
23
+
24
+
25
+ class ElectricalOutletsAudioDataset(Dataset):
26
+ """Audio dataset from electrical_outlets_sounds_100 WAVs."""
27
+
28
+ def __init__(
29
+ self,
30
+ root: Path,
31
+ label_mapping_path: Path,
32
+ split: str = "train",
33
+ train_ratio: float = 0.7,
34
+ val_ratio: float = 0.15,
35
+ seed: int = 42,
36
+ transform: Optional[Callable] = None,
37
+ target_length_sec: float = 5.0,
38
+ sample_rate: int = 22050,
39
+ ):
40
+ self.root = Path(root)
41
+ self.transform = transform
42
+ self.target_length_sec = target_length_sec
43
+ self.sample_rate = sample_rate
44
+ with open(label_mapping_path) as f:
45
+ lm = json.load(f)
46
+ self.file_pattern_to_label = lm["audio"]["file_pattern_to_label"]
47
+ self.class_to_idx = lm["audio"]["class_to_idx"]
48
+ self.idx_to_label = lm["audio"]["idx_to_label"]
49
+ self.label_to_severity = lm["audio"]["label_to_severity"]
50
+ self.label_to_issue_type = lm["audio"]["label_to_issue_type"]
51
+ self.num_classes = len(self.class_to_idx)
52
+
53
+ self.samples: List[Tuple[Path, int]] = []
54
+ # rglob to search subfolders
55
+ for wav in self.root.rglob("*.wav"):
56
+ label = _label_from_filename(wav.stem, self.file_pattern_to_label)
57
+ if label not in self.class_to_idx:
58
+ logger.warning(f"Unmatched audio file: {wav.name} β†’ label '{label}' not in class_to_idx")
59
+ continue
60
+ self.samples.append((wav, self.class_to_idx[label]))
61
+
62
+ # Stratified split
63
+ by_class = defaultdict(list)
64
+ for i, (_, cls) in enumerate(self.samples):
65
+ by_class[cls].append(i)
66
+
67
+ train_idx, val_idx, test_idx = [], [], []
68
+ for cls in sorted(by_class.keys()):
69
+ indices = by_class[cls]
70
+ g = torch.Generator().manual_seed(seed)
71
+ perm = torch.randperm(len(indices), generator=g).tolist()
72
+ n_cls = len(indices)
73
+ n_tr = int(n_cls * train_ratio)
74
+ n_va = int(n_cls * val_ratio)
75
+ train_idx.extend([indices[p] for p in perm[:n_tr]])
76
+ val_idx.extend([indices[p] for p in perm[n_tr:n_tr + n_va]])
77
+ test_idx.extend([indices[p] for p in perm[n_tr + n_va:]])
78
+
79
+ if split == "train":
80
+ self.indices = train_idx
81
+ elif split == "val":
82
+ self.indices = val_idx
83
+ else:
84
+ self.indices = test_idx
85
+
86
+ def __len__(self) -> int:
87
+ return len(self.indices)
88
+
89
+ def __getitem__(self, idx: int):
90
+ i = self.indices[idx]
91
+ path, cls = self.samples[i]
92
+ waveform, sr = torchaudio.load(str(path))
93
+ if sr != self.sample_rate:
94
+ waveform = torchaudio.functional.resample(waveform, sr, self.sample_rate)
95
+ if waveform.shape[0] > 1:
96
+ waveform = waveform.mean(dim=0, keepdim=True)
97
+ target_len = int(self.target_length_sec * self.sample_rate)
98
+ if waveform.shape[1] >= target_len:
99
+ start = (waveform.shape[1] - target_len) // 2
100
+ waveform = waveform[:, start : start + target_len]
101
+ else:
102
+ waveform = torch.nn.functional.pad(waveform, (0, target_len - waveform.shape[1]))
103
+ if self.transform:
104
+ waveform = self.transform(waveform)
105
+ return waveform, cls
src/data/image_dataset.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Image dataset for Electrical Outlets.
3
+ FINAL v5: Direct folder_to_class mapping β€” no pattern matching, no ambiguity.
4
+ """
5
+ from pathlib import Path
6
+ import json
7
+ import logging
8
+ from collections import defaultdict
9
+ from typing import Optional, Callable, List, Tuple
10
+
11
+ import torch
12
+ from torch.utils.data import Dataset
13
+ from PIL import Image
14
+
15
+ logger = logging.getLogger(__name__)
16
+ logging.basicConfig(level=logging.INFO, format="%(message)s")
17
+
18
+
19
+ class ElectricalOutletsImageDataset(Dataset):
20
+
21
+ def __init__(
22
+ self,
23
+ root: Path,
24
+ label_mapping_path: Path,
25
+ split: str = "train",
26
+ train_ratio: float = 0.7,
27
+ val_ratio: float = 0.15,
28
+ seed: int = 42,
29
+ transform: Optional[Callable] = None,
30
+ extensions: Tuple[str, ...] = (".jpg", ".jpeg", ".png"),
31
+ ):
32
+ self.root = Path(root)
33
+ self.transform = transform
34
+ self.extensions = extensions
35
+ self.split = split
36
+
37
+ with open(label_mapping_path) as f:
38
+ lm = json.load(f)
39
+
40
+ self.folder_to_class = lm["image"]["folder_to_class"]
41
+ self.class_to_idx = lm["image"]["class_to_idx"]
42
+ self.idx_to_issue_type = lm["image"]["idx_to_issue_type"]
43
+ self.idx_to_severity = lm["image"]["idx_to_severity"]
44
+ self.num_classes = len(self.class_to_idx)
45
+
46
+ # Build samples list
47
+ self.samples: List[Tuple[Path, int]] = []
48
+ class_counts = defaultdict(int)
49
+ matched_folders = []
50
+ unmatched_folders = []
51
+
52
+ for folder in sorted(self.root.iterdir()):
53
+ if not folder.is_dir():
54
+ continue
55
+ # Direct lookup by exact folder name
56
+ class_key = self.folder_to_class.get(folder.name)
57
+ if class_key is None:
58
+ unmatched_folders.append(folder.name)
59
+ continue
60
+ cls_idx = self.class_to_idx[class_key]
61
+ count = 0
62
+ for f in folder.iterdir():
63
+ if f.suffix.lower() in self.extensions:
64
+ self.samples.append((f, cls_idx))
65
+ count += 1
66
+ class_counts[cls_idx] += count
67
+ matched_folders.append(f" βœ“ {folder.name} β†’ {class_key} (idx={cls_idx}): {count} images")
68
+
69
+ # Log results
70
+ logger.info(f"\n{'='*60}")
71
+ logger.info(f"Dataset loading from: {self.root}")
72
+ logger.info(f"{'='*60}")
73
+ for line in matched_folders:
74
+ logger.info(line)
75
+ for uf in unmatched_folders:
76
+ logger.warning(f" βœ— SKIPPED: '{uf}' (not in folder_to_class)")
77
+ logger.info(f"\nClass distribution:")
78
+ for idx in sorted(class_counts.keys()):
79
+ name = [k for k, v in self.class_to_idx.items() if v == idx][0]
80
+ logger.info(f" Class {idx} ({name}): {class_counts[idx]} images")
81
+ logger.info(f"Total: {len(self.samples)} images in {self.num_classes} classes")
82
+
83
+ if len(self.samples) == 0:
84
+ logger.error("NO SAMPLES FOUND! Check that data_root points to the folder containing your class subfolders.")
85
+ raise ValueError(f"No images found in {self.root}. Check folder names match label_mapping.json folder_to_class keys.")
86
+
87
+ # Stratified split
88
+ by_class = defaultdict(list)
89
+ for i, (_, cls) in enumerate(self.samples):
90
+ by_class[cls].append(i)
91
+
92
+ train_idx, val_idx, test_idx = [], [], []
93
+ for cls in sorted(by_class.keys()):
94
+ indices = by_class[cls]
95
+ g = torch.Generator().manual_seed(seed)
96
+ perm = torch.randperm(len(indices), generator=g).tolist()
97
+ n_cls = len(indices)
98
+ n_tr = int(n_cls * train_ratio)
99
+ n_va = int(n_cls * val_ratio)
100
+ train_idx.extend([indices[p] for p in perm[:n_tr]])
101
+ val_idx.extend([indices[p] for p in perm[n_tr:n_tr + n_va]])
102
+ test_idx.extend([indices[p] for p in perm[n_tr + n_va:]])
103
+
104
+ if split == "train":
105
+ self.indices = train_idx
106
+ elif split == "val":
107
+ self.indices = val_idx
108
+ else:
109
+ self.indices = test_idx
110
+
111
+ logger.info(f"Split '{split}': {len(self.indices)} samples\n")
112
+
113
+ def __len__(self):
114
+ return len(self.indices)
115
+
116
+ def __getitem__(self, idx):
117
+ i = self.indices[idx]
118
+ path, cls = self.samples[i]
119
+ img = Image.open(path).convert("RGB")
120
+ if self.transform:
121
+ img = self.transform(img)
122
+ return img, cls
123
+
124
+ def get_issue_type(self, class_idx: int) -> str:
125
+ return self.idx_to_issue_type[class_idx]
126
+
127
+ def get_severity(self, class_idx: int) -> str:
128
+ return self.idx_to_severity[class_idx]
129
+
130
+
131
+ def get_image_class_weights(label_mapping_path: Path, root: Path) -> torch.Tensor:
132
+ """Compute inverse frequency weights for class-weighted loss."""
133
+ with open(label_mapping_path) as f:
134
+ lm = json.load(f)
135
+ folder_to_class = lm["image"]["folder_to_class"]
136
+ class_to_idx = lm["image"]["class_to_idx"]
137
+ num_classes = len(class_to_idx)
138
+ counts = [0] * num_classes
139
+
140
+ root = Path(root)
141
+ for folder in root.iterdir():
142
+ if not folder.is_dir():
143
+ continue
144
+ class_key = folder_to_class.get(folder.name)
145
+ if class_key is None:
146
+ continue
147
+ cls_idx = class_to_idx[class_key]
148
+ n = sum(1 for f in folder.iterdir() if f.suffix.lower() in (".jpg", ".jpeg", ".png"))
149
+ counts[cls_idx] += n
150
+
151
+ total = sum(counts)
152
+ if total == 0:
153
+ return torch.ones(num_classes)
154
+ weights = [total / (num_classes * c) if c else 1.0 for c in counts]
155
+ return torch.tensor(weights, dtype=torch.float32)
src/fusion/fusion_logic.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Decision-level fusion for Electrical Outlets. No early fusion.
3
+ Rules: final_severity = max(image_severity, audio_severity); result = issue_detected | normal | uncertain.
4
+ """
5
+ from typing import Optional, Dict, Any
6
+ from dataclasses import dataclass
7
+
8
+
9
+ @dataclass
10
+ class ModalityOutput:
11
+ result: str # issue_detected | normal | uncertain
12
+ issue_type: Optional[str] = None
13
+ severity: str = "low"
14
+ confidence: float = 0.0
15
+
16
+
17
+ def _severity_rank(s: str, order: list) -> int:
18
+ try:
19
+ return order.index(s)
20
+ except ValueError:
21
+ return 0
22
+
23
+
24
+ def fuse_modalities(
25
+ image_out: Optional[ModalityOutput],
26
+ audio_out: Optional[ModalityOutput],
27
+ confidence_issue_min: float = 0.6,
28
+ confidence_normal_min: float = 0.75,
29
+ uncertain_if_disagree: bool = True,
30
+ high_confidence_override: float = 0.92,
31
+ severity_order: Optional[list] = None,
32
+ ) -> Dict[str, Any]:
33
+ """
34
+ Fuse image and/or audio outputs into single diagnostic result.
35
+ Prefer uncertain over normal when in doubt.
36
+ """
37
+ if severity_order is None:
38
+ severity_order = ["low", "medium", "high", "critical"]
39
+
40
+ modality_contributions = {}
41
+ outputs = []
42
+ if image_out is not None:
43
+ outputs.append(("image", image_out))
44
+ modality_contributions["image"] = {
45
+ "result": image_out.result,
46
+ "issue_type": image_out.issue_type,
47
+ "severity": image_out.severity,
48
+ "confidence": image_out.confidence,
49
+ }
50
+ if audio_out is not None:
51
+ outputs.append(("audio", audio_out))
52
+ modality_contributions["audio"] = {
53
+ "result": audio_out.result,
54
+ "issue_type": audio_out.issue_type,
55
+ "severity": audio_out.severity,
56
+ "confidence": audio_out.confidence,
57
+ }
58
+
59
+ if not outputs:
60
+ return {
61
+ "diagnostic_element": "electrical_outlets",
62
+ "result": "uncertain",
63
+ "issue_type": None,
64
+ "severity": "low",
65
+ "confidence": 0.0,
66
+ "modality_contributions": None,
67
+ "primary_issue": None,
68
+ "secondary_issue": None,
69
+ }
70
+
71
+ # Severity: max across modalities
72
+ max_severity_rank = -1
73
+ max_severity = "low"
74
+ for _, out in outputs:
75
+ r = _severity_rank(out.severity, severity_order)
76
+ if r > max_severity_rank:
77
+ max_severity_rank = r
78
+ max_severity = out.severity
79
+
80
+ # Result and issue_type
81
+ primary_issue = None
82
+ secondary_issue = None
83
+ has_issue = any(o.result == "issue_detected" for _, o in outputs)
84
+ all_normal = all(o.result == "normal" for _, o in outputs)
85
+ max_conf = max(o.confidence for _, o in outputs)
86
+ disagree = len(outputs) == 2 and (
87
+ (outputs[0][1].result == "issue_detected" and outputs[1][1].result == "normal")
88
+ or (outputs[0][1].result == "normal" and outputs[1][1].result == "issue_detected")
89
+ )
90
+
91
+ if has_issue and max_conf >= confidence_issue_min:
92
+ if disagree and uncertain_if_disagree:
93
+ override = any(o.confidence >= high_confidence_override and o.result == "issue_detected" for _, o in outputs)
94
+ if override:
95
+ result = "issue_detected"
96
+ issue_type = next(o.issue_type for _, o in outputs if o.result == "issue_detected" and o.confidence >= high_confidence_override)
97
+ primary_issue = issue_type
98
+ else:
99
+ result = "uncertain"
100
+ issue_type = None
101
+ else:
102
+ result = "issue_detected"
103
+ defect_outs = [(n, o) for n, o in outputs if o.result == "issue_detected"]
104
+ if len(defect_outs) >= 2:
105
+ defect_outs.sort(key=lambda x: _severity_rank(x[1].severity, severity_order), reverse=True)
106
+ issue_type = defect_outs[0][1].issue_type
107
+ primary_issue = defect_outs[0][1].issue_type
108
+ secondary_issue = defect_outs[1][1].issue_type if defect_outs[0][1].issue_type != defect_outs[1][1].issue_type else None
109
+ else:
110
+ issue_type = defect_outs[0][1].issue_type
111
+ primary_issue = issue_type
112
+ elif all_normal and all(o.confidence >= confidence_normal_min for _, o in outputs):
113
+ result = "normal"
114
+ issue_type = "normal"
115
+ else:
116
+ result = "uncertain"
117
+ issue_type = None
118
+ confidence = max_conf if result != "uncertain" else min(o.confidence for _, o in outputs)
119
+
120
+ return {
121
+ "diagnostic_element": "electrical_outlets",
122
+ "result": result,
123
+ "issue_type": issue_type,
124
+ "severity": max_severity,
125
+ "confidence": round(confidence, 4),
126
+ "modality_contributions": modality_contributions if len(modality_contributions) > 1 else None,
127
+ "primary_issue": primary_issue if result == "issue_detected" else None,
128
+ "secondary_issue": secondary_issue,
129
+ }
src/inference/wrapper.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference wrapper: load image + audio models, run modalities present, apply fusion, return schema.
3
+ """
4
+ from pathlib import Path
5
+ from typing import Optional, Dict, Any, BinaryIO
6
+ import json
7
+ import torch
8
+ import torchaudio
9
+ from torchvision import transforms
10
+ from PIL import Image
11
+
12
+ # Optional imports for models
13
+ import sys
14
+ ROOT = Path(__file__).resolve().parent.parent.parent
15
+ sys.path.insert(0, str(ROOT))
16
+
17
+
18
+ def _load_image_model(weights_path: Path, label_mapping_path: Path, device: str):
19
+ from src.models.image_model import ElectricalOutletsImageModel
20
+ ckpt = torch.load(weights_path, map_location=device)
21
+ model = ElectricalOutletsImageModel(
22
+ num_classes=ckpt["num_classes"],
23
+ label_mapping_path=label_mapping_path,
24
+ pretrained=False,
25
+ )
26
+ model.load_state_dict(ckpt["model_state_dict"])
27
+ model.idx_to_issue_type = ckpt.get("idx_to_issue_type")
28
+ model.idx_to_severity = ckpt.get("idx_to_severity")
29
+ model.eval()
30
+ return model.to(device), ckpt.get("temperature", 1.0)
31
+
32
+
33
+ def _load_audio_model(weights_path: Path, label_mapping_path: Path, device: str, config: dict):
34
+ from src.models.audio_model import ElectricalOutletsAudioModel
35
+ ckpt = torch.load(weights_path, map_location=device)
36
+ model = ElectricalOutletsAudioModel(
37
+ num_classes=ckpt["num_classes"],
38
+ label_mapping_path=label_mapping_path,
39
+ n_mels=config.get("n_mels", 64),
40
+ time_steps=config.get("time_steps", 128),
41
+ )
42
+ model.load_state_dict(ckpt["model_state_dict"])
43
+ model.idx_to_label = ckpt.get("idx_to_label")
44
+ model.idx_to_issue_type = ckpt.get("idx_to_issue_type")
45
+ model.idx_to_severity = ckpt.get("idx_to_severity")
46
+ model.eval()
47
+ return model.to(device), ckpt.get("temperature", 1.0)
48
+
49
+
50
+ def run_electrical_outlets_inference(
51
+ image_path: Optional[Path] = None,
52
+ image_fp: Optional[BinaryIO] = None,
53
+ audio_path: Optional[Path] = None,
54
+ audio_fp: Optional[BinaryIO] = None,
55
+ weights_dir: Path = None,
56
+ config_dir: Path = None,
57
+ device: str = None,
58
+ ) -> Dict[str, Any]:
59
+ """
60
+ Run image and/or audio model, then fuse. Returns canonical schema dict.
61
+ """
62
+ if weights_dir is None:
63
+ weights_dir = ROOT / "weights"
64
+ if config_dir is None:
65
+ config_dir = ROOT / "config"
66
+ if device is None:
67
+ device = "cuda" if torch.cuda.is_available() else "cpu"
68
+
69
+ label_mapping_path = config_dir / "label_mapping.json"
70
+ thresholds_path = config_dir / "thresholds.yaml"
71
+ import yaml
72
+ with open(thresholds_path) as f:
73
+ thresholds = yaml.safe_load(f)
74
+
75
+ image_out = None
76
+ if image_path or image_fp:
77
+ img = Image.open(image_path or image_fp).convert("RGB")
78
+ tf = transforms.Compose([
79
+ transforms.Resize(256),
80
+ transforms.CenterCrop(224),
81
+ transforms.ToTensor(),
82
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
83
+ ])
84
+ x = tf(img).unsqueeze(0).to(device)
85
+ model, T = _load_image_model(weights_dir / "electrical_outlets_image_best.pt", label_mapping_path, device)
86
+ with torch.no_grad():
87
+ logits = model(x) / T
88
+ from src.fusion.fusion_logic import ModalityOutput
89
+ pred = model.predict_to_schema(logits)
90
+ image_out = ModalityOutput(
91
+ result=pred["result"],
92
+ issue_type=pred.get("issue_type"),
93
+ severity=pred["severity"],
94
+ confidence=pred["confidence"],
95
+ )
96
+
97
+ audio_out = None
98
+ if (audio_path or audio_fp) and (weights_dir / "electrical_outlets_audio_best.pt").exists():
99
+ if audio_path:
100
+ waveform, sr = torchaudio.load(str(audio_path))
101
+ else:
102
+ import io
103
+ waveform, sr = torchaudio.load(io.BytesIO(audio_fp.read()))
104
+ if sr != 16000:
105
+ waveform = torchaudio.functional.resample(waveform, sr, 16000)
106
+ if waveform.shape[0] > 1:
107
+ waveform = waveform.mean(dim=0, keepdim=True)
108
+ target_len = int(5.0 * 16000)
109
+ if waveform.shape[1] >= target_len:
110
+ start = (waveform.shape[1] - target_len) // 2
111
+ waveform = waveform[:, start : start + target_len]
112
+ else:
113
+ waveform = torch.nn.functional.pad(waveform, (0, target_len - waveform.shape[1]))
114
+ mel = torchaudio.transforms.MelSpectrogram(
115
+ sample_rate=16000, n_fft=512, hop_length=256, win_length=512, n_mels=64,
116
+ )(waveform)
117
+ log_mel = torch.log(mel.clamp(min=1e-5)).unsqueeze(0).to(device)
118
+ model, T = _load_audio_model(
119
+ weights_dir / "electrical_outlets_audio_best.pt",
120
+ label_mapping_path,
121
+ device,
122
+ {"n_mels": 64, "time_steps": 128},
123
+ )
124
+ with torch.no_grad():
125
+ logits = model(log_mel) / T
126
+ from src.fusion.fusion_logic import ModalityOutput
127
+ pred = model.predict_to_schema(logits)
128
+ audio_out = ModalityOutput(
129
+ result=pred["result"],
130
+ issue_type=pred.get("issue_type"),
131
+ severity=pred["severity"],
132
+ confidence=pred["confidence"],
133
+ )
134
+
135
+ from src.fusion.fusion_logic import fuse_modalities
136
+ return fuse_modalities(
137
+ image_out,
138
+ audio_out,
139
+ confidence_issue_min=thresholds.get("confidence_issue_min", 0.6),
140
+ confidence_normal_min=thresholds.get("confidence_normal_min", 0.75),
141
+ uncertain_if_disagree=thresholds.get("uncertain_if_disagree", True),
142
+ high_confidence_override=thresholds.get("high_confidence_override", 0.92),
143
+ severity_order=thresholds.get("severity_order"),
144
+ )
src/models/audio_model.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Audio classifier for Electrical Outlets. Expects spectrogram or waveform; outputs class logits.
3
+ Severity from label_mapping. Small CNN for 100-sample regime.
4
+ """
5
+ from pathlib import Path
6
+ from typing import Dict, Any, Optional
7
+ import json
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+
12
+ class SpectrogramCNN(nn.Module):
13
+ """Lightweight CNN on mel spectrogram (n_mels x time)."""
14
+
15
+ def __init__(self, n_mels: int = 64, time_steps: int = 128, num_classes: int = 4):
16
+ super().__init__()
17
+ self.conv = nn.Sequential(
18
+ nn.Conv2d(1, 32, 3, padding=1),
19
+ nn.BatchNorm2d(32),
20
+ nn.ReLU(),
21
+ nn.MaxPool2d(2),
22
+ nn.Conv2d(32, 64, 3, padding=1),
23
+ nn.BatchNorm2d(64),
24
+ nn.ReLU(),
25
+ nn.MaxPool2d(2),
26
+ nn.Conv2d(64, 128, 3, padding=1),
27
+ nn.BatchNorm2d(128),
28
+ nn.ReLU(),
29
+ nn.AdaptiveAvgPool2d(1),
30
+ )
31
+ self.fc = nn.Linear(128, num_classes)
32
+
33
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
34
+ if x.dim() == 2:
35
+ x = x.unsqueeze(0).unsqueeze(0)
36
+ elif x.dim() == 3:
37
+ x = x.unsqueeze(1)
38
+ x = self.conv(x)
39
+ x = x.flatten(1)
40
+ return self.fc(x)
41
+
42
+
43
+ class ElectricalOutletsAudioModel(nn.Module):
44
+ """Wrapper: optional mel transform then SpectrogramCNN. Severity from mapping."""
45
+
46
+ def __init__(
47
+ self,
48
+ num_classes: int = 4,
49
+ label_mapping_path: Optional[Path] = None,
50
+ n_mels: int = 64,
51
+ time_steps: int = 128,
52
+ ):
53
+ super().__init__()
54
+ self.num_classes = num_classes
55
+ self.n_mels = n_mels
56
+ self.time_steps = time_steps
57
+ self.backbone = SpectrogramCNN(n_mels=n_mels, time_steps=time_steps, num_classes=num_classes)
58
+ self.idx_to_label = None
59
+ self.idx_to_issue_type = None
60
+ self.idx_to_severity = None
61
+ if label_mapping_path and Path(label_mapping_path).exists():
62
+ with open(label_mapping_path) as f:
63
+ lm = json.load(f)
64
+ self.idx_to_label = lm["audio"]["idx_to_label"]
65
+ self.idx_to_issue_type = [lm["audio"]["label_to_issue_type"].get(lbl, "normal") for lbl in lm["audio"]["idx_to_label"]]
66
+ self.idx_to_severity = [lm["audio"]["label_to_severity"].get(lm["audio"]["idx_to_label"][i], "medium") for i in range(num_classes)]
67
+
68
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
69
+ return self.backbone(x)
70
+
71
+ def predict_to_schema(self, logits: torch.Tensor) -> Dict[str, Any]:
72
+ probs = torch.softmax(logits, dim=-1)
73
+ if logits.dim() == 1:
74
+ probs = probs.unsqueeze(0)
75
+ conf, pred = probs.max(dim=-1)
76
+ pred = pred.item() if pred.numel() == 1 else pred
77
+ conf = conf.item() if conf.numel() == 1 else conf
78
+ issue_type = (self.idx_to_issue_type or ["normal"] * self.num_classes)[pred]
79
+ severity = (self.idx_to_severity or ["medium"] * self.num_classes)[pred]
80
+ result = "normal" if issue_type == "normal" else "issue_detected"
81
+ return {
82
+ "result": result,
83
+ "issue_type": issue_type,
84
+ "severity": severity,
85
+ "confidence": float(conf),
86
+ "class_idx": int(pred),
87
+ }
src/models/image_model.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Image classifier for Electrical Outlets. EfficientNet-B0 backbone + MLP head.
3
+ FINAL v5: 5 classes (no GFCI).
4
+ """
5
+ from pathlib import Path
6
+ from typing import Dict, Any, Optional
7
+ import json
8
+ import torch
9
+ import torch.nn as nn
10
+ from torchvision import models
11
+
12
+
13
+ class ElectricalOutletsImageModel(nn.Module):
14
+
15
+ def __init__(
16
+ self,
17
+ num_classes: int = 5,
18
+ label_mapping_path: Optional[Path] = None,
19
+ pretrained: bool = True,
20
+ head_hidden: int = 256,
21
+ head_dropout: float = 0.4,
22
+ ):
23
+ super().__init__()
24
+ self.num_classes = num_classes
25
+ self.backbone = models.efficientnet_b0(
26
+ weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1 if pretrained else None
27
+ )
28
+ in_features = self.backbone.classifier[1].in_features # 1280
29
+ self.backbone.classifier = nn.Identity()
30
+
31
+ self.head = nn.Sequential(
32
+ nn.Dropout(head_dropout),
33
+ nn.Linear(in_features, head_hidden),
34
+ nn.ReLU(),
35
+ nn.Dropout(head_dropout * 0.5),
36
+ nn.Linear(head_hidden, num_classes),
37
+ )
38
+
39
+ self.idx_to_issue_type = None
40
+ self.idx_to_severity = None
41
+ if label_mapping_path and Path(label_mapping_path).exists():
42
+ with open(label_mapping_path) as f:
43
+ lm = json.load(f)
44
+ self.idx_to_issue_type = lm["image"]["idx_to_issue_type"]
45
+ self.idx_to_severity = lm["image"]["idx_to_severity"]
46
+
47
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
48
+ features = self.backbone(x)
49
+ return self.head(features)
50
+
51
+ def predict_to_schema(self, logits: torch.Tensor) -> Dict[str, Any]:
52
+ probs = torch.softmax(logits, dim=-1)
53
+ if logits.dim() == 1:
54
+ probs = probs.unsqueeze(0)
55
+ conf, pred = probs.max(dim=-1)
56
+ pred = pred.item() if pred.numel() == 1 else pred
57
+ conf = conf.item() if conf.numel() == 1 else conf
58
+ issue_type = (self.idx_to_issue_type or ["unknown"] * self.num_classes)[pred]
59
+ severity = (self.idx_to_severity or ["medium"] * self.num_classes)[pred]
60
+ result = "normal" if issue_type == "normal" else "issue_detected"
61
+ return {
62
+ "result": result,
63
+ "issue_type": issue_type,
64
+ "severity": severity,
65
+ "confidence": float(conf),
66
+ "class_idx": int(pred),
67
+ }
test.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test script for Electrical Outlets diagnostic pipeline.
3
+
4
+ Usage:
5
+ python test.py --image path/to/outlet.jpg # Test image only
6
+ python test.py --audio path/to/recording.wav # Test audio only
7
+ python test.py --image photo.jpg --audio recording.wav # Test both (fusion)
8
+ python test.py --list # List sample images from dataset
9
+ python test.py --eval # Run full validation set evaluation
10
+
11
+ Requirements:
12
+ pip install torch torchvision torchaudio Pillow PyYAML soundfile
13
+ """
14
+ from pathlib import Path
15
+ import sys
16
+ import argparse
17
+ import json
18
+ from collections import defaultdict
19
+
20
+ import torch
21
+ from torchvision import transforms
22
+ from PIL import Image
23
+
24
+ ROOT = Path(__file__).resolve().parent
25
+ sys.path.insert(0, str(ROOT))
26
+
27
+
28
+ def load_image_model(weights_path, mapping_path, device):
29
+ from src.models.image_model import ElectricalOutletsImageModel
30
+
31
+ ckpt = torch.load(weights_path, map_location=device, weights_only=False)
32
+ # Infer head_hidden from saved weights (head.1 is the first Linear)
33
+ head_hidden = ckpt["model_state_dict"]["head.1.weight"].shape[0]
34
+ model = ElectricalOutletsImageModel(
35
+ num_classes=ckpt["num_classes"],
36
+ label_mapping_path=Path(mapping_path),
37
+ pretrained=False,
38
+ head_hidden=head_hidden,
39
+ )
40
+ model.load_state_dict(ckpt["model_state_dict"])
41
+ model.idx_to_issue_type = ckpt.get("idx_to_issue_type")
42
+ model.idx_to_severity = ckpt.get("idx_to_severity")
43
+ model.eval().to(device)
44
+ T = ckpt.get("temperature", 1.0)
45
+ # Clamp bad temperature values
46
+ if T <= 0 or T > 10:
47
+ T = 1.0
48
+ return model, T
49
+
50
+
51
+ def load_audio_model(weights_path, mapping_path, device):
52
+ from src.models.audio_model import ElectricalOutletsAudioModel
53
+ import yaml
54
+
55
+ ckpt = torch.load(weights_path, map_location=device, weights_only=False)
56
+
57
+ # Load audio config for n_mels
58
+ audio_cfg_path = ROOT / "config" / "audio_train_config.yaml"
59
+ n_mels, time_steps = 128, 128
60
+ if audio_cfg_path.exists():
61
+ with open(audio_cfg_path) as f:
62
+ acfg = yaml.safe_load(f)
63
+ n_mels = acfg.get("model", {}).get("n_mels", 128)
64
+ time_steps = acfg.get("model", {}).get("time_steps", 128)
65
+
66
+ model = ElectricalOutletsAudioModel(
67
+ num_classes=ckpt["num_classes"],
68
+ label_mapping_path=Path(mapping_path),
69
+ n_mels=n_mels,
70
+ time_steps=time_steps,
71
+ )
72
+ model.load_state_dict(ckpt["model_state_dict"])
73
+ model.idx_to_label = ckpt.get("idx_to_label")
74
+ model.idx_to_issue_type = ckpt.get("idx_to_issue_type")
75
+ model.idx_to_severity = ckpt.get("idx_to_severity")
76
+ model.eval().to(device)
77
+ T = ckpt.get("temperature", 1.0)
78
+ if T <= 0 or T > 10:
79
+ T = 1.0
80
+ return model, T
81
+
82
+
83
+ def predict_image(image_path, device="cuda"):
84
+ weights = ROOT / "weights" / "electrical_outlets_image_best.pt"
85
+ mapping = ROOT / "config" / "label_mapping.json"
86
+
87
+ if not weights.exists():
88
+ print(f"ERROR: Image weights not found at {weights}")
89
+ return None
90
+
91
+ model, T = load_image_model(weights, mapping, device)
92
+
93
+ tf = transforms.Compose([
94
+ transforms.Resize(256),
95
+ transforms.CenterCrop(224),
96
+ transforms.ToTensor(),
97
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
98
+ ])
99
+ img = Image.open(image_path).convert("RGB")
100
+ x = tf(img).unsqueeze(0).to(device)
101
+
102
+ with torch.no_grad():
103
+ logits = model(x) / T
104
+ probs = torch.softmax(logits, dim=-1)
105
+
106
+ pred = model.predict_to_schema(logits)
107
+
108
+ print(f"\n{'='*55}")
109
+ print(f" IMAGE: {Path(image_path).name}")
110
+ print(f"{'='*55}")
111
+ print(f" Prediction: {pred['issue_type']}")
112
+ print(f" Severity: {pred['severity']}")
113
+ print(f" Confidence: {pred['confidence']:.1%}")
114
+ print(f" Result: {pred['result']}")
115
+ print(f"\n Class probabilities:")
116
+ for i, p in enumerate(probs[0].tolist()):
117
+ name = model.idx_to_issue_type[i] if model.idx_to_issue_type else f"class_{i}"
118
+ bar = "β–ˆ" * int(p * 30)
119
+ tag = " β—„" if i == pred["class_idx"] else ""
120
+ print(f" {name:20s} {p:6.1%} {bar}{tag}")
121
+
122
+ return pred
123
+
124
+
125
+ def predict_audio(audio_path, device="cuda"):
126
+ import torchaudio
127
+ import yaml
128
+
129
+ weights = ROOT / "weights" / "electrical_outlets_audio_best.pt"
130
+ mapping = ROOT / "config" / "label_mapping.json"
131
+
132
+ if not weights.exists():
133
+ print(f"ERROR: Audio weights not found at {weights}")
134
+ return None
135
+
136
+ model, T = load_audio_model(weights, mapping, device)
137
+
138
+ # Load audio config
139
+ audio_cfg_path = ROOT / "config" / "audio_train_config.yaml"
140
+ sample_rate, n_mels, n_fft, hop, win = 22050, 128, 1024, 512, 1024
141
+ target_sec = 5.0
142
+ if audio_cfg_path.exists():
143
+ with open(audio_cfg_path) as f:
144
+ acfg = yaml.safe_load(f)
145
+ sample_rate = acfg["data"].get("sample_rate", 22050)
146
+ target_sec = acfg["data"].get("target_length_sec", 5.0)
147
+ sc = acfg.get("spectrogram", {})
148
+ n_mels = sc.get("n_mels", 128)
149
+ n_fft = sc.get("n_fft", 1024)
150
+ hop = sc.get("hop_length", 512)
151
+ win = sc.get("win_length", 1024)
152
+
153
+ waveform, sr = torchaudio.load(str(audio_path))
154
+ if sr != sample_rate:
155
+ waveform = torchaudio.functional.resample(waveform, sr, sample_rate)
156
+ if waveform.shape[0] > 1:
157
+ waveform = waveform.mean(dim=0, keepdim=True)
158
+
159
+ target_len = int(target_sec * sample_rate)
160
+ if waveform.shape[1] >= target_len:
161
+ start = (waveform.shape[1] - target_len) // 2
162
+ waveform = waveform[:, start:start + target_len]
163
+ else:
164
+ waveform = torch.nn.functional.pad(waveform, (0, target_len - waveform.shape[1]))
165
+
166
+ mel = torchaudio.transforms.MelSpectrogram(
167
+ sample_rate=sample_rate, n_fft=n_fft, hop_length=hop,
168
+ win_length=win, n_mels=n_mels,
169
+ )(waveform)
170
+ log_mel = torch.log(mel.clamp(min=1e-5)).unsqueeze(0).to(device)
171
+
172
+ with torch.no_grad():
173
+ logits = model(log_mel) / T
174
+ probs = torch.softmax(logits, dim=-1)
175
+
176
+ pred = model.predict_to_schema(logits)
177
+
178
+ print(f"\n{'='*55}")
179
+ print(f" AUDIO: {Path(audio_path).name}")
180
+ print(f"{'='*55}")
181
+ print(f" Prediction: {pred['issue_type']}")
182
+ print(f" Severity: {pred['severity']}")
183
+ print(f" Confidence: {pred['confidence']:.1%}")
184
+ print(f" Result: {pred['result']}")
185
+ print(f"\n Class probabilities:")
186
+ labels = model.idx_to_label or [f"class_{i}" for i in range(model.num_classes)]
187
+ for i, p in enumerate(probs[0].tolist()):
188
+ bar = "β–ˆ" * int(p * 30)
189
+ tag = " β—„" if i == pred["class_idx"] else ""
190
+ print(f" {labels[i]:20s} {p:6.1%} {bar}{tag}")
191
+
192
+ return pred
193
+
194
+
195
+ def run_fusion(image_pred, audio_pred):
196
+ from src.fusion.fusion_logic import fuse_modalities, ModalityOutput
197
+ import yaml
198
+
199
+ thresholds_path = ROOT / "config" / "thresholds.yaml"
200
+ thresholds = {}
201
+ if thresholds_path.exists():
202
+ with open(thresholds_path) as f:
203
+ thresholds = yaml.safe_load(f)
204
+
205
+ image_out = ModalityOutput(
206
+ result=image_pred["result"],
207
+ issue_type=image_pred.get("issue_type"),
208
+ severity=image_pred["severity"],
209
+ confidence=image_pred["confidence"],
210
+ ) if image_pred else None
211
+
212
+ audio_out = ModalityOutput(
213
+ result=audio_pred["result"],
214
+ issue_type=audio_pred.get("issue_type"),
215
+ severity=audio_pred["severity"],
216
+ confidence=audio_pred["confidence"],
217
+ ) if audio_pred else None
218
+
219
+ result = fuse_modalities(
220
+ image_out, audio_out,
221
+ confidence_issue_min=thresholds.get("confidence_issue_min", 0.6),
222
+ confidence_normal_min=thresholds.get("confidence_normal_min", 0.75),
223
+ uncertain_if_disagree=thresholds.get("uncertain_if_disagree", True),
224
+ high_confidence_override=thresholds.get("high_confidence_override", 0.92),
225
+ severity_order=thresholds.get("severity_order"),
226
+ )
227
+
228
+ print(f"\n{'='*55}")
229
+ print(f" FUSED RESULT")
230
+ print(f"{'='*55}")
231
+ print(f" Result: {result['result']}")
232
+ print(f" Issue: {result['issue_type']}")
233
+ print(f" Severity: {result['severity']}")
234
+ print(f" Confidence: {result['confidence']:.1%}")
235
+ if result.get("primary_issue"):
236
+ print(f" Primary: {result['primary_issue']}")
237
+ if result.get("secondary_issue"):
238
+ print(f" Secondary: {result['secondary_issue']}")
239
+
240
+ return result
241
+
242
+
243
+ def list_samples():
244
+ mapping_path = ROOT / "config" / "label_mapping.json"
245
+ with open(mapping_path) as f:
246
+ lm = json.load(f)
247
+
248
+ data_root = ROOT / "ELECTRICAL OUTLETS-20260106T153508Z-3-001"
249
+ if not data_root.exists():
250
+ print(f"Dataset not found at {data_root}")
251
+ return
252
+
253
+ print(f"\nDataset: {data_root}")
254
+ print(f"{'='*60}")
255
+ for folder in sorted(data_root.iterdir()):
256
+ if not folder.is_dir():
257
+ continue
258
+ cls = lm["image"]["folder_to_class"].get(folder.name, "UNMAPPED")
259
+ imgs = list(folder.glob("*.jpg")) + list(folder.glob("*.jpeg")) + list(folder.glob("*.png"))
260
+ print(f"\n {folder.name}")
261
+ print(f" β†’ class: {cls} | {len(imgs)} images")
262
+ for img in imgs[:3]:
263
+ print(f" {img}")
264
+
265
+ # Audio
266
+ audio_root = ROOT / "electrical_outlets_sounds_100"
267
+ if audio_root.exists():
268
+ print(f"\n\nAudio: {audio_root}")
269
+ print(f"{'='*60}")
270
+ for folder in sorted(audio_root.iterdir()):
271
+ if folder.is_dir():
272
+ wavs = list(folder.glob("*.wav"))
273
+ print(f" {folder.name}: {len(wavs)} files")
274
+ for w in wavs[:2]:
275
+ print(f" {w}")
276
+
277
+
278
+ def run_eval(device="cuda"):
279
+ """Run full evaluation on validation split."""
280
+ weights = ROOT / "weights" / "electrical_outlets_image_best.pt"
281
+ mapping = ROOT / "config" / "label_mapping.json"
282
+
283
+ if not weights.exists():
284
+ print("No image weights found.")
285
+ return
286
+
287
+ model, T = load_image_model(weights, mapping, device)
288
+
289
+ import yaml
290
+ cfg_path = ROOT / "config" / "image_train_config.yaml"
291
+ with open(cfg_path) as f:
292
+ cfg = yaml.safe_load(f)
293
+
294
+ from src.data.image_dataset import ElectricalOutletsImageDataset
295
+ val_tf = transforms.Compose([
296
+ transforms.Resize(256),
297
+ transforms.CenterCrop(224),
298
+ transforms.ToTensor(),
299
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
300
+ ])
301
+ data_root = ROOT / cfg["data"]["root"]
302
+ val_ds = ElectricalOutletsImageDataset(
303
+ data_root, mapping, split="val",
304
+ train_ratio=cfg["data"]["train_ratio"],
305
+ val_ratio=cfg["data"]["val_ratio"],
306
+ seed=cfg["data"].get("seed", 42),
307
+ transform=val_tf,
308
+ )
309
+
310
+ with open(mapping) as f:
311
+ lm = json.load(f)
312
+ issue_names = lm["image"]["idx_to_issue_type"]
313
+
314
+ correct = 0
315
+ total = 0
316
+ class_correct = defaultdict(int)
317
+ class_total = defaultdict(int)
318
+ confusion = defaultdict(lambda: defaultdict(int))
319
+
320
+ model.eval()
321
+ with torch.no_grad():
322
+ for i in range(len(val_ds)):
323
+ x, y = val_ds[i]
324
+ logits = model(x.unsqueeze(0).to(device)) / T
325
+ pred = logits.argmax(1).item()
326
+ correct += (pred == y)
327
+ total += 1
328
+ class_correct[y] += (pred == y)
329
+ class_total[y] += 1
330
+ confusion[y][pred] += 1
331
+
332
+ print(f"\n{'='*55}")
333
+ print(f" VALIDATION RESULTS ({total} samples)")
334
+ print(f"{'='*55}")
335
+ print(f" Overall accuracy: {correct/total:.1%}")
336
+ print(f"\n Per-class recall:")
337
+ for c in sorted(class_total.keys()):
338
+ name = issue_names[c] if c < len(issue_names) else f"class_{c}"
339
+ recall = class_correct[c] / class_total[c] if class_total[c] > 0 else 0
340
+ bar = "β–ˆ" * int(recall * 20)
341
+ print(f" {name:20s} {recall:6.1%} ({class_correct[c]}/{class_total[c]}) {bar}")
342
+
343
+ print(f"\n Confusion matrix:")
344
+ classes = sorted(class_total.keys())
345
+ header = " Actual \\ Pred " + "".join(f"{issue_names[c][:8]:>9s}" for c in classes)
346
+ print(header)
347
+ for actual in classes:
348
+ row = f" {issue_names[actual][:14]:14s}"
349
+ for pred_c in classes:
350
+ count = confusion[actual][pred_c]
351
+ row += f" {count:6d}" if count > 0 else f" {'Β·':>6s}"
352
+ row += " "
353
+ print(row)
354
+
355
+
356
+ if __name__ == "__main__":
357
+ parser = argparse.ArgumentParser(description="Test Electrical Outlets Diagnostic Pipeline")
358
+ parser.add_argument("--image", type=str, help="Path to image file")
359
+ parser.add_argument("--audio", type=str, help="Path to audio WAV file")
360
+ parser.add_argument("--list", action="store_true", help="List sample files from dataset")
361
+ parser.add_argument("--eval", action="store_true", help="Run full validation evaluation")
362
+ parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
363
+ args = parser.parse_args()
364
+
365
+ if args.list:
366
+ list_samples()
367
+ elif args.eval:
368
+ run_eval(args.device)
369
+ elif args.image or args.audio:
370
+ img_pred = predict_image(args.image, args.device) if args.image else None
371
+ audio_pred = predict_audio(args.audio, args.device) if args.audio else None
372
+ if img_pred and audio_pred:
373
+ run_fusion(img_pred, audio_pred)
374
+ print()
375
+ else:
376
+ print("Electrical Outlets Diagnostic Pipeline β€” Test Script")
377
+ print("=" * 55)
378
+ print()
379
+ print("Usage:")
380
+ print(" python test.py --image path/to/photo.jpg")
381
+ print(" python test.py --audio path/to/recording.wav")
382
+ print(" python test.py --image photo.jpg --audio recording.wav")
383
+ print(" python test.py --list")
384
+ print(" python test.py --eval")
385
+ print()
386
+ print("Examples:")
387
+ print(' python test.py --image "ELECTRICAL OUTLETS-20260106T153508Z-3-001\\Burn marks - overheating 250\\img_001.jpg"')
388
+ print(' python test.py --audio "electrical_outlets_sounds_100\\buzzing_outlet\\buzzing_outlet_060.wav"')
test_single_image.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Quick test: classify a single image.
3
+ python test_single_image.py --image "path/to/image.jpg"
4
+ python test_single_image.py --list
5
+ """
6
+ from pathlib import Path
7
+ import sys
8
+ import argparse
9
+ import json
10
+ import torch
11
+ from torchvision import transforms
12
+ from PIL import Image
13
+
14
+ ROOT = Path(__file__).resolve().parent
15
+ sys.path.insert(0, str(ROOT))
16
+ from src.models.image_model import ElectricalOutletsImageModel
17
+
18
+
19
+ def predict(image_path, weights="weights/electrical_outlets_image_best.pt",
20
+ mapping="config/label_mapping.json", device=None):
21
+ if device is None:
22
+ device = "cuda" if torch.cuda.is_available() else "cpu"
23
+
24
+ ckpt = torch.load(weights, map_location=device, weights_only=False)
25
+ head_hidden = ckpt["model_state_dict"]["head.1.weight"].shape[0]
26
+ model = ElectricalOutletsImageModel(
27
+ num_classes=ckpt["num_classes"],
28
+ label_mapping_path=Path(mapping),
29
+ pretrained=False,
30
+ head_hidden=head_hidden,
31
+ )
32
+ model.load_state_dict(ckpt["model_state_dict"])
33
+ model.idx_to_issue_type = ckpt.get("idx_to_issue_type")
34
+ model.idx_to_severity = ckpt.get("idx_to_severity")
35
+ model.eval().to(device)
36
+ T = ckpt.get("temperature", 1.0)
37
+ if T <= 0 or T > 10:
38
+ T = 1.0
39
+
40
+ tf = transforms.Compose([
41
+ transforms.Resize(256), transforms.CenterCrop(224),
42
+ transforms.ToTensor(),
43
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
44
+ ])
45
+ img = Image.open(image_path).convert("RGB")
46
+ x = tf(img).unsqueeze(0).to(device)
47
+
48
+ with torch.no_grad():
49
+ logits = model(x) / T
50
+ probs = torch.softmax(logits, dim=-1)
51
+ pred = model.predict_to_schema(logits)
52
+
53
+ print(f"\n{'='*50}")
54
+ print(f" {Path(image_path).name}")
55
+ print(f"{'='*50}")
56
+ print(f" -> {pred['issue_type']} ({pred['severity']} severity)")
57
+ print(f" -> {pred['confidence']:.1%} confidence")
58
+ print(f" -> {pred['result']}")
59
+ print()
60
+ for i, p in enumerate(probs[0].tolist()):
61
+ name = model.idx_to_issue_type[i]
62
+ bar = "β–ˆ" * int(p * 30)
63
+ tag = " β—„" if i == pred["class_idx"] else ""
64
+ print(f" {name:20s} {p:6.1%} {bar}{tag}")
65
+ print()
66
+
67
+
68
+ if __name__ == "__main__":
69
+ p = argparse.ArgumentParser()
70
+ p.add_argument("--image", type=str)
71
+ p.add_argument("--list", action="store_true")
72
+ p.add_argument("--weights", default="weights/electrical_outlets_image_best.pt")
73
+ args = p.parse_args()
74
+
75
+ if args.list:
76
+ with open("config/label_mapping.json") as f:
77
+ lm = json.load(f)
78
+ root = Path("ELECTRICAL OUTLETS-20260106T153508Z-3-001")
79
+ for folder in sorted(root.iterdir()):
80
+ if folder.is_dir():
81
+ imgs = list(folder.glob("*.jpg")) + list(folder.glob("*.jpeg")) + list(folder.glob("*.png"))
82
+ cls = lm["image"]["folder_to_class"].get(folder.name, "UNMAPPED")
83
+ print(f"\n{folder.name} -> {cls} ({len(imgs)} imgs)")
84
+ for img in imgs[:2]:
85
+ print(f" {img}")
86
+ elif args.image:
87
+ predict(args.image, args.weights)
88
+ else:
89
+ print("python test_single_image.py --image path/to/img.jpg")
90
+ print("python test_single_image.py --list")
tests/test_fusion.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unit tests for decision-level fusion."""
2
+ import sys
3
+ from pathlib import Path
4
+ sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
5
+
6
+ from src.fusion.fusion_logic import fuse_modalities, ModalityOutput
7
+
8
+
9
+ def test_image_only_issue():
10
+ out = fuse_modalities(
11
+ image_out=ModalityOutput("issue_detected", "burn_overheating", "high", 0.9),
12
+ audio_out=None,
13
+ )
14
+ assert out["result"] == "issue_detected"
15
+ assert out["severity"] == "high"
16
+ assert out["issue_type"] == "burn_overheating"
17
+
18
+
19
+ def test_both_normal_high_conf():
20
+ out = fuse_modalities(
21
+ image_out=ModalityOutput("normal", "normal", "low", 0.85),
22
+ audio_out=ModalityOutput("normal", "normal", "low", 0.8),
23
+ confidence_normal_min=0.75,
24
+ )
25
+ assert out["result"] == "normal"
26
+ assert out["severity"] == "low"
27
+
28
+
29
+ def test_severity_max():
30
+ out = fuse_modalities(
31
+ image_out=ModalityOutput("issue_detected", "cracked_faceplate", "medium", 0.88),
32
+ audio_out=ModalityOutput("issue_detected", "arcing_pop", "critical", 0.85),
33
+ )
34
+ assert out["severity"] == "critical"
35
+ assert out["result"] == "issue_detected"
36
+
37
+
38
+ def test_uncertain_low_confidence():
39
+ out = fuse_modalities(
40
+ image_out=ModalityOutput("issue_detected", "buzzing", "high", 0.5),
41
+ audio_out=None,
42
+ confidence_issue_min=0.6,
43
+ )
44
+ assert out["result"] == "uncertain"
45
+
46
+
47
+ def test_uncertain_disagree():
48
+ out = fuse_modalities(
49
+ image_out=ModalityOutput("issue_detected", "burn_overheating", "high", 0.7),
50
+ audio_out=ModalityOutput("normal", "normal", "low", 0.7),
51
+ uncertain_if_disagree=True,
52
+ high_confidence_override=0.92,
53
+ )
54
+ assert out["result"] == "uncertain"
55
+
56
+
57
+ def test_no_input():
58
+ out = fuse_modalities(None, None)
59
+ assert out["result"] == "uncertain"
60
+ assert out["confidence"] == 0.0
training/train_audio.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train Electrical Outlets audio model. Spectrogram CNN, class weights, per-class recall, early stopping.
3
+ """
4
+ from pathlib import Path
5
+ import sys
6
+ import argparse
7
+ from typing import Dict
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from torch.utils.data import DataLoader
13
+
14
+ ROOT = Path(__file__).resolve().parent.parent
15
+ sys.path.insert(0, str(ROOT))
16
+
17
+ from src.data.audio_dataset import ElectricalOutletsAudioDataset
18
+ from src.models.audio_model import ElectricalOutletsAudioModel
19
+
20
+
21
+ def load_config(config_path: Path) -> dict:
22
+ import yaml
23
+ with open(config_path) as f:
24
+ return yaml.safe_load(f)
25
+
26
+
27
+ def _wave_to_mel(waveform: torch.Tensor, n_mels: int, n_fft: int, hop: int, win: int) -> torch.Tensor:
28
+ import torchaudio
29
+ mel = torchaudio.transforms.MelSpectrogram(
30
+ sample_rate=16000, n_fft=n_fft, hop_length=hop, win_length=win, n_mels=n_mels,
31
+ )(waveform)
32
+ log_mel = torch.log(mel.clamp(min=1e-5))
33
+ return log_mel
34
+
35
+
36
+ def per_class_recall(logits: torch.Tensor, targets: torch.Tensor, num_classes: int) -> Dict[int, float]:
37
+ preds = logits.argmax(dim=1)
38
+ recall = {}
39
+ for c in range(num_classes):
40
+ mask = targets == c
41
+ if mask.sum() == 0:
42
+ recall[c] = 0.0
43
+ else:
44
+ recall[c] = (preds[mask] == c).float().mean().item()
45
+ return recall
46
+
47
+
48
+ def run_training(
49
+ data_root: Path,
50
+ label_mapping_path: Path,
51
+ config: dict,
52
+ weights_dir: Path,
53
+ device: str = "cuda",
54
+ ):
55
+ train_ratio = config["data"]["train_ratio"]
56
+ val_ratio = config["data"]["val_ratio"]
57
+ seed = config["data"].get("seed", 42)
58
+ batch_size = config["data"]["batch_size"]
59
+ num_workers = config["data"].get("num_workers", 0)
60
+ spec_cfg = config.get("spectrogram", {})
61
+ n_mels = spec_cfg.get("n_mels", 64)
62
+ n_fft = spec_cfg.get("n_fft", 512)
63
+ hop = spec_cfg.get("hop_length", 256)
64
+ win = spec_cfg.get("win_length", 512)
65
+
66
+ def to_mel(x):
67
+ return _wave_to_mel(x, n_mels, n_fft, hop, win)
68
+
69
+ train_ds = ElectricalOutletsAudioDataset(
70
+ data_root, label_mapping_path, split="train",
71
+ train_ratio=train_ratio, val_ratio=val_ratio, seed=seed, transform=to_mel,
72
+ target_length_sec=config["data"].get("target_length_sec", 5.0),
73
+ sample_rate=config["data"].get("sample_rate", 16000),
74
+ )
75
+ val_ds = ElectricalOutletsAudioDataset(
76
+ data_root, label_mapping_path, split="val",
77
+ train_ratio=train_ratio, val_ratio=val_ratio, seed=seed, transform=to_mel,
78
+ target_length_sec=config["data"].get("target_length_sec", 5.0),
79
+ sample_rate=config["data"].get("sample_rate", 16000),
80
+ )
81
+ train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers)
82
+ val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers)
83
+
84
+ num_classes = train_ds.num_classes
85
+ model = ElectricalOutletsAudioModel(
86
+ num_classes=num_classes,
87
+ label_mapping_path=label_mapping_path,
88
+ n_mels=config["model"].get("n_mels", 64),
89
+ time_steps=config["model"].get("time_steps", 128),
90
+ ).to(device)
91
+ opt = torch.optim.AdamW(
92
+ model.parameters(),
93
+ lr=config["training"]["lr"],
94
+ weight_decay=config["training"].get("weight_decay", 1e-4),
95
+ )
96
+ criterion = nn.CrossEntropyLoss()
97
+ epochs = config["training"]["epochs"]
98
+ patience = config["training"].get("early_stopping_patience", 12)
99
+ best_metric = -1.0
100
+ best_epoch = 0
101
+ wait = 0
102
+ recall = {}
103
+
104
+ for epoch in range(epochs):
105
+ model.train()
106
+ for x, y in train_loader:
107
+ x, y = x.to(device), y.to(device)
108
+ opt.zero_grad()
109
+ logits = model(x)
110
+ loss = criterion(logits, y)
111
+ loss.backward()
112
+ opt.step()
113
+
114
+ model.eval()
115
+ val_logits, val_targets = [], []
116
+ with torch.no_grad():
117
+ for x, y in val_loader:
118
+ x = x.to(device)
119
+ val_logits.append(model(x).cpu())
120
+ val_targets.append(y)
121
+ val_logits = torch.cat(val_logits, dim=0)
122
+ val_targets = torch.cat(val_targets, dim=0)
123
+ recall = per_class_recall(val_logits, val_targets, num_classes)
124
+ min_recall = min(recall.values())
125
+ macro_recall = sum(recall.values()) / num_classes
126
+ metric = macro_recall
127
+ if metric > best_metric:
128
+ best_metric = metric
129
+ best_epoch = epoch
130
+ wait = 0
131
+ weights_dir.mkdir(parents=True, exist_ok=True)
132
+ torch.save({
133
+ "model_state_dict": model.state_dict(),
134
+ "num_classes": num_classes,
135
+ "idx_to_label": model.idx_to_label,
136
+ "idx_to_issue_type": model.idx_to_issue_type,
137
+ "idx_to_severity": model.idx_to_severity,
138
+ }, weights_dir / config["output"]["best_name"])
139
+ else:
140
+ wait += 1
141
+ print(f"Epoch {epoch} min_recall={min_recall:.4f} macro_recall={macro_recall:.4f} best={best_metric:.4f}")
142
+ if wait >= patience:
143
+ print("Early stopping at epoch", epoch)
144
+ break
145
+
146
+ if config.get("calibration", {}).get("use_temperature_scaling", False):
147
+ model.load_state_dict(torch.load(weights_dir / config["output"]["best_name"], map_location=device)["model_state_dict"])
148
+ model.eval()
149
+ n_val = len(val_ds)
150
+ cal_size = max(1, int(n_val * config["calibration"].get("val_fraction_for_calibration", 0.5)))
151
+ cal_logits, cal_targets = [], []
152
+ for i in range(cal_size):
153
+ x, y = val_ds[i]
154
+ x = x.unsqueeze(0).to(device)
155
+ with torch.no_grad():
156
+ cal_logits.append(model(x).cpu())
157
+ cal_targets.append(y)
158
+ cal_logits = torch.cat(cal_logits, dim=0)
159
+ cal_targets = torch.tensor(cal_targets)
160
+ temp = nn.Parameter(torch.ones(1) * 1.5)
161
+ opt_cal = torch.optim.LBFGS([temp], lr=0.01, max_iter=50)
162
+ def eval_cal():
163
+ opt_cal.zero_grad()
164
+ loss = F.cross_entropy(cal_logits / temp, cal_targets)
165
+ loss.backward()
166
+ return loss
167
+ opt_cal.step(eval_cal)
168
+ ckpt = torch.load(weights_dir / config["output"]["best_name"], map_location="cpu")
169
+ ckpt["temperature"] = temp.item()
170
+ torch.save(ckpt, weights_dir / config["output"]["best_name"])
171
+
172
+ return {"best_epoch": best_epoch, "best_metric": best_metric, "recall_per_class": recall}
173
+
174
+
175
+ def main():
176
+ parser = argparse.ArgumentParser()
177
+ parser.add_argument("--config", default="config/audio_train_config.yaml")
178
+ parser.add_argument("--data_root", default=None)
179
+ parser.add_argument("--weights_dir", default="weights")
180
+ parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
181
+ args = parser.parse_args()
182
+ root = Path(__file__).resolve().parent.parent
183
+ config = load_config(root / args.config)
184
+ data_root = Path(args.data_root) if args.data_root else root / config["data"]["root"]
185
+ label_mapping_path = root / config["data"]["label_mapping"]
186
+ weights_dir = root / args.weights_dir
187
+ results = run_training(data_root, label_mapping_path, config, weights_dir, args.device)
188
+ report_path = root / "docs" / config["output"]["report_name"]
189
+ report_path.parent.mkdir(parents=True, exist_ok=True)
190
+ with open(report_path, "w") as f:
191
+ f.write("# Audio Model Report (Electrical Outlets)\n\n")
192
+ f.write("- **Preliminary model.** 100 samples is very small; recommend collecting more data.\n")
193
+ f.write(f"- Best epoch: {results['best_epoch']}, best metric: {results['best_metric']:.4f}\n\n")
194
+ f.write("## Per-class recall (validation)\n\n")
195
+ for c, r in results.get("recall_per_class", {}).items():
196
+ f.write(f"- Class {c}: {r:.4f}\n")
197
+ f.write("\n## Limitations\n- Small dataset; use audio as support in fusion. Do not rely on audio-only for critical decisions.\n")
198
+ print("Report written to", report_path)
199
+
200
+
201
+ if __name__ == "__main__":
202
+ main()
training/train_image.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train Electrical Outlets image model.
3
+ FINAL v5: Frozen backbone β†’ partial unfreeze. 5 classes, 1300 images.
4
+ """
5
+ from pathlib import Path
6
+ import sys
7
+ import argparse
8
+ from typing import Dict
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from torch.utils.data import DataLoader
14
+ from torchvision import transforms
15
+
16
+ ROOT = Path(__file__).resolve().parent.parent
17
+ sys.path.insert(0, str(ROOT))
18
+
19
+ from src.data.image_dataset import ElectricalOutletsImageDataset, get_image_class_weights
20
+ from src.models.image_model import ElectricalOutletsImageModel
21
+
22
+
23
+ def load_config(path):
24
+ import yaml
25
+ with open(path) as f:
26
+ return yaml.safe_load(f)
27
+
28
+
29
+ def focal_loss(logits, targets, alpha=0.25, gamma=2.0, weight=None):
30
+ ce = F.cross_entropy(logits, targets, reduction="none", weight=weight)
31
+ pt = torch.exp(-ce)
32
+ return (alpha * (1 - pt) ** gamma * ce).mean()
33
+
34
+
35
+ def per_class_recall(logits, targets, num_classes):
36
+ preds = logits.argmax(dim=1)
37
+ recall = {}
38
+ for c in range(num_classes):
39
+ mask = targets == c
40
+ recall[c] = (preds[mask] == c).float().mean().item() if mask.sum() > 0 else 0.0
41
+ return recall
42
+
43
+
44
+ def run_training(data_root, label_mapping_path, config, weights_dir, device="cuda"):
45
+ cfg_data = config["data"]
46
+ cfg_train = config["training"]
47
+ cfg_aug = config["augmentation"]
48
+ cfg_model = config["model"]
49
+
50
+ # Transforms
51
+ train_tf = transforms.Compose([
52
+ transforms.Resize(cfg_aug["resize"]),
53
+ transforms.RandomResizedCrop(cfg_aug["crop"], scale=(0.65, 1.0)),
54
+ transforms.RandomHorizontalFlip(0.5),
55
+ transforms.RandomRotation(15),
56
+ transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.05),
57
+ transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
58
+ transforms.GaussianBlur(3, sigma=(0.1, 2.0)),
59
+ transforms.ToTensor(),
60
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
61
+ transforms.RandomErasing(p=0.15),
62
+ ])
63
+ val_tf = transforms.Compose([
64
+ transforms.Resize(cfg_aug["resize"]),
65
+ transforms.CenterCrop(cfg_aug["crop"]),
66
+ transforms.ToTensor(),
67
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
68
+ ])
69
+
70
+ # Datasets
71
+ train_ds = ElectricalOutletsImageDataset(
72
+ data_root, label_mapping_path, split="train",
73
+ train_ratio=cfg_data["train_ratio"], val_ratio=cfg_data["val_ratio"],
74
+ seed=cfg_data.get("seed", 42), transform=train_tf,
75
+ )
76
+ val_ds = ElectricalOutletsImageDataset(
77
+ data_root, label_mapping_path, split="val",
78
+ train_ratio=cfg_data["train_ratio"], val_ratio=cfg_data["val_ratio"],
79
+ seed=cfg_data.get("seed", 42), transform=val_tf,
80
+ )
81
+ train_loader = DataLoader(train_ds, batch_size=cfg_data["batch_size"], shuffle=True,
82
+ num_workers=cfg_data.get("num_workers", 4), pin_memory=True)
83
+ val_loader = DataLoader(val_ds, batch_size=cfg_data["batch_size"], shuffle=False,
84
+ num_workers=cfg_data.get("num_workers", 4))
85
+
86
+ num_classes = train_ds.num_classes
87
+ print(f"\nTrain: {len(train_ds)}, Val: {len(val_ds)}, Classes: {num_classes}")
88
+
89
+ # Class weights
90
+ class_weights = None
91
+ if cfg_train.get("use_class_weights", True):
92
+ class_weights = get_image_class_weights(label_mapping_path, data_root).to(device)
93
+ print(f"Class weights: {[f'{w:.3f}' for w in class_weights.tolist()]}")
94
+
95
+ use_focal = cfg_train.get("use_focal", True)
96
+ criterion_ce = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.1)
97
+
98
+ # Model
99
+ model = ElectricalOutletsImageModel(
100
+ num_classes=num_classes,
101
+ label_mapping_path=label_mapping_path,
102
+ pretrained=True,
103
+ head_hidden=cfg_model.get("head_hidden", 256),
104
+ head_dropout=cfg_model.get("head_dropout", 0.4),
105
+ ).to(device)
106
+
107
+ # ══════════════════════════════════════════════
108
+ # STAGE 1: Frozen backbone β€” train head only
109
+ # ══════════════════════════════════════════════
110
+ for p in model.backbone.parameters():
111
+ p.requires_grad = False
112
+
113
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
114
+ total_params = sum(p.numel() for p in model.parameters())
115
+ print(f"Params: {trainable:,} trainable / {total_params:,} total ({100*trainable/total_params:.1f}%)")
116
+
117
+ epochs = cfg_train["epochs"]
118
+ patience = cfg_train.get("early_stopping_patience", 20)
119
+ lr = cfg_train.get("lr", 3e-3)
120
+
121
+ opt = torch.optim.AdamW(
122
+ filter(lambda p: p.requires_grad, model.parameters()),
123
+ lr=lr, weight_decay=cfg_train.get("weight_decay", 1e-3),
124
+ )
125
+ sched = torch.optim.lr_scheduler.OneCycleLR(
126
+ opt, max_lr=lr, epochs=epochs,
127
+ steps_per_epoch=len(train_loader), pct_start=0.15,
128
+ )
129
+
130
+ print(f"\n{'='*60}")
131
+ print(f" Stage 1: Frozen backbone, lr={lr}, {epochs} epochs max")
132
+ print(f"{'='*60}")
133
+
134
+ best_metric = -1.0
135
+ best_epoch = 0
136
+ wait = 0
137
+ recall = {}
138
+
139
+ for epoch in range(epochs):
140
+ model.train()
141
+ epoch_loss = 0
142
+ for x, y in train_loader:
143
+ x, y = x.to(device), y.to(device)
144
+ opt.zero_grad()
145
+ logits = model(x)
146
+ loss = focal_loss(logits, y, weight=class_weights) if use_focal else criterion_ce(logits, y)
147
+ loss.backward()
148
+ opt.step()
149
+ sched.step()
150
+ epoch_loss += loss.item()
151
+
152
+ # Validate
153
+ model.eval()
154
+ vl, vt = [], []
155
+ with torch.no_grad():
156
+ for x, y in val_loader:
157
+ vl.append(model(x.to(device)).cpu())
158
+ vt.append(y)
159
+ vl, vt = torch.cat(vl), torch.cat(vt)
160
+ recall = per_class_recall(vl, vt, num_classes)
161
+ min_r = min(recall.values())
162
+ macro_r = sum(recall.values()) / num_classes
163
+ val_acc = (vl.argmax(1) == vt).float().mean().item()
164
+ metric = min_r if cfg_train.get("early_stopping_metric") == "val_min_recall" else macro_r
165
+
166
+ star = ""
167
+ if metric > best_metric:
168
+ best_metric = metric
169
+ best_epoch = epoch
170
+ wait = 0
171
+ weights_dir.mkdir(parents=True, exist_ok=True)
172
+ torch.save({
173
+ "model_state_dict": model.state_dict(),
174
+ "num_classes": num_classes,
175
+ "idx_to_issue_type": model.idx_to_issue_type,
176
+ "idx_to_severity": model.idx_to_severity,
177
+ }, weights_dir / config["output"]["best_name"])
178
+ star = " β˜…"
179
+ else:
180
+ wait += 1
181
+
182
+ print(f"E{epoch:3d} loss={epoch_loss/len(train_loader):.4f} acc={val_acc:.3f} "
183
+ f"min_r={min_r:.3f} macro={macro_r:.3f} best={best_metric:.3f}@{best_epoch}{star}")
184
+
185
+ if wait >= patience:
186
+ print(f"Early stop @ {epoch}")
187
+ break
188
+
189
+ # ══════════════════════════════════════════════
190
+ # STAGE 2: Unfreeze last 2 backbone blocks
191
+ # ══════════════════════════════════════════════
192
+ if cfg_train.get("finetune_last_blocks", True) and best_metric > 0.15:
193
+ print(f"\n{'='*60}")
194
+ print(f" Stage 2: Partial unfreeze (last 2 blocks)")
195
+ print(f"{'='*60}")
196
+
197
+ ckpt = torch.load(weights_dir / config["output"]["best_name"], map_location=device)
198
+ model.load_state_dict(ckpt["model_state_dict"])
199
+
200
+ for p in model.backbone.parameters():
201
+ p.requires_grad = False
202
+ for name, p in model.backbone.named_parameters():
203
+ if "features.7" in name or "features.8" in name:
204
+ p.requires_grad = True
205
+ # Head stays trainable
206
+ for p in model.head.parameters():
207
+ p.requires_grad = True
208
+
209
+ ft_lr = cfg_train.get("finetune_lr", 5e-5)
210
+ ft_epochs = cfg_train.get("finetune_epochs", 25)
211
+ opt2 = torch.optim.AdamW(
212
+ filter(lambda p: p.requires_grad, model.parameters()),
213
+ lr=ft_lr, weight_decay=1e-3,
214
+ )
215
+ sched2 = torch.optim.lr_scheduler.CosineAnnealingLR(opt2, T_max=ft_epochs, eta_min=1e-6)
216
+ wait2 = 0
217
+
218
+ for epoch in range(ft_epochs):
219
+ model.train()
220
+ el = 0
221
+ for x, y in train_loader:
222
+ x, y = x.to(device), y.to(device)
223
+ opt2.zero_grad()
224
+ logits = model(x)
225
+ loss = focal_loss(logits, y, weight=class_weights) if use_focal else criterion_ce(logits, y)
226
+ loss.backward()
227
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
228
+ opt2.step()
229
+ el += loss.item()
230
+ sched2.step()
231
+
232
+ model.eval()
233
+ vl, vt = [], []
234
+ with torch.no_grad():
235
+ for x, y in val_loader:
236
+ vl.append(model(x.to(device)).cpu())
237
+ vt.append(y)
238
+ vl, vt = torch.cat(vl), torch.cat(vt)
239
+ recall = per_class_recall(vl, vt, num_classes)
240
+ min_r = min(recall.values())
241
+ macro_r = sum(recall.values()) / num_classes
242
+ val_acc = (vl.argmax(1) == vt).float().mean().item()
243
+ metric = min_r if cfg_train.get("early_stopping_metric") == "val_min_recall" else macro_r
244
+
245
+ star = ""
246
+ if metric > best_metric:
247
+ best_metric = metric
248
+ best_epoch = epoch + 1000
249
+ wait2 = 0
250
+ torch.save({
251
+ "model_state_dict": model.state_dict(),
252
+ "num_classes": num_classes,
253
+ "idx_to_issue_type": model.idx_to_issue_type,
254
+ "idx_to_severity": model.idx_to_severity,
255
+ }, weights_dir / config["output"]["best_name"])
256
+ star = " β˜…"
257
+ else:
258
+ wait2 += 1
259
+
260
+ print(f" FT{epoch:3d} loss={el/len(train_loader):.4f} acc={val_acc:.3f} "
261
+ f"min_r={min_r:.3f} macro={macro_r:.3f} best={best_metric:.3f}{star}")
262
+ if wait2 >= 10:
263
+ print(f" FT early stop @ {epoch}")
264
+ break
265
+
266
+ # Temperature scaling
267
+ if config.get("calibration", {}).get("use_temperature_scaling", False):
268
+ ckpt = torch.load(weights_dir / config["output"]["best_name"], map_location=device)
269
+ model.load_state_dict(ckpt["model_state_dict"])
270
+ model.eval()
271
+ cal_size = max(1, int(len(val_ds) * 0.5))
272
+ cl, ct = [], []
273
+ for i in range(cal_size):
274
+ x, y = val_ds[i]
275
+ with torch.no_grad():
276
+ cl.append(model(x.unsqueeze(0).to(device)).cpu())
277
+ ct.append(y)
278
+ cl, ct = torch.cat(cl), torch.tensor(ct)
279
+ temp = nn.Parameter(torch.ones(1) * 1.5)
280
+ opt_c = torch.optim.LBFGS([temp], lr=0.01, max_iter=50)
281
+ def eval_c():
282
+ opt_c.zero_grad()
283
+ l = F.cross_entropy(cl / temp, ct)
284
+ l.backward()
285
+ return l
286
+ opt_c.step(eval_c)
287
+ ckpt["temperature"] = temp.item()
288
+ torch.save(ckpt, weights_dir / config["output"]["best_name"])
289
+ print(f"Temperature T={temp.item():.4f}")
290
+
291
+ print(f"\n{'='*60}")
292
+ print(f" DONE β€” Best: {best_metric:.4f}")
293
+ per_cls = " | ".join([f"C{c}={r:.2f}" for c, r in recall.items()])
294
+ print(f" Recall: {per_cls}")
295
+ print(f"{'='*60}\n")
296
+
297
+ return {"best_epoch": best_epoch, "best_metric": best_metric, "recall_per_class": recall}
298
+
299
+
300
+ def main():
301
+ parser = argparse.ArgumentParser()
302
+ parser.add_argument("--config", default="config/image_train_config.yaml")
303
+ parser.add_argument("--data_root", default=None)
304
+ parser.add_argument("--weights_dir", default="weights")
305
+ parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
306
+ args = parser.parse_args()
307
+ root = ROOT
308
+ config = load_config(root / args.config)
309
+ data_root = Path(args.data_root) if args.data_root else root / config["data"]["root"]
310
+ label_mapping_path = root / config["data"]["label_mapping"]
311
+ weights_dir = root / args.weights_dir
312
+ results = run_training(data_root, label_mapping_path, config, weights_dir, args.device)
313
+
314
+ report_path = root / "docs" / config["output"]["report_name"]
315
+ report_path.parent.mkdir(parents=True, exist_ok=True)
316
+ with open(report_path, "w") as f:
317
+ f.write("# Image Model Report (Electrical Outlets)\n\n")
318
+ f.write(f"- Best metric: {results['best_metric']:.4f}\n")
319
+ f.write(f"- Classes: 5 (burn, cracked, loose, normal, water)\n\n")
320
+ f.write("## Per-class recall\n\n")
321
+ issue_names = ["burn_overheating", "cracked_faceplate", "loose_outlet", "normal", "water_exposed"]
322
+ for c, r in results.get("recall_per_class", {}).items():
323
+ name = issue_names[c] if c < len(issue_names) else f"class_{c}"
324
+ f.write(f"- {name}: {r:.4f}\n")
325
+ print("Report:", report_path)
326
+
327
+
328
+ if __name__ == "__main__":
329
+ main()
weights/.gitkeep ADDED
@@ -0,0 +1 @@
 
 
1
+