Asadrizvi64's picture
Electrical Outlets diagnostic pipeline v1.0
5666923
"""
Test script for Electrical Outlets diagnostic pipeline.
Usage:
python test.py --image path/to/outlet.jpg # Test image only
python test.py --audio path/to/recording.wav # Test audio only
python test.py --image photo.jpg --audio recording.wav # Test both (fusion)
python test.py --list # List sample images from dataset
python test.py --eval # Run full validation set evaluation
Requirements:
pip install torch torchvision torchaudio Pillow PyYAML soundfile
"""
from pathlib import Path
import sys
import argparse
import json
from collections import defaultdict
import torch
from torchvision import transforms
from PIL import Image
ROOT = Path(__file__).resolve().parent
sys.path.insert(0, str(ROOT))
def load_image_model(weights_path, mapping_path, device):
from src.models.image_model import ElectricalOutletsImageModel
ckpt = torch.load(weights_path, map_location=device, weights_only=False)
# Infer head_hidden from saved weights (head.1 is the first Linear)
head_hidden = ckpt["model_state_dict"]["head.1.weight"].shape[0]
model = ElectricalOutletsImageModel(
num_classes=ckpt["num_classes"],
label_mapping_path=Path(mapping_path),
pretrained=False,
head_hidden=head_hidden,
)
model.load_state_dict(ckpt["model_state_dict"])
model.idx_to_issue_type = ckpt.get("idx_to_issue_type")
model.idx_to_severity = ckpt.get("idx_to_severity")
model.eval().to(device)
T = ckpt.get("temperature", 1.0)
# Clamp bad temperature values
if T <= 0 or T > 10:
T = 1.0
return model, T
def load_audio_model(weights_path, mapping_path, device):
from src.models.audio_model import ElectricalOutletsAudioModel
import yaml
ckpt = torch.load(weights_path, map_location=device, weights_only=False)
# Load audio config for n_mels
audio_cfg_path = ROOT / "config" / "audio_train_config.yaml"
n_mels, time_steps = 128, 128
if audio_cfg_path.exists():
with open(audio_cfg_path) as f:
acfg = yaml.safe_load(f)
n_mels = acfg.get("model", {}).get("n_mels", 128)
time_steps = acfg.get("model", {}).get("time_steps", 128)
model = ElectricalOutletsAudioModel(
num_classes=ckpt["num_classes"],
label_mapping_path=Path(mapping_path),
n_mels=n_mels,
time_steps=time_steps,
)
model.load_state_dict(ckpt["model_state_dict"])
model.idx_to_label = ckpt.get("idx_to_label")
model.idx_to_issue_type = ckpt.get("idx_to_issue_type")
model.idx_to_severity = ckpt.get("idx_to_severity")
model.eval().to(device)
T = ckpt.get("temperature", 1.0)
if T <= 0 or T > 10:
T = 1.0
return model, T
def predict_image(image_path, device="cuda"):
weights = ROOT / "weights" / "electrical_outlets_image_best.pt"
mapping = ROOT / "config" / "label_mapping.json"
if not weights.exists():
print(f"ERROR: Image weights not found at {weights}")
return None
model, T = load_image_model(weights, mapping, device)
tf = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
img = Image.open(image_path).convert("RGB")
x = tf(img).unsqueeze(0).to(device)
with torch.no_grad():
logits = model(x) / T
probs = torch.softmax(logits, dim=-1)
pred = model.predict_to_schema(logits)
print(f"\n{'='*55}")
print(f" IMAGE: {Path(image_path).name}")
print(f"{'='*55}")
print(f" Prediction: {pred['issue_type']}")
print(f" Severity: {pred['severity']}")
print(f" Confidence: {pred['confidence']:.1%}")
print(f" Result: {pred['result']}")
print(f"\n Class probabilities:")
for i, p in enumerate(probs[0].tolist()):
name = model.idx_to_issue_type[i] if model.idx_to_issue_type else f"class_{i}"
bar = "█" * int(p * 30)
tag = " ◄" if i == pred["class_idx"] else ""
print(f" {name:20s} {p:6.1%} {bar}{tag}")
return pred
def predict_audio(audio_path, device="cuda"):
import torchaudio
import yaml
weights = ROOT / "weights" / "electrical_outlets_audio_best.pt"
mapping = ROOT / "config" / "label_mapping.json"
if not weights.exists():
print(f"ERROR: Audio weights not found at {weights}")
return None
model, T = load_audio_model(weights, mapping, device)
# Load audio config
audio_cfg_path = ROOT / "config" / "audio_train_config.yaml"
sample_rate, n_mels, n_fft, hop, win = 22050, 128, 1024, 512, 1024
target_sec = 5.0
if audio_cfg_path.exists():
with open(audio_cfg_path) as f:
acfg = yaml.safe_load(f)
sample_rate = acfg["data"].get("sample_rate", 22050)
target_sec = acfg["data"].get("target_length_sec", 5.0)
sc = acfg.get("spectrogram", {})
n_mels = sc.get("n_mels", 128)
n_fft = sc.get("n_fft", 1024)
hop = sc.get("hop_length", 512)
win = sc.get("win_length", 1024)
waveform, sr = torchaudio.load(str(audio_path))
if sr != sample_rate:
waveform = torchaudio.functional.resample(waveform, sr, sample_rate)
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
target_len = int(target_sec * sample_rate)
if waveform.shape[1] >= target_len:
start = (waveform.shape[1] - target_len) // 2
waveform = waveform[:, start:start + target_len]
else:
waveform = torch.nn.functional.pad(waveform, (0, target_len - waveform.shape[1]))
mel = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate, n_fft=n_fft, hop_length=hop,
win_length=win, n_mels=n_mels,
)(waveform)
log_mel = torch.log(mel.clamp(min=1e-5)).unsqueeze(0).to(device)
with torch.no_grad():
logits = model(log_mel) / T
probs = torch.softmax(logits, dim=-1)
pred = model.predict_to_schema(logits)
print(f"\n{'='*55}")
print(f" AUDIO: {Path(audio_path).name}")
print(f"{'='*55}")
print(f" Prediction: {pred['issue_type']}")
print(f" Severity: {pred['severity']}")
print(f" Confidence: {pred['confidence']:.1%}")
print(f" Result: {pred['result']}")
print(f"\n Class probabilities:")
labels = model.idx_to_label or [f"class_{i}" for i in range(model.num_classes)]
for i, p in enumerate(probs[0].tolist()):
bar = "█" * int(p * 30)
tag = " ◄" if i == pred["class_idx"] else ""
print(f" {labels[i]:20s} {p:6.1%} {bar}{tag}")
return pred
def run_fusion(image_pred, audio_pred):
from src.fusion.fusion_logic import fuse_modalities, ModalityOutput
import yaml
thresholds_path = ROOT / "config" / "thresholds.yaml"
thresholds = {}
if thresholds_path.exists():
with open(thresholds_path) as f:
thresholds = yaml.safe_load(f)
image_out = ModalityOutput(
result=image_pred["result"],
issue_type=image_pred.get("issue_type"),
severity=image_pred["severity"],
confidence=image_pred["confidence"],
) if image_pred else None
audio_out = ModalityOutput(
result=audio_pred["result"],
issue_type=audio_pred.get("issue_type"),
severity=audio_pred["severity"],
confidence=audio_pred["confidence"],
) if audio_pred else None
result = fuse_modalities(
image_out, audio_out,
confidence_issue_min=thresholds.get("confidence_issue_min", 0.6),
confidence_normal_min=thresholds.get("confidence_normal_min", 0.75),
uncertain_if_disagree=thresholds.get("uncertain_if_disagree", True),
high_confidence_override=thresholds.get("high_confidence_override", 0.92),
severity_order=thresholds.get("severity_order"),
)
print(f"\n{'='*55}")
print(f" FUSED RESULT")
print(f"{'='*55}")
print(f" Result: {result['result']}")
print(f" Issue: {result['issue_type']}")
print(f" Severity: {result['severity']}")
print(f" Confidence: {result['confidence']:.1%}")
if result.get("primary_issue"):
print(f" Primary: {result['primary_issue']}")
if result.get("secondary_issue"):
print(f" Secondary: {result['secondary_issue']}")
return result
def list_samples():
mapping_path = ROOT / "config" / "label_mapping.json"
with open(mapping_path) as f:
lm = json.load(f)
data_root = ROOT / "ELECTRICAL OUTLETS-20260106T153508Z-3-001"
if not data_root.exists():
print(f"Dataset not found at {data_root}")
return
print(f"\nDataset: {data_root}")
print(f"{'='*60}")
for folder in sorted(data_root.iterdir()):
if not folder.is_dir():
continue
cls = lm["image"]["folder_to_class"].get(folder.name, "UNMAPPED")
imgs = list(folder.glob("*.jpg")) + list(folder.glob("*.jpeg")) + list(folder.glob("*.png"))
print(f"\n {folder.name}")
print(f" → class: {cls} | {len(imgs)} images")
for img in imgs[:3]:
print(f" {img}")
# Audio
audio_root = ROOT / "electrical_outlets_sounds_100"
if audio_root.exists():
print(f"\n\nAudio: {audio_root}")
print(f"{'='*60}")
for folder in sorted(audio_root.iterdir()):
if folder.is_dir():
wavs = list(folder.glob("*.wav"))
print(f" {folder.name}: {len(wavs)} files")
for w in wavs[:2]:
print(f" {w}")
def run_eval(device="cuda"):
"""Run full evaluation on validation split."""
weights = ROOT / "weights" / "electrical_outlets_image_best.pt"
mapping = ROOT / "config" / "label_mapping.json"
if not weights.exists():
print("No image weights found.")
return
model, T = load_image_model(weights, mapping, device)
import yaml
cfg_path = ROOT / "config" / "image_train_config.yaml"
with open(cfg_path) as f:
cfg = yaml.safe_load(f)
from src.data.image_dataset import ElectricalOutletsImageDataset
val_tf = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
data_root = ROOT / cfg["data"]["root"]
val_ds = ElectricalOutletsImageDataset(
data_root, mapping, split="val",
train_ratio=cfg["data"]["train_ratio"],
val_ratio=cfg["data"]["val_ratio"],
seed=cfg["data"].get("seed", 42),
transform=val_tf,
)
with open(mapping) as f:
lm = json.load(f)
issue_names = lm["image"]["idx_to_issue_type"]
correct = 0
total = 0
class_correct = defaultdict(int)
class_total = defaultdict(int)
confusion = defaultdict(lambda: defaultdict(int))
model.eval()
with torch.no_grad():
for i in range(len(val_ds)):
x, y = val_ds[i]
logits = model(x.unsqueeze(0).to(device)) / T
pred = logits.argmax(1).item()
correct += (pred == y)
total += 1
class_correct[y] += (pred == y)
class_total[y] += 1
confusion[y][pred] += 1
print(f"\n{'='*55}")
print(f" VALIDATION RESULTS ({total} samples)")
print(f"{'='*55}")
print(f" Overall accuracy: {correct/total:.1%}")
print(f"\n Per-class recall:")
for c in sorted(class_total.keys()):
name = issue_names[c] if c < len(issue_names) else f"class_{c}"
recall = class_correct[c] / class_total[c] if class_total[c] > 0 else 0
bar = "█" * int(recall * 20)
print(f" {name:20s} {recall:6.1%} ({class_correct[c]}/{class_total[c]}) {bar}")
print(f"\n Confusion matrix:")
classes = sorted(class_total.keys())
header = " Actual \\ Pred " + "".join(f"{issue_names[c][:8]:>9s}" for c in classes)
print(header)
for actual in classes:
row = f" {issue_names[actual][:14]:14s}"
for pred_c in classes:
count = confusion[actual][pred_c]
row += f" {count:6d}" if count > 0 else f" {'·':>6s}"
row += " "
print(row)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Test Electrical Outlets Diagnostic Pipeline")
parser.add_argument("--image", type=str, help="Path to image file")
parser.add_argument("--audio", type=str, help="Path to audio WAV file")
parser.add_argument("--list", action="store_true", help="List sample files from dataset")
parser.add_argument("--eval", action="store_true", help="Run full validation evaluation")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
args = parser.parse_args()
if args.list:
list_samples()
elif args.eval:
run_eval(args.device)
elif args.image or args.audio:
img_pred = predict_image(args.image, args.device) if args.image else None
audio_pred = predict_audio(args.audio, args.device) if args.audio else None
if img_pred and audio_pred:
run_fusion(img_pred, audio_pred)
print()
else:
print("Electrical Outlets Diagnostic Pipeline — Test Script")
print("=" * 55)
print()
print("Usage:")
print(" python test.py --image path/to/photo.jpg")
print(" python test.py --audio path/to/recording.wav")
print(" python test.py --image photo.jpg --audio recording.wav")
print(" python test.py --list")
print(" python test.py --eval")
print()
print("Examples:")
print(' python test.py --image "ELECTRICAL OUTLETS-20260106T153508Z-3-001\\Burn marks - overheating 250\\img_001.jpg"')
print(' python test.py --audio "electrical_outlets_sounds_100\\buzzing_outlet\\buzzing_outlet_060.wav"')