mgbam commited on
Commit
5ec9e9d
·
verified ·
1 Parent(s): cbecd45

Upload 5 files

Browse files
app/ml/ast_adapter.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Type
2
+
3
+ import torch
4
+
5
+
6
+ def _apply_torch_amp_shims() -> None:
7
+ """
8
+ AST expects torch.amp.GradScaler/autocast (torch 2.3+); shim from torch.cuda.amp for 2.2.
9
+ """
10
+ if not hasattr(torch.amp, "GradScaler") and hasattr(torch.cuda, "amp"):
11
+ torch.amp.GradScaler = torch.cuda.amp.GradScaler # type: ignore[attr-defined]
12
+ if not hasattr(torch.amp, "autocast") and hasattr(torch.cuda, "amp"):
13
+ torch.amp.autocast = torch.cuda.amp.autocast # type: ignore[attr-defined]
14
+
15
+
16
+ def load_ast_trainer() -> Tuple[Optional[Type[object]], Optional[Type[object]], Optional[Exception]]:
17
+ """
18
+ Try to import AdaptiveSparseTrainer and ASTConfig from adaptive-sparse-training.
19
+ Returns (trainer_cls, config_cls, error)
20
+ """
21
+ try:
22
+ _apply_torch_amp_shims()
23
+ from adaptive_sparse_training import AdaptiveSparseTrainer # type: ignore
24
+ from adaptive_sparse_training.config import ASTConfig # type: ignore
25
+
26
+ return AdaptiveSparseTrainer, ASTConfig, None
27
+ except Exception as exc: # pragma: no cover - optional dependency
28
+ return None, None, exc
app/ml/gating.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Tuple
2
+
3
+ from sundew.gating import gate_probability_with_hysteresis, significance_score
4
+
5
+
6
+ def _clamp(value: float, low: float = 0.0, high: float = 1.0) -> float:
7
+ return max(low, min(high, value))
8
+
9
+
10
+ def _window_features(window: List[float]) -> Dict[str, float]:
11
+ """
12
+ Compute basic features for a window to feed Sundew's significance score.
13
+ """
14
+ if not window:
15
+ return {"magnitude": 0.0, "anomaly_score": 0.0, "context_relevance": 0.0, "urgency": 0.0}
16
+
17
+ length = float(len(window))
18
+ mean = sum(window) / length
19
+ mean_abs = sum(abs(x) for x in window) / length
20
+ max_abs = max(abs(x) for x in window)
21
+ variance = sum((x - mean) ** 2 for x in window) / length
22
+
23
+ magnitude = _clamp(max_abs * 10.0, 0.0, 10.0) * 10.0 # 0..100 scale
24
+ anomaly_score = _clamp(variance / (variance + 1.0))
25
+ context_relevance = _clamp(mean_abs)
26
+ urgency = _clamp(mean_abs * 0.5)
27
+
28
+ return {
29
+ "magnitude": magnitude,
30
+ "anomaly_score": anomaly_score,
31
+ "context_relevance": context_relevance,
32
+ "urgency": urgency,
33
+ }
34
+
35
+
36
+ def gate_signal(
37
+ signal: List[float],
38
+ window_size: int = 128,
39
+ step: int = 64,
40
+ threshold: float = 0.55,
41
+ temperature: float = 0.15,
42
+ return_windows: bool = False,
43
+ max_windows: int = 200,
44
+ ) -> Tuple[List[float], Dict[str, Any]]:
45
+ """
46
+ Apply Sundew gating over a sliding window to reduce workload.
47
+
48
+ Args:
49
+ signal: raw signal values
50
+ window_size: sliding window size
51
+ step: stride between windows
52
+ threshold: gate threshold
53
+ temperature: gate temperature (0=hard)
54
+ return_windows: include per-window metadata
55
+ max_windows: limit number of windows returned (for previews)
56
+
57
+ Returns:
58
+ gated_signal: flattened list of selected windows (original signal if nothing selected)
59
+ meta: gating metadata (counts, ratios, thresholds, optional windows)
60
+ """
61
+ if len(signal) < window_size:
62
+ meta = {
63
+ "total_windows": 0,
64
+ "selected_windows": 0,
65
+ "ratio": 1.0,
66
+ "threshold": threshold,
67
+ "temperature": temperature,
68
+ }
69
+ return signal, meta
70
+
71
+ last_activation = False
72
+ selected: List[float] = []
73
+ total_windows = 0
74
+ selected_windows = 0
75
+ windows_meta: List[Dict[str, Any]] = []
76
+
77
+ for start in range(0, len(signal) - window_size + 1, step):
78
+ window = signal[start : start + window_size]
79
+ total_windows += 1
80
+
81
+ features = _window_features(window)
82
+ sig = significance_score(features, w_mag=0.35, w_ano=0.4, w_ctx=0.15, w_urg=0.1)
83
+ prob = gate_probability_with_hysteresis(sig, threshold=threshold, temperature=temperature, last_activation=last_activation)
84
+
85
+ chosen = prob >= 0.5
86
+ if chosen:
87
+ selected.extend(window)
88
+ selected_windows += 1
89
+ last_activation = True
90
+ else:
91
+ last_activation = False
92
+
93
+ if return_windows and len(windows_meta) < max_windows:
94
+ windows_meta.append(
95
+ {
96
+ "start": start,
97
+ "end": start + window_size,
98
+ "significance": sig,
99
+ "probability": prob,
100
+ "selected": chosen,
101
+ }
102
+ )
103
+
104
+ if not selected:
105
+ selected = signal # fall back to full signal
106
+
107
+ meta: Dict[str, Any] = {
108
+ "total_windows": total_windows,
109
+ "selected_windows": selected_windows,
110
+ "ratio": len(selected) / max(len(signal), 1),
111
+ "threshold": threshold,
112
+ "temperature": temperature,
113
+ }
114
+ if return_windows:
115
+ meta["windows"] = windows_meta
116
+ return selected, meta
app/ml/inference.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, Dict, List, Optional
3
+
4
+ import torch
5
+ from torch.nn import functional as F
6
+
7
+ from app.core.config import settings
8
+ from app.ml.model import ECGClassifier
9
+ from app.ml.gating import gate_signal
10
+
11
+ _model: ECGClassifier | None = None
12
+ _device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+
14
+
15
+ def load_model() -> ECGClassifier:
16
+ """
17
+ Lazy-load or initialize the ECG model.
18
+ In production, you would load trained weights.
19
+ """
20
+ global _model
21
+ if _model is None:
22
+ model = ECGClassifier(num_classes=2)
23
+ weights_path: Optional[str] = settings.MODEL_WEIGHTS_PATH
24
+ if weights_path and os.path.exists(weights_path):
25
+ state = torch.load(weights_path, map_location=_device)
26
+ model.load_state_dict(state)
27
+ model.to(_device)
28
+ model.eval()
29
+ _model = model
30
+ return _model
31
+
32
+
33
+ @torch.no_grad()
34
+ def infer_ecg(
35
+ signal: List[float],
36
+ original_len: Optional[int] = None,
37
+ gating_meta: Optional[Dict[str, Any]] = None,
38
+ ) -> Dict[str, float | str | int]:
39
+ """
40
+ Run model inference on a single ECG signal.
41
+ Returns a label and score.
42
+ """
43
+ model = load_model()
44
+ if not signal:
45
+ raise ValueError("Signal cannot be empty.")
46
+
47
+ tensor = torch.tensor(signal, dtype=torch.float32, device=_device).unsqueeze(0).unsqueeze(0)
48
+ logits = model(tensor)
49
+ probs = F.softmax(logits, dim=1)
50
+ score = float(probs[0, 1].item())
51
+
52
+ label = "arrhythmia" if score >= 0.5 else "normal"
53
+
54
+ # Dummy heart rate estimation as placeholder
55
+ hr_estimate = int(60 + 80 * score)
56
+
57
+ original_len = original_len or len(signal)
58
+ gating_ratio = len(signal) / max(original_len, 1)
59
+
60
+ result: Dict[str, float | str | int] = {
61
+ "label": label,
62
+ "score": score,
63
+ "hr": hr_estimate,
64
+ "gated_ratio": gating_ratio,
65
+ }
66
+ if gating_meta:
67
+ result["gating"] = gating_meta
68
+ return result
app/ml/model.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class ECGClassifier(nn.Module):
6
+ """
7
+ Simple 1D CNN for ECG classification.
8
+ """
9
+
10
+ def __init__(self, num_classes: int = 2):
11
+ super().__init__()
12
+ self.features = nn.Sequential(
13
+ nn.Conv1d(1, 16, kernel_size=5, padding=2),
14
+ nn.BatchNorm1d(16),
15
+ nn.ReLU(inplace=True),
16
+ nn.MaxPool1d(kernel_size=2),
17
+ nn.Conv1d(16, 32, kernel_size=3, padding=1),
18
+ nn.BatchNorm1d(32),
19
+ nn.ReLU(inplace=True),
20
+ nn.AdaptiveAvgPool1d(1),
21
+ )
22
+ self.classifier = nn.Sequential(
23
+ nn.Flatten(),
24
+ nn.Linear(32, num_classes),
25
+ )
26
+
27
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
28
+ # x shape: (batch, channels=1, length)
29
+ feats = self.features(x)
30
+ logits = self.classifier(feats)
31
+ return logits
app/ml/train_ecg.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Sequence, Tuple
3
+
4
+ import torch
5
+ from torch import nn, optim
6
+ from torch.utils.data import DataLoader, Dataset
7
+ from sqlalchemy import create_engine, select
8
+ from sqlalchemy.orm import Session, sessionmaker
9
+
10
+ from app.core.config import settings
11
+ from app.ml.model import ECGClassifier
12
+ from app.models.ecg import Base, ECGSample
13
+ from app.ml.ast_adapter import load_ast_trainer
14
+
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ LABEL_TO_IDX = {"normal": 0, "arrhythmia": 1}
17
+ AST_TRAINER, AST_CONFIG, AST_ERROR = load_ast_trainer()
18
+
19
+
20
+ class ECGDataset(Dataset):
21
+ """
22
+ In-memory dataset built from ECGSample rows.
23
+ """
24
+
25
+ def __init__(self, samples: Sequence[ECGSample], max_len: int):
26
+ self.samples = samples
27
+ self.max_len = max_len
28
+ self.items: List[Tuple[torch.Tensor, int]] = []
29
+ for sample in samples:
30
+ signal = sample.signal or []
31
+ if not signal:
32
+ continue
33
+ tensor = torch.tensor(signal, dtype=torch.float32)
34
+ if tensor.numel() < self.max_len:
35
+ pad = self.max_len - tensor.numel()
36
+ tensor = torch.nn.functional.pad(tensor, (0, pad))
37
+ elif tensor.numel() > self.max_len:
38
+ tensor = tensor[: self.max_len]
39
+ # reshape to (channels=1, length)
40
+ tensor = tensor.unsqueeze(0)
41
+ label_idx = LABEL_TO_IDX.get(sample.label or "normal", 0)
42
+ self.items.append((tensor, label_idx))
43
+
44
+ def __len__(self) -> int:
45
+ return len(self.items)
46
+
47
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
48
+ return self.items[idx]
49
+
50
+
51
+ def load_samples() -> List[ECGSample]:
52
+ """
53
+ Load all ECGSample rows from the configured database.
54
+ Ensures tables exist before querying.
55
+ """
56
+ engine = create_engine(settings.DATABASE_URL, future=True)
57
+ SessionLocal = sessionmaker(bind=engine)
58
+ Base.metadata.create_all(bind=engine)
59
+
60
+ with SessionLocal() as session:
61
+ result = session.execute(select(ECGSample))
62
+ rows = result.scalars().all()
63
+
64
+ engine.dispose()
65
+ return list(rows)
66
+
67
+
68
+ def train_model(dataset: Dataset, epochs: int = 3, batch_size: int = 8, lr: float = 1e-3) -> ECGClassifier:
69
+ loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
70
+ model = ECGClassifier(num_classes=len(LABEL_TO_IDX)).to(device)
71
+ criterion = nn.CrossEntropyLoss()
72
+ optimizer = optim.Adam(model.parameters(), lr=lr)
73
+
74
+ model.train()
75
+ for epoch in range(epochs):
76
+ running_loss = 0.0
77
+ for batch_x, batch_y in loader:
78
+ batch_x = batch_x.to(device)
79
+ batch_y = batch_y.to(device)
80
+ optimizer.zero_grad()
81
+ logits = model(batch_x)
82
+ loss = criterion(logits, batch_y)
83
+ loss.backward()
84
+ optimizer.step()
85
+ running_loss += loss.item() * batch_x.size(0)
86
+ epoch_loss = running_loss / max(len(dataset), 1)
87
+ print(f"Epoch {epoch + 1}/{epochs} - loss: {epoch_loss:.4f}")
88
+
89
+ model.eval()
90
+ return model
91
+
92
+
93
+ def save_weights(model: ECGClassifier) -> str:
94
+ """
95
+ Save model weights to the configured path (or default).
96
+ """
97
+ path = settings.MODEL_WEIGHTS_PATH or "./checkpoints/ecg_classifier.pt"
98
+ os.makedirs(os.path.dirname(path), exist_ok=True)
99
+ torch.save(model.state_dict(), path)
100
+ return path
101
+
102
+
103
+ def build_dataloader(dataset: Dataset, batch_size: int = 8) -> DataLoader:
104
+ return DataLoader(dataset, batch_size=batch_size, shuffle=True)
105
+
106
+
107
+ def generate_synthetic_samples() -> List[ECGSample]:
108
+ """
109
+ Create a tiny synthetic dataset if the DB is empty (not persisted).
110
+ """
111
+ import math
112
+
113
+ class SyntheticSample:
114
+ def __init__(self, signal: List[float], label: str):
115
+ self.signal = signal
116
+ self.label = label
117
+
118
+ t = [i / 50.0 for i in range(256)]
119
+ normal = [0.05 * math.sin(2 * math.pi * f) for f in t]
120
+ arrhythmia = [0.3 * math.sin(2 * math.pi * f * 3) + 0.1 * math.sin(2 * math.pi * f * 7) for f in t]
121
+ return [
122
+ SyntheticSample(normal, "normal"),
123
+ SyntheticSample(arrhythmia, "arrhythmia"),
124
+ ]
125
+
126
+
127
+ def main() -> None:
128
+ samples = load_samples()
129
+ if not samples:
130
+ print("No ECG samples found in the database. Using synthetic samples for a minimal run.")
131
+ samples = generate_synthetic_samples()
132
+
133
+ max_len = max(len(sample.signal or []) for sample in samples)
134
+ if max_len == 0:
135
+ print("ECG samples contain empty signals; cannot train.")
136
+ return
137
+
138
+ dataset = ECGDataset(samples, max_len=max_len)
139
+ if len(dataset) == 0:
140
+ print("Dataset is empty after filtering; cannot train.")
141
+ return
142
+
143
+ train_loader = build_dataloader(dataset)
144
+ model = ECGClassifier(num_classes=len(LABEL_TO_IDX)).to(device)
145
+
146
+ if AST_TRAINER and AST_CONFIG:
147
+ cfg = AST_CONFIG(
148
+ target_activation_rate=0.4,
149
+ initial_threshold=2.5,
150
+ adapt_kp=0.005,
151
+ adapt_ki=0.0001,
152
+ ema_alpha=0.1,
153
+ energy_per_activation=1.0,
154
+ energy_per_skip=0.01,
155
+ use_amp=False, # CPU-only by default here
156
+ device=device.type,
157
+ )
158
+ optimizer = optim.Adam(model.parameters(), lr=1e-3)
159
+ criterion = nn.CrossEntropyLoss(reduction="none")
160
+ trainer = AST_TRAINER(model, train_loader, train_loader, cfg, optimizer=optimizer, criterion=criterion)
161
+ trainer.train(epochs=3, warmup_epochs=0)
162
+ print("Adaptive Sparse Training completed.")
163
+ else:
164
+ if AST_ERROR:
165
+ print(f"Adaptive Sparse Training not active (optional): {AST_ERROR}")
166
+ model = train_model(dataset)
167
+
168
+ weights_path = save_weights(model)
169
+ print(f"Training complete. Weights saved to {weights_path}")
170
+
171
+
172
+ if __name__ == "__main__":
173
+ main()