Spaces:
Running
Running
Upload 5 files
Browse files- app/ml/ast_adapter.py +28 -0
- app/ml/gating.py +116 -0
- app/ml/inference.py +68 -0
- app/ml/model.py +31 -0
- app/ml/train_ecg.py +173 -0
app/ml/ast_adapter.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Tuple, Type
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def _apply_torch_amp_shims() -> None:
|
| 7 |
+
"""
|
| 8 |
+
AST expects torch.amp.GradScaler/autocast (torch 2.3+); shim from torch.cuda.amp for 2.2.
|
| 9 |
+
"""
|
| 10 |
+
if not hasattr(torch.amp, "GradScaler") and hasattr(torch.cuda, "amp"):
|
| 11 |
+
torch.amp.GradScaler = torch.cuda.amp.GradScaler # type: ignore[attr-defined]
|
| 12 |
+
if not hasattr(torch.amp, "autocast") and hasattr(torch.cuda, "amp"):
|
| 13 |
+
torch.amp.autocast = torch.cuda.amp.autocast # type: ignore[attr-defined]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def load_ast_trainer() -> Tuple[Optional[Type[object]], Optional[Type[object]], Optional[Exception]]:
|
| 17 |
+
"""
|
| 18 |
+
Try to import AdaptiveSparseTrainer and ASTConfig from adaptive-sparse-training.
|
| 19 |
+
Returns (trainer_cls, config_cls, error)
|
| 20 |
+
"""
|
| 21 |
+
try:
|
| 22 |
+
_apply_torch_amp_shims()
|
| 23 |
+
from adaptive_sparse_training import AdaptiveSparseTrainer # type: ignore
|
| 24 |
+
from adaptive_sparse_training.config import ASTConfig # type: ignore
|
| 25 |
+
|
| 26 |
+
return AdaptiveSparseTrainer, ASTConfig, None
|
| 27 |
+
except Exception as exc: # pragma: no cover - optional dependency
|
| 28 |
+
return None, None, exc
|
app/ml/gating.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, List, Tuple
|
| 2 |
+
|
| 3 |
+
from sundew.gating import gate_probability_with_hysteresis, significance_score
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def _clamp(value: float, low: float = 0.0, high: float = 1.0) -> float:
|
| 7 |
+
return max(low, min(high, value))
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _window_features(window: List[float]) -> Dict[str, float]:
|
| 11 |
+
"""
|
| 12 |
+
Compute basic features for a window to feed Sundew's significance score.
|
| 13 |
+
"""
|
| 14 |
+
if not window:
|
| 15 |
+
return {"magnitude": 0.0, "anomaly_score": 0.0, "context_relevance": 0.0, "urgency": 0.0}
|
| 16 |
+
|
| 17 |
+
length = float(len(window))
|
| 18 |
+
mean = sum(window) / length
|
| 19 |
+
mean_abs = sum(abs(x) for x in window) / length
|
| 20 |
+
max_abs = max(abs(x) for x in window)
|
| 21 |
+
variance = sum((x - mean) ** 2 for x in window) / length
|
| 22 |
+
|
| 23 |
+
magnitude = _clamp(max_abs * 10.0, 0.0, 10.0) * 10.0 # 0..100 scale
|
| 24 |
+
anomaly_score = _clamp(variance / (variance + 1.0))
|
| 25 |
+
context_relevance = _clamp(mean_abs)
|
| 26 |
+
urgency = _clamp(mean_abs * 0.5)
|
| 27 |
+
|
| 28 |
+
return {
|
| 29 |
+
"magnitude": magnitude,
|
| 30 |
+
"anomaly_score": anomaly_score,
|
| 31 |
+
"context_relevance": context_relevance,
|
| 32 |
+
"urgency": urgency,
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def gate_signal(
|
| 37 |
+
signal: List[float],
|
| 38 |
+
window_size: int = 128,
|
| 39 |
+
step: int = 64,
|
| 40 |
+
threshold: float = 0.55,
|
| 41 |
+
temperature: float = 0.15,
|
| 42 |
+
return_windows: bool = False,
|
| 43 |
+
max_windows: int = 200,
|
| 44 |
+
) -> Tuple[List[float], Dict[str, Any]]:
|
| 45 |
+
"""
|
| 46 |
+
Apply Sundew gating over a sliding window to reduce workload.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
signal: raw signal values
|
| 50 |
+
window_size: sliding window size
|
| 51 |
+
step: stride between windows
|
| 52 |
+
threshold: gate threshold
|
| 53 |
+
temperature: gate temperature (0=hard)
|
| 54 |
+
return_windows: include per-window metadata
|
| 55 |
+
max_windows: limit number of windows returned (for previews)
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
gated_signal: flattened list of selected windows (original signal if nothing selected)
|
| 59 |
+
meta: gating metadata (counts, ratios, thresholds, optional windows)
|
| 60 |
+
"""
|
| 61 |
+
if len(signal) < window_size:
|
| 62 |
+
meta = {
|
| 63 |
+
"total_windows": 0,
|
| 64 |
+
"selected_windows": 0,
|
| 65 |
+
"ratio": 1.0,
|
| 66 |
+
"threshold": threshold,
|
| 67 |
+
"temperature": temperature,
|
| 68 |
+
}
|
| 69 |
+
return signal, meta
|
| 70 |
+
|
| 71 |
+
last_activation = False
|
| 72 |
+
selected: List[float] = []
|
| 73 |
+
total_windows = 0
|
| 74 |
+
selected_windows = 0
|
| 75 |
+
windows_meta: List[Dict[str, Any]] = []
|
| 76 |
+
|
| 77 |
+
for start in range(0, len(signal) - window_size + 1, step):
|
| 78 |
+
window = signal[start : start + window_size]
|
| 79 |
+
total_windows += 1
|
| 80 |
+
|
| 81 |
+
features = _window_features(window)
|
| 82 |
+
sig = significance_score(features, w_mag=0.35, w_ano=0.4, w_ctx=0.15, w_urg=0.1)
|
| 83 |
+
prob = gate_probability_with_hysteresis(sig, threshold=threshold, temperature=temperature, last_activation=last_activation)
|
| 84 |
+
|
| 85 |
+
chosen = prob >= 0.5
|
| 86 |
+
if chosen:
|
| 87 |
+
selected.extend(window)
|
| 88 |
+
selected_windows += 1
|
| 89 |
+
last_activation = True
|
| 90 |
+
else:
|
| 91 |
+
last_activation = False
|
| 92 |
+
|
| 93 |
+
if return_windows and len(windows_meta) < max_windows:
|
| 94 |
+
windows_meta.append(
|
| 95 |
+
{
|
| 96 |
+
"start": start,
|
| 97 |
+
"end": start + window_size,
|
| 98 |
+
"significance": sig,
|
| 99 |
+
"probability": prob,
|
| 100 |
+
"selected": chosen,
|
| 101 |
+
}
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
if not selected:
|
| 105 |
+
selected = signal # fall back to full signal
|
| 106 |
+
|
| 107 |
+
meta: Dict[str, Any] = {
|
| 108 |
+
"total_windows": total_windows,
|
| 109 |
+
"selected_windows": selected_windows,
|
| 110 |
+
"ratio": len(selected) / max(len(signal), 1),
|
| 111 |
+
"threshold": threshold,
|
| 112 |
+
"temperature": temperature,
|
| 113 |
+
}
|
| 114 |
+
if return_windows:
|
| 115 |
+
meta["windows"] = windows_meta
|
| 116 |
+
return selected, meta
|
app/ml/inference.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Any, Dict, List, Optional
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
|
| 7 |
+
from app.core.config import settings
|
| 8 |
+
from app.ml.model import ECGClassifier
|
| 9 |
+
from app.ml.gating import gate_signal
|
| 10 |
+
|
| 11 |
+
_model: ECGClassifier | None = None
|
| 12 |
+
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def load_model() -> ECGClassifier:
|
| 16 |
+
"""
|
| 17 |
+
Lazy-load or initialize the ECG model.
|
| 18 |
+
In production, you would load trained weights.
|
| 19 |
+
"""
|
| 20 |
+
global _model
|
| 21 |
+
if _model is None:
|
| 22 |
+
model = ECGClassifier(num_classes=2)
|
| 23 |
+
weights_path: Optional[str] = settings.MODEL_WEIGHTS_PATH
|
| 24 |
+
if weights_path and os.path.exists(weights_path):
|
| 25 |
+
state = torch.load(weights_path, map_location=_device)
|
| 26 |
+
model.load_state_dict(state)
|
| 27 |
+
model.to(_device)
|
| 28 |
+
model.eval()
|
| 29 |
+
_model = model
|
| 30 |
+
return _model
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@torch.no_grad()
|
| 34 |
+
def infer_ecg(
|
| 35 |
+
signal: List[float],
|
| 36 |
+
original_len: Optional[int] = None,
|
| 37 |
+
gating_meta: Optional[Dict[str, Any]] = None,
|
| 38 |
+
) -> Dict[str, float | str | int]:
|
| 39 |
+
"""
|
| 40 |
+
Run model inference on a single ECG signal.
|
| 41 |
+
Returns a label and score.
|
| 42 |
+
"""
|
| 43 |
+
model = load_model()
|
| 44 |
+
if not signal:
|
| 45 |
+
raise ValueError("Signal cannot be empty.")
|
| 46 |
+
|
| 47 |
+
tensor = torch.tensor(signal, dtype=torch.float32, device=_device).unsqueeze(0).unsqueeze(0)
|
| 48 |
+
logits = model(tensor)
|
| 49 |
+
probs = F.softmax(logits, dim=1)
|
| 50 |
+
score = float(probs[0, 1].item())
|
| 51 |
+
|
| 52 |
+
label = "arrhythmia" if score >= 0.5 else "normal"
|
| 53 |
+
|
| 54 |
+
# Dummy heart rate estimation as placeholder
|
| 55 |
+
hr_estimate = int(60 + 80 * score)
|
| 56 |
+
|
| 57 |
+
original_len = original_len or len(signal)
|
| 58 |
+
gating_ratio = len(signal) / max(original_len, 1)
|
| 59 |
+
|
| 60 |
+
result: Dict[str, float | str | int] = {
|
| 61 |
+
"label": label,
|
| 62 |
+
"score": score,
|
| 63 |
+
"hr": hr_estimate,
|
| 64 |
+
"gated_ratio": gating_ratio,
|
| 65 |
+
}
|
| 66 |
+
if gating_meta:
|
| 67 |
+
result["gating"] = gating_meta
|
| 68 |
+
return result
|
app/ml/model.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class ECGClassifier(nn.Module):
|
| 6 |
+
"""
|
| 7 |
+
Simple 1D CNN for ECG classification.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
def __init__(self, num_classes: int = 2):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.features = nn.Sequential(
|
| 13 |
+
nn.Conv1d(1, 16, kernel_size=5, padding=2),
|
| 14 |
+
nn.BatchNorm1d(16),
|
| 15 |
+
nn.ReLU(inplace=True),
|
| 16 |
+
nn.MaxPool1d(kernel_size=2),
|
| 17 |
+
nn.Conv1d(16, 32, kernel_size=3, padding=1),
|
| 18 |
+
nn.BatchNorm1d(32),
|
| 19 |
+
nn.ReLU(inplace=True),
|
| 20 |
+
nn.AdaptiveAvgPool1d(1),
|
| 21 |
+
)
|
| 22 |
+
self.classifier = nn.Sequential(
|
| 23 |
+
nn.Flatten(),
|
| 24 |
+
nn.Linear(32, num_classes),
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 28 |
+
# x shape: (batch, channels=1, length)
|
| 29 |
+
feats = self.features(x)
|
| 30 |
+
logits = self.classifier(feats)
|
| 31 |
+
return logits
|
app/ml/train_ecg.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import List, Sequence, Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn, optim
|
| 6 |
+
from torch.utils.data import DataLoader, Dataset
|
| 7 |
+
from sqlalchemy import create_engine, select
|
| 8 |
+
from sqlalchemy.orm import Session, sessionmaker
|
| 9 |
+
|
| 10 |
+
from app.core.config import settings
|
| 11 |
+
from app.ml.model import ECGClassifier
|
| 12 |
+
from app.models.ecg import Base, ECGSample
|
| 13 |
+
from app.ml.ast_adapter import load_ast_trainer
|
| 14 |
+
|
| 15 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 16 |
+
LABEL_TO_IDX = {"normal": 0, "arrhythmia": 1}
|
| 17 |
+
AST_TRAINER, AST_CONFIG, AST_ERROR = load_ast_trainer()
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class ECGDataset(Dataset):
|
| 21 |
+
"""
|
| 22 |
+
In-memory dataset built from ECGSample rows.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, samples: Sequence[ECGSample], max_len: int):
|
| 26 |
+
self.samples = samples
|
| 27 |
+
self.max_len = max_len
|
| 28 |
+
self.items: List[Tuple[torch.Tensor, int]] = []
|
| 29 |
+
for sample in samples:
|
| 30 |
+
signal = sample.signal or []
|
| 31 |
+
if not signal:
|
| 32 |
+
continue
|
| 33 |
+
tensor = torch.tensor(signal, dtype=torch.float32)
|
| 34 |
+
if tensor.numel() < self.max_len:
|
| 35 |
+
pad = self.max_len - tensor.numel()
|
| 36 |
+
tensor = torch.nn.functional.pad(tensor, (0, pad))
|
| 37 |
+
elif tensor.numel() > self.max_len:
|
| 38 |
+
tensor = tensor[: self.max_len]
|
| 39 |
+
# reshape to (channels=1, length)
|
| 40 |
+
tensor = tensor.unsqueeze(0)
|
| 41 |
+
label_idx = LABEL_TO_IDX.get(sample.label or "normal", 0)
|
| 42 |
+
self.items.append((tensor, label_idx))
|
| 43 |
+
|
| 44 |
+
def __len__(self) -> int:
|
| 45 |
+
return len(self.items)
|
| 46 |
+
|
| 47 |
+
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
|
| 48 |
+
return self.items[idx]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def load_samples() -> List[ECGSample]:
|
| 52 |
+
"""
|
| 53 |
+
Load all ECGSample rows from the configured database.
|
| 54 |
+
Ensures tables exist before querying.
|
| 55 |
+
"""
|
| 56 |
+
engine = create_engine(settings.DATABASE_URL, future=True)
|
| 57 |
+
SessionLocal = sessionmaker(bind=engine)
|
| 58 |
+
Base.metadata.create_all(bind=engine)
|
| 59 |
+
|
| 60 |
+
with SessionLocal() as session:
|
| 61 |
+
result = session.execute(select(ECGSample))
|
| 62 |
+
rows = result.scalars().all()
|
| 63 |
+
|
| 64 |
+
engine.dispose()
|
| 65 |
+
return list(rows)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def train_model(dataset: Dataset, epochs: int = 3, batch_size: int = 8, lr: float = 1e-3) -> ECGClassifier:
|
| 69 |
+
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
| 70 |
+
model = ECGClassifier(num_classes=len(LABEL_TO_IDX)).to(device)
|
| 71 |
+
criterion = nn.CrossEntropyLoss()
|
| 72 |
+
optimizer = optim.Adam(model.parameters(), lr=lr)
|
| 73 |
+
|
| 74 |
+
model.train()
|
| 75 |
+
for epoch in range(epochs):
|
| 76 |
+
running_loss = 0.0
|
| 77 |
+
for batch_x, batch_y in loader:
|
| 78 |
+
batch_x = batch_x.to(device)
|
| 79 |
+
batch_y = batch_y.to(device)
|
| 80 |
+
optimizer.zero_grad()
|
| 81 |
+
logits = model(batch_x)
|
| 82 |
+
loss = criterion(logits, batch_y)
|
| 83 |
+
loss.backward()
|
| 84 |
+
optimizer.step()
|
| 85 |
+
running_loss += loss.item() * batch_x.size(0)
|
| 86 |
+
epoch_loss = running_loss / max(len(dataset), 1)
|
| 87 |
+
print(f"Epoch {epoch + 1}/{epochs} - loss: {epoch_loss:.4f}")
|
| 88 |
+
|
| 89 |
+
model.eval()
|
| 90 |
+
return model
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def save_weights(model: ECGClassifier) -> str:
|
| 94 |
+
"""
|
| 95 |
+
Save model weights to the configured path (or default).
|
| 96 |
+
"""
|
| 97 |
+
path = settings.MODEL_WEIGHTS_PATH or "./checkpoints/ecg_classifier.pt"
|
| 98 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
| 99 |
+
torch.save(model.state_dict(), path)
|
| 100 |
+
return path
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def build_dataloader(dataset: Dataset, batch_size: int = 8) -> DataLoader:
|
| 104 |
+
return DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def generate_synthetic_samples() -> List[ECGSample]:
|
| 108 |
+
"""
|
| 109 |
+
Create a tiny synthetic dataset if the DB is empty (not persisted).
|
| 110 |
+
"""
|
| 111 |
+
import math
|
| 112 |
+
|
| 113 |
+
class SyntheticSample:
|
| 114 |
+
def __init__(self, signal: List[float], label: str):
|
| 115 |
+
self.signal = signal
|
| 116 |
+
self.label = label
|
| 117 |
+
|
| 118 |
+
t = [i / 50.0 for i in range(256)]
|
| 119 |
+
normal = [0.05 * math.sin(2 * math.pi * f) for f in t]
|
| 120 |
+
arrhythmia = [0.3 * math.sin(2 * math.pi * f * 3) + 0.1 * math.sin(2 * math.pi * f * 7) for f in t]
|
| 121 |
+
return [
|
| 122 |
+
SyntheticSample(normal, "normal"),
|
| 123 |
+
SyntheticSample(arrhythmia, "arrhythmia"),
|
| 124 |
+
]
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def main() -> None:
|
| 128 |
+
samples = load_samples()
|
| 129 |
+
if not samples:
|
| 130 |
+
print("No ECG samples found in the database. Using synthetic samples for a minimal run.")
|
| 131 |
+
samples = generate_synthetic_samples()
|
| 132 |
+
|
| 133 |
+
max_len = max(len(sample.signal or []) for sample in samples)
|
| 134 |
+
if max_len == 0:
|
| 135 |
+
print("ECG samples contain empty signals; cannot train.")
|
| 136 |
+
return
|
| 137 |
+
|
| 138 |
+
dataset = ECGDataset(samples, max_len=max_len)
|
| 139 |
+
if len(dataset) == 0:
|
| 140 |
+
print("Dataset is empty after filtering; cannot train.")
|
| 141 |
+
return
|
| 142 |
+
|
| 143 |
+
train_loader = build_dataloader(dataset)
|
| 144 |
+
model = ECGClassifier(num_classes=len(LABEL_TO_IDX)).to(device)
|
| 145 |
+
|
| 146 |
+
if AST_TRAINER and AST_CONFIG:
|
| 147 |
+
cfg = AST_CONFIG(
|
| 148 |
+
target_activation_rate=0.4,
|
| 149 |
+
initial_threshold=2.5,
|
| 150 |
+
adapt_kp=0.005,
|
| 151 |
+
adapt_ki=0.0001,
|
| 152 |
+
ema_alpha=0.1,
|
| 153 |
+
energy_per_activation=1.0,
|
| 154 |
+
energy_per_skip=0.01,
|
| 155 |
+
use_amp=False, # CPU-only by default here
|
| 156 |
+
device=device.type,
|
| 157 |
+
)
|
| 158 |
+
optimizer = optim.Adam(model.parameters(), lr=1e-3)
|
| 159 |
+
criterion = nn.CrossEntropyLoss(reduction="none")
|
| 160 |
+
trainer = AST_TRAINER(model, train_loader, train_loader, cfg, optimizer=optimizer, criterion=criterion)
|
| 161 |
+
trainer.train(epochs=3, warmup_epochs=0)
|
| 162 |
+
print("Adaptive Sparse Training completed.")
|
| 163 |
+
else:
|
| 164 |
+
if AST_ERROR:
|
| 165 |
+
print(f"Adaptive Sparse Training not active (optional): {AST_ERROR}")
|
| 166 |
+
model = train_model(dataset)
|
| 167 |
+
|
| 168 |
+
weights_path = save_weights(model)
|
| 169 |
+
print(f"Training complete. Weights saved to {weights_path}")
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
if __name__ == "__main__":
|
| 173 |
+
main()
|