File size: 5,604 Bytes
5666923
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""
Inference wrapper: load image + audio models, run modalities present, apply fusion, return schema.
"""
from pathlib import Path
from typing import Optional, Dict, Any, BinaryIO
import json
import torch
import torchaudio
from torchvision import transforms
from PIL import Image

# Optional imports for models
import sys
ROOT = Path(__file__).resolve().parent.parent.parent
sys.path.insert(0, str(ROOT))


def _load_image_model(weights_path: Path, label_mapping_path: Path, device: str):
    from src.models.image_model import ElectricalOutletsImageModel
    ckpt = torch.load(weights_path, map_location=device)
    model = ElectricalOutletsImageModel(
        num_classes=ckpt["num_classes"],
        label_mapping_path=label_mapping_path,
        pretrained=False,
    )
    model.load_state_dict(ckpt["model_state_dict"])
    model.idx_to_issue_type = ckpt.get("idx_to_issue_type")
    model.idx_to_severity = ckpt.get("idx_to_severity")
    model.eval()
    return model.to(device), ckpt.get("temperature", 1.0)


def _load_audio_model(weights_path: Path, label_mapping_path: Path, device: str, config: dict):
    from src.models.audio_model import ElectricalOutletsAudioModel
    ckpt = torch.load(weights_path, map_location=device)
    model = ElectricalOutletsAudioModel(
        num_classes=ckpt["num_classes"],
        label_mapping_path=label_mapping_path,
        n_mels=config.get("n_mels", 64),
        time_steps=config.get("time_steps", 128),
    )
    model.load_state_dict(ckpt["model_state_dict"])
    model.idx_to_label = ckpt.get("idx_to_label")
    model.idx_to_issue_type = ckpt.get("idx_to_issue_type")
    model.idx_to_severity = ckpt.get("idx_to_severity")
    model.eval()
    return model.to(device), ckpt.get("temperature", 1.0)


def run_electrical_outlets_inference(
    image_path: Optional[Path] = None,
    image_fp: Optional[BinaryIO] = None,
    audio_path: Optional[Path] = None,
    audio_fp: Optional[BinaryIO] = None,
    weights_dir: Path = None,
    config_dir: Path = None,
    device: str = None,
) -> Dict[str, Any]:
    """
    Run image and/or audio model, then fuse. Returns canonical schema dict.
    """
    if weights_dir is None:
        weights_dir = ROOT / "weights"
    if config_dir is None:
        config_dir = ROOT / "config"
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"

    label_mapping_path = config_dir / "label_mapping.json"
    thresholds_path = config_dir / "thresholds.yaml"
    import yaml
    with open(thresholds_path) as f:
        thresholds = yaml.safe_load(f)

    image_out = None
    if image_path or image_fp:
        img = Image.open(image_path or image_fp).convert("RGB")
        tf = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])
        x = tf(img).unsqueeze(0).to(device)
        model, T = _load_image_model(weights_dir / "electrical_outlets_image_best.pt", label_mapping_path, device)
        with torch.no_grad():
            logits = model(x) / T
        from src.fusion.fusion_logic import ModalityOutput
        pred = model.predict_to_schema(logits)
        image_out = ModalityOutput(
            result=pred["result"],
            issue_type=pred.get("issue_type"),
            severity=pred["severity"],
            confidence=pred["confidence"],
        )

    audio_out = None
    if (audio_path or audio_fp) and (weights_dir / "electrical_outlets_audio_best.pt").exists():
        if audio_path:
            waveform, sr = torchaudio.load(str(audio_path))
        else:
            import io
            waveform, sr = torchaudio.load(io.BytesIO(audio_fp.read()))
        if sr != 16000:
            waveform = torchaudio.functional.resample(waveform, sr, 16000)
        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)
        target_len = int(5.0 * 16000)
        if waveform.shape[1] >= target_len:
            start = (waveform.shape[1] - target_len) // 2
            waveform = waveform[:, start : start + target_len]
        else:
            waveform = torch.nn.functional.pad(waveform, (0, target_len - waveform.shape[1]))
        mel = torchaudio.transforms.MelSpectrogram(
            sample_rate=16000, n_fft=512, hop_length=256, win_length=512, n_mels=64,
        )(waveform)
        log_mel = torch.log(mel.clamp(min=1e-5)).unsqueeze(0).to(device)
        model, T = _load_audio_model(
            weights_dir / "electrical_outlets_audio_best.pt",
            label_mapping_path,
            device,
            {"n_mels": 64, "time_steps": 128},
        )
        with torch.no_grad():
            logits = model(log_mel) / T
        from src.fusion.fusion_logic import ModalityOutput
        pred = model.predict_to_schema(logits)
        audio_out = ModalityOutput(
            result=pred["result"],
            issue_type=pred.get("issue_type"),
            severity=pred["severity"],
            confidence=pred["confidence"],
        )

    from src.fusion.fusion_logic import fuse_modalities
    return fuse_modalities(
        image_out,
        audio_out,
        confidence_issue_min=thresholds.get("confidence_issue_min", 0.6),
        confidence_normal_min=thresholds.get("confidence_normal_min", 0.75),
        uncertain_if_disagree=thresholds.get("uncertain_if_disagree", True),
        high_confidence_override=thresholds.get("high_confidence_override", 0.92),
        severity_order=thresholds.get("severity_order"),
    )