import os from typing import Any, Dict, List, Optional import torch from torch.nn import functional as F from app.core.config import settings from app.ml.model import ECGClassifier from app.ml.gating import gate_signal _model: ECGClassifier | None = None _device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def load_model() -> ECGClassifier: """ Lazy-load or initialize the ECG model. In production, you would load trained weights. """ global _model if _model is None: model = ECGClassifier(num_classes=2) weights_path: Optional[str] = settings.MODEL_WEIGHTS_PATH if weights_path and os.path.exists(weights_path): state = torch.load(weights_path, map_location=_device) model.load_state_dict(state) model.to(_device) model.eval() _model = model return _model @torch.no_grad() def infer_ecg( signal: List[float], original_len: Optional[int] = None, gating_meta: Optional[Dict[str, Any]] = None, ) -> Dict[str, float | str | int]: """ Run model inference on a single ECG signal. Returns a label and score. """ model = load_model() if not signal: raise ValueError("Signal cannot be empty.") tensor = torch.tensor(signal, dtype=torch.float32, device=_device).unsqueeze(0).unsqueeze(0) logits = model(tensor) probs = F.softmax(logits, dim=1) score = float(probs[0, 1].item()) label = "arrhythmia" if score >= 0.5 else "normal" # Dummy heart rate estimation as placeholder hr_estimate = int(60 + 80 * score) original_len = original_len or len(signal) gating_ratio = len(signal) / max(original_len, 1) result: Dict[str, float | str | int] = { "label": label, "score": score, "hr": hr_estimate, "gated_ratio": gating_ratio, } if gating_meta: result["gating"] = gating_meta return result