File size: 1,945 Bytes
5ec9e9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
from typing import Any, Dict, List, Optional

import torch
from torch.nn import functional as F

from app.core.config import settings
from app.ml.model import ECGClassifier
from app.ml.gating import gate_signal

_model: ECGClassifier | None = None
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def load_model() -> ECGClassifier:
    """
    Lazy-load or initialize the ECG model.
    In production, you would load trained weights.
    """
    global _model
    if _model is None:
        model = ECGClassifier(num_classes=2)
        weights_path: Optional[str] = settings.MODEL_WEIGHTS_PATH
        if weights_path and os.path.exists(weights_path):
            state = torch.load(weights_path, map_location=_device)
            model.load_state_dict(state)
        model.to(_device)
        model.eval()
        _model = model
    return _model


@torch.no_grad()
def infer_ecg(
    signal: List[float],
    original_len: Optional[int] = None,
    gating_meta: Optional[Dict[str, Any]] = None,
) -> Dict[str, float | str | int]:
    """
    Run model inference on a single ECG signal.
    Returns a label and score.
    """
    model = load_model()
    if not signal:
        raise ValueError("Signal cannot be empty.")

    tensor = torch.tensor(signal, dtype=torch.float32, device=_device).unsqueeze(0).unsqueeze(0)
    logits = model(tensor)
    probs = F.softmax(logits, dim=1)
    score = float(probs[0, 1].item())

    label = "arrhythmia" if score >= 0.5 else "normal"

    # Dummy heart rate estimation as placeholder
    hr_estimate = int(60 + 80 * score)

    original_len = original_len or len(signal)
    gating_ratio = len(signal) / max(original_len, 1)

    result: Dict[str, float | str | int] = {
        "label": label,
        "score": score,
        "hr": hr_estimate,
        "gated_ratio": gating_ratio,
    }
    if gating_meta:
        result["gating"] = gating_meta
    return result