import os from typing import List, Sequence, Tuple import torch from torch import nn, optim from torch.utils.data import DataLoader, Dataset from sqlalchemy import create_engine, select from sqlalchemy.orm import Session, sessionmaker from app.core.config import settings from app.ml.model import ECGClassifier from app.models.ecg import Base, ECGSample from app.ml.ast_adapter import load_ast_trainer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") LABEL_TO_IDX = {"normal": 0, "arrhythmia": 1} AST_TRAINER, AST_CONFIG, AST_ERROR = load_ast_trainer() class ECGDataset(Dataset): """ In-memory dataset built from ECGSample rows. """ def __init__(self, samples: Sequence[ECGSample], max_len: int): self.samples = samples self.max_len = max_len self.items: List[Tuple[torch.Tensor, int]] = [] for sample in samples: signal = sample.signal or [] if not signal: continue tensor = torch.tensor(signal, dtype=torch.float32) if tensor.numel() < self.max_len: pad = self.max_len - tensor.numel() tensor = torch.nn.functional.pad(tensor, (0, pad)) elif tensor.numel() > self.max_len: tensor = tensor[: self.max_len] # reshape to (channels=1, length) tensor = tensor.unsqueeze(0) label_idx = LABEL_TO_IDX.get(sample.label or "normal", 0) self.items.append((tensor, label_idx)) def __len__(self) -> int: return len(self.items) def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]: return self.items[idx] def load_samples() -> List[ECGSample]: """ Load all ECGSample rows from the configured database. Ensures tables exist before querying. """ engine = create_engine(settings.DATABASE_URL, future=True) SessionLocal = sessionmaker(bind=engine) Base.metadata.create_all(bind=engine) with SessionLocal() as session: result = session.execute(select(ECGSample)) rows = result.scalars().all() engine.dispose() return list(rows) def train_model(dataset: Dataset, epochs: int = 3, batch_size: int = 8, lr: float = 1e-3) -> ECGClassifier: loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) model = ECGClassifier(num_classes=len(LABEL_TO_IDX)).to(device) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=lr) model.train() for epoch in range(epochs): running_loss = 0.0 for batch_x, batch_y in loader: batch_x = batch_x.to(device) batch_y = batch_y.to(device) optimizer.zero_grad() logits = model(batch_x) loss = criterion(logits, batch_y) loss.backward() optimizer.step() running_loss += loss.item() * batch_x.size(0) epoch_loss = running_loss / max(len(dataset), 1) print(f"Epoch {epoch + 1}/{epochs} - loss: {epoch_loss:.4f}") model.eval() return model def save_weights(model: ECGClassifier) -> str: """ Save model weights to the configured path (or default). """ path = settings.MODEL_WEIGHTS_PATH or "./checkpoints/ecg_classifier.pt" os.makedirs(os.path.dirname(path), exist_ok=True) torch.save(model.state_dict(), path) return path def build_dataloader(dataset: Dataset, batch_size: int = 8) -> DataLoader: return DataLoader(dataset, batch_size=batch_size, shuffle=True) def generate_synthetic_samples() -> List[ECGSample]: """ Create a tiny synthetic dataset if the DB is empty (not persisted). """ import math class SyntheticSample: def __init__(self, signal: List[float], label: str): self.signal = signal self.label = label t = [i / 50.0 for i in range(256)] normal = [0.05 * math.sin(2 * math.pi * f) for f in t] arrhythmia = [0.3 * math.sin(2 * math.pi * f * 3) + 0.1 * math.sin(2 * math.pi * f * 7) for f in t] return [ SyntheticSample(normal, "normal"), SyntheticSample(arrhythmia, "arrhythmia"), ] def main() -> None: samples = load_samples() if not samples: print("No ECG samples found in the database. Using synthetic samples for a minimal run.") samples = generate_synthetic_samples() max_len = max(len(sample.signal or []) for sample in samples) if max_len == 0: print("ECG samples contain empty signals; cannot train.") return dataset = ECGDataset(samples, max_len=max_len) if len(dataset) == 0: print("Dataset is empty after filtering; cannot train.") return train_loader = build_dataloader(dataset) model = ECGClassifier(num_classes=len(LABEL_TO_IDX)).to(device) if AST_TRAINER and AST_CONFIG: cfg = AST_CONFIG( target_activation_rate=0.4, initial_threshold=2.5, adapt_kp=0.005, adapt_ki=0.0001, ema_alpha=0.1, energy_per_activation=1.0, energy_per_skip=0.01, use_amp=False, # CPU-only by default here device=device.type, ) optimizer = optim.Adam(model.parameters(), lr=1e-3) criterion = nn.CrossEntropyLoss(reduction="none") trainer = AST_TRAINER(model, train_loader, train_loader, cfg, optimizer=optimizer, criterion=criterion) trainer.train(epochs=3, warmup_epochs=0) print("Adaptive Sparse Training completed.") else: if AST_ERROR: print(f"Adaptive Sparse Training not active (optional): {AST_ERROR}") model = train_model(dataset) weights_path = save_weights(model) print(f"Training complete. Weights saved to {weights_path}") if __name__ == "__main__": main()