| """ |
| Test script for Electrical Outlets diagnostic pipeline. |
| |
| Usage: |
| python test.py --image path/to/outlet.jpg # Test image only |
| python test.py --audio path/to/recording.wav # Test audio only |
| python test.py --image photo.jpg --audio recording.wav # Test both (fusion) |
| python test.py --list # List sample images from dataset |
| python test.py --eval # Run full validation set evaluation |
| |
| Requirements: |
| pip install torch torchvision torchaudio Pillow PyYAML soundfile |
| """ |
| from pathlib import Path |
| import sys |
| import argparse |
| import json |
| from collections import defaultdict |
|
|
| import torch |
| from torchvision import transforms |
| from PIL import Image |
|
|
| ROOT = Path(__file__).resolve().parent |
| sys.path.insert(0, str(ROOT)) |
|
|
|
|
| def load_image_model(weights_path, mapping_path, device): |
| from src.models.image_model import ElectricalOutletsImageModel |
|
|
| ckpt = torch.load(weights_path, map_location=device, weights_only=False) |
| |
| head_hidden = ckpt["model_state_dict"]["head.1.weight"].shape[0] |
| model = ElectricalOutletsImageModel( |
| num_classes=ckpt["num_classes"], |
| label_mapping_path=Path(mapping_path), |
| pretrained=False, |
| head_hidden=head_hidden, |
| ) |
| model.load_state_dict(ckpt["model_state_dict"]) |
| model.idx_to_issue_type = ckpt.get("idx_to_issue_type") |
| model.idx_to_severity = ckpt.get("idx_to_severity") |
| model.eval().to(device) |
| T = ckpt.get("temperature", 1.0) |
| |
| if T <= 0 or T > 10: |
| T = 1.0 |
| return model, T |
|
|
|
|
| def load_audio_model(weights_path, mapping_path, device): |
| from src.models.audio_model import ElectricalOutletsAudioModel |
| import yaml |
|
|
| ckpt = torch.load(weights_path, map_location=device, weights_only=False) |
|
|
| |
| audio_cfg_path = ROOT / "config" / "audio_train_config.yaml" |
| n_mels, time_steps = 128, 128 |
| if audio_cfg_path.exists(): |
| with open(audio_cfg_path) as f: |
| acfg = yaml.safe_load(f) |
| n_mels = acfg.get("model", {}).get("n_mels", 128) |
| time_steps = acfg.get("model", {}).get("time_steps", 128) |
|
|
| model = ElectricalOutletsAudioModel( |
| num_classes=ckpt["num_classes"], |
| label_mapping_path=Path(mapping_path), |
| n_mels=n_mels, |
| time_steps=time_steps, |
| ) |
| model.load_state_dict(ckpt["model_state_dict"]) |
| model.idx_to_label = ckpt.get("idx_to_label") |
| model.idx_to_issue_type = ckpt.get("idx_to_issue_type") |
| model.idx_to_severity = ckpt.get("idx_to_severity") |
| model.eval().to(device) |
| T = ckpt.get("temperature", 1.0) |
| if T <= 0 or T > 10: |
| T = 1.0 |
| return model, T |
|
|
|
|
| def predict_image(image_path, device="cuda"): |
| weights = ROOT / "weights" / "electrical_outlets_image_best.pt" |
| mapping = ROOT / "config" / "label_mapping.json" |
|
|
| if not weights.exists(): |
| print(f"ERROR: Image weights not found at {weights}") |
| return None |
|
|
| model, T = load_image_model(weights, mapping, device) |
|
|
| tf = transforms.Compose([ |
| transforms.Resize(256), |
| transforms.CenterCrop(224), |
| transforms.ToTensor(), |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
| ]) |
| img = Image.open(image_path).convert("RGB") |
| x = tf(img).unsqueeze(0).to(device) |
|
|
| with torch.no_grad(): |
| logits = model(x) / T |
| probs = torch.softmax(logits, dim=-1) |
|
|
| pred = model.predict_to_schema(logits) |
|
|
| print(f"\n{'='*55}") |
| print(f" IMAGE: {Path(image_path).name}") |
| print(f"{'='*55}") |
| print(f" Prediction: {pred['issue_type']}") |
| print(f" Severity: {pred['severity']}") |
| print(f" Confidence: {pred['confidence']:.1%}") |
| print(f" Result: {pred['result']}") |
| print(f"\n Class probabilities:") |
| for i, p in enumerate(probs[0].tolist()): |
| name = model.idx_to_issue_type[i] if model.idx_to_issue_type else f"class_{i}" |
| bar = "█" * int(p * 30) |
| tag = " ◄" if i == pred["class_idx"] else "" |
| print(f" {name:20s} {p:6.1%} {bar}{tag}") |
|
|
| return pred |
|
|
|
|
| def predict_audio(audio_path, device="cuda"): |
| import torchaudio |
| import yaml |
|
|
| weights = ROOT / "weights" / "electrical_outlets_audio_best.pt" |
| mapping = ROOT / "config" / "label_mapping.json" |
|
|
| if not weights.exists(): |
| print(f"ERROR: Audio weights not found at {weights}") |
| return None |
|
|
| model, T = load_audio_model(weights, mapping, device) |
|
|
| |
| audio_cfg_path = ROOT / "config" / "audio_train_config.yaml" |
| sample_rate, n_mels, n_fft, hop, win = 22050, 128, 1024, 512, 1024 |
| target_sec = 5.0 |
| if audio_cfg_path.exists(): |
| with open(audio_cfg_path) as f: |
| acfg = yaml.safe_load(f) |
| sample_rate = acfg["data"].get("sample_rate", 22050) |
| target_sec = acfg["data"].get("target_length_sec", 5.0) |
| sc = acfg.get("spectrogram", {}) |
| n_mels = sc.get("n_mels", 128) |
| n_fft = sc.get("n_fft", 1024) |
| hop = sc.get("hop_length", 512) |
| win = sc.get("win_length", 1024) |
|
|
| waveform, sr = torchaudio.load(str(audio_path)) |
| if sr != sample_rate: |
| waveform = torchaudio.functional.resample(waveform, sr, sample_rate) |
| if waveform.shape[0] > 1: |
| waveform = waveform.mean(dim=0, keepdim=True) |
|
|
| target_len = int(target_sec * sample_rate) |
| if waveform.shape[1] >= target_len: |
| start = (waveform.shape[1] - target_len) // 2 |
| waveform = waveform[:, start:start + target_len] |
| else: |
| waveform = torch.nn.functional.pad(waveform, (0, target_len - waveform.shape[1])) |
|
|
| mel = torchaudio.transforms.MelSpectrogram( |
| sample_rate=sample_rate, n_fft=n_fft, hop_length=hop, |
| win_length=win, n_mels=n_mels, |
| )(waveform) |
| log_mel = torch.log(mel.clamp(min=1e-5)).unsqueeze(0).to(device) |
|
|
| with torch.no_grad(): |
| logits = model(log_mel) / T |
| probs = torch.softmax(logits, dim=-1) |
|
|
| pred = model.predict_to_schema(logits) |
|
|
| print(f"\n{'='*55}") |
| print(f" AUDIO: {Path(audio_path).name}") |
| print(f"{'='*55}") |
| print(f" Prediction: {pred['issue_type']}") |
| print(f" Severity: {pred['severity']}") |
| print(f" Confidence: {pred['confidence']:.1%}") |
| print(f" Result: {pred['result']}") |
| print(f"\n Class probabilities:") |
| labels = model.idx_to_label or [f"class_{i}" for i in range(model.num_classes)] |
| for i, p in enumerate(probs[0].tolist()): |
| bar = "█" * int(p * 30) |
| tag = " ◄" if i == pred["class_idx"] else "" |
| print(f" {labels[i]:20s} {p:6.1%} {bar}{tag}") |
|
|
| return pred |
|
|
|
|
| def run_fusion(image_pred, audio_pred): |
| from src.fusion.fusion_logic import fuse_modalities, ModalityOutput |
| import yaml |
|
|
| thresholds_path = ROOT / "config" / "thresholds.yaml" |
| thresholds = {} |
| if thresholds_path.exists(): |
| with open(thresholds_path) as f: |
| thresholds = yaml.safe_load(f) |
|
|
| image_out = ModalityOutput( |
| result=image_pred["result"], |
| issue_type=image_pred.get("issue_type"), |
| severity=image_pred["severity"], |
| confidence=image_pred["confidence"], |
| ) if image_pred else None |
|
|
| audio_out = ModalityOutput( |
| result=audio_pred["result"], |
| issue_type=audio_pred.get("issue_type"), |
| severity=audio_pred["severity"], |
| confidence=audio_pred["confidence"], |
| ) if audio_pred else None |
|
|
| result = fuse_modalities( |
| image_out, audio_out, |
| confidence_issue_min=thresholds.get("confidence_issue_min", 0.6), |
| confidence_normal_min=thresholds.get("confidence_normal_min", 0.75), |
| uncertain_if_disagree=thresholds.get("uncertain_if_disagree", True), |
| high_confidence_override=thresholds.get("high_confidence_override", 0.92), |
| severity_order=thresholds.get("severity_order"), |
| ) |
|
|
| print(f"\n{'='*55}") |
| print(f" FUSED RESULT") |
| print(f"{'='*55}") |
| print(f" Result: {result['result']}") |
| print(f" Issue: {result['issue_type']}") |
| print(f" Severity: {result['severity']}") |
| print(f" Confidence: {result['confidence']:.1%}") |
| if result.get("primary_issue"): |
| print(f" Primary: {result['primary_issue']}") |
| if result.get("secondary_issue"): |
| print(f" Secondary: {result['secondary_issue']}") |
|
|
| return result |
|
|
|
|
| def list_samples(): |
| mapping_path = ROOT / "config" / "label_mapping.json" |
| with open(mapping_path) as f: |
| lm = json.load(f) |
|
|
| data_root = ROOT / "ELECTRICAL OUTLETS-20260106T153508Z-3-001" |
| if not data_root.exists(): |
| print(f"Dataset not found at {data_root}") |
| return |
|
|
| print(f"\nDataset: {data_root}") |
| print(f"{'='*60}") |
| for folder in sorted(data_root.iterdir()): |
| if not folder.is_dir(): |
| continue |
| cls = lm["image"]["folder_to_class"].get(folder.name, "UNMAPPED") |
| imgs = list(folder.glob("*.jpg")) + list(folder.glob("*.jpeg")) + list(folder.glob("*.png")) |
| print(f"\n {folder.name}") |
| print(f" → class: {cls} | {len(imgs)} images") |
| for img in imgs[:3]: |
| print(f" {img}") |
|
|
| |
| audio_root = ROOT / "electrical_outlets_sounds_100" |
| if audio_root.exists(): |
| print(f"\n\nAudio: {audio_root}") |
| print(f"{'='*60}") |
| for folder in sorted(audio_root.iterdir()): |
| if folder.is_dir(): |
| wavs = list(folder.glob("*.wav")) |
| print(f" {folder.name}: {len(wavs)} files") |
| for w in wavs[:2]: |
| print(f" {w}") |
|
|
|
|
| def run_eval(device="cuda"): |
| """Run full evaluation on validation split.""" |
| weights = ROOT / "weights" / "electrical_outlets_image_best.pt" |
| mapping = ROOT / "config" / "label_mapping.json" |
|
|
| if not weights.exists(): |
| print("No image weights found.") |
| return |
|
|
| model, T = load_image_model(weights, mapping, device) |
|
|
| import yaml |
| cfg_path = ROOT / "config" / "image_train_config.yaml" |
| with open(cfg_path) as f: |
| cfg = yaml.safe_load(f) |
|
|
| from src.data.image_dataset import ElectricalOutletsImageDataset |
| val_tf = transforms.Compose([ |
| transforms.Resize(256), |
| transforms.CenterCrop(224), |
| transforms.ToTensor(), |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
| ]) |
| data_root = ROOT / cfg["data"]["root"] |
| val_ds = ElectricalOutletsImageDataset( |
| data_root, mapping, split="val", |
| train_ratio=cfg["data"]["train_ratio"], |
| val_ratio=cfg["data"]["val_ratio"], |
| seed=cfg["data"].get("seed", 42), |
| transform=val_tf, |
| ) |
|
|
| with open(mapping) as f: |
| lm = json.load(f) |
| issue_names = lm["image"]["idx_to_issue_type"] |
|
|
| correct = 0 |
| total = 0 |
| class_correct = defaultdict(int) |
| class_total = defaultdict(int) |
| confusion = defaultdict(lambda: defaultdict(int)) |
|
|
| model.eval() |
| with torch.no_grad(): |
| for i in range(len(val_ds)): |
| x, y = val_ds[i] |
| logits = model(x.unsqueeze(0).to(device)) / T |
| pred = logits.argmax(1).item() |
| correct += (pred == y) |
| total += 1 |
| class_correct[y] += (pred == y) |
| class_total[y] += 1 |
| confusion[y][pred] += 1 |
|
|
| print(f"\n{'='*55}") |
| print(f" VALIDATION RESULTS ({total} samples)") |
| print(f"{'='*55}") |
| print(f" Overall accuracy: {correct/total:.1%}") |
| print(f"\n Per-class recall:") |
| for c in sorted(class_total.keys()): |
| name = issue_names[c] if c < len(issue_names) else f"class_{c}" |
| recall = class_correct[c] / class_total[c] if class_total[c] > 0 else 0 |
| bar = "█" * int(recall * 20) |
| print(f" {name:20s} {recall:6.1%} ({class_correct[c]}/{class_total[c]}) {bar}") |
|
|
| print(f"\n Confusion matrix:") |
| classes = sorted(class_total.keys()) |
| header = " Actual \\ Pred " + "".join(f"{issue_names[c][:8]:>9s}" for c in classes) |
| print(header) |
| for actual in classes: |
| row = f" {issue_names[actual][:14]:14s}" |
| for pred_c in classes: |
| count = confusion[actual][pred_c] |
| row += f" {count:6d}" if count > 0 else f" {'·':>6s}" |
| row += " " |
| print(row) |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Test Electrical Outlets Diagnostic Pipeline") |
| parser.add_argument("--image", type=str, help="Path to image file") |
| parser.add_argument("--audio", type=str, help="Path to audio WAV file") |
| parser.add_argument("--list", action="store_true", help="List sample files from dataset") |
| parser.add_argument("--eval", action="store_true", help="Run full validation evaluation") |
| parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") |
| args = parser.parse_args() |
|
|
| if args.list: |
| list_samples() |
| elif args.eval: |
| run_eval(args.device) |
| elif args.image or args.audio: |
| img_pred = predict_image(args.image, args.device) if args.image else None |
| audio_pred = predict_audio(args.audio, args.device) if args.audio else None |
| if img_pred and audio_pred: |
| run_fusion(img_pred, audio_pred) |
| print() |
| else: |
| print("Electrical Outlets Diagnostic Pipeline — Test Script") |
| print("=" * 55) |
| print() |
| print("Usage:") |
| print(" python test.py --image path/to/photo.jpg") |
| print(" python test.py --audio path/to/recording.wav") |
| print(" python test.py --image photo.jpg --audio recording.wav") |
| print(" python test.py --list") |
| print(" python test.py --eval") |
| print() |
| print("Examples:") |
| print(' python test.py --image "ELECTRICAL OUTLETS-20260106T153508Z-3-001\\Burn marks - overheating 250\\img_001.jpg"') |
| print(' python test.py --audio "electrical_outlets_sounds_100\\buzzing_outlet\\buzzing_outlet_060.wav"') |
|
|