SundewAIHealth / app /ml /inference.py
mgbam's picture
Upload 5 files
5ec9e9d verified
raw
history blame
1.95 kB
import os
from typing import Any, Dict, List, Optional
import torch
from torch.nn import functional as F
from app.core.config import settings
from app.ml.model import ECGClassifier
from app.ml.gating import gate_signal
_model: ECGClassifier | None = None
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_model() -> ECGClassifier:
"""
Lazy-load or initialize the ECG model.
In production, you would load trained weights.
"""
global _model
if _model is None:
model = ECGClassifier(num_classes=2)
weights_path: Optional[str] = settings.MODEL_WEIGHTS_PATH
if weights_path and os.path.exists(weights_path):
state = torch.load(weights_path, map_location=_device)
model.load_state_dict(state)
model.to(_device)
model.eval()
_model = model
return _model
@torch.no_grad()
def infer_ecg(
signal: List[float],
original_len: Optional[int] = None,
gating_meta: Optional[Dict[str, Any]] = None,
) -> Dict[str, float | str | int]:
"""
Run model inference on a single ECG signal.
Returns a label and score.
"""
model = load_model()
if not signal:
raise ValueError("Signal cannot be empty.")
tensor = torch.tensor(signal, dtype=torch.float32, device=_device).unsqueeze(0).unsqueeze(0)
logits = model(tensor)
probs = F.softmax(logits, dim=1)
score = float(probs[0, 1].item())
label = "arrhythmia" if score >= 0.5 else "normal"
# Dummy heart rate estimation as placeholder
hr_estimate = int(60 + 80 * score)
original_len = original_len or len(signal)
gating_ratio = len(signal) / max(original_len, 1)
result: Dict[str, float | str | int] = {
"label": label,
"score": score,
"hr": hr_estimate,
"gated_ratio": gating_ratio,
}
if gating_meta:
result["gating"] = gating_meta
return result