""" Inference wrapper: load image + audio models, run modalities present, apply fusion, return schema. """ from pathlib import Path from typing import Optional, Dict, Any, BinaryIO import json import torch import torchaudio from torchvision import transforms from PIL import Image # Optional imports for models import sys ROOT = Path(__file__).resolve().parent.parent.parent sys.path.insert(0, str(ROOT)) def _load_image_model(weights_path: Path, label_mapping_path: Path, device: str): from src.models.image_model import ElectricalOutletsImageModel ckpt = torch.load(weights_path, map_location=device) model = ElectricalOutletsImageModel( num_classes=ckpt["num_classes"], label_mapping_path=label_mapping_path, pretrained=False, ) model.load_state_dict(ckpt["model_state_dict"]) model.idx_to_issue_type = ckpt.get("idx_to_issue_type") model.idx_to_severity = ckpt.get("idx_to_severity") model.eval() return model.to(device), ckpt.get("temperature", 1.0) def _load_audio_model(weights_path: Path, label_mapping_path: Path, device: str, config: dict): from src.models.audio_model import ElectricalOutletsAudioModel ckpt = torch.load(weights_path, map_location=device) model = ElectricalOutletsAudioModel( num_classes=ckpt["num_classes"], label_mapping_path=label_mapping_path, n_mels=config.get("n_mels", 64), time_steps=config.get("time_steps", 128), ) model.load_state_dict(ckpt["model_state_dict"]) model.idx_to_label = ckpt.get("idx_to_label") model.idx_to_issue_type = ckpt.get("idx_to_issue_type") model.idx_to_severity = ckpt.get("idx_to_severity") model.eval() return model.to(device), ckpt.get("temperature", 1.0) def run_electrical_outlets_inference( image_path: Optional[Path] = None, image_fp: Optional[BinaryIO] = None, audio_path: Optional[Path] = None, audio_fp: Optional[BinaryIO] = None, weights_dir: Path = None, config_dir: Path = None, device: str = None, ) -> Dict[str, Any]: """ Run image and/or audio model, then fuse. Returns canonical schema dict. """ if weights_dir is None: weights_dir = ROOT / "weights" if config_dir is None: config_dir = ROOT / "config" if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" label_mapping_path = config_dir / "label_mapping.json" thresholds_path = config_dir / "thresholds.yaml" import yaml with open(thresholds_path) as f: thresholds = yaml.safe_load(f) image_out = None if image_path or image_fp: img = Image.open(image_path or image_fp).convert("RGB") tf = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) x = tf(img).unsqueeze(0).to(device) model, T = _load_image_model(weights_dir / "electrical_outlets_image_best.pt", label_mapping_path, device) with torch.no_grad(): logits = model(x) / T from src.fusion.fusion_logic import ModalityOutput pred = model.predict_to_schema(logits) image_out = ModalityOutput( result=pred["result"], issue_type=pred.get("issue_type"), severity=pred["severity"], confidence=pred["confidence"], ) audio_out = None if (audio_path or audio_fp) and (weights_dir / "electrical_outlets_audio_best.pt").exists(): if audio_path: waveform, sr = torchaudio.load(str(audio_path)) else: import io waveform, sr = torchaudio.load(io.BytesIO(audio_fp.read())) if sr != 16000: waveform = torchaudio.functional.resample(waveform, sr, 16000) if waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) target_len = int(5.0 * 16000) if waveform.shape[1] >= target_len: start = (waveform.shape[1] - target_len) // 2 waveform = waveform[:, start : start + target_len] else: waveform = torch.nn.functional.pad(waveform, (0, target_len - waveform.shape[1])) mel = torchaudio.transforms.MelSpectrogram( sample_rate=16000, n_fft=512, hop_length=256, win_length=512, n_mels=64, )(waveform) log_mel = torch.log(mel.clamp(min=1e-5)).unsqueeze(0).to(device) model, T = _load_audio_model( weights_dir / "electrical_outlets_audio_best.pt", label_mapping_path, device, {"n_mels": 64, "time_steps": 128}, ) with torch.no_grad(): logits = model(log_mel) / T from src.fusion.fusion_logic import ModalityOutput pred = model.predict_to_schema(logits) audio_out = ModalityOutput( result=pred["result"], issue_type=pred.get("issue_type"), severity=pred["severity"], confidence=pred["confidence"], ) from src.fusion.fusion_logic import fuse_modalities return fuse_modalities( image_out, audio_out, confidence_issue_min=thresholds.get("confidence_issue_min", 0.6), confidence_normal_min=thresholds.get("confidence_normal_min", 0.75), uncertain_if_disagree=thresholds.get("uncertain_if_disagree", True), high_confidence_override=thresholds.get("high_confidence_override", 0.92), severity_order=thresholds.get("severity_order"), )