Spaces:
Sleeping
Sleeping
File size: 1,945 Bytes
5ec9e9d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import os
from typing import Any, Dict, List, Optional
import torch
from torch.nn import functional as F
from app.core.config import settings
from app.ml.model import ECGClassifier
from app.ml.gating import gate_signal
_model: ECGClassifier | None = None
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_model() -> ECGClassifier:
"""
Lazy-load or initialize the ECG model.
In production, you would load trained weights.
"""
global _model
if _model is None:
model = ECGClassifier(num_classes=2)
weights_path: Optional[str] = settings.MODEL_WEIGHTS_PATH
if weights_path and os.path.exists(weights_path):
state = torch.load(weights_path, map_location=_device)
model.load_state_dict(state)
model.to(_device)
model.eval()
_model = model
return _model
@torch.no_grad()
def infer_ecg(
signal: List[float],
original_len: Optional[int] = None,
gating_meta: Optional[Dict[str, Any]] = None,
) -> Dict[str, float | str | int]:
"""
Run model inference on a single ECG signal.
Returns a label and score.
"""
model = load_model()
if not signal:
raise ValueError("Signal cannot be empty.")
tensor = torch.tensor(signal, dtype=torch.float32, device=_device).unsqueeze(0).unsqueeze(0)
logits = model(tensor)
probs = F.softmax(logits, dim=1)
score = float(probs[0, 1].item())
label = "arrhythmia" if score >= 0.5 else "normal"
# Dummy heart rate estimation as placeholder
hr_estimate = int(60 + 80 * score)
original_len = original_len or len(signal)
gating_ratio = len(signal) / max(original_len, 1)
result: Dict[str, float | str | int] = {
"label": label,
"score": score,
"hr": hr_estimate,
"gated_ratio": gating_ratio,
}
if gating_meta:
result["gating"] = gating_meta
return result
|