Spaces:
Sleeping
Sleeping
| import os | |
| from typing import Any, Dict, List, Optional | |
| import torch | |
| from torch.nn import functional as F | |
| from app.core.config import settings | |
| from app.ml.model import ECGClassifier | |
| from app.ml.gating import gate_signal | |
| _model: ECGClassifier | None = None | |
| _device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def load_model() -> ECGClassifier: | |
| """ | |
| Lazy-load or initialize the ECG model. | |
| In production, you would load trained weights. | |
| """ | |
| global _model | |
| if _model is None: | |
| model = ECGClassifier(num_classes=2) | |
| weights_path: Optional[str] = settings.MODEL_WEIGHTS_PATH | |
| if weights_path and os.path.exists(weights_path): | |
| state = torch.load(weights_path, map_location=_device) | |
| model.load_state_dict(state) | |
| model.to(_device) | |
| model.eval() | |
| _model = model | |
| return _model | |
| def infer_ecg( | |
| signal: List[float], | |
| original_len: Optional[int] = None, | |
| gating_meta: Optional[Dict[str, Any]] = None, | |
| ) -> Dict[str, float | str | int]: | |
| """ | |
| Run model inference on a single ECG signal. | |
| Returns a label and score. | |
| """ | |
| model = load_model() | |
| if not signal: | |
| raise ValueError("Signal cannot be empty.") | |
| tensor = torch.tensor(signal, dtype=torch.float32, device=_device).unsqueeze(0).unsqueeze(0) | |
| logits = model(tensor) | |
| probs = F.softmax(logits, dim=1) | |
| score = float(probs[0, 1].item()) | |
| label = "arrhythmia" if score >= 0.5 else "normal" | |
| # Dummy heart rate estimation as placeholder | |
| hr_estimate = int(60 + 80 * score) | |
| original_len = original_len or len(signal) | |
| gating_ratio = len(signal) / max(original_len, 1) | |
| result: Dict[str, float | str | int] = { | |
| "label": label, | |
| "score": score, | |
| "hr": hr_estimate, | |
| "gated_ratio": gating_ratio, | |
| } | |
| if gating_meta: | |
| result["gating"] = gating_meta | |
| return result | |