File size: 5,604 Bytes
5666923 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | """
Inference wrapper: load image + audio models, run modalities present, apply fusion, return schema.
"""
from pathlib import Path
from typing import Optional, Dict, Any, BinaryIO
import json
import torch
import torchaudio
from torchvision import transforms
from PIL import Image
# Optional imports for models
import sys
ROOT = Path(__file__).resolve().parent.parent.parent
sys.path.insert(0, str(ROOT))
def _load_image_model(weights_path: Path, label_mapping_path: Path, device: str):
from src.models.image_model import ElectricalOutletsImageModel
ckpt = torch.load(weights_path, map_location=device)
model = ElectricalOutletsImageModel(
num_classes=ckpt["num_classes"],
label_mapping_path=label_mapping_path,
pretrained=False,
)
model.load_state_dict(ckpt["model_state_dict"])
model.idx_to_issue_type = ckpt.get("idx_to_issue_type")
model.idx_to_severity = ckpt.get("idx_to_severity")
model.eval()
return model.to(device), ckpt.get("temperature", 1.0)
def _load_audio_model(weights_path: Path, label_mapping_path: Path, device: str, config: dict):
from src.models.audio_model import ElectricalOutletsAudioModel
ckpt = torch.load(weights_path, map_location=device)
model = ElectricalOutletsAudioModel(
num_classes=ckpt["num_classes"],
label_mapping_path=label_mapping_path,
n_mels=config.get("n_mels", 64),
time_steps=config.get("time_steps", 128),
)
model.load_state_dict(ckpt["model_state_dict"])
model.idx_to_label = ckpt.get("idx_to_label")
model.idx_to_issue_type = ckpt.get("idx_to_issue_type")
model.idx_to_severity = ckpt.get("idx_to_severity")
model.eval()
return model.to(device), ckpt.get("temperature", 1.0)
def run_electrical_outlets_inference(
image_path: Optional[Path] = None,
image_fp: Optional[BinaryIO] = None,
audio_path: Optional[Path] = None,
audio_fp: Optional[BinaryIO] = None,
weights_dir: Path = None,
config_dir: Path = None,
device: str = None,
) -> Dict[str, Any]:
"""
Run image and/or audio model, then fuse. Returns canonical schema dict.
"""
if weights_dir is None:
weights_dir = ROOT / "weights"
if config_dir is None:
config_dir = ROOT / "config"
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
label_mapping_path = config_dir / "label_mapping.json"
thresholds_path = config_dir / "thresholds.yaml"
import yaml
with open(thresholds_path) as f:
thresholds = yaml.safe_load(f)
image_out = None
if image_path or image_fp:
img = Image.open(image_path or image_fp).convert("RGB")
tf = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
x = tf(img).unsqueeze(0).to(device)
model, T = _load_image_model(weights_dir / "electrical_outlets_image_best.pt", label_mapping_path, device)
with torch.no_grad():
logits = model(x) / T
from src.fusion.fusion_logic import ModalityOutput
pred = model.predict_to_schema(logits)
image_out = ModalityOutput(
result=pred["result"],
issue_type=pred.get("issue_type"),
severity=pred["severity"],
confidence=pred["confidence"],
)
audio_out = None
if (audio_path or audio_fp) and (weights_dir / "electrical_outlets_audio_best.pt").exists():
if audio_path:
waveform, sr = torchaudio.load(str(audio_path))
else:
import io
waveform, sr = torchaudio.load(io.BytesIO(audio_fp.read()))
if sr != 16000:
waveform = torchaudio.functional.resample(waveform, sr, 16000)
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
target_len = int(5.0 * 16000)
if waveform.shape[1] >= target_len:
start = (waveform.shape[1] - target_len) // 2
waveform = waveform[:, start : start + target_len]
else:
waveform = torch.nn.functional.pad(waveform, (0, target_len - waveform.shape[1]))
mel = torchaudio.transforms.MelSpectrogram(
sample_rate=16000, n_fft=512, hop_length=256, win_length=512, n_mels=64,
)(waveform)
log_mel = torch.log(mel.clamp(min=1e-5)).unsqueeze(0).to(device)
model, T = _load_audio_model(
weights_dir / "electrical_outlets_audio_best.pt",
label_mapping_path,
device,
{"n_mels": 64, "time_steps": 128},
)
with torch.no_grad():
logits = model(log_mel) / T
from src.fusion.fusion_logic import ModalityOutput
pred = model.predict_to_schema(logits)
audio_out = ModalityOutput(
result=pred["result"],
issue_type=pred.get("issue_type"),
severity=pred["severity"],
confidence=pred["confidence"],
)
from src.fusion.fusion_logic import fuse_modalities
return fuse_modalities(
image_out,
audio_out,
confidence_issue_min=thresholds.get("confidence_issue_min", 0.6),
confidence_normal_min=thresholds.get("confidence_normal_min", 0.75),
uncertain_if_disagree=thresholds.get("uncertain_if_disagree", True),
high_confidence_override=thresholds.get("high_confidence_override", 0.92),
severity_order=thresholds.get("severity_order"),
)
|