Commit Β·
5666923
1
Parent(s): 352967a
Electrical Outlets diagnostic pipeline v1.0
Browse files- .gitignore +38 -10
- README.md +237 -0
- app.py +262 -0
- config/audio_train_config.yaml +41 -0
- config/image_train_config.yaml +43 -0
- config/label_mapping.json +93 -0
- config/schema.yaml +81 -0
- config/thresholds.yaml +13 -0
- releases.md +8 -0
- requirements.txt +145 -0
- src/__init__.py +4 -0
- src/data/audio_dataset.py +105 -0
- src/data/image_dataset.py +155 -0
- src/fusion/fusion_logic.py +129 -0
- src/inference/wrapper.py +144 -0
- src/models/audio_model.py +87 -0
- src/models/image_model.py +67 -0
- test.py +388 -0
- test_single_image.py +90 -0
- tests/test_fusion.py +60 -0
- training/train_audio.py +202 -0
- training/train_image.py +329 -0
- weights/.gitkeep +1 -0
.gitignore
CHANGED
|
@@ -1,10 +1,38 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
*.
|
| 10 |
-
*.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ELECTRICAL OUTLETS-20260106T153508Z-3-001/
|
| 2 |
+
electrical_outlets_sounds_100/
|
| 3 |
+
111/
|
| 4 |
+
|
| 5 |
+
# Model weights (upload separately via LFS)
|
| 6 |
+
weights/*.pt
|
| 7 |
+
|
| 8 |
+
# Binary files
|
| 9 |
+
*.pdf
|
| 10 |
+
*.jpg
|
| 11 |
+
*.jpeg
|
| 12 |
+
*.png
|
| 13 |
+
*.wav
|
| 14 |
+
*.mp3
|
| 15 |
+
|
| 16 |
+
# Python
|
| 17 |
+
__pycache__/
|
| 18 |
+
*.py[cod]
|
| 19 |
+
*.egg-info/
|
| 20 |
+
venv/
|
| 21 |
+
.venv/
|
| 22 |
+
env/
|
| 23 |
+
|
| 24 |
+
# IDE
|
| 25 |
+
.vscode/
|
| 26 |
+
.idea/
|
| 27 |
+
|
| 28 |
+
# OS
|
| 29 |
+
.DS_Store
|
| 30 |
+
Thumbs.db
|
| 31 |
+
|
| 32 |
+
# Notebooks
|
| 33 |
+
notebooks/
|
| 34 |
+
|
| 35 |
+
# Misc
|
| 36 |
+
tmp/
|
| 37 |
+
*.log
|
| 38 |
+
wandb/
|
README.md
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Electrical Outlets & Switches Diagnostic Pipeline
|
| 2 |
+
|
| 3 |
+
Non-intrusive AI diagnostic system for electrical outlets and switches using **image classification** and **audio analysis** with decision-level fusion.
|
| 4 |
+
|
| 5 |
+
## Overview
|
| 6 |
+
|
| 7 |
+
This pipeline analyzes photos and/or audio recordings of electrical outlets to detect potential safety issues without requiring physical inspection. It uses two independent models fused at the decision level for robust predictions.
|
| 8 |
+
|
| 9 |
+
### Image Model
|
| 10 |
+
- **Architecture:** EfficientNet-B0 (frozen backbone) + MLP head (512 β 5 classes)
|
| 11 |
+
- **Classes:** burn/overheating, cracked faceplate, loose outlet, normal, water exposed
|
| 12 |
+
- **Performance:** 77.3% accuracy, 66.7% minimum per-class recall
|
| 13 |
+
- **Training data:** 1,299 images across 10 source categories merged into 5 classes
|
| 14 |
+
|
| 15 |
+
### Audio Model
|
| 16 |
+
- **Architecture:** 3-layer Spectrogram CNN (32β64β128 channels + adaptive pooling)
|
| 17 |
+
- **Classes:** normal, buzzing, crackling/arcing, arcing pop
|
| 18 |
+
- **Performance:** 100% macro recall on validation
|
| 19 |
+
- **Training data:** 100 WAV files (22050 Hz, mel spectrograms with SpecAugment)
|
| 20 |
+
|
| 21 |
+
### Fusion
|
| 22 |
+
- Decision-level fusion combining both modalities
|
| 23 |
+
- Safety-first: prefers "uncertain" over "normal" when in doubt
|
| 24 |
+
- Severity = max(image_severity, audio_severity)
|
| 25 |
+
- Configurable confidence thresholds in `config/thresholds.yaml`
|
| 26 |
+
|
| 27 |
+
## Project Structure
|
| 28 |
+
|
| 29 |
+
```
|
| 30 |
+
CV/
|
| 31 |
+
βββ config/
|
| 32 |
+
β βββ label_mapping.json # Class definitions & folderβclass mapping
|
| 33 |
+
β βββ image_train_config.yaml # Image training hyperparameters
|
| 34 |
+
β βββ audio_train_config.yaml # Audio training hyperparameters
|
| 35 |
+
β βββ thresholds.yaml # Fusion confidence thresholds
|
| 36 |
+
β βββ schema.yaml # API output schema
|
| 37 |
+
βββ src/
|
| 38 |
+
β βββ data/
|
| 39 |
+
β β βββ image_dataset.py # Image dataset with stratified splits
|
| 40 |
+
β β βββ audio_dataset.py # Audio dataset with stratified splits
|
| 41 |
+
β βββ models/
|
| 42 |
+
β β βββ image_model.py # EfficientNet-B0 + MLP classifier
|
| 43 |
+
β β βββ audio_model.py # Spectrogram CNN classifier
|
| 44 |
+
β βββ fusion/
|
| 45 |
+
β β βββ fusion_logic.py # Decision-level fusion
|
| 46 |
+
β βββ inference/
|
| 47 |
+
β βββ wrapper.py # End-to-end inference pipeline
|
| 48 |
+
βββ training/
|
| 49 |
+
β βββ train_image.py # Image model training (2-stage)
|
| 50 |
+
β βββ train_audio.py # Audio model training
|
| 51 |
+
βββ api/
|
| 52 |
+
β βββ main.py # FastAPI endpoint
|
| 53 |
+
βββ weights/
|
| 54 |
+
β βββ electrical_outlets_image_best.pt # Trained image model
|
| 55 |
+
β βββ electrical_outlets_audio_best.pt # Trained audio model
|
| 56 |
+
βββ tests/
|
| 57 |
+
β βββ test_fusion.py # Fusion logic tests
|
| 58 |
+
βββ test_single_image.py # Quick single-image testing
|
| 59 |
+
βββ requirements.txt
|
| 60 |
+
βββ README.md
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
## Setup
|
| 64 |
+
|
| 65 |
+
### Requirements
|
| 66 |
+
|
| 67 |
+
- Python 3.10+
|
| 68 |
+
- NVIDIA GPU with CUDA (recommended: RTX 3090 or better)
|
| 69 |
+
|
| 70 |
+
### Installation
|
| 71 |
+
|
| 72 |
+
```bash
|
| 73 |
+
git clone https://huggingface.co/<your-repo>/electrical-outlets-diagnostic
|
| 74 |
+
cd electrical-outlets-diagnostic
|
| 75 |
+
|
| 76 |
+
pip install -r requirements.txt
|
| 77 |
+
|
| 78 |
+
# If GPU: install CUDA-enabled PyTorch
|
| 79 |
+
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
|
| 80 |
+
# Also needed on Windows:
|
| 81 |
+
pip install soundfile
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
### Download Weights
|
| 85 |
+
|
| 86 |
+
Download the model weights from the HuggingFace repository and place them in `weights/`:
|
| 87 |
+
|
| 88 |
+
```
|
| 89 |
+
weights/
|
| 90 |
+
βββ electrical_outlets_image_best.pt (~ 17 MB)
|
| 91 |
+
βββ electrical_outlets_audio_best.pt (~ 2 MB)
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
## Usage
|
| 95 |
+
|
| 96 |
+
### Test a Single Image
|
| 97 |
+
|
| 98 |
+
```bash
|
| 99 |
+
python test_single_image.py --image path/to/outlet_photo.jpg
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
Output:
|
| 103 |
+
```
|
| 104 |
+
==================================================
|
| 105 |
+
burned_outlet.jpg
|
| 106 |
+
==================================================
|
| 107 |
+
β burn_overheating (high severity)
|
| 108 |
+
β 87.3% confidence
|
| 109 |
+
β issue_detected
|
| 110 |
+
|
| 111 |
+
burn_overheating 87.3% ββββββββββββββββββββββββββ β
|
| 112 |
+
cracked_faceplate 5.2% β
|
| 113 |
+
loose_outlet 3.1% β
|
| 114 |
+
normal 2.8% β
|
| 115 |
+
water_exposed 1.6% β
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
### API Server
|
| 119 |
+
|
| 120 |
+
```bash
|
| 121 |
+
uvicorn api.main:app --host 0.0.0.0 --port 8000
|
| 122 |
+
```
|
| 123 |
+
|
| 124 |
+
#### Endpoints
|
| 125 |
+
|
| 126 |
+
**POST** `/v1/diagnose/electrical_outlets`
|
| 127 |
+
|
| 128 |
+
Upload image and/or audio for diagnosis:
|
| 129 |
+
```bash
|
| 130 |
+
# Image only
|
| 131 |
+
curl -X POST http://localhost:8000/v1/diagnose/electrical_outlets \
|
| 132 |
+
-F "image=@outlet_photo.jpg"
|
| 133 |
+
|
| 134 |
+
# Image + Audio
|
| 135 |
+
curl -X POST http://localhost:8000/v1/diagnose/electrical_outlets \
|
| 136 |
+
-F "image=@outlet_photo.jpg" \
|
| 137 |
+
-F "audio=@outlet_recording.wav"
|
| 138 |
+
```
|
| 139 |
+
|
| 140 |
+
Response:
|
| 141 |
+
```json
|
| 142 |
+
{
|
| 143 |
+
"diagnostic_element": "electrical_outlets",
|
| 144 |
+
"result": "issue_detected",
|
| 145 |
+
"issue_type": "burn_overheating",
|
| 146 |
+
"severity": "high",
|
| 147 |
+
"confidence": 0.873,
|
| 148 |
+
"modality_contributions": null,
|
| 149 |
+
"primary_issue": "burn_overheating",
|
| 150 |
+
"secondary_issue": null
|
| 151 |
+
}
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
**GET** `/health` β Check model availability
|
| 155 |
+
|
| 156 |
+
### Python API
|
| 157 |
+
|
| 158 |
+
```python
|
| 159 |
+
from src.inference.wrapper import run_electrical_outlets_inference
|
| 160 |
+
|
| 161 |
+
result = run_electrical_outlets_inference(
|
| 162 |
+
image_path="path/to/photo.jpg",
|
| 163 |
+
audio_path="path/to/recording.wav", # optional
|
| 164 |
+
)
|
| 165 |
+
print(result)
|
| 166 |
+
```
|
| 167 |
+
|
| 168 |
+
## Training
|
| 169 |
+
|
| 170 |
+
### Image Model
|
| 171 |
+
|
| 172 |
+
```bash
|
| 173 |
+
python training/train_image.py --device cuda
|
| 174 |
+
```
|
| 175 |
+
|
| 176 |
+
Two-stage training:
|
| 177 |
+
1. **Stage 1:** Frozen EfficientNet-B0 backbone, train MLP head only (80-100 epochs)
|
| 178 |
+
2. **Stage 2:** Unfreeze last 2 backbone blocks, fine-tune with low LR (25 epochs)
|
| 179 |
+
|
| 180 |
+
### Audio Model
|
| 181 |
+
|
| 182 |
+
```bash
|
| 183 |
+
python training/train_audio.py --device cuda
|
| 184 |
+
```
|
| 185 |
+
|
| 186 |
+
Single-stage with SpecAugment, class-weighted loss, cosine LR schedule.
|
| 187 |
+
|
| 188 |
+
## Class Mapping
|
| 189 |
+
|
| 190 |
+
### Image Classes (5)
|
| 191 |
+
|
| 192 |
+
| Class | Issue Type | Severity | Source Folders |
|
| 193 |
+
|-------|-----------|----------|----------------|
|
| 194 |
+
| 0 | burn_overheating | high | Burn marks (250), Discoloration (100), Sparking damage (150) |
|
| 195 |
+
| 1 | cracked_faceplate | medium | Cracked faceplate (150), Damaged switches (50) |
|
| 196 |
+
| 2 | loose_outlet | medium | Loose outlet (200), Exposed wiring (150) |
|
| 197 |
+
| 3 | normal | low | Normal outlets (50), Normal switches (50) |
|
| 198 |
+
| 4 | water_exposed | high | Water intrusion (150) |
|
| 199 |
+
|
| 200 |
+
### Audio Classes (4)
|
| 201 |
+
|
| 202 |
+
| Class | Issue Type | Severity |
|
| 203 |
+
|-------|-----------|----------|
|
| 204 |
+
| 0 | normal | low |
|
| 205 |
+
| 1 | buzzing | high |
|
| 206 |
+
| 2 | crackling_arcing | high |
|
| 207 |
+
| 3 | arcing_pop | critical |
|
| 208 |
+
|
| 209 |
+
## Severity Levels
|
| 210 |
+
|
| 211 |
+
| Level | Action Required |
|
| 212 |
+
|-------|----------------|
|
| 213 |
+
| **low** | Monitor β no immediate action |
|
| 214 |
+
| **medium** | Schedule repair |
|
| 215 |
+
| **high** | Shut off circuit immediately |
|
| 216 |
+
| **critical** | Shut off main breaker immediately |
|
| 217 |
+
|
| 218 |
+
## Fusion Logic
|
| 219 |
+
|
| 220 |
+
The fusion layer combines image and audio predictions:
|
| 221 |
+
|
| 222 |
+
- If **both agree** on issue β `issue_detected` with max severity
|
| 223 |
+
- If **both agree** on normal with high confidence β `normal`
|
| 224 |
+
- If **they disagree** β `uncertain` (unless one has >92% confidence)
|
| 225 |
+
- **Safety-first:** defaults to `uncertain` over `normal` when confidence is low
|
| 226 |
+
|
| 227 |
+
## Limitations
|
| 228 |
+
|
| 229 |
+
- Image model trained on web-sourced images (some watermarked/AI-generated)
|
| 230 |
+
- Audio model trained on 100 synthetic clips β use as supporting evidence only
|
| 231 |
+
- Water damage and cracked faceplate classes have lower recall (64-67%)
|
| 232 |
+
- No GFCI failure detection (no training data available)
|
| 233 |
+
- Real-world accuracy will be lower than validation metrics
|
| 234 |
+
|
| 235 |
+
## License
|
| 236 |
+
|
| 237 |
+
Proprietary β for use in the Electrical Outlets diagnostic pipeline only.
|
app.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Electrical Outlets Diagnostic β Gradio Demo
|
| 3 |
+
Install: pip install gradio
|
| 4 |
+
Run: python app.py
|
| 5 |
+
"""
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import sys
|
| 8 |
+
import json
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import numpy as np
|
| 12 |
+
from torchvision import transforms
|
| 13 |
+
from PIL import Image
|
| 14 |
+
|
| 15 |
+
ROOT = Path(__file__).resolve().parent
|
| 16 |
+
sys.path.insert(0, str(ROOT))
|
| 17 |
+
|
| 18 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 19 |
+
IMAGE_MODEL = None
|
| 20 |
+
IMAGE_TEMP = 1.0
|
| 21 |
+
AUDIO_MODEL = None
|
| 22 |
+
AUDIO_TEMP = 1.0
|
| 23 |
+
AUDIO_CFG = {}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def load_models():
|
| 27 |
+
global IMAGE_MODEL, IMAGE_TEMP, AUDIO_MODEL, AUDIO_TEMP, AUDIO_CFG
|
| 28 |
+
|
| 29 |
+
img_weights = ROOT / "weights" / "electrical_outlets_image_best.pt"
|
| 30 |
+
mapping = ROOT / "config" / "label_mapping.json"
|
| 31 |
+
|
| 32 |
+
if img_weights.exists():
|
| 33 |
+
from src.models.image_model import ElectricalOutletsImageModel
|
| 34 |
+
ckpt = torch.load(img_weights, map_location=DEVICE, weights_only=False)
|
| 35 |
+
head_hidden = ckpt["model_state_dict"]["head.1.weight"].shape[0]
|
| 36 |
+
IMAGE_MODEL = ElectricalOutletsImageModel(
|
| 37 |
+
num_classes=ckpt["num_classes"], label_mapping_path=mapping,
|
| 38 |
+
pretrained=False, head_hidden=head_hidden,
|
| 39 |
+
)
|
| 40 |
+
IMAGE_MODEL.load_state_dict(ckpt["model_state_dict"])
|
| 41 |
+
IMAGE_MODEL.idx_to_issue_type = ckpt.get("idx_to_issue_type")
|
| 42 |
+
IMAGE_MODEL.idx_to_severity = ckpt.get("idx_to_severity")
|
| 43 |
+
IMAGE_MODEL.eval().to(DEVICE)
|
| 44 |
+
T = ckpt.get("temperature", 1.0)
|
| 45 |
+
IMAGE_TEMP = T if 0 < T < 10 else 1.0
|
| 46 |
+
print(f" Image model loaded ({ckpt['num_classes']} classes, head={head_hidden})")
|
| 47 |
+
|
| 48 |
+
audio_weights = ROOT / "weights" / "electrical_outlets_audio_best.pt"
|
| 49 |
+
if audio_weights.exists():
|
| 50 |
+
from src.models.audio_model import ElectricalOutletsAudioModel
|
| 51 |
+
import yaml
|
| 52 |
+
ckpt = torch.load(audio_weights, map_location=DEVICE, weights_only=False)
|
| 53 |
+
audio_cfg_path = ROOT / "config" / "audio_train_config.yaml"
|
| 54 |
+
n_mels, time_steps = 128, 128
|
| 55 |
+
if audio_cfg_path.exists():
|
| 56 |
+
with open(audio_cfg_path) as f:
|
| 57 |
+
AUDIO_CFG = yaml.safe_load(f)
|
| 58 |
+
n_mels = AUDIO_CFG.get("model", {}).get("n_mels", 128)
|
| 59 |
+
time_steps = AUDIO_CFG.get("model", {}).get("time_steps", 128)
|
| 60 |
+
AUDIO_MODEL = ElectricalOutletsAudioModel(
|
| 61 |
+
num_classes=ckpt["num_classes"], label_mapping_path=mapping,
|
| 62 |
+
n_mels=n_mels, time_steps=time_steps,
|
| 63 |
+
)
|
| 64 |
+
AUDIO_MODEL.load_state_dict(ckpt["model_state_dict"])
|
| 65 |
+
AUDIO_MODEL.idx_to_label = ckpt.get("idx_to_label")
|
| 66 |
+
AUDIO_MODEL.idx_to_issue_type = ckpt.get("idx_to_issue_type")
|
| 67 |
+
AUDIO_MODEL.idx_to_severity = ckpt.get("idx_to_severity")
|
| 68 |
+
AUDIO_MODEL.eval().to(DEVICE)
|
| 69 |
+
T = ckpt.get("temperature", 1.0)
|
| 70 |
+
AUDIO_TEMP = T if 0 < T < 10 else 1.0
|
| 71 |
+
print(f" Audio model loaded ({ckpt['num_classes']} classes)")
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
SEV_COLORS = {"low": "#22c55e", "medium": "#f59e0b", "high": "#ef4444", "critical": "#dc2626"}
|
| 75 |
+
SEV_ICONS = {"low": "β
", "medium": "β οΈ", "high": "π΄", "critical": "π¨"}
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def make_bar_html(probs_dict, highlight=None):
|
| 79 |
+
rows = ""
|
| 80 |
+
for name, prob in sorted(probs_dict.items(), key=lambda x: -x[1]):
|
| 81 |
+
pct = prob * 100
|
| 82 |
+
color = "#60a5fa" if name != highlight else "#f59e0b"
|
| 83 |
+
rows += f"""
|
| 84 |
+
<div style="display:flex;align-items:center;gap:8px;margin:3px 0;">
|
| 85 |
+
<div style="width:140px;font-size:13px;text-align:right;color:#ccc;">{name.replace('_',' ')}</div>
|
| 86 |
+
<div style="flex:1;background:#2a2a3e;border-radius:4px;height:20px;overflow:hidden;">
|
| 87 |
+
<div style="width:{pct}%;background:{color};height:100%;border-radius:4px;"></div>
|
| 88 |
+
</div>
|
| 89 |
+
<div style="width:55px;font-size:13px;color:#eee;">{pct:.1f}%</div>
|
| 90 |
+
</div>"""
|
| 91 |
+
return f'<div style="padding:8px 0;">{rows}</div>'
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def make_result_html(pred, title, probs_dict=None):
|
| 95 |
+
sev = pred.get("severity", "low")
|
| 96 |
+
color = SEV_COLORS.get(sev, "#666")
|
| 97 |
+
sev_icon = SEV_ICONS.get(sev, "")
|
| 98 |
+
conf = pred.get("confidence", 0)
|
| 99 |
+
issue = (pred.get("issue_type") or "uncertain").replace("_", " ").title()
|
| 100 |
+
result_text = pred.get("result", "").replace("_", " ").title()
|
| 101 |
+
bars = make_bar_html(probs_dict, pred.get("issue_type")) if probs_dict else ""
|
| 102 |
+
|
| 103 |
+
return f"""
|
| 104 |
+
<div style="background:#1a1a2e;border-radius:12px;padding:20px;margin:8px 0;
|
| 105 |
+
border-left:4px solid {color};color:#e0e0e0;font-family:system-ui;">
|
| 106 |
+
<div style="font-size:12px;color:#888;text-transform:uppercase;letter-spacing:1px;margin-bottom:10px;">{title}</div>
|
| 107 |
+
<div style="font-size:26px;font-weight:700;margin-bottom:6px;">{result_text}</div>
|
| 108 |
+
<div style="font-size:18px;color:{color};font-weight:600;margin-bottom:14px;">{issue}</div>
|
| 109 |
+
<div style="display:flex;gap:32px;">
|
| 110 |
+
<div><div style="font-size:11px;color:#888;text-transform:uppercase;">Severity</div>
|
| 111 |
+
<div style="font-size:15px;font-weight:600;color:{color};">{sev_icon} {sev.upper()}</div></div>
|
| 112 |
+
<div><div style="font-size:11px;color:#888;text-transform:uppercase;">Confidence</div>
|
| 113 |
+
<div style="font-size:15px;font-weight:600;">{conf:.1%}</div></div>
|
| 114 |
+
</div>
|
| 115 |
+
{bars}
|
| 116 |
+
</div>"""
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def predict_image_fn(img):
|
| 120 |
+
if IMAGE_MODEL is None:
|
| 121 |
+
return None, None
|
| 122 |
+
tf = transforms.Compose([
|
| 123 |
+
transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(),
|
| 124 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
| 125 |
+
])
|
| 126 |
+
x = tf(img.convert("RGB")).unsqueeze(0).to(DEVICE)
|
| 127 |
+
with torch.no_grad():
|
| 128 |
+
logits = IMAGE_MODEL(x) / IMAGE_TEMP
|
| 129 |
+
probs = torch.softmax(logits, dim=-1)[0]
|
| 130 |
+
pred = IMAGE_MODEL.predict_to_schema(logits)
|
| 131 |
+
probs_dict = {IMAGE_MODEL.idx_to_issue_type[i]: p for i, p in enumerate(probs.tolist())}
|
| 132 |
+
return pred, probs_dict
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def predict_audio_fn(audio_tuple):
|
| 136 |
+
if AUDIO_MODEL is None:
|
| 137 |
+
return None, None
|
| 138 |
+
import torchaudio
|
| 139 |
+
sr_in, audio_data = audio_tuple
|
| 140 |
+
if isinstance(audio_data, np.ndarray):
|
| 141 |
+
waveform = torch.from_numpy(audio_data.astype(np.float32))
|
| 142 |
+
if waveform.dim() == 1:
|
| 143 |
+
waveform = waveform.unsqueeze(0)
|
| 144 |
+
elif waveform.dim() == 2:
|
| 145 |
+
if waveform.shape[1] <= 2:
|
| 146 |
+
waveform = waveform.T
|
| 147 |
+
if waveform.shape[0] > 1:
|
| 148 |
+
waveform = waveform.mean(dim=0, keepdim=True)
|
| 149 |
+
mx = waveform.abs().max()
|
| 150 |
+
if mx > 0:
|
| 151 |
+
waveform = waveform / mx
|
| 152 |
+
else:
|
| 153 |
+
return None, None
|
| 154 |
+
|
| 155 |
+
sample_rate = AUDIO_CFG.get("data", {}).get("sample_rate", 22050)
|
| 156 |
+
if sr_in != sample_rate:
|
| 157 |
+
waveform = torchaudio.functional.resample(waveform, sr_in, sample_rate)
|
| 158 |
+
target_len = int(AUDIO_CFG.get("data", {}).get("target_length_sec", 5.0) * sample_rate)
|
| 159 |
+
if waveform.shape[1] >= target_len:
|
| 160 |
+
s = (waveform.shape[1] - target_len) // 2
|
| 161 |
+
waveform = waveform[:, s:s + target_len]
|
| 162 |
+
else:
|
| 163 |
+
waveform = torch.nn.functional.pad(waveform, (0, target_len - waveform.shape[1]))
|
| 164 |
+
|
| 165 |
+
sc = AUDIO_CFG.get("spectrogram", {})
|
| 166 |
+
mel = torchaudio.transforms.MelSpectrogram(
|
| 167 |
+
sample_rate=sample_rate, n_fft=sc.get("n_fft", 1024),
|
| 168 |
+
hop_length=sc.get("hop_length", 512), win_length=sc.get("win_length", 1024),
|
| 169 |
+
n_mels=sc.get("n_mels", 128),
|
| 170 |
+
)(waveform)
|
| 171 |
+
log_mel = torch.log(mel.clamp(min=1e-5)).unsqueeze(0).to(DEVICE)
|
| 172 |
+
|
| 173 |
+
with torch.no_grad():
|
| 174 |
+
logits = AUDIO_MODEL(log_mel) / AUDIO_TEMP
|
| 175 |
+
probs = torch.softmax(logits, dim=-1)[0]
|
| 176 |
+
pred = AUDIO_MODEL.predict_to_schema(logits)
|
| 177 |
+
labels = AUDIO_MODEL.idx_to_label or [f"class_{i}" for i in range(AUDIO_MODEL.num_classes)]
|
| 178 |
+
probs_dict = {labels[i]: p for i, p in enumerate(probs.tolist())}
|
| 179 |
+
return pred, probs_dict
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def fuse_fn(image_pred, audio_pred):
|
| 183 |
+
from src.fusion.fusion_logic import fuse_modalities, ModalityOutput
|
| 184 |
+
import yaml
|
| 185 |
+
th_path = ROOT / "config" / "thresholds.yaml"
|
| 186 |
+
th = {}
|
| 187 |
+
if th_path.exists():
|
| 188 |
+
with open(th_path) as f:
|
| 189 |
+
th = yaml.safe_load(f) or {}
|
| 190 |
+
img_out = ModalityOutput(result=image_pred["result"], issue_type=image_pred.get("issue_type"),
|
| 191 |
+
severity=image_pred["severity"], confidence=image_pred["confidence"])
|
| 192 |
+
aud_out = ModalityOutput(result=audio_pred["result"], issue_type=audio_pred.get("issue_type"),
|
| 193 |
+
severity=audio_pred["severity"], confidence=audio_pred["confidence"])
|
| 194 |
+
return fuse_modalities(img_out, aud_out,
|
| 195 |
+
confidence_issue_min=th.get("confidence_issue_min", 0.6),
|
| 196 |
+
confidence_normal_min=th.get("confidence_normal_min", 0.75),
|
| 197 |
+
uncertain_if_disagree=th.get("uncertain_if_disagree", True),
|
| 198 |
+
high_confidence_override=th.get("high_confidence_override", 0.92))
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def diagnose(image, audio):
|
| 202 |
+
if image is None and audio is None:
|
| 203 |
+
return '<div style="padding:40px;color:#888;text-align:center;font-style:italic;">Upload an image or audio to begin diagnosis...</div>'
|
| 204 |
+
|
| 205 |
+
img_pred, img_probs, aud_pred, aud_probs = None, None, None, None
|
| 206 |
+
try:
|
| 207 |
+
if image is not None:
|
| 208 |
+
img = Image.fromarray(image) if isinstance(image, np.ndarray) else image
|
| 209 |
+
img_pred, img_probs = predict_image_fn(img)
|
| 210 |
+
if audio is not None:
|
| 211 |
+
aud_pred, aud_probs = predict_audio_fn(audio)
|
| 212 |
+
except Exception as e:
|
| 213 |
+
return f'<div style="padding:20px;color:#f87171;">Error: {e}</div>'
|
| 214 |
+
|
| 215 |
+
html = ""
|
| 216 |
+
if img_pred and aud_pred:
|
| 217 |
+
fused = fuse_fn(img_pred, aud_pred)
|
| 218 |
+
html += make_result_html(fused, "β‘ Fused Diagnosis")
|
| 219 |
+
html += '<div style="display:flex;gap:12px;">'
|
| 220 |
+
html += f'<div style="flex:1;">{make_result_html(img_pred, "π· Image", img_probs)}</div>'
|
| 221 |
+
html += f'<div style="flex:1;">{make_result_html(aud_pred, "π€ Audio", aud_probs)}</div>'
|
| 222 |
+
html += '</div>'
|
| 223 |
+
elif img_pred:
|
| 224 |
+
html += make_result_html(img_pred, "π· Image Diagnosis", img_probs)
|
| 225 |
+
elif aud_pred:
|
| 226 |
+
html += make_result_html(aud_pred, "π€ Audio Diagnosis", aud_probs)
|
| 227 |
+
else:
|
| 228 |
+
html = '<div style="padding:20px;color:#f87171;">Could not process input.</div>'
|
| 229 |
+
return html
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
if __name__ == "__main__":
|
| 233 |
+
import gradio as gr
|
| 234 |
+
|
| 235 |
+
print("Loading models...")
|
| 236 |
+
load_models()
|
| 237 |
+
print(f"Device: {DEVICE}\n")
|
| 238 |
+
|
| 239 |
+
with gr.Blocks(
|
| 240 |
+
title="Electrical Outlets Diagnostic",
|
| 241 |
+
theme=gr.themes.Base(primary_hue="red", secondary_hue="amber", neutral_hue="slate",
|
| 242 |
+
font=gr.themes.GoogleFont("Inter")),
|
| 243 |
+
css=".gradio-container{max-width:960px!important} footer{display:none!important}"
|
| 244 |
+
) as demo:
|
| 245 |
+
|
| 246 |
+
gr.Markdown("# β‘ Electrical Outlets Diagnostic\nUpload a **photo** and/or **audio** to detect safety issues.")
|
| 247 |
+
|
| 248 |
+
with gr.Row():
|
| 249 |
+
with gr.Column(scale=1):
|
| 250 |
+
image_input = gr.Image(label="π· Outlet Photo", type="numpy", height=300)
|
| 251 |
+
audio_input = gr.Audio(label="π€ Audio Recording", type="numpy")
|
| 252 |
+
btn = gr.Button("π Diagnose", variant="primary", size="lg")
|
| 253 |
+
with gr.Column(scale=1):
|
| 254 |
+
output = gr.HTML(value='<div style="padding:40px;color:#888;text-align:center;font-style:italic;">Upload an image or audio to begin...</div>')
|
| 255 |
+
|
| 256 |
+
btn.click(fn=diagnose, inputs=[image_input, audio_input], outputs=[output])
|
| 257 |
+
image_input.change(fn=diagnose, inputs=[image_input, audio_input], outputs=[output])
|
| 258 |
+
audio_input.change(fn=diagnose, inputs=[image_input, audio_input], outputs=[output])
|
| 259 |
+
|
| 260 |
+
gr.Markdown("---\n| Severity | Action |\n|--|--|\n| β
Low | Monitor |\n| β οΈ Medium | Schedule repair |\n| π΄ High | Shut off circuit |\n| π¨ Critical | Shut off main breaker |")
|
| 261 |
+
|
| 262 |
+
demo.launch(server_name="127.0.0.1", server_port=7860, share=False, show_error=True)
|
config/audio_train_config.yaml
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Audio model training config - Electrical Outlets
|
| 2 |
+
# 100 samples: heavy augmentation, balanced batching, treat as preliminary
|
| 3 |
+
|
| 4 |
+
data:
|
| 5 |
+
root: "electrical_outlets_sounds_100"
|
| 6 |
+
label_mapping: "config/label_mapping.json"
|
| 7 |
+
train_ratio: 0.7
|
| 8 |
+
val_ratio: 0.15
|
| 9 |
+
seed: 42
|
| 10 |
+
batch_size: 16
|
| 11 |
+
num_workers: 0
|
| 12 |
+
target_length_sec: 5.0
|
| 13 |
+
sample_rate: 16000
|
| 14 |
+
|
| 15 |
+
spectrogram:
|
| 16 |
+
n_mels: 64
|
| 17 |
+
n_fft: 512
|
| 18 |
+
hop_length: 256
|
| 19 |
+
win_length: 512
|
| 20 |
+
|
| 21 |
+
model:
|
| 22 |
+
num_classes: 4
|
| 23 |
+
n_mels: 64
|
| 24 |
+
time_steps: 128
|
| 25 |
+
|
| 26 |
+
training:
|
| 27 |
+
epochs: 80
|
| 28 |
+
lr: 1.0e-3
|
| 29 |
+
weight_decay: 1.0e-4
|
| 30 |
+
use_class_weights: true
|
| 31 |
+
early_stopping_patience: 12
|
| 32 |
+
early_stopping_metric: "val_macro_recall"
|
| 33 |
+
|
| 34 |
+
calibration:
|
| 35 |
+
use_temperature_scaling: true
|
| 36 |
+
val_fraction_for_calibration: 0.5
|
| 37 |
+
|
| 38 |
+
output:
|
| 39 |
+
weights_dir: "weights"
|
| 40 |
+
best_name: "electrical_outlets_audio_best.pt"
|
| 41 |
+
report_name: "audio_model_report.md"
|
config/image_train_config.yaml
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# v5.1 β Push past 63% min recall
|
| 2 |
+
# Changes: higher finetune LR, bigger head, fixed temp scaling
|
| 3 |
+
|
| 4 |
+
data:
|
| 5 |
+
root: "ELECTRICAL OUTLETS-20260106T153508Z-3-001"
|
| 6 |
+
label_mapping: "config/label_mapping.json"
|
| 7 |
+
train_ratio: 0.7
|
| 8 |
+
val_ratio: 0.15
|
| 9 |
+
seed: 42
|
| 10 |
+
batch_size: 64
|
| 11 |
+
num_workers: 4
|
| 12 |
+
|
| 13 |
+
augmentation:
|
| 14 |
+
resize: 256
|
| 15 |
+
crop: 224
|
| 16 |
+
|
| 17 |
+
model:
|
| 18 |
+
num_classes: 5
|
| 19 |
+
pretrained: true
|
| 20 |
+
head_hidden: 512 # was 256 β more capacity with 1300 images
|
| 21 |
+
head_dropout: 0.5 # was 0.4 β stronger regularization
|
| 22 |
+
|
| 23 |
+
training:
|
| 24 |
+
epochs: 100 # was 80 β give head more time
|
| 25 |
+
lr: 3.0e-3
|
| 26 |
+
weight_decay: 1.0e-3
|
| 27 |
+
use_class_weights: true
|
| 28 |
+
use_focal: true
|
| 29 |
+
focal_alpha: 0.25
|
| 30 |
+
focal_gamma: 2.0
|
| 31 |
+
early_stopping_patience: 25 # was 20
|
| 32 |
+
early_stopping_metric: "val_min_recall"
|
| 33 |
+
finetune_last_blocks: true
|
| 34 |
+
finetune_lr: 2.0e-4 # was 5e-5 β 4x higher, backbone needs to adapt more
|
| 35 |
+
finetune_epochs: 30 # was 25
|
| 36 |
+
|
| 37 |
+
calibration:
|
| 38 |
+
use_temperature_scaling: false # DISABLED β was producing negative T
|
| 39 |
+
|
| 40 |
+
output:
|
| 41 |
+
weights_dir: "weights"
|
| 42 |
+
best_name: "electrical_outlets_image_best.pt"
|
| 43 |
+
report_name: "image_model_report.md"
|
config/label_mapping.json
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"image": {
|
| 3 |
+
"classes": [
|
| 4 |
+
{
|
| 5 |
+
"folder_key": "burn_marks_overheating",
|
| 6 |
+
"issue_type": "burn_overheating",
|
| 7 |
+
"severity": "high",
|
| 8 |
+
"description": "Fire, overheating, sparking, discoloration"
|
| 9 |
+
},
|
| 10 |
+
{
|
| 11 |
+
"folder_key": "cracked_faceplates",
|
| 12 |
+
"issue_type": "cracked_faceplate",
|
| 13 |
+
"severity": "medium",
|
| 14 |
+
"description": "Cracked/broken faceplate, damaged switches"
|
| 15 |
+
},
|
| 16 |
+
{
|
| 17 |
+
"folder_key": "loose_outlets",
|
| 18 |
+
"issue_type": "loose_outlet",
|
| 19 |
+
"severity": "medium",
|
| 20 |
+
"description": "Loose outlet, pulled from wall, exposed wiring"
|
| 21 |
+
},
|
| 22 |
+
{
|
| 23 |
+
"folder_key": "normal_outlets",
|
| 24 |
+
"issue_type": "normal",
|
| 25 |
+
"severity": "low",
|
| 26 |
+
"description": "Normal outlet/switch condition"
|
| 27 |
+
},
|
| 28 |
+
{
|
| 29 |
+
"folder_key": "water_exposed",
|
| 30 |
+
"issue_type": "water_exposed",
|
| 31 |
+
"severity": "high",
|
| 32 |
+
"description": "Water intrusion near outlet"
|
| 33 |
+
}
|
| 34 |
+
],
|
| 35 |
+
"folder_to_class": {
|
| 36 |
+
"Burn marks - overheating 250": "burn_marks_overheating",
|
| 37 |
+
"Discoloration (heat aging) 100": "burn_marks_overheating",
|
| 38 |
+
"Sparking damage evidence 150": "burn_marks_overheating",
|
| 39 |
+
"Cracked faceplate 150": "cracked_faceplates",
|
| 40 |
+
"Damaged switches 50": "cracked_faceplates",
|
| 41 |
+
"Loose outlet - pulled from wall 200": "loose_outlets",
|
| 42 |
+
"Exposed wiring 150": "loose_outlets",
|
| 43 |
+
"Normal outlets 50": "normal_outlets",
|
| 44 |
+
"Normal switches 50": "normal_outlets",
|
| 45 |
+
"Water intrusion near outlet 150": "water_exposed"
|
| 46 |
+
},
|
| 47 |
+
"class_to_idx": {
|
| 48 |
+
"burn_marks_overheating": 0,
|
| 49 |
+
"cracked_faceplates": 1,
|
| 50 |
+
"loose_outlets": 2,
|
| 51 |
+
"normal_outlets": 3,
|
| 52 |
+
"water_exposed": 4
|
| 53 |
+
},
|
| 54 |
+
"idx_to_issue_type": [
|
| 55 |
+
"burn_overheating",
|
| 56 |
+
"cracked_faceplate",
|
| 57 |
+
"loose_outlet",
|
| 58 |
+
"normal",
|
| 59 |
+
"water_exposed"
|
| 60 |
+
],
|
| 61 |
+
"idx_to_severity": ["high", "medium", "medium", "low", "high"]
|
| 62 |
+
},
|
| 63 |
+
"audio": {
|
| 64 |
+
"file_pattern_to_label": {
|
| 65 |
+
"normal_near_silent": "normal",
|
| 66 |
+
"plug_insert_remove_clicks": "normal",
|
| 67 |
+
"load_switching": "normal",
|
| 68 |
+
"buzzing_outlet": "buzzing",
|
| 69 |
+
"loose_contact_crackle": "crackling_arcing",
|
| 70 |
+
"arcing_pop": "arcing_pop"
|
| 71 |
+
},
|
| 72 |
+
"label_to_severity": {
|
| 73 |
+
"normal": "low",
|
| 74 |
+
"buzzing": "high",
|
| 75 |
+
"crackling_arcing": "high",
|
| 76 |
+
"arcing_pop": "critical"
|
| 77 |
+
},
|
| 78 |
+
"label_to_issue_type": {
|
| 79 |
+
"normal": "normal",
|
| 80 |
+
"buzzing": "buzzing",
|
| 81 |
+
"crackling_arcing": "crackling_arcing",
|
| 82 |
+
"arcing_pop": "arcing_pop"
|
| 83 |
+
},
|
| 84 |
+
"class_to_idx": {
|
| 85 |
+
"normal": 0,
|
| 86 |
+
"buzzing": 1,
|
| 87 |
+
"crackling_arcing": 2,
|
| 88 |
+
"arcing_pop": 3
|
| 89 |
+
},
|
| 90 |
+
"idx_to_label": ["normal", "buzzing", "crackling_arcing", "arcing_pop"],
|
| 91 |
+
"num_classes": 4
|
| 92 |
+
}
|
| 93 |
+
}
|
config/schema.yaml
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Canonical output schema for Electrical Outlets diagnostic element
|
| 2 |
+
# Used by image model, audio model, fusion layer, and API
|
| 3 |
+
# Aligned to client PDF: Electrical outlet & switchs diagnostiocs
|
| 4 |
+
|
| 5 |
+
diagnostic_element: electrical_outlets
|
| 6 |
+
|
| 7 |
+
result:
|
| 8 |
+
type: string
|
| 9 |
+
enum:
|
| 10 |
+
- issue_detected
|
| 11 |
+
- normal
|
| 12 |
+
- uncertain
|
| 13 |
+
description: "Final outcome; uncertain triggers backend-guided adjustment or escalation"
|
| 14 |
+
|
| 15 |
+
issue_type:
|
| 16 |
+
type: string
|
| 17 |
+
nullable: true
|
| 18 |
+
enum:
|
| 19 |
+
# Image-derived (NOT OPEN, PDF diagnostics 1-38)
|
| 20 |
+
- burn_overheating
|
| 21 |
+
- cracked_faceplate
|
| 22 |
+
- gfci_failure
|
| 23 |
+
- loose_outlet
|
| 24 |
+
- water_exposed
|
| 25 |
+
# Audio-derived (PDF diagnostics 21-28)
|
| 26 |
+
- buzzing
|
| 27 |
+
- humming
|
| 28 |
+
- crackling_arcing
|
| 29 |
+
- arcing_pop
|
| 30 |
+
- sizzling
|
| 31 |
+
- clicking_idle
|
| 32 |
+
# Combined / generic
|
| 33 |
+
- normal
|
| 34 |
+
description: "Primary issue type when result is issue_detected; null for normal/uncertain when no single type"
|
| 35 |
+
|
| 36 |
+
severity:
|
| 37 |
+
type: string
|
| 38 |
+
enum:
|
| 39 |
+
- low
|
| 40 |
+
- medium
|
| 41 |
+
- high
|
| 42 |
+
- critical
|
| 43 |
+
description: "Per PDF: low=monitor, medium=repair, high=shut circuit, critical=shut main breaker"
|
| 44 |
+
|
| 45 |
+
confidence:
|
| 46 |
+
type: number
|
| 47 |
+
minimum: 0
|
| 48 |
+
maximum: 1
|
| 49 |
+
description: "Calibrated probability; drives uncertain path when below threshold"
|
| 50 |
+
|
| 51 |
+
modality_contributions:
|
| 52 |
+
type: object
|
| 53 |
+
nullable: true
|
| 54 |
+
properties:
|
| 55 |
+
image:
|
| 56 |
+
type: object
|
| 57 |
+
nullable: true
|
| 58 |
+
properties:
|
| 59 |
+
result: { type: string }
|
| 60 |
+
issue_type: { type: string, nullable: true }
|
| 61 |
+
severity: { type: string }
|
| 62 |
+
confidence: { type: number }
|
| 63 |
+
audio:
|
| 64 |
+
type: object
|
| 65 |
+
nullable: true
|
| 66 |
+
properties:
|
| 67 |
+
result: { type: string }
|
| 68 |
+
issue_type: { type: string, nullable: true }
|
| 69 |
+
severity: { type: string }
|
| 70 |
+
confidence: { type: number }
|
| 71 |
+
description: "Per-modality outputs for transparency; present when both image and audio provided"
|
| 72 |
+
|
| 73 |
+
# For fusion when both modalities detect different issues
|
| 74 |
+
primary_issue:
|
| 75 |
+
type: string
|
| 76 |
+
nullable: true
|
| 77 |
+
description: "Higher-severity issue when both modalities detect issues"
|
| 78 |
+
secondary_issue:
|
| 79 |
+
type: string
|
| 80 |
+
nullable: true
|
| 81 |
+
description: "Other issue when both modalities detect different issues"
|
config/thresholds.yaml
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Threshold and safety configuration - Electrical Outlets
|
| 2 |
+
# Prefer "uncertain" over "normal" when in doubt (minimize false negatives)
|
| 3 |
+
|
| 4 |
+
confidence_issue_min: 0.6 # below this -> result = uncertain when issue_detected
|
| 5 |
+
confidence_normal_min: 0.75 # both modalities must exceed this to return "normal"
|
| 6 |
+
uncertain_if_disagree: true # image defect + audio normal (or vice versa) -> uncertain unless one side very high
|
| 7 |
+
high_confidence_override: 0.92 # if one modality >= this and says issue_detected, can override disagree
|
| 8 |
+
|
| 9 |
+
severity_order:
|
| 10 |
+
- low
|
| 11 |
+
- medium
|
| 12 |
+
- high
|
| 13 |
+
- critical
|
releases.md
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Releases
|
| 2 |
+
|
| 3 |
+
## v1.0 β February 2026
|
| 4 |
+
- 5-class image model (EfficientNet-B0 + MLP head): 77% accuracy, 67% min recall
|
| 5 |
+
- 4-class audio model (Spectrogram CNN): 100% recall
|
| 6 |
+
- Decision-level fusion with configurable thresholds
|
| 7 |
+
- Gradio demo app
|
| 8 |
+
- FastAPI endpoint
|
requirements.txt
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
aiofiles==23.2.1
|
| 2 |
+
aiohappyeyeballs==2.6.1
|
| 3 |
+
aiohttp==3.11.18
|
| 4 |
+
aiosignal==1.3.2
|
| 5 |
+
altair==5.5.0
|
| 6 |
+
annotated-doc==0.0.4
|
| 7 |
+
annotated-types==0.7.0
|
| 8 |
+
anyio==4.9.0
|
| 9 |
+
async-timeout==4.0.3
|
| 10 |
+
attrs==25.3.0
|
| 11 |
+
beautifulsoup4==4.13.4
|
| 12 |
+
Brotli @ file:///D:/bld/brotli-split_1725267609074/work
|
| 13 |
+
certifi @ file:///home/conda/feedstock_root/build_artifacts/certifi_1739515848642/work/certifi
|
| 14 |
+
cffi @ file:///D:/bld/cffi_1725560792189/work
|
| 15 |
+
charset-normalizer @ file:///home/conda/feedstock_root/build_artifacts/charset-normalizer_1746214863626/work
|
| 16 |
+
click==8.1.8
|
| 17 |
+
colorama==0.4.6
|
| 18 |
+
comtypes==1.4.10
|
| 19 |
+
contourpy==1.3.0
|
| 20 |
+
cycler==0.12.1
|
| 21 |
+
dataclasses-json==0.6.7
|
| 22 |
+
docopt==0.6.2
|
| 23 |
+
exceptiongroup==1.2.2
|
| 24 |
+
fastapi==0.95.2
|
| 25 |
+
ffmpy==1.0.0
|
| 26 |
+
filelock==3.18.0
|
| 27 |
+
fonttools==4.60.2
|
| 28 |
+
frozenlist==1.6.0
|
| 29 |
+
fsspec==2025.3.2
|
| 30 |
+
gradio==3.50.2
|
| 31 |
+
gradio_client==0.6.1
|
| 32 |
+
greenlet==3.2.1
|
| 33 |
+
h11==0.16.0
|
| 34 |
+
h2 @ file:///home/conda/feedstock_root/build_artifacts/h2_1738578511449/work
|
| 35 |
+
hpack @ file:///home/conda/feedstock_root/build_artifacts/hpack_1737618293087/work
|
| 36 |
+
httpcore==1.0.9
|
| 37 |
+
httpx==0.28.1
|
| 38 |
+
httpx-sse==0.4.0
|
| 39 |
+
huggingface-hub==0.31.1
|
| 40 |
+
hyperframe @ file:///home/conda/feedstock_root/build_artifacts/hyperframe_1737618333194/work
|
| 41 |
+
idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1733211830134/work
|
| 42 |
+
importlib_resources==6.5.2
|
| 43 |
+
Jinja2==3.1.6
|
| 44 |
+
joblib==1.5.3
|
| 45 |
+
Js2Py==0.74
|
| 46 |
+
jsonpatch==1.33
|
| 47 |
+
jsonpointer==3.0.0
|
| 48 |
+
jsonschema==4.25.1
|
| 49 |
+
jsonschema-specifications==2025.9.1
|
| 50 |
+
kiwisolver==1.4.7
|
| 51 |
+
langchain==0.3.25
|
| 52 |
+
langchain-community==0.3.23
|
| 53 |
+
langchain-core==0.3.58
|
| 54 |
+
langchain-ollama==0.3.2
|
| 55 |
+
langchain-text-splitters==0.3.8
|
| 56 |
+
langgraph==0.4.1
|
| 57 |
+
langgraph-checkpoint==2.0.25
|
| 58 |
+
langgraph-prebuilt==0.1.8
|
| 59 |
+
langgraph-sdk==0.1.66
|
| 60 |
+
langsmith==0.3.42
|
| 61 |
+
llvmlite==0.43.0
|
| 62 |
+
markdown-it-py==3.0.0
|
| 63 |
+
MarkupSafe==2.1.5
|
| 64 |
+
marshmallow==3.26.1
|
| 65 |
+
matplotlib==3.9.4
|
| 66 |
+
mdurl==0.1.2
|
| 67 |
+
more-itertools==10.7.0
|
| 68 |
+
mpmath==1.3.0
|
| 69 |
+
multidict==6.4.3
|
| 70 |
+
mypy_extensions==1.1.0
|
| 71 |
+
narwhals==2.17.0
|
| 72 |
+
networkx==3.2.1
|
| 73 |
+
numba==0.60.0
|
| 74 |
+
numpy==1.26.4
|
| 75 |
+
ollama==0.4.8
|
| 76 |
+
openai-whisper==20240930
|
| 77 |
+
orjson==3.10.18
|
| 78 |
+
ormsgpack==1.9.1
|
| 79 |
+
packaging==24.2
|
| 80 |
+
pandas==2.3.3
|
| 81 |
+
pillow==10.4.0
|
| 82 |
+
pipwin==0.5.2
|
| 83 |
+
propcache==0.3.1
|
| 84 |
+
PyAudio==0.2.14
|
| 85 |
+
pycparser @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_pycparser_1733195786/work
|
| 86 |
+
pydantic==1.10.13
|
| 87 |
+
pydantic-settings==2.9.1
|
| 88 |
+
pydantic_core==2.41.5
|
| 89 |
+
pydub==0.25.1
|
| 90 |
+
pygame==2.6.1
|
| 91 |
+
Pygments==2.19.2
|
| 92 |
+
pyjsparser==2.7.1
|
| 93 |
+
pyparsing==3.3.2
|
| 94 |
+
pypiwin32==223
|
| 95 |
+
PyPrind==2.11.3
|
| 96 |
+
pySmartDL==1.3.4
|
| 97 |
+
PySocks @ file:///D:/bld/pysocks_1733217287171/work
|
| 98 |
+
python-dateutil==2.9.0.post0
|
| 99 |
+
python-dotenv==1.1.0
|
| 100 |
+
python-multipart==0.0.20
|
| 101 |
+
pyttsx3==2.98
|
| 102 |
+
pytz==2025.2
|
| 103 |
+
pywin32==310
|
| 104 |
+
PyYAML==6.0.2
|
| 105 |
+
referencing==0.36.2
|
| 106 |
+
regex==2024.11.6
|
| 107 |
+
requests @ file:///home/conda/feedstock_root/build_artifacts/requests_1733217035951/work
|
| 108 |
+
requests-toolbelt==1.0.0
|
| 109 |
+
rich==14.3.3
|
| 110 |
+
rpds-py==0.27.1
|
| 111 |
+
ruff==0.15.2
|
| 112 |
+
scikit-learn==1.6.1
|
| 113 |
+
scipy==1.13.1
|
| 114 |
+
semantic-version==2.10.0
|
| 115 |
+
shellingham==1.5.4
|
| 116 |
+
six==1.17.0
|
| 117 |
+
sniffio==1.3.1
|
| 118 |
+
soundfile==0.13.1
|
| 119 |
+
soupsieve==2.7
|
| 120 |
+
SpeechRecognition @ file:///home/conda/feedstock_root/build_artifacts/speechrecognition_1742707644995/work
|
| 121 |
+
SQLAlchemy==2.0.40
|
| 122 |
+
starlette==0.27.0
|
| 123 |
+
sympy==1.13.1
|
| 124 |
+
tenacity==9.1.2
|
| 125 |
+
threadpoolctl==3.6.0
|
| 126 |
+
tiktoken==0.9.0
|
| 127 |
+
tomlkit==0.12.0
|
| 128 |
+
torch==2.6.0+cu124
|
| 129 |
+
torchaudio==2.6.0+cu124
|
| 130 |
+
torchvision==0.21.0+cu124
|
| 131 |
+
tqdm==4.67.1
|
| 132 |
+
typer==0.23.2
|
| 133 |
+
typing-inspect==0.9.0
|
| 134 |
+
typing-inspection==0.4.2
|
| 135 |
+
typing_extensions==4.15.0
|
| 136 |
+
tzdata==2025.2
|
| 137 |
+
tzlocal==5.3.1
|
| 138 |
+
urllib3 @ file:///home/conda/feedstock_root/build_artifacts/urllib3_1744323578849/work
|
| 139 |
+
uvicorn==0.39.0
|
| 140 |
+
websockets==11.0.3
|
| 141 |
+
win_inet_pton @ file:///D:/bld/win_inet_pton_1733130564612/work
|
| 142 |
+
xxhash==3.5.0
|
| 143 |
+
yarl==1.20.0
|
| 144 |
+
zipp==3.23.0
|
| 145 |
+
zstandard==0.23.0
|
src/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .image_model import ElectricalOutletsImageModel
|
| 2 |
+
from .audio_model import ElectricalOutletsAudioModel
|
| 3 |
+
|
| 4 |
+
__all__ = ["ElectricalOutletsImageModel", "ElectricalOutletsAudioModel"]
|
src/data/audio_dataset.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Audio dataset for Electrical Outlets. Uses README/file naming and config/label_mapping.json.
|
| 3 |
+
PATCHED: rglob for subfolders, torchaudio import at module level, stratified splits.
|
| 4 |
+
"""
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import json
|
| 7 |
+
import logging
|
| 8 |
+
from collections import defaultdict
|
| 9 |
+
from typing import Optional, Callable, List, Tuple
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torchaudio
|
| 13 |
+
from torch.utils.data import Dataset
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _label_from_filename(filename: str, file_pattern_to_label: dict) -> str:
|
| 19 |
+
for pattern, label in file_pattern_to_label.items():
|
| 20 |
+
if filename.startswith(pattern) or pattern in filename:
|
| 21 |
+
return label
|
| 22 |
+
return "normal"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class ElectricalOutletsAudioDataset(Dataset):
|
| 26 |
+
"""Audio dataset from electrical_outlets_sounds_100 WAVs."""
|
| 27 |
+
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
root: Path,
|
| 31 |
+
label_mapping_path: Path,
|
| 32 |
+
split: str = "train",
|
| 33 |
+
train_ratio: float = 0.7,
|
| 34 |
+
val_ratio: float = 0.15,
|
| 35 |
+
seed: int = 42,
|
| 36 |
+
transform: Optional[Callable] = None,
|
| 37 |
+
target_length_sec: float = 5.0,
|
| 38 |
+
sample_rate: int = 22050,
|
| 39 |
+
):
|
| 40 |
+
self.root = Path(root)
|
| 41 |
+
self.transform = transform
|
| 42 |
+
self.target_length_sec = target_length_sec
|
| 43 |
+
self.sample_rate = sample_rate
|
| 44 |
+
with open(label_mapping_path) as f:
|
| 45 |
+
lm = json.load(f)
|
| 46 |
+
self.file_pattern_to_label = lm["audio"]["file_pattern_to_label"]
|
| 47 |
+
self.class_to_idx = lm["audio"]["class_to_idx"]
|
| 48 |
+
self.idx_to_label = lm["audio"]["idx_to_label"]
|
| 49 |
+
self.label_to_severity = lm["audio"]["label_to_severity"]
|
| 50 |
+
self.label_to_issue_type = lm["audio"]["label_to_issue_type"]
|
| 51 |
+
self.num_classes = len(self.class_to_idx)
|
| 52 |
+
|
| 53 |
+
self.samples: List[Tuple[Path, int]] = []
|
| 54 |
+
# rglob to search subfolders
|
| 55 |
+
for wav in self.root.rglob("*.wav"):
|
| 56 |
+
label = _label_from_filename(wav.stem, self.file_pattern_to_label)
|
| 57 |
+
if label not in self.class_to_idx:
|
| 58 |
+
logger.warning(f"Unmatched audio file: {wav.name} β label '{label}' not in class_to_idx")
|
| 59 |
+
continue
|
| 60 |
+
self.samples.append((wav, self.class_to_idx[label]))
|
| 61 |
+
|
| 62 |
+
# Stratified split
|
| 63 |
+
by_class = defaultdict(list)
|
| 64 |
+
for i, (_, cls) in enumerate(self.samples):
|
| 65 |
+
by_class[cls].append(i)
|
| 66 |
+
|
| 67 |
+
train_idx, val_idx, test_idx = [], [], []
|
| 68 |
+
for cls in sorted(by_class.keys()):
|
| 69 |
+
indices = by_class[cls]
|
| 70 |
+
g = torch.Generator().manual_seed(seed)
|
| 71 |
+
perm = torch.randperm(len(indices), generator=g).tolist()
|
| 72 |
+
n_cls = len(indices)
|
| 73 |
+
n_tr = int(n_cls * train_ratio)
|
| 74 |
+
n_va = int(n_cls * val_ratio)
|
| 75 |
+
train_idx.extend([indices[p] for p in perm[:n_tr]])
|
| 76 |
+
val_idx.extend([indices[p] for p in perm[n_tr:n_tr + n_va]])
|
| 77 |
+
test_idx.extend([indices[p] for p in perm[n_tr + n_va:]])
|
| 78 |
+
|
| 79 |
+
if split == "train":
|
| 80 |
+
self.indices = train_idx
|
| 81 |
+
elif split == "val":
|
| 82 |
+
self.indices = val_idx
|
| 83 |
+
else:
|
| 84 |
+
self.indices = test_idx
|
| 85 |
+
|
| 86 |
+
def __len__(self) -> int:
|
| 87 |
+
return len(self.indices)
|
| 88 |
+
|
| 89 |
+
def __getitem__(self, idx: int):
|
| 90 |
+
i = self.indices[idx]
|
| 91 |
+
path, cls = self.samples[i]
|
| 92 |
+
waveform, sr = torchaudio.load(str(path))
|
| 93 |
+
if sr != self.sample_rate:
|
| 94 |
+
waveform = torchaudio.functional.resample(waveform, sr, self.sample_rate)
|
| 95 |
+
if waveform.shape[0] > 1:
|
| 96 |
+
waveform = waveform.mean(dim=0, keepdim=True)
|
| 97 |
+
target_len = int(self.target_length_sec * self.sample_rate)
|
| 98 |
+
if waveform.shape[1] >= target_len:
|
| 99 |
+
start = (waveform.shape[1] - target_len) // 2
|
| 100 |
+
waveform = waveform[:, start : start + target_len]
|
| 101 |
+
else:
|
| 102 |
+
waveform = torch.nn.functional.pad(waveform, (0, target_len - waveform.shape[1]))
|
| 103 |
+
if self.transform:
|
| 104 |
+
waveform = self.transform(waveform)
|
| 105 |
+
return waveform, cls
|
src/data/image_dataset.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Image dataset for Electrical Outlets.
|
| 3 |
+
FINAL v5: Direct folder_to_class mapping β no pattern matching, no ambiguity.
|
| 4 |
+
"""
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import json
|
| 7 |
+
import logging
|
| 8 |
+
from collections import defaultdict
|
| 9 |
+
from typing import Optional, Callable, List, Tuple
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from torch.utils.data import Dataset
|
| 13 |
+
from PIL import Image
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class ElectricalOutletsImageDataset(Dataset):
|
| 20 |
+
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
root: Path,
|
| 24 |
+
label_mapping_path: Path,
|
| 25 |
+
split: str = "train",
|
| 26 |
+
train_ratio: float = 0.7,
|
| 27 |
+
val_ratio: float = 0.15,
|
| 28 |
+
seed: int = 42,
|
| 29 |
+
transform: Optional[Callable] = None,
|
| 30 |
+
extensions: Tuple[str, ...] = (".jpg", ".jpeg", ".png"),
|
| 31 |
+
):
|
| 32 |
+
self.root = Path(root)
|
| 33 |
+
self.transform = transform
|
| 34 |
+
self.extensions = extensions
|
| 35 |
+
self.split = split
|
| 36 |
+
|
| 37 |
+
with open(label_mapping_path) as f:
|
| 38 |
+
lm = json.load(f)
|
| 39 |
+
|
| 40 |
+
self.folder_to_class = lm["image"]["folder_to_class"]
|
| 41 |
+
self.class_to_idx = lm["image"]["class_to_idx"]
|
| 42 |
+
self.idx_to_issue_type = lm["image"]["idx_to_issue_type"]
|
| 43 |
+
self.idx_to_severity = lm["image"]["idx_to_severity"]
|
| 44 |
+
self.num_classes = len(self.class_to_idx)
|
| 45 |
+
|
| 46 |
+
# Build samples list
|
| 47 |
+
self.samples: List[Tuple[Path, int]] = []
|
| 48 |
+
class_counts = defaultdict(int)
|
| 49 |
+
matched_folders = []
|
| 50 |
+
unmatched_folders = []
|
| 51 |
+
|
| 52 |
+
for folder in sorted(self.root.iterdir()):
|
| 53 |
+
if not folder.is_dir():
|
| 54 |
+
continue
|
| 55 |
+
# Direct lookup by exact folder name
|
| 56 |
+
class_key = self.folder_to_class.get(folder.name)
|
| 57 |
+
if class_key is None:
|
| 58 |
+
unmatched_folders.append(folder.name)
|
| 59 |
+
continue
|
| 60 |
+
cls_idx = self.class_to_idx[class_key]
|
| 61 |
+
count = 0
|
| 62 |
+
for f in folder.iterdir():
|
| 63 |
+
if f.suffix.lower() in self.extensions:
|
| 64 |
+
self.samples.append((f, cls_idx))
|
| 65 |
+
count += 1
|
| 66 |
+
class_counts[cls_idx] += count
|
| 67 |
+
matched_folders.append(f" β {folder.name} β {class_key} (idx={cls_idx}): {count} images")
|
| 68 |
+
|
| 69 |
+
# Log results
|
| 70 |
+
logger.info(f"\n{'='*60}")
|
| 71 |
+
logger.info(f"Dataset loading from: {self.root}")
|
| 72 |
+
logger.info(f"{'='*60}")
|
| 73 |
+
for line in matched_folders:
|
| 74 |
+
logger.info(line)
|
| 75 |
+
for uf in unmatched_folders:
|
| 76 |
+
logger.warning(f" β SKIPPED: '{uf}' (not in folder_to_class)")
|
| 77 |
+
logger.info(f"\nClass distribution:")
|
| 78 |
+
for idx in sorted(class_counts.keys()):
|
| 79 |
+
name = [k for k, v in self.class_to_idx.items() if v == idx][0]
|
| 80 |
+
logger.info(f" Class {idx} ({name}): {class_counts[idx]} images")
|
| 81 |
+
logger.info(f"Total: {len(self.samples)} images in {self.num_classes} classes")
|
| 82 |
+
|
| 83 |
+
if len(self.samples) == 0:
|
| 84 |
+
logger.error("NO SAMPLES FOUND! Check that data_root points to the folder containing your class subfolders.")
|
| 85 |
+
raise ValueError(f"No images found in {self.root}. Check folder names match label_mapping.json folder_to_class keys.")
|
| 86 |
+
|
| 87 |
+
# Stratified split
|
| 88 |
+
by_class = defaultdict(list)
|
| 89 |
+
for i, (_, cls) in enumerate(self.samples):
|
| 90 |
+
by_class[cls].append(i)
|
| 91 |
+
|
| 92 |
+
train_idx, val_idx, test_idx = [], [], []
|
| 93 |
+
for cls in sorted(by_class.keys()):
|
| 94 |
+
indices = by_class[cls]
|
| 95 |
+
g = torch.Generator().manual_seed(seed)
|
| 96 |
+
perm = torch.randperm(len(indices), generator=g).tolist()
|
| 97 |
+
n_cls = len(indices)
|
| 98 |
+
n_tr = int(n_cls * train_ratio)
|
| 99 |
+
n_va = int(n_cls * val_ratio)
|
| 100 |
+
train_idx.extend([indices[p] for p in perm[:n_tr]])
|
| 101 |
+
val_idx.extend([indices[p] for p in perm[n_tr:n_tr + n_va]])
|
| 102 |
+
test_idx.extend([indices[p] for p in perm[n_tr + n_va:]])
|
| 103 |
+
|
| 104 |
+
if split == "train":
|
| 105 |
+
self.indices = train_idx
|
| 106 |
+
elif split == "val":
|
| 107 |
+
self.indices = val_idx
|
| 108 |
+
else:
|
| 109 |
+
self.indices = test_idx
|
| 110 |
+
|
| 111 |
+
logger.info(f"Split '{split}': {len(self.indices)} samples\n")
|
| 112 |
+
|
| 113 |
+
def __len__(self):
|
| 114 |
+
return len(self.indices)
|
| 115 |
+
|
| 116 |
+
def __getitem__(self, idx):
|
| 117 |
+
i = self.indices[idx]
|
| 118 |
+
path, cls = self.samples[i]
|
| 119 |
+
img = Image.open(path).convert("RGB")
|
| 120 |
+
if self.transform:
|
| 121 |
+
img = self.transform(img)
|
| 122 |
+
return img, cls
|
| 123 |
+
|
| 124 |
+
def get_issue_type(self, class_idx: int) -> str:
|
| 125 |
+
return self.idx_to_issue_type[class_idx]
|
| 126 |
+
|
| 127 |
+
def get_severity(self, class_idx: int) -> str:
|
| 128 |
+
return self.idx_to_severity[class_idx]
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def get_image_class_weights(label_mapping_path: Path, root: Path) -> torch.Tensor:
|
| 132 |
+
"""Compute inverse frequency weights for class-weighted loss."""
|
| 133 |
+
with open(label_mapping_path) as f:
|
| 134 |
+
lm = json.load(f)
|
| 135 |
+
folder_to_class = lm["image"]["folder_to_class"]
|
| 136 |
+
class_to_idx = lm["image"]["class_to_idx"]
|
| 137 |
+
num_classes = len(class_to_idx)
|
| 138 |
+
counts = [0] * num_classes
|
| 139 |
+
|
| 140 |
+
root = Path(root)
|
| 141 |
+
for folder in root.iterdir():
|
| 142 |
+
if not folder.is_dir():
|
| 143 |
+
continue
|
| 144 |
+
class_key = folder_to_class.get(folder.name)
|
| 145 |
+
if class_key is None:
|
| 146 |
+
continue
|
| 147 |
+
cls_idx = class_to_idx[class_key]
|
| 148 |
+
n = sum(1 for f in folder.iterdir() if f.suffix.lower() in (".jpg", ".jpeg", ".png"))
|
| 149 |
+
counts[cls_idx] += n
|
| 150 |
+
|
| 151 |
+
total = sum(counts)
|
| 152 |
+
if total == 0:
|
| 153 |
+
return torch.ones(num_classes)
|
| 154 |
+
weights = [total / (num_classes * c) if c else 1.0 for c in counts]
|
| 155 |
+
return torch.tensor(weights, dtype=torch.float32)
|
src/fusion/fusion_logic.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Decision-level fusion for Electrical Outlets. No early fusion.
|
| 3 |
+
Rules: final_severity = max(image_severity, audio_severity); result = issue_detected | normal | uncertain.
|
| 4 |
+
"""
|
| 5 |
+
from typing import Optional, Dict, Any
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class ModalityOutput:
|
| 11 |
+
result: str # issue_detected | normal | uncertain
|
| 12 |
+
issue_type: Optional[str] = None
|
| 13 |
+
severity: str = "low"
|
| 14 |
+
confidence: float = 0.0
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _severity_rank(s: str, order: list) -> int:
|
| 18 |
+
try:
|
| 19 |
+
return order.index(s)
|
| 20 |
+
except ValueError:
|
| 21 |
+
return 0
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def fuse_modalities(
|
| 25 |
+
image_out: Optional[ModalityOutput],
|
| 26 |
+
audio_out: Optional[ModalityOutput],
|
| 27 |
+
confidence_issue_min: float = 0.6,
|
| 28 |
+
confidence_normal_min: float = 0.75,
|
| 29 |
+
uncertain_if_disagree: bool = True,
|
| 30 |
+
high_confidence_override: float = 0.92,
|
| 31 |
+
severity_order: Optional[list] = None,
|
| 32 |
+
) -> Dict[str, Any]:
|
| 33 |
+
"""
|
| 34 |
+
Fuse image and/or audio outputs into single diagnostic result.
|
| 35 |
+
Prefer uncertain over normal when in doubt.
|
| 36 |
+
"""
|
| 37 |
+
if severity_order is None:
|
| 38 |
+
severity_order = ["low", "medium", "high", "critical"]
|
| 39 |
+
|
| 40 |
+
modality_contributions = {}
|
| 41 |
+
outputs = []
|
| 42 |
+
if image_out is not None:
|
| 43 |
+
outputs.append(("image", image_out))
|
| 44 |
+
modality_contributions["image"] = {
|
| 45 |
+
"result": image_out.result,
|
| 46 |
+
"issue_type": image_out.issue_type,
|
| 47 |
+
"severity": image_out.severity,
|
| 48 |
+
"confidence": image_out.confidence,
|
| 49 |
+
}
|
| 50 |
+
if audio_out is not None:
|
| 51 |
+
outputs.append(("audio", audio_out))
|
| 52 |
+
modality_contributions["audio"] = {
|
| 53 |
+
"result": audio_out.result,
|
| 54 |
+
"issue_type": audio_out.issue_type,
|
| 55 |
+
"severity": audio_out.severity,
|
| 56 |
+
"confidence": audio_out.confidence,
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
if not outputs:
|
| 60 |
+
return {
|
| 61 |
+
"diagnostic_element": "electrical_outlets",
|
| 62 |
+
"result": "uncertain",
|
| 63 |
+
"issue_type": None,
|
| 64 |
+
"severity": "low",
|
| 65 |
+
"confidence": 0.0,
|
| 66 |
+
"modality_contributions": None,
|
| 67 |
+
"primary_issue": None,
|
| 68 |
+
"secondary_issue": None,
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
# Severity: max across modalities
|
| 72 |
+
max_severity_rank = -1
|
| 73 |
+
max_severity = "low"
|
| 74 |
+
for _, out in outputs:
|
| 75 |
+
r = _severity_rank(out.severity, severity_order)
|
| 76 |
+
if r > max_severity_rank:
|
| 77 |
+
max_severity_rank = r
|
| 78 |
+
max_severity = out.severity
|
| 79 |
+
|
| 80 |
+
# Result and issue_type
|
| 81 |
+
primary_issue = None
|
| 82 |
+
secondary_issue = None
|
| 83 |
+
has_issue = any(o.result == "issue_detected" for _, o in outputs)
|
| 84 |
+
all_normal = all(o.result == "normal" for _, o in outputs)
|
| 85 |
+
max_conf = max(o.confidence for _, o in outputs)
|
| 86 |
+
disagree = len(outputs) == 2 and (
|
| 87 |
+
(outputs[0][1].result == "issue_detected" and outputs[1][1].result == "normal")
|
| 88 |
+
or (outputs[0][1].result == "normal" and outputs[1][1].result == "issue_detected")
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
if has_issue and max_conf >= confidence_issue_min:
|
| 92 |
+
if disagree and uncertain_if_disagree:
|
| 93 |
+
override = any(o.confidence >= high_confidence_override and o.result == "issue_detected" for _, o in outputs)
|
| 94 |
+
if override:
|
| 95 |
+
result = "issue_detected"
|
| 96 |
+
issue_type = next(o.issue_type for _, o in outputs if o.result == "issue_detected" and o.confidence >= high_confidence_override)
|
| 97 |
+
primary_issue = issue_type
|
| 98 |
+
else:
|
| 99 |
+
result = "uncertain"
|
| 100 |
+
issue_type = None
|
| 101 |
+
else:
|
| 102 |
+
result = "issue_detected"
|
| 103 |
+
defect_outs = [(n, o) for n, o in outputs if o.result == "issue_detected"]
|
| 104 |
+
if len(defect_outs) >= 2:
|
| 105 |
+
defect_outs.sort(key=lambda x: _severity_rank(x[1].severity, severity_order), reverse=True)
|
| 106 |
+
issue_type = defect_outs[0][1].issue_type
|
| 107 |
+
primary_issue = defect_outs[0][1].issue_type
|
| 108 |
+
secondary_issue = defect_outs[1][1].issue_type if defect_outs[0][1].issue_type != defect_outs[1][1].issue_type else None
|
| 109 |
+
else:
|
| 110 |
+
issue_type = defect_outs[0][1].issue_type
|
| 111 |
+
primary_issue = issue_type
|
| 112 |
+
elif all_normal and all(o.confidence >= confidence_normal_min for _, o in outputs):
|
| 113 |
+
result = "normal"
|
| 114 |
+
issue_type = "normal"
|
| 115 |
+
else:
|
| 116 |
+
result = "uncertain"
|
| 117 |
+
issue_type = None
|
| 118 |
+
confidence = max_conf if result != "uncertain" else min(o.confidence for _, o in outputs)
|
| 119 |
+
|
| 120 |
+
return {
|
| 121 |
+
"diagnostic_element": "electrical_outlets",
|
| 122 |
+
"result": result,
|
| 123 |
+
"issue_type": issue_type,
|
| 124 |
+
"severity": max_severity,
|
| 125 |
+
"confidence": round(confidence, 4),
|
| 126 |
+
"modality_contributions": modality_contributions if len(modality_contributions) > 1 else None,
|
| 127 |
+
"primary_issue": primary_issue if result == "issue_detected" else None,
|
| 128 |
+
"secondary_issue": secondary_issue,
|
| 129 |
+
}
|
src/inference/wrapper.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Inference wrapper: load image + audio models, run modalities present, apply fusion, return schema.
|
| 3 |
+
"""
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Optional, Dict, Any, BinaryIO
|
| 6 |
+
import json
|
| 7 |
+
import torch
|
| 8 |
+
import torchaudio
|
| 9 |
+
from torchvision import transforms
|
| 10 |
+
from PIL import Image
|
| 11 |
+
|
| 12 |
+
# Optional imports for models
|
| 13 |
+
import sys
|
| 14 |
+
ROOT = Path(__file__).resolve().parent.parent.parent
|
| 15 |
+
sys.path.insert(0, str(ROOT))
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _load_image_model(weights_path: Path, label_mapping_path: Path, device: str):
|
| 19 |
+
from src.models.image_model import ElectricalOutletsImageModel
|
| 20 |
+
ckpt = torch.load(weights_path, map_location=device)
|
| 21 |
+
model = ElectricalOutletsImageModel(
|
| 22 |
+
num_classes=ckpt["num_classes"],
|
| 23 |
+
label_mapping_path=label_mapping_path,
|
| 24 |
+
pretrained=False,
|
| 25 |
+
)
|
| 26 |
+
model.load_state_dict(ckpt["model_state_dict"])
|
| 27 |
+
model.idx_to_issue_type = ckpt.get("idx_to_issue_type")
|
| 28 |
+
model.idx_to_severity = ckpt.get("idx_to_severity")
|
| 29 |
+
model.eval()
|
| 30 |
+
return model.to(device), ckpt.get("temperature", 1.0)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _load_audio_model(weights_path: Path, label_mapping_path: Path, device: str, config: dict):
|
| 34 |
+
from src.models.audio_model import ElectricalOutletsAudioModel
|
| 35 |
+
ckpt = torch.load(weights_path, map_location=device)
|
| 36 |
+
model = ElectricalOutletsAudioModel(
|
| 37 |
+
num_classes=ckpt["num_classes"],
|
| 38 |
+
label_mapping_path=label_mapping_path,
|
| 39 |
+
n_mels=config.get("n_mels", 64),
|
| 40 |
+
time_steps=config.get("time_steps", 128),
|
| 41 |
+
)
|
| 42 |
+
model.load_state_dict(ckpt["model_state_dict"])
|
| 43 |
+
model.idx_to_label = ckpt.get("idx_to_label")
|
| 44 |
+
model.idx_to_issue_type = ckpt.get("idx_to_issue_type")
|
| 45 |
+
model.idx_to_severity = ckpt.get("idx_to_severity")
|
| 46 |
+
model.eval()
|
| 47 |
+
return model.to(device), ckpt.get("temperature", 1.0)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def run_electrical_outlets_inference(
|
| 51 |
+
image_path: Optional[Path] = None,
|
| 52 |
+
image_fp: Optional[BinaryIO] = None,
|
| 53 |
+
audio_path: Optional[Path] = None,
|
| 54 |
+
audio_fp: Optional[BinaryIO] = None,
|
| 55 |
+
weights_dir: Path = None,
|
| 56 |
+
config_dir: Path = None,
|
| 57 |
+
device: str = None,
|
| 58 |
+
) -> Dict[str, Any]:
|
| 59 |
+
"""
|
| 60 |
+
Run image and/or audio model, then fuse. Returns canonical schema dict.
|
| 61 |
+
"""
|
| 62 |
+
if weights_dir is None:
|
| 63 |
+
weights_dir = ROOT / "weights"
|
| 64 |
+
if config_dir is None:
|
| 65 |
+
config_dir = ROOT / "config"
|
| 66 |
+
if device is None:
|
| 67 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 68 |
+
|
| 69 |
+
label_mapping_path = config_dir / "label_mapping.json"
|
| 70 |
+
thresholds_path = config_dir / "thresholds.yaml"
|
| 71 |
+
import yaml
|
| 72 |
+
with open(thresholds_path) as f:
|
| 73 |
+
thresholds = yaml.safe_load(f)
|
| 74 |
+
|
| 75 |
+
image_out = None
|
| 76 |
+
if image_path or image_fp:
|
| 77 |
+
img = Image.open(image_path or image_fp).convert("RGB")
|
| 78 |
+
tf = transforms.Compose([
|
| 79 |
+
transforms.Resize(256),
|
| 80 |
+
transforms.CenterCrop(224),
|
| 81 |
+
transforms.ToTensor(),
|
| 82 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
| 83 |
+
])
|
| 84 |
+
x = tf(img).unsqueeze(0).to(device)
|
| 85 |
+
model, T = _load_image_model(weights_dir / "electrical_outlets_image_best.pt", label_mapping_path, device)
|
| 86 |
+
with torch.no_grad():
|
| 87 |
+
logits = model(x) / T
|
| 88 |
+
from src.fusion.fusion_logic import ModalityOutput
|
| 89 |
+
pred = model.predict_to_schema(logits)
|
| 90 |
+
image_out = ModalityOutput(
|
| 91 |
+
result=pred["result"],
|
| 92 |
+
issue_type=pred.get("issue_type"),
|
| 93 |
+
severity=pred["severity"],
|
| 94 |
+
confidence=pred["confidence"],
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
audio_out = None
|
| 98 |
+
if (audio_path or audio_fp) and (weights_dir / "electrical_outlets_audio_best.pt").exists():
|
| 99 |
+
if audio_path:
|
| 100 |
+
waveform, sr = torchaudio.load(str(audio_path))
|
| 101 |
+
else:
|
| 102 |
+
import io
|
| 103 |
+
waveform, sr = torchaudio.load(io.BytesIO(audio_fp.read()))
|
| 104 |
+
if sr != 16000:
|
| 105 |
+
waveform = torchaudio.functional.resample(waveform, sr, 16000)
|
| 106 |
+
if waveform.shape[0] > 1:
|
| 107 |
+
waveform = waveform.mean(dim=0, keepdim=True)
|
| 108 |
+
target_len = int(5.0 * 16000)
|
| 109 |
+
if waveform.shape[1] >= target_len:
|
| 110 |
+
start = (waveform.shape[1] - target_len) // 2
|
| 111 |
+
waveform = waveform[:, start : start + target_len]
|
| 112 |
+
else:
|
| 113 |
+
waveform = torch.nn.functional.pad(waveform, (0, target_len - waveform.shape[1]))
|
| 114 |
+
mel = torchaudio.transforms.MelSpectrogram(
|
| 115 |
+
sample_rate=16000, n_fft=512, hop_length=256, win_length=512, n_mels=64,
|
| 116 |
+
)(waveform)
|
| 117 |
+
log_mel = torch.log(mel.clamp(min=1e-5)).unsqueeze(0).to(device)
|
| 118 |
+
model, T = _load_audio_model(
|
| 119 |
+
weights_dir / "electrical_outlets_audio_best.pt",
|
| 120 |
+
label_mapping_path,
|
| 121 |
+
device,
|
| 122 |
+
{"n_mels": 64, "time_steps": 128},
|
| 123 |
+
)
|
| 124 |
+
with torch.no_grad():
|
| 125 |
+
logits = model(log_mel) / T
|
| 126 |
+
from src.fusion.fusion_logic import ModalityOutput
|
| 127 |
+
pred = model.predict_to_schema(logits)
|
| 128 |
+
audio_out = ModalityOutput(
|
| 129 |
+
result=pred["result"],
|
| 130 |
+
issue_type=pred.get("issue_type"),
|
| 131 |
+
severity=pred["severity"],
|
| 132 |
+
confidence=pred["confidence"],
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
from src.fusion.fusion_logic import fuse_modalities
|
| 136 |
+
return fuse_modalities(
|
| 137 |
+
image_out,
|
| 138 |
+
audio_out,
|
| 139 |
+
confidence_issue_min=thresholds.get("confidence_issue_min", 0.6),
|
| 140 |
+
confidence_normal_min=thresholds.get("confidence_normal_min", 0.75),
|
| 141 |
+
uncertain_if_disagree=thresholds.get("uncertain_if_disagree", True),
|
| 142 |
+
high_confidence_override=thresholds.get("high_confidence_override", 0.92),
|
| 143 |
+
severity_order=thresholds.get("severity_order"),
|
| 144 |
+
)
|
src/models/audio_model.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Audio classifier for Electrical Outlets. Expects spectrogram or waveform; outputs class logits.
|
| 3 |
+
Severity from label_mapping. Small CNN for 100-sample regime.
|
| 4 |
+
"""
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Dict, Any, Optional
|
| 7 |
+
import json
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class SpectrogramCNN(nn.Module):
|
| 13 |
+
"""Lightweight CNN on mel spectrogram (n_mels x time)."""
|
| 14 |
+
|
| 15 |
+
def __init__(self, n_mels: int = 64, time_steps: int = 128, num_classes: int = 4):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.conv = nn.Sequential(
|
| 18 |
+
nn.Conv2d(1, 32, 3, padding=1),
|
| 19 |
+
nn.BatchNorm2d(32),
|
| 20 |
+
nn.ReLU(),
|
| 21 |
+
nn.MaxPool2d(2),
|
| 22 |
+
nn.Conv2d(32, 64, 3, padding=1),
|
| 23 |
+
nn.BatchNorm2d(64),
|
| 24 |
+
nn.ReLU(),
|
| 25 |
+
nn.MaxPool2d(2),
|
| 26 |
+
nn.Conv2d(64, 128, 3, padding=1),
|
| 27 |
+
nn.BatchNorm2d(128),
|
| 28 |
+
nn.ReLU(),
|
| 29 |
+
nn.AdaptiveAvgPool2d(1),
|
| 30 |
+
)
|
| 31 |
+
self.fc = nn.Linear(128, num_classes)
|
| 32 |
+
|
| 33 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 34 |
+
if x.dim() == 2:
|
| 35 |
+
x = x.unsqueeze(0).unsqueeze(0)
|
| 36 |
+
elif x.dim() == 3:
|
| 37 |
+
x = x.unsqueeze(1)
|
| 38 |
+
x = self.conv(x)
|
| 39 |
+
x = x.flatten(1)
|
| 40 |
+
return self.fc(x)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class ElectricalOutletsAudioModel(nn.Module):
|
| 44 |
+
"""Wrapper: optional mel transform then SpectrogramCNN. Severity from mapping."""
|
| 45 |
+
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
num_classes: int = 4,
|
| 49 |
+
label_mapping_path: Optional[Path] = None,
|
| 50 |
+
n_mels: int = 64,
|
| 51 |
+
time_steps: int = 128,
|
| 52 |
+
):
|
| 53 |
+
super().__init__()
|
| 54 |
+
self.num_classes = num_classes
|
| 55 |
+
self.n_mels = n_mels
|
| 56 |
+
self.time_steps = time_steps
|
| 57 |
+
self.backbone = SpectrogramCNN(n_mels=n_mels, time_steps=time_steps, num_classes=num_classes)
|
| 58 |
+
self.idx_to_label = None
|
| 59 |
+
self.idx_to_issue_type = None
|
| 60 |
+
self.idx_to_severity = None
|
| 61 |
+
if label_mapping_path and Path(label_mapping_path).exists():
|
| 62 |
+
with open(label_mapping_path) as f:
|
| 63 |
+
lm = json.load(f)
|
| 64 |
+
self.idx_to_label = lm["audio"]["idx_to_label"]
|
| 65 |
+
self.idx_to_issue_type = [lm["audio"]["label_to_issue_type"].get(lbl, "normal") for lbl in lm["audio"]["idx_to_label"]]
|
| 66 |
+
self.idx_to_severity = [lm["audio"]["label_to_severity"].get(lm["audio"]["idx_to_label"][i], "medium") for i in range(num_classes)]
|
| 67 |
+
|
| 68 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 69 |
+
return self.backbone(x)
|
| 70 |
+
|
| 71 |
+
def predict_to_schema(self, logits: torch.Tensor) -> Dict[str, Any]:
|
| 72 |
+
probs = torch.softmax(logits, dim=-1)
|
| 73 |
+
if logits.dim() == 1:
|
| 74 |
+
probs = probs.unsqueeze(0)
|
| 75 |
+
conf, pred = probs.max(dim=-1)
|
| 76 |
+
pred = pred.item() if pred.numel() == 1 else pred
|
| 77 |
+
conf = conf.item() if conf.numel() == 1 else conf
|
| 78 |
+
issue_type = (self.idx_to_issue_type or ["normal"] * self.num_classes)[pred]
|
| 79 |
+
severity = (self.idx_to_severity or ["medium"] * self.num_classes)[pred]
|
| 80 |
+
result = "normal" if issue_type == "normal" else "issue_detected"
|
| 81 |
+
return {
|
| 82 |
+
"result": result,
|
| 83 |
+
"issue_type": issue_type,
|
| 84 |
+
"severity": severity,
|
| 85 |
+
"confidence": float(conf),
|
| 86 |
+
"class_idx": int(pred),
|
| 87 |
+
}
|
src/models/image_model.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Image classifier for Electrical Outlets. EfficientNet-B0 backbone + MLP head.
|
| 3 |
+
FINAL v5: 5 classes (no GFCI).
|
| 4 |
+
"""
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Dict, Any, Optional
|
| 7 |
+
import json
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
from torchvision import models
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ElectricalOutletsImageModel(nn.Module):
|
| 14 |
+
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
num_classes: int = 5,
|
| 18 |
+
label_mapping_path: Optional[Path] = None,
|
| 19 |
+
pretrained: bool = True,
|
| 20 |
+
head_hidden: int = 256,
|
| 21 |
+
head_dropout: float = 0.4,
|
| 22 |
+
):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.num_classes = num_classes
|
| 25 |
+
self.backbone = models.efficientnet_b0(
|
| 26 |
+
weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1 if pretrained else None
|
| 27 |
+
)
|
| 28 |
+
in_features = self.backbone.classifier[1].in_features # 1280
|
| 29 |
+
self.backbone.classifier = nn.Identity()
|
| 30 |
+
|
| 31 |
+
self.head = nn.Sequential(
|
| 32 |
+
nn.Dropout(head_dropout),
|
| 33 |
+
nn.Linear(in_features, head_hidden),
|
| 34 |
+
nn.ReLU(),
|
| 35 |
+
nn.Dropout(head_dropout * 0.5),
|
| 36 |
+
nn.Linear(head_hidden, num_classes),
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
self.idx_to_issue_type = None
|
| 40 |
+
self.idx_to_severity = None
|
| 41 |
+
if label_mapping_path and Path(label_mapping_path).exists():
|
| 42 |
+
with open(label_mapping_path) as f:
|
| 43 |
+
lm = json.load(f)
|
| 44 |
+
self.idx_to_issue_type = lm["image"]["idx_to_issue_type"]
|
| 45 |
+
self.idx_to_severity = lm["image"]["idx_to_severity"]
|
| 46 |
+
|
| 47 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 48 |
+
features = self.backbone(x)
|
| 49 |
+
return self.head(features)
|
| 50 |
+
|
| 51 |
+
def predict_to_schema(self, logits: torch.Tensor) -> Dict[str, Any]:
|
| 52 |
+
probs = torch.softmax(logits, dim=-1)
|
| 53 |
+
if logits.dim() == 1:
|
| 54 |
+
probs = probs.unsqueeze(0)
|
| 55 |
+
conf, pred = probs.max(dim=-1)
|
| 56 |
+
pred = pred.item() if pred.numel() == 1 else pred
|
| 57 |
+
conf = conf.item() if conf.numel() == 1 else conf
|
| 58 |
+
issue_type = (self.idx_to_issue_type or ["unknown"] * self.num_classes)[pred]
|
| 59 |
+
severity = (self.idx_to_severity or ["medium"] * self.num_classes)[pred]
|
| 60 |
+
result = "normal" if issue_type == "normal" else "issue_detected"
|
| 61 |
+
return {
|
| 62 |
+
"result": result,
|
| 63 |
+
"issue_type": issue_type,
|
| 64 |
+
"severity": severity,
|
| 65 |
+
"confidence": float(conf),
|
| 66 |
+
"class_idx": int(pred),
|
| 67 |
+
}
|
test.py
ADDED
|
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test script for Electrical Outlets diagnostic pipeline.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
python test.py --image path/to/outlet.jpg # Test image only
|
| 6 |
+
python test.py --audio path/to/recording.wav # Test audio only
|
| 7 |
+
python test.py --image photo.jpg --audio recording.wav # Test both (fusion)
|
| 8 |
+
python test.py --list # List sample images from dataset
|
| 9 |
+
python test.py --eval # Run full validation set evaluation
|
| 10 |
+
|
| 11 |
+
Requirements:
|
| 12 |
+
pip install torch torchvision torchaudio Pillow PyYAML soundfile
|
| 13 |
+
"""
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
import sys
|
| 16 |
+
import argparse
|
| 17 |
+
import json
|
| 18 |
+
from collections import defaultdict
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
from torchvision import transforms
|
| 22 |
+
from PIL import Image
|
| 23 |
+
|
| 24 |
+
ROOT = Path(__file__).resolve().parent
|
| 25 |
+
sys.path.insert(0, str(ROOT))
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def load_image_model(weights_path, mapping_path, device):
|
| 29 |
+
from src.models.image_model import ElectricalOutletsImageModel
|
| 30 |
+
|
| 31 |
+
ckpt = torch.load(weights_path, map_location=device, weights_only=False)
|
| 32 |
+
# Infer head_hidden from saved weights (head.1 is the first Linear)
|
| 33 |
+
head_hidden = ckpt["model_state_dict"]["head.1.weight"].shape[0]
|
| 34 |
+
model = ElectricalOutletsImageModel(
|
| 35 |
+
num_classes=ckpt["num_classes"],
|
| 36 |
+
label_mapping_path=Path(mapping_path),
|
| 37 |
+
pretrained=False,
|
| 38 |
+
head_hidden=head_hidden,
|
| 39 |
+
)
|
| 40 |
+
model.load_state_dict(ckpt["model_state_dict"])
|
| 41 |
+
model.idx_to_issue_type = ckpt.get("idx_to_issue_type")
|
| 42 |
+
model.idx_to_severity = ckpt.get("idx_to_severity")
|
| 43 |
+
model.eval().to(device)
|
| 44 |
+
T = ckpt.get("temperature", 1.0)
|
| 45 |
+
# Clamp bad temperature values
|
| 46 |
+
if T <= 0 or T > 10:
|
| 47 |
+
T = 1.0
|
| 48 |
+
return model, T
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def load_audio_model(weights_path, mapping_path, device):
|
| 52 |
+
from src.models.audio_model import ElectricalOutletsAudioModel
|
| 53 |
+
import yaml
|
| 54 |
+
|
| 55 |
+
ckpt = torch.load(weights_path, map_location=device, weights_only=False)
|
| 56 |
+
|
| 57 |
+
# Load audio config for n_mels
|
| 58 |
+
audio_cfg_path = ROOT / "config" / "audio_train_config.yaml"
|
| 59 |
+
n_mels, time_steps = 128, 128
|
| 60 |
+
if audio_cfg_path.exists():
|
| 61 |
+
with open(audio_cfg_path) as f:
|
| 62 |
+
acfg = yaml.safe_load(f)
|
| 63 |
+
n_mels = acfg.get("model", {}).get("n_mels", 128)
|
| 64 |
+
time_steps = acfg.get("model", {}).get("time_steps", 128)
|
| 65 |
+
|
| 66 |
+
model = ElectricalOutletsAudioModel(
|
| 67 |
+
num_classes=ckpt["num_classes"],
|
| 68 |
+
label_mapping_path=Path(mapping_path),
|
| 69 |
+
n_mels=n_mels,
|
| 70 |
+
time_steps=time_steps,
|
| 71 |
+
)
|
| 72 |
+
model.load_state_dict(ckpt["model_state_dict"])
|
| 73 |
+
model.idx_to_label = ckpt.get("idx_to_label")
|
| 74 |
+
model.idx_to_issue_type = ckpt.get("idx_to_issue_type")
|
| 75 |
+
model.idx_to_severity = ckpt.get("idx_to_severity")
|
| 76 |
+
model.eval().to(device)
|
| 77 |
+
T = ckpt.get("temperature", 1.0)
|
| 78 |
+
if T <= 0 or T > 10:
|
| 79 |
+
T = 1.0
|
| 80 |
+
return model, T
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def predict_image(image_path, device="cuda"):
|
| 84 |
+
weights = ROOT / "weights" / "electrical_outlets_image_best.pt"
|
| 85 |
+
mapping = ROOT / "config" / "label_mapping.json"
|
| 86 |
+
|
| 87 |
+
if not weights.exists():
|
| 88 |
+
print(f"ERROR: Image weights not found at {weights}")
|
| 89 |
+
return None
|
| 90 |
+
|
| 91 |
+
model, T = load_image_model(weights, mapping, device)
|
| 92 |
+
|
| 93 |
+
tf = transforms.Compose([
|
| 94 |
+
transforms.Resize(256),
|
| 95 |
+
transforms.CenterCrop(224),
|
| 96 |
+
transforms.ToTensor(),
|
| 97 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
| 98 |
+
])
|
| 99 |
+
img = Image.open(image_path).convert("RGB")
|
| 100 |
+
x = tf(img).unsqueeze(0).to(device)
|
| 101 |
+
|
| 102 |
+
with torch.no_grad():
|
| 103 |
+
logits = model(x) / T
|
| 104 |
+
probs = torch.softmax(logits, dim=-1)
|
| 105 |
+
|
| 106 |
+
pred = model.predict_to_schema(logits)
|
| 107 |
+
|
| 108 |
+
print(f"\n{'='*55}")
|
| 109 |
+
print(f" IMAGE: {Path(image_path).name}")
|
| 110 |
+
print(f"{'='*55}")
|
| 111 |
+
print(f" Prediction: {pred['issue_type']}")
|
| 112 |
+
print(f" Severity: {pred['severity']}")
|
| 113 |
+
print(f" Confidence: {pred['confidence']:.1%}")
|
| 114 |
+
print(f" Result: {pred['result']}")
|
| 115 |
+
print(f"\n Class probabilities:")
|
| 116 |
+
for i, p in enumerate(probs[0].tolist()):
|
| 117 |
+
name = model.idx_to_issue_type[i] if model.idx_to_issue_type else f"class_{i}"
|
| 118 |
+
bar = "β" * int(p * 30)
|
| 119 |
+
tag = " β" if i == pred["class_idx"] else ""
|
| 120 |
+
print(f" {name:20s} {p:6.1%} {bar}{tag}")
|
| 121 |
+
|
| 122 |
+
return pred
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def predict_audio(audio_path, device="cuda"):
|
| 126 |
+
import torchaudio
|
| 127 |
+
import yaml
|
| 128 |
+
|
| 129 |
+
weights = ROOT / "weights" / "electrical_outlets_audio_best.pt"
|
| 130 |
+
mapping = ROOT / "config" / "label_mapping.json"
|
| 131 |
+
|
| 132 |
+
if not weights.exists():
|
| 133 |
+
print(f"ERROR: Audio weights not found at {weights}")
|
| 134 |
+
return None
|
| 135 |
+
|
| 136 |
+
model, T = load_audio_model(weights, mapping, device)
|
| 137 |
+
|
| 138 |
+
# Load audio config
|
| 139 |
+
audio_cfg_path = ROOT / "config" / "audio_train_config.yaml"
|
| 140 |
+
sample_rate, n_mels, n_fft, hop, win = 22050, 128, 1024, 512, 1024
|
| 141 |
+
target_sec = 5.0
|
| 142 |
+
if audio_cfg_path.exists():
|
| 143 |
+
with open(audio_cfg_path) as f:
|
| 144 |
+
acfg = yaml.safe_load(f)
|
| 145 |
+
sample_rate = acfg["data"].get("sample_rate", 22050)
|
| 146 |
+
target_sec = acfg["data"].get("target_length_sec", 5.0)
|
| 147 |
+
sc = acfg.get("spectrogram", {})
|
| 148 |
+
n_mels = sc.get("n_mels", 128)
|
| 149 |
+
n_fft = sc.get("n_fft", 1024)
|
| 150 |
+
hop = sc.get("hop_length", 512)
|
| 151 |
+
win = sc.get("win_length", 1024)
|
| 152 |
+
|
| 153 |
+
waveform, sr = torchaudio.load(str(audio_path))
|
| 154 |
+
if sr != sample_rate:
|
| 155 |
+
waveform = torchaudio.functional.resample(waveform, sr, sample_rate)
|
| 156 |
+
if waveform.shape[0] > 1:
|
| 157 |
+
waveform = waveform.mean(dim=0, keepdim=True)
|
| 158 |
+
|
| 159 |
+
target_len = int(target_sec * sample_rate)
|
| 160 |
+
if waveform.shape[1] >= target_len:
|
| 161 |
+
start = (waveform.shape[1] - target_len) // 2
|
| 162 |
+
waveform = waveform[:, start:start + target_len]
|
| 163 |
+
else:
|
| 164 |
+
waveform = torch.nn.functional.pad(waveform, (0, target_len - waveform.shape[1]))
|
| 165 |
+
|
| 166 |
+
mel = torchaudio.transforms.MelSpectrogram(
|
| 167 |
+
sample_rate=sample_rate, n_fft=n_fft, hop_length=hop,
|
| 168 |
+
win_length=win, n_mels=n_mels,
|
| 169 |
+
)(waveform)
|
| 170 |
+
log_mel = torch.log(mel.clamp(min=1e-5)).unsqueeze(0).to(device)
|
| 171 |
+
|
| 172 |
+
with torch.no_grad():
|
| 173 |
+
logits = model(log_mel) / T
|
| 174 |
+
probs = torch.softmax(logits, dim=-1)
|
| 175 |
+
|
| 176 |
+
pred = model.predict_to_schema(logits)
|
| 177 |
+
|
| 178 |
+
print(f"\n{'='*55}")
|
| 179 |
+
print(f" AUDIO: {Path(audio_path).name}")
|
| 180 |
+
print(f"{'='*55}")
|
| 181 |
+
print(f" Prediction: {pred['issue_type']}")
|
| 182 |
+
print(f" Severity: {pred['severity']}")
|
| 183 |
+
print(f" Confidence: {pred['confidence']:.1%}")
|
| 184 |
+
print(f" Result: {pred['result']}")
|
| 185 |
+
print(f"\n Class probabilities:")
|
| 186 |
+
labels = model.idx_to_label or [f"class_{i}" for i in range(model.num_classes)]
|
| 187 |
+
for i, p in enumerate(probs[0].tolist()):
|
| 188 |
+
bar = "β" * int(p * 30)
|
| 189 |
+
tag = " β" if i == pred["class_idx"] else ""
|
| 190 |
+
print(f" {labels[i]:20s} {p:6.1%} {bar}{tag}")
|
| 191 |
+
|
| 192 |
+
return pred
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def run_fusion(image_pred, audio_pred):
|
| 196 |
+
from src.fusion.fusion_logic import fuse_modalities, ModalityOutput
|
| 197 |
+
import yaml
|
| 198 |
+
|
| 199 |
+
thresholds_path = ROOT / "config" / "thresholds.yaml"
|
| 200 |
+
thresholds = {}
|
| 201 |
+
if thresholds_path.exists():
|
| 202 |
+
with open(thresholds_path) as f:
|
| 203 |
+
thresholds = yaml.safe_load(f)
|
| 204 |
+
|
| 205 |
+
image_out = ModalityOutput(
|
| 206 |
+
result=image_pred["result"],
|
| 207 |
+
issue_type=image_pred.get("issue_type"),
|
| 208 |
+
severity=image_pred["severity"],
|
| 209 |
+
confidence=image_pred["confidence"],
|
| 210 |
+
) if image_pred else None
|
| 211 |
+
|
| 212 |
+
audio_out = ModalityOutput(
|
| 213 |
+
result=audio_pred["result"],
|
| 214 |
+
issue_type=audio_pred.get("issue_type"),
|
| 215 |
+
severity=audio_pred["severity"],
|
| 216 |
+
confidence=audio_pred["confidence"],
|
| 217 |
+
) if audio_pred else None
|
| 218 |
+
|
| 219 |
+
result = fuse_modalities(
|
| 220 |
+
image_out, audio_out,
|
| 221 |
+
confidence_issue_min=thresholds.get("confidence_issue_min", 0.6),
|
| 222 |
+
confidence_normal_min=thresholds.get("confidence_normal_min", 0.75),
|
| 223 |
+
uncertain_if_disagree=thresholds.get("uncertain_if_disagree", True),
|
| 224 |
+
high_confidence_override=thresholds.get("high_confidence_override", 0.92),
|
| 225 |
+
severity_order=thresholds.get("severity_order"),
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
print(f"\n{'='*55}")
|
| 229 |
+
print(f" FUSED RESULT")
|
| 230 |
+
print(f"{'='*55}")
|
| 231 |
+
print(f" Result: {result['result']}")
|
| 232 |
+
print(f" Issue: {result['issue_type']}")
|
| 233 |
+
print(f" Severity: {result['severity']}")
|
| 234 |
+
print(f" Confidence: {result['confidence']:.1%}")
|
| 235 |
+
if result.get("primary_issue"):
|
| 236 |
+
print(f" Primary: {result['primary_issue']}")
|
| 237 |
+
if result.get("secondary_issue"):
|
| 238 |
+
print(f" Secondary: {result['secondary_issue']}")
|
| 239 |
+
|
| 240 |
+
return result
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def list_samples():
|
| 244 |
+
mapping_path = ROOT / "config" / "label_mapping.json"
|
| 245 |
+
with open(mapping_path) as f:
|
| 246 |
+
lm = json.load(f)
|
| 247 |
+
|
| 248 |
+
data_root = ROOT / "ELECTRICAL OUTLETS-20260106T153508Z-3-001"
|
| 249 |
+
if not data_root.exists():
|
| 250 |
+
print(f"Dataset not found at {data_root}")
|
| 251 |
+
return
|
| 252 |
+
|
| 253 |
+
print(f"\nDataset: {data_root}")
|
| 254 |
+
print(f"{'='*60}")
|
| 255 |
+
for folder in sorted(data_root.iterdir()):
|
| 256 |
+
if not folder.is_dir():
|
| 257 |
+
continue
|
| 258 |
+
cls = lm["image"]["folder_to_class"].get(folder.name, "UNMAPPED")
|
| 259 |
+
imgs = list(folder.glob("*.jpg")) + list(folder.glob("*.jpeg")) + list(folder.glob("*.png"))
|
| 260 |
+
print(f"\n {folder.name}")
|
| 261 |
+
print(f" β class: {cls} | {len(imgs)} images")
|
| 262 |
+
for img in imgs[:3]:
|
| 263 |
+
print(f" {img}")
|
| 264 |
+
|
| 265 |
+
# Audio
|
| 266 |
+
audio_root = ROOT / "electrical_outlets_sounds_100"
|
| 267 |
+
if audio_root.exists():
|
| 268 |
+
print(f"\n\nAudio: {audio_root}")
|
| 269 |
+
print(f"{'='*60}")
|
| 270 |
+
for folder in sorted(audio_root.iterdir()):
|
| 271 |
+
if folder.is_dir():
|
| 272 |
+
wavs = list(folder.glob("*.wav"))
|
| 273 |
+
print(f" {folder.name}: {len(wavs)} files")
|
| 274 |
+
for w in wavs[:2]:
|
| 275 |
+
print(f" {w}")
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def run_eval(device="cuda"):
|
| 279 |
+
"""Run full evaluation on validation split."""
|
| 280 |
+
weights = ROOT / "weights" / "electrical_outlets_image_best.pt"
|
| 281 |
+
mapping = ROOT / "config" / "label_mapping.json"
|
| 282 |
+
|
| 283 |
+
if not weights.exists():
|
| 284 |
+
print("No image weights found.")
|
| 285 |
+
return
|
| 286 |
+
|
| 287 |
+
model, T = load_image_model(weights, mapping, device)
|
| 288 |
+
|
| 289 |
+
import yaml
|
| 290 |
+
cfg_path = ROOT / "config" / "image_train_config.yaml"
|
| 291 |
+
with open(cfg_path) as f:
|
| 292 |
+
cfg = yaml.safe_load(f)
|
| 293 |
+
|
| 294 |
+
from src.data.image_dataset import ElectricalOutletsImageDataset
|
| 295 |
+
val_tf = transforms.Compose([
|
| 296 |
+
transforms.Resize(256),
|
| 297 |
+
transforms.CenterCrop(224),
|
| 298 |
+
transforms.ToTensor(),
|
| 299 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
| 300 |
+
])
|
| 301 |
+
data_root = ROOT / cfg["data"]["root"]
|
| 302 |
+
val_ds = ElectricalOutletsImageDataset(
|
| 303 |
+
data_root, mapping, split="val",
|
| 304 |
+
train_ratio=cfg["data"]["train_ratio"],
|
| 305 |
+
val_ratio=cfg["data"]["val_ratio"],
|
| 306 |
+
seed=cfg["data"].get("seed", 42),
|
| 307 |
+
transform=val_tf,
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
with open(mapping) as f:
|
| 311 |
+
lm = json.load(f)
|
| 312 |
+
issue_names = lm["image"]["idx_to_issue_type"]
|
| 313 |
+
|
| 314 |
+
correct = 0
|
| 315 |
+
total = 0
|
| 316 |
+
class_correct = defaultdict(int)
|
| 317 |
+
class_total = defaultdict(int)
|
| 318 |
+
confusion = defaultdict(lambda: defaultdict(int))
|
| 319 |
+
|
| 320 |
+
model.eval()
|
| 321 |
+
with torch.no_grad():
|
| 322 |
+
for i in range(len(val_ds)):
|
| 323 |
+
x, y = val_ds[i]
|
| 324 |
+
logits = model(x.unsqueeze(0).to(device)) / T
|
| 325 |
+
pred = logits.argmax(1).item()
|
| 326 |
+
correct += (pred == y)
|
| 327 |
+
total += 1
|
| 328 |
+
class_correct[y] += (pred == y)
|
| 329 |
+
class_total[y] += 1
|
| 330 |
+
confusion[y][pred] += 1
|
| 331 |
+
|
| 332 |
+
print(f"\n{'='*55}")
|
| 333 |
+
print(f" VALIDATION RESULTS ({total} samples)")
|
| 334 |
+
print(f"{'='*55}")
|
| 335 |
+
print(f" Overall accuracy: {correct/total:.1%}")
|
| 336 |
+
print(f"\n Per-class recall:")
|
| 337 |
+
for c in sorted(class_total.keys()):
|
| 338 |
+
name = issue_names[c] if c < len(issue_names) else f"class_{c}"
|
| 339 |
+
recall = class_correct[c] / class_total[c] if class_total[c] > 0 else 0
|
| 340 |
+
bar = "β" * int(recall * 20)
|
| 341 |
+
print(f" {name:20s} {recall:6.1%} ({class_correct[c]}/{class_total[c]}) {bar}")
|
| 342 |
+
|
| 343 |
+
print(f"\n Confusion matrix:")
|
| 344 |
+
classes = sorted(class_total.keys())
|
| 345 |
+
header = " Actual \\ Pred " + "".join(f"{issue_names[c][:8]:>9s}" for c in classes)
|
| 346 |
+
print(header)
|
| 347 |
+
for actual in classes:
|
| 348 |
+
row = f" {issue_names[actual][:14]:14s}"
|
| 349 |
+
for pred_c in classes:
|
| 350 |
+
count = confusion[actual][pred_c]
|
| 351 |
+
row += f" {count:6d}" if count > 0 else f" {'Β·':>6s}"
|
| 352 |
+
row += " "
|
| 353 |
+
print(row)
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
if __name__ == "__main__":
|
| 357 |
+
parser = argparse.ArgumentParser(description="Test Electrical Outlets Diagnostic Pipeline")
|
| 358 |
+
parser.add_argument("--image", type=str, help="Path to image file")
|
| 359 |
+
parser.add_argument("--audio", type=str, help="Path to audio WAV file")
|
| 360 |
+
parser.add_argument("--list", action="store_true", help="List sample files from dataset")
|
| 361 |
+
parser.add_argument("--eval", action="store_true", help="Run full validation evaluation")
|
| 362 |
+
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
|
| 363 |
+
args = parser.parse_args()
|
| 364 |
+
|
| 365 |
+
if args.list:
|
| 366 |
+
list_samples()
|
| 367 |
+
elif args.eval:
|
| 368 |
+
run_eval(args.device)
|
| 369 |
+
elif args.image or args.audio:
|
| 370 |
+
img_pred = predict_image(args.image, args.device) if args.image else None
|
| 371 |
+
audio_pred = predict_audio(args.audio, args.device) if args.audio else None
|
| 372 |
+
if img_pred and audio_pred:
|
| 373 |
+
run_fusion(img_pred, audio_pred)
|
| 374 |
+
print()
|
| 375 |
+
else:
|
| 376 |
+
print("Electrical Outlets Diagnostic Pipeline β Test Script")
|
| 377 |
+
print("=" * 55)
|
| 378 |
+
print()
|
| 379 |
+
print("Usage:")
|
| 380 |
+
print(" python test.py --image path/to/photo.jpg")
|
| 381 |
+
print(" python test.py --audio path/to/recording.wav")
|
| 382 |
+
print(" python test.py --image photo.jpg --audio recording.wav")
|
| 383 |
+
print(" python test.py --list")
|
| 384 |
+
print(" python test.py --eval")
|
| 385 |
+
print()
|
| 386 |
+
print("Examples:")
|
| 387 |
+
print(' python test.py --image "ELECTRICAL OUTLETS-20260106T153508Z-3-001\\Burn marks - overheating 250\\img_001.jpg"')
|
| 388 |
+
print(' python test.py --audio "electrical_outlets_sounds_100\\buzzing_outlet\\buzzing_outlet_060.wav"')
|
test_single_image.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Quick test: classify a single image.
|
| 3 |
+
python test_single_image.py --image "path/to/image.jpg"
|
| 4 |
+
python test_single_image.py --list
|
| 5 |
+
"""
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import sys
|
| 8 |
+
import argparse
|
| 9 |
+
import json
|
| 10 |
+
import torch
|
| 11 |
+
from torchvision import transforms
|
| 12 |
+
from PIL import Image
|
| 13 |
+
|
| 14 |
+
ROOT = Path(__file__).resolve().parent
|
| 15 |
+
sys.path.insert(0, str(ROOT))
|
| 16 |
+
from src.models.image_model import ElectricalOutletsImageModel
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def predict(image_path, weights="weights/electrical_outlets_image_best.pt",
|
| 20 |
+
mapping="config/label_mapping.json", device=None):
|
| 21 |
+
if device is None:
|
| 22 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 23 |
+
|
| 24 |
+
ckpt = torch.load(weights, map_location=device, weights_only=False)
|
| 25 |
+
head_hidden = ckpt["model_state_dict"]["head.1.weight"].shape[0]
|
| 26 |
+
model = ElectricalOutletsImageModel(
|
| 27 |
+
num_classes=ckpt["num_classes"],
|
| 28 |
+
label_mapping_path=Path(mapping),
|
| 29 |
+
pretrained=False,
|
| 30 |
+
head_hidden=head_hidden,
|
| 31 |
+
)
|
| 32 |
+
model.load_state_dict(ckpt["model_state_dict"])
|
| 33 |
+
model.idx_to_issue_type = ckpt.get("idx_to_issue_type")
|
| 34 |
+
model.idx_to_severity = ckpt.get("idx_to_severity")
|
| 35 |
+
model.eval().to(device)
|
| 36 |
+
T = ckpt.get("temperature", 1.0)
|
| 37 |
+
if T <= 0 or T > 10:
|
| 38 |
+
T = 1.0
|
| 39 |
+
|
| 40 |
+
tf = transforms.Compose([
|
| 41 |
+
transforms.Resize(256), transforms.CenterCrop(224),
|
| 42 |
+
transforms.ToTensor(),
|
| 43 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
| 44 |
+
])
|
| 45 |
+
img = Image.open(image_path).convert("RGB")
|
| 46 |
+
x = tf(img).unsqueeze(0).to(device)
|
| 47 |
+
|
| 48 |
+
with torch.no_grad():
|
| 49 |
+
logits = model(x) / T
|
| 50 |
+
probs = torch.softmax(logits, dim=-1)
|
| 51 |
+
pred = model.predict_to_schema(logits)
|
| 52 |
+
|
| 53 |
+
print(f"\n{'='*50}")
|
| 54 |
+
print(f" {Path(image_path).name}")
|
| 55 |
+
print(f"{'='*50}")
|
| 56 |
+
print(f" -> {pred['issue_type']} ({pred['severity']} severity)")
|
| 57 |
+
print(f" -> {pred['confidence']:.1%} confidence")
|
| 58 |
+
print(f" -> {pred['result']}")
|
| 59 |
+
print()
|
| 60 |
+
for i, p in enumerate(probs[0].tolist()):
|
| 61 |
+
name = model.idx_to_issue_type[i]
|
| 62 |
+
bar = "β" * int(p * 30)
|
| 63 |
+
tag = " β" if i == pred["class_idx"] else ""
|
| 64 |
+
print(f" {name:20s} {p:6.1%} {bar}{tag}")
|
| 65 |
+
print()
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
if __name__ == "__main__":
|
| 69 |
+
p = argparse.ArgumentParser()
|
| 70 |
+
p.add_argument("--image", type=str)
|
| 71 |
+
p.add_argument("--list", action="store_true")
|
| 72 |
+
p.add_argument("--weights", default="weights/electrical_outlets_image_best.pt")
|
| 73 |
+
args = p.parse_args()
|
| 74 |
+
|
| 75 |
+
if args.list:
|
| 76 |
+
with open("config/label_mapping.json") as f:
|
| 77 |
+
lm = json.load(f)
|
| 78 |
+
root = Path("ELECTRICAL OUTLETS-20260106T153508Z-3-001")
|
| 79 |
+
for folder in sorted(root.iterdir()):
|
| 80 |
+
if folder.is_dir():
|
| 81 |
+
imgs = list(folder.glob("*.jpg")) + list(folder.glob("*.jpeg")) + list(folder.glob("*.png"))
|
| 82 |
+
cls = lm["image"]["folder_to_class"].get(folder.name, "UNMAPPED")
|
| 83 |
+
print(f"\n{folder.name} -> {cls} ({len(imgs)} imgs)")
|
| 84 |
+
for img in imgs[:2]:
|
| 85 |
+
print(f" {img}")
|
| 86 |
+
elif args.image:
|
| 87 |
+
predict(args.image, args.weights)
|
| 88 |
+
else:
|
| 89 |
+
print("python test_single_image.py --image path/to/img.jpg")
|
| 90 |
+
print("python test_single_image.py --list")
|
tests/test_fusion.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Unit tests for decision-level fusion."""
|
| 2 |
+
import sys
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
| 5 |
+
|
| 6 |
+
from src.fusion.fusion_logic import fuse_modalities, ModalityOutput
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def test_image_only_issue():
|
| 10 |
+
out = fuse_modalities(
|
| 11 |
+
image_out=ModalityOutput("issue_detected", "burn_overheating", "high", 0.9),
|
| 12 |
+
audio_out=None,
|
| 13 |
+
)
|
| 14 |
+
assert out["result"] == "issue_detected"
|
| 15 |
+
assert out["severity"] == "high"
|
| 16 |
+
assert out["issue_type"] == "burn_overheating"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def test_both_normal_high_conf():
|
| 20 |
+
out = fuse_modalities(
|
| 21 |
+
image_out=ModalityOutput("normal", "normal", "low", 0.85),
|
| 22 |
+
audio_out=ModalityOutput("normal", "normal", "low", 0.8),
|
| 23 |
+
confidence_normal_min=0.75,
|
| 24 |
+
)
|
| 25 |
+
assert out["result"] == "normal"
|
| 26 |
+
assert out["severity"] == "low"
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def test_severity_max():
|
| 30 |
+
out = fuse_modalities(
|
| 31 |
+
image_out=ModalityOutput("issue_detected", "cracked_faceplate", "medium", 0.88),
|
| 32 |
+
audio_out=ModalityOutput("issue_detected", "arcing_pop", "critical", 0.85),
|
| 33 |
+
)
|
| 34 |
+
assert out["severity"] == "critical"
|
| 35 |
+
assert out["result"] == "issue_detected"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def test_uncertain_low_confidence():
|
| 39 |
+
out = fuse_modalities(
|
| 40 |
+
image_out=ModalityOutput("issue_detected", "buzzing", "high", 0.5),
|
| 41 |
+
audio_out=None,
|
| 42 |
+
confidence_issue_min=0.6,
|
| 43 |
+
)
|
| 44 |
+
assert out["result"] == "uncertain"
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def test_uncertain_disagree():
|
| 48 |
+
out = fuse_modalities(
|
| 49 |
+
image_out=ModalityOutput("issue_detected", "burn_overheating", "high", 0.7),
|
| 50 |
+
audio_out=ModalityOutput("normal", "normal", "low", 0.7),
|
| 51 |
+
uncertain_if_disagree=True,
|
| 52 |
+
high_confidence_override=0.92,
|
| 53 |
+
)
|
| 54 |
+
assert out["result"] == "uncertain"
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def test_no_input():
|
| 58 |
+
out = fuse_modalities(None, None)
|
| 59 |
+
assert out["result"] == "uncertain"
|
| 60 |
+
assert out["confidence"] == 0.0
|
training/train_audio.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Train Electrical Outlets audio model. Spectrogram CNN, class weights, per-class recall, early stopping.
|
| 3 |
+
"""
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import sys
|
| 6 |
+
import argparse
|
| 7 |
+
from typing import Dict
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from torch.utils.data import DataLoader
|
| 13 |
+
|
| 14 |
+
ROOT = Path(__file__).resolve().parent.parent
|
| 15 |
+
sys.path.insert(0, str(ROOT))
|
| 16 |
+
|
| 17 |
+
from src.data.audio_dataset import ElectricalOutletsAudioDataset
|
| 18 |
+
from src.models.audio_model import ElectricalOutletsAudioModel
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def load_config(config_path: Path) -> dict:
|
| 22 |
+
import yaml
|
| 23 |
+
with open(config_path) as f:
|
| 24 |
+
return yaml.safe_load(f)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _wave_to_mel(waveform: torch.Tensor, n_mels: int, n_fft: int, hop: int, win: int) -> torch.Tensor:
|
| 28 |
+
import torchaudio
|
| 29 |
+
mel = torchaudio.transforms.MelSpectrogram(
|
| 30 |
+
sample_rate=16000, n_fft=n_fft, hop_length=hop, win_length=win, n_mels=n_mels,
|
| 31 |
+
)(waveform)
|
| 32 |
+
log_mel = torch.log(mel.clamp(min=1e-5))
|
| 33 |
+
return log_mel
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def per_class_recall(logits: torch.Tensor, targets: torch.Tensor, num_classes: int) -> Dict[int, float]:
|
| 37 |
+
preds = logits.argmax(dim=1)
|
| 38 |
+
recall = {}
|
| 39 |
+
for c in range(num_classes):
|
| 40 |
+
mask = targets == c
|
| 41 |
+
if mask.sum() == 0:
|
| 42 |
+
recall[c] = 0.0
|
| 43 |
+
else:
|
| 44 |
+
recall[c] = (preds[mask] == c).float().mean().item()
|
| 45 |
+
return recall
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def run_training(
|
| 49 |
+
data_root: Path,
|
| 50 |
+
label_mapping_path: Path,
|
| 51 |
+
config: dict,
|
| 52 |
+
weights_dir: Path,
|
| 53 |
+
device: str = "cuda",
|
| 54 |
+
):
|
| 55 |
+
train_ratio = config["data"]["train_ratio"]
|
| 56 |
+
val_ratio = config["data"]["val_ratio"]
|
| 57 |
+
seed = config["data"].get("seed", 42)
|
| 58 |
+
batch_size = config["data"]["batch_size"]
|
| 59 |
+
num_workers = config["data"].get("num_workers", 0)
|
| 60 |
+
spec_cfg = config.get("spectrogram", {})
|
| 61 |
+
n_mels = spec_cfg.get("n_mels", 64)
|
| 62 |
+
n_fft = spec_cfg.get("n_fft", 512)
|
| 63 |
+
hop = spec_cfg.get("hop_length", 256)
|
| 64 |
+
win = spec_cfg.get("win_length", 512)
|
| 65 |
+
|
| 66 |
+
def to_mel(x):
|
| 67 |
+
return _wave_to_mel(x, n_mels, n_fft, hop, win)
|
| 68 |
+
|
| 69 |
+
train_ds = ElectricalOutletsAudioDataset(
|
| 70 |
+
data_root, label_mapping_path, split="train",
|
| 71 |
+
train_ratio=train_ratio, val_ratio=val_ratio, seed=seed, transform=to_mel,
|
| 72 |
+
target_length_sec=config["data"].get("target_length_sec", 5.0),
|
| 73 |
+
sample_rate=config["data"].get("sample_rate", 16000),
|
| 74 |
+
)
|
| 75 |
+
val_ds = ElectricalOutletsAudioDataset(
|
| 76 |
+
data_root, label_mapping_path, split="val",
|
| 77 |
+
train_ratio=train_ratio, val_ratio=val_ratio, seed=seed, transform=to_mel,
|
| 78 |
+
target_length_sec=config["data"].get("target_length_sec", 5.0),
|
| 79 |
+
sample_rate=config["data"].get("sample_rate", 16000),
|
| 80 |
+
)
|
| 81 |
+
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers)
|
| 82 |
+
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers)
|
| 83 |
+
|
| 84 |
+
num_classes = train_ds.num_classes
|
| 85 |
+
model = ElectricalOutletsAudioModel(
|
| 86 |
+
num_classes=num_classes,
|
| 87 |
+
label_mapping_path=label_mapping_path,
|
| 88 |
+
n_mels=config["model"].get("n_mels", 64),
|
| 89 |
+
time_steps=config["model"].get("time_steps", 128),
|
| 90 |
+
).to(device)
|
| 91 |
+
opt = torch.optim.AdamW(
|
| 92 |
+
model.parameters(),
|
| 93 |
+
lr=config["training"]["lr"],
|
| 94 |
+
weight_decay=config["training"].get("weight_decay", 1e-4),
|
| 95 |
+
)
|
| 96 |
+
criterion = nn.CrossEntropyLoss()
|
| 97 |
+
epochs = config["training"]["epochs"]
|
| 98 |
+
patience = config["training"].get("early_stopping_patience", 12)
|
| 99 |
+
best_metric = -1.0
|
| 100 |
+
best_epoch = 0
|
| 101 |
+
wait = 0
|
| 102 |
+
recall = {}
|
| 103 |
+
|
| 104 |
+
for epoch in range(epochs):
|
| 105 |
+
model.train()
|
| 106 |
+
for x, y in train_loader:
|
| 107 |
+
x, y = x.to(device), y.to(device)
|
| 108 |
+
opt.zero_grad()
|
| 109 |
+
logits = model(x)
|
| 110 |
+
loss = criterion(logits, y)
|
| 111 |
+
loss.backward()
|
| 112 |
+
opt.step()
|
| 113 |
+
|
| 114 |
+
model.eval()
|
| 115 |
+
val_logits, val_targets = [], []
|
| 116 |
+
with torch.no_grad():
|
| 117 |
+
for x, y in val_loader:
|
| 118 |
+
x = x.to(device)
|
| 119 |
+
val_logits.append(model(x).cpu())
|
| 120 |
+
val_targets.append(y)
|
| 121 |
+
val_logits = torch.cat(val_logits, dim=0)
|
| 122 |
+
val_targets = torch.cat(val_targets, dim=0)
|
| 123 |
+
recall = per_class_recall(val_logits, val_targets, num_classes)
|
| 124 |
+
min_recall = min(recall.values())
|
| 125 |
+
macro_recall = sum(recall.values()) / num_classes
|
| 126 |
+
metric = macro_recall
|
| 127 |
+
if metric > best_metric:
|
| 128 |
+
best_metric = metric
|
| 129 |
+
best_epoch = epoch
|
| 130 |
+
wait = 0
|
| 131 |
+
weights_dir.mkdir(parents=True, exist_ok=True)
|
| 132 |
+
torch.save({
|
| 133 |
+
"model_state_dict": model.state_dict(),
|
| 134 |
+
"num_classes": num_classes,
|
| 135 |
+
"idx_to_label": model.idx_to_label,
|
| 136 |
+
"idx_to_issue_type": model.idx_to_issue_type,
|
| 137 |
+
"idx_to_severity": model.idx_to_severity,
|
| 138 |
+
}, weights_dir / config["output"]["best_name"])
|
| 139 |
+
else:
|
| 140 |
+
wait += 1
|
| 141 |
+
print(f"Epoch {epoch} min_recall={min_recall:.4f} macro_recall={macro_recall:.4f} best={best_metric:.4f}")
|
| 142 |
+
if wait >= patience:
|
| 143 |
+
print("Early stopping at epoch", epoch)
|
| 144 |
+
break
|
| 145 |
+
|
| 146 |
+
if config.get("calibration", {}).get("use_temperature_scaling", False):
|
| 147 |
+
model.load_state_dict(torch.load(weights_dir / config["output"]["best_name"], map_location=device)["model_state_dict"])
|
| 148 |
+
model.eval()
|
| 149 |
+
n_val = len(val_ds)
|
| 150 |
+
cal_size = max(1, int(n_val * config["calibration"].get("val_fraction_for_calibration", 0.5)))
|
| 151 |
+
cal_logits, cal_targets = [], []
|
| 152 |
+
for i in range(cal_size):
|
| 153 |
+
x, y = val_ds[i]
|
| 154 |
+
x = x.unsqueeze(0).to(device)
|
| 155 |
+
with torch.no_grad():
|
| 156 |
+
cal_logits.append(model(x).cpu())
|
| 157 |
+
cal_targets.append(y)
|
| 158 |
+
cal_logits = torch.cat(cal_logits, dim=0)
|
| 159 |
+
cal_targets = torch.tensor(cal_targets)
|
| 160 |
+
temp = nn.Parameter(torch.ones(1) * 1.5)
|
| 161 |
+
opt_cal = torch.optim.LBFGS([temp], lr=0.01, max_iter=50)
|
| 162 |
+
def eval_cal():
|
| 163 |
+
opt_cal.zero_grad()
|
| 164 |
+
loss = F.cross_entropy(cal_logits / temp, cal_targets)
|
| 165 |
+
loss.backward()
|
| 166 |
+
return loss
|
| 167 |
+
opt_cal.step(eval_cal)
|
| 168 |
+
ckpt = torch.load(weights_dir / config["output"]["best_name"], map_location="cpu")
|
| 169 |
+
ckpt["temperature"] = temp.item()
|
| 170 |
+
torch.save(ckpt, weights_dir / config["output"]["best_name"])
|
| 171 |
+
|
| 172 |
+
return {"best_epoch": best_epoch, "best_metric": best_metric, "recall_per_class": recall}
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def main():
|
| 176 |
+
parser = argparse.ArgumentParser()
|
| 177 |
+
parser.add_argument("--config", default="config/audio_train_config.yaml")
|
| 178 |
+
parser.add_argument("--data_root", default=None)
|
| 179 |
+
parser.add_argument("--weights_dir", default="weights")
|
| 180 |
+
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
|
| 181 |
+
args = parser.parse_args()
|
| 182 |
+
root = Path(__file__).resolve().parent.parent
|
| 183 |
+
config = load_config(root / args.config)
|
| 184 |
+
data_root = Path(args.data_root) if args.data_root else root / config["data"]["root"]
|
| 185 |
+
label_mapping_path = root / config["data"]["label_mapping"]
|
| 186 |
+
weights_dir = root / args.weights_dir
|
| 187 |
+
results = run_training(data_root, label_mapping_path, config, weights_dir, args.device)
|
| 188 |
+
report_path = root / "docs" / config["output"]["report_name"]
|
| 189 |
+
report_path.parent.mkdir(parents=True, exist_ok=True)
|
| 190 |
+
with open(report_path, "w") as f:
|
| 191 |
+
f.write("# Audio Model Report (Electrical Outlets)\n\n")
|
| 192 |
+
f.write("- **Preliminary model.** 100 samples is very small; recommend collecting more data.\n")
|
| 193 |
+
f.write(f"- Best epoch: {results['best_epoch']}, best metric: {results['best_metric']:.4f}\n\n")
|
| 194 |
+
f.write("## Per-class recall (validation)\n\n")
|
| 195 |
+
for c, r in results.get("recall_per_class", {}).items():
|
| 196 |
+
f.write(f"- Class {c}: {r:.4f}\n")
|
| 197 |
+
f.write("\n## Limitations\n- Small dataset; use audio as support in fusion. Do not rely on audio-only for critical decisions.\n")
|
| 198 |
+
print("Report written to", report_path)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
if __name__ == "__main__":
|
| 202 |
+
main()
|
training/train_image.py
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Train Electrical Outlets image model.
|
| 3 |
+
FINAL v5: Frozen backbone β partial unfreeze. 5 classes, 1300 images.
|
| 4 |
+
"""
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import sys
|
| 7 |
+
import argparse
|
| 8 |
+
from typing import Dict
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from torch.utils.data import DataLoader
|
| 14 |
+
from torchvision import transforms
|
| 15 |
+
|
| 16 |
+
ROOT = Path(__file__).resolve().parent.parent
|
| 17 |
+
sys.path.insert(0, str(ROOT))
|
| 18 |
+
|
| 19 |
+
from src.data.image_dataset import ElectricalOutletsImageDataset, get_image_class_weights
|
| 20 |
+
from src.models.image_model import ElectricalOutletsImageModel
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def load_config(path):
|
| 24 |
+
import yaml
|
| 25 |
+
with open(path) as f:
|
| 26 |
+
return yaml.safe_load(f)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def focal_loss(logits, targets, alpha=0.25, gamma=2.0, weight=None):
|
| 30 |
+
ce = F.cross_entropy(logits, targets, reduction="none", weight=weight)
|
| 31 |
+
pt = torch.exp(-ce)
|
| 32 |
+
return (alpha * (1 - pt) ** gamma * ce).mean()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def per_class_recall(logits, targets, num_classes):
|
| 36 |
+
preds = logits.argmax(dim=1)
|
| 37 |
+
recall = {}
|
| 38 |
+
for c in range(num_classes):
|
| 39 |
+
mask = targets == c
|
| 40 |
+
recall[c] = (preds[mask] == c).float().mean().item() if mask.sum() > 0 else 0.0
|
| 41 |
+
return recall
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def run_training(data_root, label_mapping_path, config, weights_dir, device="cuda"):
|
| 45 |
+
cfg_data = config["data"]
|
| 46 |
+
cfg_train = config["training"]
|
| 47 |
+
cfg_aug = config["augmentation"]
|
| 48 |
+
cfg_model = config["model"]
|
| 49 |
+
|
| 50 |
+
# Transforms
|
| 51 |
+
train_tf = transforms.Compose([
|
| 52 |
+
transforms.Resize(cfg_aug["resize"]),
|
| 53 |
+
transforms.RandomResizedCrop(cfg_aug["crop"], scale=(0.65, 1.0)),
|
| 54 |
+
transforms.RandomHorizontalFlip(0.5),
|
| 55 |
+
transforms.RandomRotation(15),
|
| 56 |
+
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.05),
|
| 57 |
+
transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
|
| 58 |
+
transforms.GaussianBlur(3, sigma=(0.1, 2.0)),
|
| 59 |
+
transforms.ToTensor(),
|
| 60 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
| 61 |
+
transforms.RandomErasing(p=0.15),
|
| 62 |
+
])
|
| 63 |
+
val_tf = transforms.Compose([
|
| 64 |
+
transforms.Resize(cfg_aug["resize"]),
|
| 65 |
+
transforms.CenterCrop(cfg_aug["crop"]),
|
| 66 |
+
transforms.ToTensor(),
|
| 67 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
| 68 |
+
])
|
| 69 |
+
|
| 70 |
+
# Datasets
|
| 71 |
+
train_ds = ElectricalOutletsImageDataset(
|
| 72 |
+
data_root, label_mapping_path, split="train",
|
| 73 |
+
train_ratio=cfg_data["train_ratio"], val_ratio=cfg_data["val_ratio"],
|
| 74 |
+
seed=cfg_data.get("seed", 42), transform=train_tf,
|
| 75 |
+
)
|
| 76 |
+
val_ds = ElectricalOutletsImageDataset(
|
| 77 |
+
data_root, label_mapping_path, split="val",
|
| 78 |
+
train_ratio=cfg_data["train_ratio"], val_ratio=cfg_data["val_ratio"],
|
| 79 |
+
seed=cfg_data.get("seed", 42), transform=val_tf,
|
| 80 |
+
)
|
| 81 |
+
train_loader = DataLoader(train_ds, batch_size=cfg_data["batch_size"], shuffle=True,
|
| 82 |
+
num_workers=cfg_data.get("num_workers", 4), pin_memory=True)
|
| 83 |
+
val_loader = DataLoader(val_ds, batch_size=cfg_data["batch_size"], shuffle=False,
|
| 84 |
+
num_workers=cfg_data.get("num_workers", 4))
|
| 85 |
+
|
| 86 |
+
num_classes = train_ds.num_classes
|
| 87 |
+
print(f"\nTrain: {len(train_ds)}, Val: {len(val_ds)}, Classes: {num_classes}")
|
| 88 |
+
|
| 89 |
+
# Class weights
|
| 90 |
+
class_weights = None
|
| 91 |
+
if cfg_train.get("use_class_weights", True):
|
| 92 |
+
class_weights = get_image_class_weights(label_mapping_path, data_root).to(device)
|
| 93 |
+
print(f"Class weights: {[f'{w:.3f}' for w in class_weights.tolist()]}")
|
| 94 |
+
|
| 95 |
+
use_focal = cfg_train.get("use_focal", True)
|
| 96 |
+
criterion_ce = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.1)
|
| 97 |
+
|
| 98 |
+
# Model
|
| 99 |
+
model = ElectricalOutletsImageModel(
|
| 100 |
+
num_classes=num_classes,
|
| 101 |
+
label_mapping_path=label_mapping_path,
|
| 102 |
+
pretrained=True,
|
| 103 |
+
head_hidden=cfg_model.get("head_hidden", 256),
|
| 104 |
+
head_dropout=cfg_model.get("head_dropout", 0.4),
|
| 105 |
+
).to(device)
|
| 106 |
+
|
| 107 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 108 |
+
# STAGE 1: Frozen backbone β train head only
|
| 109 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 110 |
+
for p in model.backbone.parameters():
|
| 111 |
+
p.requires_grad = False
|
| 112 |
+
|
| 113 |
+
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 114 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 115 |
+
print(f"Params: {trainable:,} trainable / {total_params:,} total ({100*trainable/total_params:.1f}%)")
|
| 116 |
+
|
| 117 |
+
epochs = cfg_train["epochs"]
|
| 118 |
+
patience = cfg_train.get("early_stopping_patience", 20)
|
| 119 |
+
lr = cfg_train.get("lr", 3e-3)
|
| 120 |
+
|
| 121 |
+
opt = torch.optim.AdamW(
|
| 122 |
+
filter(lambda p: p.requires_grad, model.parameters()),
|
| 123 |
+
lr=lr, weight_decay=cfg_train.get("weight_decay", 1e-3),
|
| 124 |
+
)
|
| 125 |
+
sched = torch.optim.lr_scheduler.OneCycleLR(
|
| 126 |
+
opt, max_lr=lr, epochs=epochs,
|
| 127 |
+
steps_per_epoch=len(train_loader), pct_start=0.15,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
print(f"\n{'='*60}")
|
| 131 |
+
print(f" Stage 1: Frozen backbone, lr={lr}, {epochs} epochs max")
|
| 132 |
+
print(f"{'='*60}")
|
| 133 |
+
|
| 134 |
+
best_metric = -1.0
|
| 135 |
+
best_epoch = 0
|
| 136 |
+
wait = 0
|
| 137 |
+
recall = {}
|
| 138 |
+
|
| 139 |
+
for epoch in range(epochs):
|
| 140 |
+
model.train()
|
| 141 |
+
epoch_loss = 0
|
| 142 |
+
for x, y in train_loader:
|
| 143 |
+
x, y = x.to(device), y.to(device)
|
| 144 |
+
opt.zero_grad()
|
| 145 |
+
logits = model(x)
|
| 146 |
+
loss = focal_loss(logits, y, weight=class_weights) if use_focal else criterion_ce(logits, y)
|
| 147 |
+
loss.backward()
|
| 148 |
+
opt.step()
|
| 149 |
+
sched.step()
|
| 150 |
+
epoch_loss += loss.item()
|
| 151 |
+
|
| 152 |
+
# Validate
|
| 153 |
+
model.eval()
|
| 154 |
+
vl, vt = [], []
|
| 155 |
+
with torch.no_grad():
|
| 156 |
+
for x, y in val_loader:
|
| 157 |
+
vl.append(model(x.to(device)).cpu())
|
| 158 |
+
vt.append(y)
|
| 159 |
+
vl, vt = torch.cat(vl), torch.cat(vt)
|
| 160 |
+
recall = per_class_recall(vl, vt, num_classes)
|
| 161 |
+
min_r = min(recall.values())
|
| 162 |
+
macro_r = sum(recall.values()) / num_classes
|
| 163 |
+
val_acc = (vl.argmax(1) == vt).float().mean().item()
|
| 164 |
+
metric = min_r if cfg_train.get("early_stopping_metric") == "val_min_recall" else macro_r
|
| 165 |
+
|
| 166 |
+
star = ""
|
| 167 |
+
if metric > best_metric:
|
| 168 |
+
best_metric = metric
|
| 169 |
+
best_epoch = epoch
|
| 170 |
+
wait = 0
|
| 171 |
+
weights_dir.mkdir(parents=True, exist_ok=True)
|
| 172 |
+
torch.save({
|
| 173 |
+
"model_state_dict": model.state_dict(),
|
| 174 |
+
"num_classes": num_classes,
|
| 175 |
+
"idx_to_issue_type": model.idx_to_issue_type,
|
| 176 |
+
"idx_to_severity": model.idx_to_severity,
|
| 177 |
+
}, weights_dir / config["output"]["best_name"])
|
| 178 |
+
star = " β
"
|
| 179 |
+
else:
|
| 180 |
+
wait += 1
|
| 181 |
+
|
| 182 |
+
print(f"E{epoch:3d} loss={epoch_loss/len(train_loader):.4f} acc={val_acc:.3f} "
|
| 183 |
+
f"min_r={min_r:.3f} macro={macro_r:.3f} best={best_metric:.3f}@{best_epoch}{star}")
|
| 184 |
+
|
| 185 |
+
if wait >= patience:
|
| 186 |
+
print(f"Early stop @ {epoch}")
|
| 187 |
+
break
|
| 188 |
+
|
| 189 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 190 |
+
# STAGE 2: Unfreeze last 2 backbone blocks
|
| 191 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 192 |
+
if cfg_train.get("finetune_last_blocks", True) and best_metric > 0.15:
|
| 193 |
+
print(f"\n{'='*60}")
|
| 194 |
+
print(f" Stage 2: Partial unfreeze (last 2 blocks)")
|
| 195 |
+
print(f"{'='*60}")
|
| 196 |
+
|
| 197 |
+
ckpt = torch.load(weights_dir / config["output"]["best_name"], map_location=device)
|
| 198 |
+
model.load_state_dict(ckpt["model_state_dict"])
|
| 199 |
+
|
| 200 |
+
for p in model.backbone.parameters():
|
| 201 |
+
p.requires_grad = False
|
| 202 |
+
for name, p in model.backbone.named_parameters():
|
| 203 |
+
if "features.7" in name or "features.8" in name:
|
| 204 |
+
p.requires_grad = True
|
| 205 |
+
# Head stays trainable
|
| 206 |
+
for p in model.head.parameters():
|
| 207 |
+
p.requires_grad = True
|
| 208 |
+
|
| 209 |
+
ft_lr = cfg_train.get("finetune_lr", 5e-5)
|
| 210 |
+
ft_epochs = cfg_train.get("finetune_epochs", 25)
|
| 211 |
+
opt2 = torch.optim.AdamW(
|
| 212 |
+
filter(lambda p: p.requires_grad, model.parameters()),
|
| 213 |
+
lr=ft_lr, weight_decay=1e-3,
|
| 214 |
+
)
|
| 215 |
+
sched2 = torch.optim.lr_scheduler.CosineAnnealingLR(opt2, T_max=ft_epochs, eta_min=1e-6)
|
| 216 |
+
wait2 = 0
|
| 217 |
+
|
| 218 |
+
for epoch in range(ft_epochs):
|
| 219 |
+
model.train()
|
| 220 |
+
el = 0
|
| 221 |
+
for x, y in train_loader:
|
| 222 |
+
x, y = x.to(device), y.to(device)
|
| 223 |
+
opt2.zero_grad()
|
| 224 |
+
logits = model(x)
|
| 225 |
+
loss = focal_loss(logits, y, weight=class_weights) if use_focal else criterion_ce(logits, y)
|
| 226 |
+
loss.backward()
|
| 227 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 228 |
+
opt2.step()
|
| 229 |
+
el += loss.item()
|
| 230 |
+
sched2.step()
|
| 231 |
+
|
| 232 |
+
model.eval()
|
| 233 |
+
vl, vt = [], []
|
| 234 |
+
with torch.no_grad():
|
| 235 |
+
for x, y in val_loader:
|
| 236 |
+
vl.append(model(x.to(device)).cpu())
|
| 237 |
+
vt.append(y)
|
| 238 |
+
vl, vt = torch.cat(vl), torch.cat(vt)
|
| 239 |
+
recall = per_class_recall(vl, vt, num_classes)
|
| 240 |
+
min_r = min(recall.values())
|
| 241 |
+
macro_r = sum(recall.values()) / num_classes
|
| 242 |
+
val_acc = (vl.argmax(1) == vt).float().mean().item()
|
| 243 |
+
metric = min_r if cfg_train.get("early_stopping_metric") == "val_min_recall" else macro_r
|
| 244 |
+
|
| 245 |
+
star = ""
|
| 246 |
+
if metric > best_metric:
|
| 247 |
+
best_metric = metric
|
| 248 |
+
best_epoch = epoch + 1000
|
| 249 |
+
wait2 = 0
|
| 250 |
+
torch.save({
|
| 251 |
+
"model_state_dict": model.state_dict(),
|
| 252 |
+
"num_classes": num_classes,
|
| 253 |
+
"idx_to_issue_type": model.idx_to_issue_type,
|
| 254 |
+
"idx_to_severity": model.idx_to_severity,
|
| 255 |
+
}, weights_dir / config["output"]["best_name"])
|
| 256 |
+
star = " β
"
|
| 257 |
+
else:
|
| 258 |
+
wait2 += 1
|
| 259 |
+
|
| 260 |
+
print(f" FT{epoch:3d} loss={el/len(train_loader):.4f} acc={val_acc:.3f} "
|
| 261 |
+
f"min_r={min_r:.3f} macro={macro_r:.3f} best={best_metric:.3f}{star}")
|
| 262 |
+
if wait2 >= 10:
|
| 263 |
+
print(f" FT early stop @ {epoch}")
|
| 264 |
+
break
|
| 265 |
+
|
| 266 |
+
# Temperature scaling
|
| 267 |
+
if config.get("calibration", {}).get("use_temperature_scaling", False):
|
| 268 |
+
ckpt = torch.load(weights_dir / config["output"]["best_name"], map_location=device)
|
| 269 |
+
model.load_state_dict(ckpt["model_state_dict"])
|
| 270 |
+
model.eval()
|
| 271 |
+
cal_size = max(1, int(len(val_ds) * 0.5))
|
| 272 |
+
cl, ct = [], []
|
| 273 |
+
for i in range(cal_size):
|
| 274 |
+
x, y = val_ds[i]
|
| 275 |
+
with torch.no_grad():
|
| 276 |
+
cl.append(model(x.unsqueeze(0).to(device)).cpu())
|
| 277 |
+
ct.append(y)
|
| 278 |
+
cl, ct = torch.cat(cl), torch.tensor(ct)
|
| 279 |
+
temp = nn.Parameter(torch.ones(1) * 1.5)
|
| 280 |
+
opt_c = torch.optim.LBFGS([temp], lr=0.01, max_iter=50)
|
| 281 |
+
def eval_c():
|
| 282 |
+
opt_c.zero_grad()
|
| 283 |
+
l = F.cross_entropy(cl / temp, ct)
|
| 284 |
+
l.backward()
|
| 285 |
+
return l
|
| 286 |
+
opt_c.step(eval_c)
|
| 287 |
+
ckpt["temperature"] = temp.item()
|
| 288 |
+
torch.save(ckpt, weights_dir / config["output"]["best_name"])
|
| 289 |
+
print(f"Temperature T={temp.item():.4f}")
|
| 290 |
+
|
| 291 |
+
print(f"\n{'='*60}")
|
| 292 |
+
print(f" DONE β Best: {best_metric:.4f}")
|
| 293 |
+
per_cls = " | ".join([f"C{c}={r:.2f}" for c, r in recall.items()])
|
| 294 |
+
print(f" Recall: {per_cls}")
|
| 295 |
+
print(f"{'='*60}\n")
|
| 296 |
+
|
| 297 |
+
return {"best_epoch": best_epoch, "best_metric": best_metric, "recall_per_class": recall}
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def main():
|
| 301 |
+
parser = argparse.ArgumentParser()
|
| 302 |
+
parser.add_argument("--config", default="config/image_train_config.yaml")
|
| 303 |
+
parser.add_argument("--data_root", default=None)
|
| 304 |
+
parser.add_argument("--weights_dir", default="weights")
|
| 305 |
+
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
|
| 306 |
+
args = parser.parse_args()
|
| 307 |
+
root = ROOT
|
| 308 |
+
config = load_config(root / args.config)
|
| 309 |
+
data_root = Path(args.data_root) if args.data_root else root / config["data"]["root"]
|
| 310 |
+
label_mapping_path = root / config["data"]["label_mapping"]
|
| 311 |
+
weights_dir = root / args.weights_dir
|
| 312 |
+
results = run_training(data_root, label_mapping_path, config, weights_dir, args.device)
|
| 313 |
+
|
| 314 |
+
report_path = root / "docs" / config["output"]["report_name"]
|
| 315 |
+
report_path.parent.mkdir(parents=True, exist_ok=True)
|
| 316 |
+
with open(report_path, "w") as f:
|
| 317 |
+
f.write("# Image Model Report (Electrical Outlets)\n\n")
|
| 318 |
+
f.write(f"- Best metric: {results['best_metric']:.4f}\n")
|
| 319 |
+
f.write(f"- Classes: 5 (burn, cracked, loose, normal, water)\n\n")
|
| 320 |
+
f.write("## Per-class recall\n\n")
|
| 321 |
+
issue_names = ["burn_overheating", "cracked_faceplate", "loose_outlet", "normal", "water_exposed"]
|
| 322 |
+
for c, r in results.get("recall_per_class", {}).items():
|
| 323 |
+
name = issue_names[c] if c < len(issue_names) else f"class_{c}"
|
| 324 |
+
f.write(f"- {name}: {r:.4f}\n")
|
| 325 |
+
print("Report:", report_path)
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
if __name__ == "__main__":
|
| 329 |
+
main()
|
weights/.gitkeep
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|