| | """ |
| | Inference wrapper: load image + audio models, run modalities present, apply fusion, return schema. |
| | """ |
| | from pathlib import Path |
| | from typing import Optional, Dict, Any, BinaryIO |
| | import json |
| | import torch |
| | import torchaudio |
| | from torchvision import transforms |
| | from PIL import Image |
| |
|
| | |
| | import sys |
| | ROOT = Path(__file__).resolve().parent.parent.parent |
| | sys.path.insert(0, str(ROOT)) |
| |
|
| |
|
| | def _load_image_model(weights_path: Path, label_mapping_path: Path, device: str): |
| | from src.models.image_model import ElectricalOutletsImageModel |
| | ckpt = torch.load(weights_path, map_location=device) |
| | model = ElectricalOutletsImageModel( |
| | num_classes=ckpt["num_classes"], |
| | label_mapping_path=label_mapping_path, |
| | pretrained=False, |
| | ) |
| | model.load_state_dict(ckpt["model_state_dict"]) |
| | model.idx_to_issue_type = ckpt.get("idx_to_issue_type") |
| | model.idx_to_severity = ckpt.get("idx_to_severity") |
| | model.eval() |
| | return model.to(device), ckpt.get("temperature", 1.0) |
| |
|
| |
|
| | def _load_audio_model(weights_path: Path, label_mapping_path: Path, device: str, config: dict): |
| | from src.models.audio_model import ElectricalOutletsAudioModel |
| | ckpt = torch.load(weights_path, map_location=device) |
| | model = ElectricalOutletsAudioModel( |
| | num_classes=ckpt["num_classes"], |
| | label_mapping_path=label_mapping_path, |
| | n_mels=config.get("n_mels", 64), |
| | time_steps=config.get("time_steps", 128), |
| | ) |
| | model.load_state_dict(ckpt["model_state_dict"]) |
| | model.idx_to_label = ckpt.get("idx_to_label") |
| | model.idx_to_issue_type = ckpt.get("idx_to_issue_type") |
| | model.idx_to_severity = ckpt.get("idx_to_severity") |
| | model.eval() |
| | return model.to(device), ckpt.get("temperature", 1.0) |
| |
|
| |
|
| | def run_electrical_outlets_inference( |
| | image_path: Optional[Path] = None, |
| | image_fp: Optional[BinaryIO] = None, |
| | audio_path: Optional[Path] = None, |
| | audio_fp: Optional[BinaryIO] = None, |
| | weights_dir: Path = None, |
| | config_dir: Path = None, |
| | device: str = None, |
| | ) -> Dict[str, Any]: |
| | """ |
| | Run image and/or audio model, then fuse. Returns canonical schema dict. |
| | """ |
| | if weights_dir is None: |
| | weights_dir = ROOT / "weights" |
| | if config_dir is None: |
| | config_dir = ROOT / "config" |
| | if device is None: |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| | label_mapping_path = config_dir / "label_mapping.json" |
| | thresholds_path = config_dir / "thresholds.yaml" |
| | import yaml |
| | with open(thresholds_path) as f: |
| | thresholds = yaml.safe_load(f) |
| |
|
| | image_out = None |
| | if image_path or image_fp: |
| | img = Image.open(image_path or image_fp).convert("RGB") |
| | tf = transforms.Compose([ |
| | transforms.Resize(256), |
| | transforms.CenterCrop(224), |
| | transforms.ToTensor(), |
| | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
| | ]) |
| | x = tf(img).unsqueeze(0).to(device) |
| | model, T = _load_image_model(weights_dir / "electrical_outlets_image_best.pt", label_mapping_path, device) |
| | with torch.no_grad(): |
| | logits = model(x) / T |
| | from src.fusion.fusion_logic import ModalityOutput |
| | pred = model.predict_to_schema(logits) |
| | image_out = ModalityOutput( |
| | result=pred["result"], |
| | issue_type=pred.get("issue_type"), |
| | severity=pred["severity"], |
| | confidence=pred["confidence"], |
| | ) |
| |
|
| | audio_out = None |
| | if (audio_path or audio_fp) and (weights_dir / "electrical_outlets_audio_best.pt").exists(): |
| | if audio_path: |
| | waveform, sr = torchaudio.load(str(audio_path)) |
| | else: |
| | import io |
| | waveform, sr = torchaudio.load(io.BytesIO(audio_fp.read())) |
| | if sr != 16000: |
| | waveform = torchaudio.functional.resample(waveform, sr, 16000) |
| | if waveform.shape[0] > 1: |
| | waveform = waveform.mean(dim=0, keepdim=True) |
| | target_len = int(5.0 * 16000) |
| | if waveform.shape[1] >= target_len: |
| | start = (waveform.shape[1] - target_len) // 2 |
| | waveform = waveform[:, start : start + target_len] |
| | else: |
| | waveform = torch.nn.functional.pad(waveform, (0, target_len - waveform.shape[1])) |
| | mel = torchaudio.transforms.MelSpectrogram( |
| | sample_rate=16000, n_fft=512, hop_length=256, win_length=512, n_mels=64, |
| | )(waveform) |
| | log_mel = torch.log(mel.clamp(min=1e-5)).unsqueeze(0).to(device) |
| | model, T = _load_audio_model( |
| | weights_dir / "electrical_outlets_audio_best.pt", |
| | label_mapping_path, |
| | device, |
| | {"n_mels": 64, "time_steps": 128}, |
| | ) |
| | with torch.no_grad(): |
| | logits = model(log_mel) / T |
| | from src.fusion.fusion_logic import ModalityOutput |
| | pred = model.predict_to_schema(logits) |
| | audio_out = ModalityOutput( |
| | result=pred["result"], |
| | issue_type=pred.get("issue_type"), |
| | severity=pred["severity"], |
| | confidence=pred["confidence"], |
| | ) |
| |
|
| | from src.fusion.fusion_logic import fuse_modalities |
| | return fuse_modalities( |
| | image_out, |
| | audio_out, |
| | confidence_issue_min=thresholds.get("confidence_issue_min", 0.6), |
| | confidence_normal_min=thresholds.get("confidence_normal_min", 0.75), |
| | uncertain_if_disagree=thresholds.get("uncertain_if_disagree", True), |
| | high_confidence_override=thresholds.get("high_confidence_override", 0.92), |
| | severity_order=thresholds.get("severity_order"), |
| | ) |
| |
|