File size: 5,845 Bytes
5ec9e9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import os
from typing import List, Sequence, Tuple

import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from sqlalchemy import create_engine, select
from sqlalchemy.orm import Session, sessionmaker

from app.core.config import settings
from app.ml.model import ECGClassifier
from app.models.ecg import Base, ECGSample
from app.ml.ast_adapter import load_ast_trainer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LABEL_TO_IDX = {"normal": 0, "arrhythmia": 1}
AST_TRAINER, AST_CONFIG, AST_ERROR = load_ast_trainer()


class ECGDataset(Dataset):
    """
    In-memory dataset built from ECGSample rows.
    """

    def __init__(self, samples: Sequence[ECGSample], max_len: int):
        self.samples = samples
        self.max_len = max_len
        self.items: List[Tuple[torch.Tensor, int]] = []
        for sample in samples:
            signal = sample.signal or []
            if not signal:
                continue
            tensor = torch.tensor(signal, dtype=torch.float32)
            if tensor.numel() < self.max_len:
                pad = self.max_len - tensor.numel()
                tensor = torch.nn.functional.pad(tensor, (0, pad))
            elif tensor.numel() > self.max_len:
                tensor = tensor[: self.max_len]
            # reshape to (channels=1, length)
            tensor = tensor.unsqueeze(0)
            label_idx = LABEL_TO_IDX.get(sample.label or "normal", 0)
            self.items.append((tensor, label_idx))

    def __len__(self) -> int:
        return len(self.items)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        return self.items[idx]


def load_samples() -> List[ECGSample]:
    """
    Load all ECGSample rows from the configured database.
    Ensures tables exist before querying.
    """
    engine = create_engine(settings.DATABASE_URL, future=True)
    SessionLocal = sessionmaker(bind=engine)
    Base.metadata.create_all(bind=engine)

    with SessionLocal() as session:
        result = session.execute(select(ECGSample))
        rows = result.scalars().all()

    engine.dispose()
    return list(rows)


def train_model(dataset: Dataset, epochs: int = 3, batch_size: int = 8, lr: float = 1e-3) -> ECGClassifier:
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    model = ECGClassifier(num_classes=len(LABEL_TO_IDX)).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for batch_x, batch_y in loader:
            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)
            optimizer.zero_grad()
            logits = model(batch_x)
            loss = criterion(logits, batch_y)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * batch_x.size(0)
        epoch_loss = running_loss / max(len(dataset), 1)
        print(f"Epoch {epoch + 1}/{epochs} - loss: {epoch_loss:.4f}")

    model.eval()
    return model


def save_weights(model: ECGClassifier) -> str:
    """
    Save model weights to the configured path (or default).
    """
    path = settings.MODEL_WEIGHTS_PATH or "./checkpoints/ecg_classifier.pt"
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save(model.state_dict(), path)
    return path


def build_dataloader(dataset: Dataset, batch_size: int = 8) -> DataLoader:
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)


def generate_synthetic_samples() -> List[ECGSample]:
    """
    Create a tiny synthetic dataset if the DB is empty (not persisted).
    """
    import math

    class SyntheticSample:
        def __init__(self, signal: List[float], label: str):
            self.signal = signal
            self.label = label

    t = [i / 50.0 for i in range(256)]
    normal = [0.05 * math.sin(2 * math.pi * f) for f in t]
    arrhythmia = [0.3 * math.sin(2 * math.pi * f * 3) + 0.1 * math.sin(2 * math.pi * f * 7) for f in t]
    return [
        SyntheticSample(normal, "normal"),
        SyntheticSample(arrhythmia, "arrhythmia"),
    ]


def main() -> None:
    samples = load_samples()
    if not samples:
        print("No ECG samples found in the database. Using synthetic samples for a minimal run.")
        samples = generate_synthetic_samples()

    max_len = max(len(sample.signal or []) for sample in samples)
    if max_len == 0:
        print("ECG samples contain empty signals; cannot train.")
        return

    dataset = ECGDataset(samples, max_len=max_len)
    if len(dataset) == 0:
        print("Dataset is empty after filtering; cannot train.")
        return

    train_loader = build_dataloader(dataset)
    model = ECGClassifier(num_classes=len(LABEL_TO_IDX)).to(device)

    if AST_TRAINER and AST_CONFIG:
        cfg = AST_CONFIG(
            target_activation_rate=0.4,
            initial_threshold=2.5,
            adapt_kp=0.005,
            adapt_ki=0.0001,
            ema_alpha=0.1,
            energy_per_activation=1.0,
            energy_per_skip=0.01,
            use_amp=False,  # CPU-only by default here
            device=device.type,
        )
        optimizer = optim.Adam(model.parameters(), lr=1e-3)
        criterion = nn.CrossEntropyLoss(reduction="none")
        trainer = AST_TRAINER(model, train_loader, train_loader, cfg, optimizer=optimizer, criterion=criterion)
        trainer.train(epochs=3, warmup_epochs=0)
        print("Adaptive Sparse Training completed.")
    else:
        if AST_ERROR:
            print(f"Adaptive Sparse Training not active (optional): {AST_ERROR}")
        model = train_model(dataset)

    weights_path = save_weights(model)
    print(f"Training complete. Weights saved to {weights_path}")


if __name__ == "__main__":
    main()