Asadrizvi64's picture
Electrical Outlets diagnostic pipeline v1.0
5666923
"""
Inference wrapper: load image + audio models, run modalities present, apply fusion, return schema.
"""
from pathlib import Path
from typing import Optional, Dict, Any, BinaryIO
import json
import torch
import torchaudio
from torchvision import transforms
from PIL import Image
# Optional imports for models
import sys
ROOT = Path(__file__).resolve().parent.parent.parent
sys.path.insert(0, str(ROOT))
def _load_image_model(weights_path: Path, label_mapping_path: Path, device: str):
from src.models.image_model import ElectricalOutletsImageModel
ckpt = torch.load(weights_path, map_location=device)
model = ElectricalOutletsImageModel(
num_classes=ckpt["num_classes"],
label_mapping_path=label_mapping_path,
pretrained=False,
)
model.load_state_dict(ckpt["model_state_dict"])
model.idx_to_issue_type = ckpt.get("idx_to_issue_type")
model.idx_to_severity = ckpt.get("idx_to_severity")
model.eval()
return model.to(device), ckpt.get("temperature", 1.0)
def _load_audio_model(weights_path: Path, label_mapping_path: Path, device: str, config: dict):
from src.models.audio_model import ElectricalOutletsAudioModel
ckpt = torch.load(weights_path, map_location=device)
model = ElectricalOutletsAudioModel(
num_classes=ckpt["num_classes"],
label_mapping_path=label_mapping_path,
n_mels=config.get("n_mels", 64),
time_steps=config.get("time_steps", 128),
)
model.load_state_dict(ckpt["model_state_dict"])
model.idx_to_label = ckpt.get("idx_to_label")
model.idx_to_issue_type = ckpt.get("idx_to_issue_type")
model.idx_to_severity = ckpt.get("idx_to_severity")
model.eval()
return model.to(device), ckpt.get("temperature", 1.0)
def run_electrical_outlets_inference(
image_path: Optional[Path] = None,
image_fp: Optional[BinaryIO] = None,
audio_path: Optional[Path] = None,
audio_fp: Optional[BinaryIO] = None,
weights_dir: Path = None,
config_dir: Path = None,
device: str = None,
) -> Dict[str, Any]:
"""
Run image and/or audio model, then fuse. Returns canonical schema dict.
"""
if weights_dir is None:
weights_dir = ROOT / "weights"
if config_dir is None:
config_dir = ROOT / "config"
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
label_mapping_path = config_dir / "label_mapping.json"
thresholds_path = config_dir / "thresholds.yaml"
import yaml
with open(thresholds_path) as f:
thresholds = yaml.safe_load(f)
image_out = None
if image_path or image_fp:
img = Image.open(image_path or image_fp).convert("RGB")
tf = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
x = tf(img).unsqueeze(0).to(device)
model, T = _load_image_model(weights_dir / "electrical_outlets_image_best.pt", label_mapping_path, device)
with torch.no_grad():
logits = model(x) / T
from src.fusion.fusion_logic import ModalityOutput
pred = model.predict_to_schema(logits)
image_out = ModalityOutput(
result=pred["result"],
issue_type=pred.get("issue_type"),
severity=pred["severity"],
confidence=pred["confidence"],
)
audio_out = None
if (audio_path or audio_fp) and (weights_dir / "electrical_outlets_audio_best.pt").exists():
if audio_path:
waveform, sr = torchaudio.load(str(audio_path))
else:
import io
waveform, sr = torchaudio.load(io.BytesIO(audio_fp.read()))
if sr != 16000:
waveform = torchaudio.functional.resample(waveform, sr, 16000)
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
target_len = int(5.0 * 16000)
if waveform.shape[1] >= target_len:
start = (waveform.shape[1] - target_len) // 2
waveform = waveform[:, start : start + target_len]
else:
waveform = torch.nn.functional.pad(waveform, (0, target_len - waveform.shape[1]))
mel = torchaudio.transforms.MelSpectrogram(
sample_rate=16000, n_fft=512, hop_length=256, win_length=512, n_mels=64,
)(waveform)
log_mel = torch.log(mel.clamp(min=1e-5)).unsqueeze(0).to(device)
model, T = _load_audio_model(
weights_dir / "electrical_outlets_audio_best.pt",
label_mapping_path,
device,
{"n_mels": 64, "time_steps": 128},
)
with torch.no_grad():
logits = model(log_mel) / T
from src.fusion.fusion_logic import ModalityOutput
pred = model.predict_to_schema(logits)
audio_out = ModalityOutput(
result=pred["result"],
issue_type=pred.get("issue_type"),
severity=pred["severity"],
confidence=pred["confidence"],
)
from src.fusion.fusion_logic import fuse_modalities
return fuse_modalities(
image_out,
audio_out,
confidence_issue_min=thresholds.get("confidence_issue_min", 0.6),
confidence_normal_min=thresholds.get("confidence_normal_min", 0.75),
uncertain_if_disagree=thresholds.get("uncertain_if_disagree", True),
high_confidence_override=thresholds.get("high_confidence_override", 0.92),
severity_order=thresholds.get("severity_order"),
)