Spaces:
Sleeping
Sleeping
| import os | |
| from typing import List, Sequence, Tuple | |
| import torch | |
| from torch import nn, optim | |
| from torch.utils.data import DataLoader, Dataset | |
| from sqlalchemy import create_engine, select | |
| from sqlalchemy.orm import Session, sessionmaker | |
| from app.core.config import settings | |
| from app.ml.model import ECGClassifier | |
| from app.models.ecg import Base, ECGSample | |
| from app.ml.ast_adapter import load_ast_trainer | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| LABEL_TO_IDX = {"normal": 0, "arrhythmia": 1} | |
| AST_TRAINER, AST_CONFIG, AST_ERROR = load_ast_trainer() | |
| class ECGDataset(Dataset): | |
| """ | |
| In-memory dataset built from ECGSample rows. | |
| """ | |
| def __init__(self, samples: Sequence[ECGSample], max_len: int): | |
| self.samples = samples | |
| self.max_len = max_len | |
| self.items: List[Tuple[torch.Tensor, int]] = [] | |
| for sample in samples: | |
| signal = sample.signal or [] | |
| if not signal: | |
| continue | |
| tensor = torch.tensor(signal, dtype=torch.float32) | |
| if tensor.numel() < self.max_len: | |
| pad = self.max_len - tensor.numel() | |
| tensor = torch.nn.functional.pad(tensor, (0, pad)) | |
| elif tensor.numel() > self.max_len: | |
| tensor = tensor[: self.max_len] | |
| # reshape to (channels=1, length) | |
| tensor = tensor.unsqueeze(0) | |
| label_idx = LABEL_TO_IDX.get(sample.label or "normal", 0) | |
| self.items.append((tensor, label_idx)) | |
| def __len__(self) -> int: | |
| return len(self.items) | |
| def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]: | |
| return self.items[idx] | |
| def load_samples() -> List[ECGSample]: | |
| """ | |
| Load all ECGSample rows from the configured database. | |
| Ensures tables exist before querying. | |
| """ | |
| engine = create_engine(settings.DATABASE_URL, future=True) | |
| SessionLocal = sessionmaker(bind=engine) | |
| Base.metadata.create_all(bind=engine) | |
| with SessionLocal() as session: | |
| result = session.execute(select(ECGSample)) | |
| rows = result.scalars().all() | |
| engine.dispose() | |
| return list(rows) | |
| def train_model(dataset: Dataset, epochs: int = 3, batch_size: int = 8, lr: float = 1e-3) -> ECGClassifier: | |
| loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) | |
| model = ECGClassifier(num_classes=len(LABEL_TO_IDX)).to(device) | |
| criterion = nn.CrossEntropyLoss() | |
| optimizer = optim.Adam(model.parameters(), lr=lr) | |
| model.train() | |
| for epoch in range(epochs): | |
| running_loss = 0.0 | |
| for batch_x, batch_y in loader: | |
| batch_x = batch_x.to(device) | |
| batch_y = batch_y.to(device) | |
| optimizer.zero_grad() | |
| logits = model(batch_x) | |
| loss = criterion(logits, batch_y) | |
| loss.backward() | |
| optimizer.step() | |
| running_loss += loss.item() * batch_x.size(0) | |
| epoch_loss = running_loss / max(len(dataset), 1) | |
| print(f"Epoch {epoch + 1}/{epochs} - loss: {epoch_loss:.4f}") | |
| model.eval() | |
| return model | |
| def save_weights(model: ECGClassifier) -> str: | |
| """ | |
| Save model weights to the configured path (or default). | |
| """ | |
| path = settings.MODEL_WEIGHTS_PATH or "./checkpoints/ecg_classifier.pt" | |
| os.makedirs(os.path.dirname(path), exist_ok=True) | |
| torch.save(model.state_dict(), path) | |
| return path | |
| def build_dataloader(dataset: Dataset, batch_size: int = 8) -> DataLoader: | |
| return DataLoader(dataset, batch_size=batch_size, shuffle=True) | |
| def generate_synthetic_samples() -> List[ECGSample]: | |
| """ | |
| Create a tiny synthetic dataset if the DB is empty (not persisted). | |
| """ | |
| import math | |
| class SyntheticSample: | |
| def __init__(self, signal: List[float], label: str): | |
| self.signal = signal | |
| self.label = label | |
| t = [i / 50.0 for i in range(256)] | |
| normal = [0.05 * math.sin(2 * math.pi * f) for f in t] | |
| arrhythmia = [0.3 * math.sin(2 * math.pi * f * 3) + 0.1 * math.sin(2 * math.pi * f * 7) for f in t] | |
| return [ | |
| SyntheticSample(normal, "normal"), | |
| SyntheticSample(arrhythmia, "arrhythmia"), | |
| ] | |
| def main() -> None: | |
| samples = load_samples() | |
| if not samples: | |
| print("No ECG samples found in the database. Using synthetic samples for a minimal run.") | |
| samples = generate_synthetic_samples() | |
| max_len = max(len(sample.signal or []) for sample in samples) | |
| if max_len == 0: | |
| print("ECG samples contain empty signals; cannot train.") | |
| return | |
| dataset = ECGDataset(samples, max_len=max_len) | |
| if len(dataset) == 0: | |
| print("Dataset is empty after filtering; cannot train.") | |
| return | |
| train_loader = build_dataloader(dataset) | |
| model = ECGClassifier(num_classes=len(LABEL_TO_IDX)).to(device) | |
| if AST_TRAINER and AST_CONFIG: | |
| cfg = AST_CONFIG( | |
| target_activation_rate=0.4, | |
| initial_threshold=2.5, | |
| adapt_kp=0.005, | |
| adapt_ki=0.0001, | |
| ema_alpha=0.1, | |
| energy_per_activation=1.0, | |
| energy_per_skip=0.01, | |
| use_amp=False, # CPU-only by default here | |
| device=device.type, | |
| ) | |
| optimizer = optim.Adam(model.parameters(), lr=1e-3) | |
| criterion = nn.CrossEntropyLoss(reduction="none") | |
| trainer = AST_TRAINER(model, train_loader, train_loader, cfg, optimizer=optimizer, criterion=criterion) | |
| trainer.train(epochs=3, warmup_epochs=0) | |
| print("Adaptive Sparse Training completed.") | |
| else: | |
| if AST_ERROR: | |
| print(f"Adaptive Sparse Training not active (optional): {AST_ERROR}") | |
| model = train_model(dataset) | |
| weights_path = save_weights(model) | |
| print(f"Training complete. Weights saved to {weights_path}") | |
| if __name__ == "__main__": | |
| main() | |