from typing import Tuple, Optional import torch import torch.nn as nn from pytorch_lightning import LightningModule import torch.nn.functional as F from einops import rearrange from itertools import chain from torchmetrics import Accuracy, Precision, Recall, F1Score, AUROC, ConfusionMatrix, CohenKappa, AveragePrecision, MetricCollection from osf.models.balanced_losses import FocalLoss, BalancedSoftmax def _create_pred_metrics(num_classes: int) -> MetricCollection: """Create metrics that take preds (class indices) as input.""" metrics = { "acc": Accuracy(task="multiclass", num_classes=num_classes, average="micro"), "f1": F1Score(task="multiclass", num_classes=num_classes, average="macro"), "f1_w": F1Score(task="multiclass", num_classes=num_classes, average="weighted"), "rec_m": Recall(task="multiclass", num_classes=num_classes, average="macro"), "kappa": CohenKappa(task="multiclass", num_classes=num_classes, weights="quadratic"), } return MetricCollection(metrics) def _create_prob_metrics(num_classes: int) -> MetricCollection: """Create metrics that take probs (probabilities) as input.""" metrics = { "auc": AUROC(task="multiclass", num_classes=num_classes, average="macro"), "auprc": AveragePrecision(task="multiclass", num_classes=num_classes, average="macro"), } return MetricCollection(metrics) def _create_perclass_pred_metrics(num_classes: int) -> MetricCollection: """Create per-class metrics that take preds as input.""" metrics = { "acc_c": Accuracy(task="multiclass", num_classes=num_classes, average=None), "prec_c": Precision(task="multiclass", num_classes=num_classes, average=None), "rec_c": Recall(task="multiclass", num_classes=num_classes, average=None), "f1_c": F1Score(task="multiclass", num_classes=num_classes, average=None), "cm": ConfusionMatrix(task="multiclass", num_classes=num_classes, normalize=None), } return MetricCollection(metrics) def _create_perclass_prob_metrics(num_classes: int) -> MetricCollection: """Create per-class metrics that take probs as input.""" metrics = { "auc_c": AUROC(task="multiclass", num_classes=num_classes, average=None), "auprc_c": AveragePrecision(task="multiclass", num_classes=num_classes, average=None), } return MetricCollection(metrics) class SSLFineTuner(LightningModule): def __init__(self, backbones, use_which_backbone, config = None, in_features: int = 256, num_classes: int = 2, epochs: int = 10, dropout: float = 0.0, lr: float = 1e-3, weight_decay: float = 1e-4, final_lr: float = 1e-5, use_channel_bank: bool = True, loss_type: str = "ce", class_distribution: Optional[torch.Tensor] = None, focal_gamma: float = 2.0, focal_alpha: Optional[float | torch.Tensor] = None, use_mean_pool: bool = False, total_training_steps: int = None, finetune_backbone: bool = False, *args, **kwargs ) -> None: super().__init__() self.save_hyperparameters() self.lr = lr self.weight_decay = weight_decay self.epochs = epochs self.final_lr = final_lr self.use_channel_bank = use_channel_bank self.loss_type = loss_type self.focal_gamma = focal_gamma self.focal_alpha = focal_alpha self.use_mean_pool = use_mean_pool self.total_training_steps = total_training_steps self.finetune_backbone = finetune_backbone if loss_type == "ce": self.criterion = None elif loss_type == "focal": alpha = focal_alpha if alpha is None and class_distribution is not None: class_dist = class_distribution.float() total_samples = class_dist.sum() alpha = total_samples / (num_classes * class_dist) alpha = alpha / alpha.mean() self.criterion = FocalLoss(alpha=alpha, gamma=focal_gamma, reduction="mean") elif loss_type == "balanced_softmax": self.criterion = BalancedSoftmax(class_distribution, reduction="mean") else: raise ValueError(f"Unknown loss_type: {loss_type}. Must be one of ['ce', 'focal', 'balanced_softmax']") if isinstance(backbones, nn.ModuleDict): self.backbones = backbones else: self.backbones = nn.ModuleDict(backbones) self.config = config self.use_which_backbone = use_which_backbone self.backbone = self.backbones[self.use_which_backbone] if self.use_which_backbone != "fusion" else None if self.use_which_backbone == "fusion": for k in ("ecg", "resp", "elect"): if k in self.backbones: for p in self.backbones[k].parameters(): p.requires_grad = self.finetune_backbone if not self.finetune_backbone: self.backbones[k].eval() else: for p in self.backbone.parameters(): p.requires_grad = self.finetune_backbone if not self.finetune_backbone: self.backbone.eval() if self.finetune_backbone: print(f"[INFO] Full finetuning mode: backbone parameters are TRAINABLE") if self.use_which_backbone == "fusion": dims = [getattr(self.backbones[k], "out_dim", in_features) for k in ("ecg", "resp", "elect") if k in self.backbones] if len(dims) == 0: raise ValueError("fusion requires at least one of {'ecg','resp','elect'} in backbones.") if len(set(dims)) != 1: raise ValueError(f"Mean fusion requires equal output dims, got {dims}") final_in_features = dims[0] else: final_in_features = getattr(self.backbone, "out_dim", in_features) self.linear_layer = nn.Sequential( nn.Dropout(dropout), nn.Linear(final_in_features, num_classes) ) self.train_pred_metrics = _create_pred_metrics(num_classes) self.val_pred_metrics = _create_pred_metrics(num_classes) self.test_pred_metrics = _create_pred_metrics(num_classes) self.train_prob_metrics = _create_prob_metrics(num_classes) self.val_prob_metrics = _create_prob_metrics(num_classes) self.test_prob_metrics = _create_prob_metrics(num_classes) self.train_pred_metrics_c = _create_perclass_pred_metrics(num_classes) self.val_pred_metrics_c = _create_perclass_pred_metrics(num_classes) self.test_pred_metrics_c = _create_perclass_pred_metrics(num_classes) self.train_prob_metrics_c = _create_perclass_prob_metrics(num_classes) self.val_prob_metrics_c = _create_perclass_prob_metrics(num_classes) self.test_prob_metrics_c = _create_perclass_prob_metrics(num_classes) self.class_names = getattr(self.config, "class_names", [str(i) for i in range(num_classes)]) def on_train_epoch_start(self) -> None: if not self.finetune_backbone: if self.use_which_backbone == "fusion": for k in ("ecg", "resp", "elect"): if k in self.backbones: self.backbones[k].eval() else: self.backbone.eval() def training_step(self, batch, batch_idx): loss, logits, y = self.shared_step(batch) probs = logits.softmax(-1) preds = logits.argmax(-1) self.train_pred_metrics.update(preds, y) self.train_prob_metrics.update(probs, y) self.train_pred_metrics_c.update(preds, y) self.train_prob_metrics_c.update(probs, y) self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=False, sync_dist=True) return loss def on_train_epoch_end(self): pred_agg = self.train_pred_metrics.compute() prob_agg = self.train_prob_metrics.compute() self.log("train_acc", pred_agg["acc"], prog_bar=True, on_step=False, on_epoch=True, sync_dist=True) self.log("train_f1", pred_agg["f1"], prog_bar=True, on_step=False, on_epoch=True, sync_dist=True) self.log("train_auc", prob_agg["auc"], prog_bar=False, on_step=False, on_epoch=True, sync_dist=True) self.log("train_auprc", prob_agg["auprc"], prog_bar=False, on_step=False, on_epoch=True, sync_dist=True) pred_c = self.train_pred_metrics_c.compute() prob_c = self.train_prob_metrics_c.compute() cm = pred_c["cm"] support = cm.sum(dim=1) if cm is not None else None for i in range(len(pred_c["acc_c"])): name = self.class_names[i] if i < len(self.class_names) else str(i) self.log(f"train/acc_{name}", pred_c["acc_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) self.log(f"train/prec_{name}", pred_c["prec_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) self.log(f"train/rec_{name}", pred_c["rec_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) self.log(f"train/f1_{name}", pred_c["f1_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) self.log(f"train/auc_{name}", prob_c["auc_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) self.log(f"train/auprc_{name}", prob_c["auprc_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) if support is not None: self.log(f"train/support_{name}", support[i].to(pred_c["acc_c"][i].dtype), on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) self.train_pred_metrics.reset() self.train_prob_metrics.reset() self.train_pred_metrics_c.reset() self.train_prob_metrics_c.reset() def validation_step(self, batch, batch_idx): loss, logits, y = self.shared_step(batch) probs = logits.softmax(-1) preds = logits.argmax(-1) self.val_pred_metrics.update(preds, y) self.val_prob_metrics.update(probs, y) self.val_pred_metrics_c.update(preds, y) self.val_prob_metrics_c.update(probs, y) self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True) return loss def on_validation_epoch_end(self): pred_agg = self.val_pred_metrics.compute() prob_agg = self.val_prob_metrics.compute() self.log("val_acc", pred_agg["acc"], prog_bar=True, on_step=False, on_epoch=True, sync_dist=True) self.log("val_f1", pred_agg["f1"], prog_bar=True, on_step=False, on_epoch=True, sync_dist=True) self.log("val_f1_w", pred_agg["f1_w"], prog_bar=False, on_step=False, on_epoch=True, sync_dist=True) self.log("val_rec_m", pred_agg["rec_m"], prog_bar=False, on_step=False, on_epoch=True, sync_dist=True) self.log("val_auc", prob_agg["auc"], prog_bar=False, on_step=False, on_epoch=True, sync_dist=True) self.log("val_auprc", prob_agg["auprc"], prog_bar=False, on_step=False, on_epoch=True, sync_dist=True) self.log("val_kappa", pred_agg["kappa"], prog_bar=True, on_step=False, on_epoch=True, sync_dist=True) pred_c = self.val_pred_metrics_c.compute() prob_c = self.val_prob_metrics_c.compute() cm = pred_c["cm"] support = cm.sum(dim=1) for i in range(len(pred_c["acc_c"])): name = self.class_names[i] if i < len(self.class_names) else str(i) self.log(f"val/acc_{name}", pred_c["acc_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) self.log(f"val/prec_{name}", pred_c["prec_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) self.log(f"val/rec_{name}", pred_c["rec_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) self.log(f"val/f1_{name}", pred_c["f1_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) self.log(f"val/auc_{name}", prob_c["auc_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) self.log(f"val/auprc_{name}", prob_c["auprc_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) self.log(f"val/support_{name}", support[i].to(pred_c["acc_c"][i].dtype), on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) self.val_pred_metrics.reset() self.val_prob_metrics.reset() self.val_pred_metrics_c.reset() self.val_prob_metrics_c.reset() def test_step(self, batch, batch_idx): loss, logits, y = self.shared_step(batch) probs = logits.softmax(-1) preds = logits.argmax(-1) self.test_pred_metrics.update(preds, y) self.test_prob_metrics.update(probs, y) self.test_pred_metrics_c.update(preds, y) self.test_prob_metrics_c.update(probs, y) self.log("test_loss", loss, on_step=False, on_epoch=True, sync_dist=True) return loss def on_test_epoch_end(self): pred_agg = self.test_pred_metrics.compute() prob_agg = self.test_prob_metrics.compute() self.log("test_acc", pred_agg["acc"], on_step=False, on_epoch=True, sync_dist=True) self.log("test_f1", pred_agg["f1"], on_step=False, on_epoch=True, sync_dist=True) self.log("test_f1_w", pred_agg["f1_w"], on_step=False, on_epoch=True, sync_dist=True) self.log("test_rec_m", pred_agg["rec_m"], on_step=False, on_epoch=True, sync_dist=True) self.log("test_auc", prob_agg["auc"], on_step=False, on_epoch=True, sync_dist=True) self.log("test_auprc", prob_agg["auprc"], on_step=False, on_epoch=True, sync_dist=True) self.log("test_kappa", pred_agg["kappa"], on_step=False, on_epoch=True, sync_dist=True) pred_c = self.test_pred_metrics_c.compute() prob_c = self.test_prob_metrics_c.compute() cm = pred_c["cm"] support = cm.sum(dim=1) if cm is not None else None for i in range(len(pred_c["acc_c"])): name = self.class_names[i] if i < len(self.class_names) else str(i) self.log(f"test/acc_{name}", pred_c["acc_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) self.log(f"test/prec_{name}", pred_c["prec_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) self.log(f"test/rec_{name}", pred_c["rec_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) self.log(f"test/f1_{name}", pred_c["f1_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) self.log(f"test/auc_{name}", prob_c["auc_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) self.log(f"test/auprc_{name}", prob_c["auprc_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) if support is not None: self.log(f"test/support_{name}", support[i].to(pred_c["acc_c"][i].dtype), on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) self.test_pred_metrics.reset() self.test_prob_metrics.reset() self.test_pred_metrics_c.reset() self.test_prob_metrics_c.reset() def shared_step(self, batch): context = torch.no_grad() if not self.finetune_backbone else torch.enable_grad() with context: psg = batch['psg'] if self.use_which_backbone == 'ecg': x = psg[:, 0:1, :] feats = self._get_features(self.backbone, x) elif self.use_which_backbone == 'resp': x = psg[:, 1:5, :] feats = self._get_features(self.backbone, x) elif self.use_which_backbone == 'elect': x = psg[:, 5:, :] feats = self._get_features(self.backbone, x) elif self.use_which_backbone == 'all': x = psg feats = self._get_features(self.backbone, x) elif self.use_which_backbone == 'fusion': feats_list = [] if 'ecg' in self.backbones: x_ecg = psg[:, 0:1, :] f_ecg = self._get_features(self.backbones['ecg'], x_ecg) feats_list.append(f_ecg) if 'resp' in self.backbones: x_resp = psg[:, 1:5, :] f_resp = self._get_features(self.backbones['resp'], x_resp) feats_list.append(f_resp) if 'elect' in self.backbones: x_elect = psg[:, 5:, :] f_elect = self._get_features(self.backbones['elect'], x_elect) feats_list.append(f_elect) feats = torch.stack(feats_list, dim=0).mean(dim=0) else: raise ValueError(f"Unknown use_which_backbone: {self.use_which_backbone}") y = batch["label"] feats = feats.view(feats.size(0), -1) logits = self.linear_layer(feats) y = y.squeeze(1).long() if self.criterion is None: loss = F.cross_entropy(logits, y) else: loss = self.criterion(logits, y) return loss, logits, y def _get_features(self, backbone, x): """Get features from backbone. Uses mean pooling if use_mean_pool=True.""" if self.use_mean_pool: if hasattr(backbone, 'forward_encoding_mean_pool'): return backbone.forward_encoding_mean_pool(x) elif hasattr(backbone, 'forward_avg_pool'): return backbone.forward_avg_pool(x) return backbone(x) def configure_optimizers(self): if self.finetune_backbone: if self.use_which_backbone == "fusion": backbone_params = chain(*[self.backbones[k].parameters() for k in ("ecg", "resp", "elect") if k in self.backbones]) else: backbone_params = self.backbone.parameters() params = chain(backbone_params, self.linear_layer.parameters()) else: params = self.linear_layer.parameters() optimizer = torch.optim.AdamW( params, lr=self.lr, weight_decay=self.weight_decay, ) if self.total_training_steps is not None and self.total_training_steps > 0: warmup_steps = int(0.1 * self.total_training_steps) cosine_steps = self.total_training_steps - warmup_steps warmup_scheduler = torch.optim.lr_scheduler.LinearLR( optimizer, start_factor=0.1, end_factor=1.0, total_iters=warmup_steps ) cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=cosine_steps, eta_min=self.final_lr ) scheduler = torch.optim.lr_scheduler.SequentialLR( optimizer, schedulers=[warmup_scheduler, cosine_scheduler], milestones=[warmup_steps] ) return [optimizer], [{"scheduler": scheduler, "interval": "step"}] else: return [optimizer] class SSLVitalSignsRegressor(SSLFineTuner): """SSL Finetuner for vital signs regression (HR, SPO2). Uses MSE loss.""" def __init__(self, backbones, use_which_backbone, config = None, in_features: int = 256, num_classes: int = 1, target_names: list = None, dropout: float = 0.0, **kwargs ) -> None: kwargs['loss_type'] = 'ce' super().__init__( backbones=backbones, use_which_backbone=use_which_backbone, config=config, in_features=in_features, num_classes=2, dropout=dropout, **kwargs ) self.num_targets = num_classes self.target_names = target_names or [f"target_{i}" for i in range(num_classes)] self.criterion = nn.MSELoss() in_feat = self.linear_layer[1].in_features self.linear_layer = nn.Sequential( nn.Dropout(dropout), nn.Linear(in_feat, num_classes) ) del self.train_pred_metrics, self.val_pred_metrics, self.test_pred_metrics del self.train_prob_metrics, self.val_prob_metrics, self.test_prob_metrics del self.train_pred_metrics_c, self.val_pred_metrics_c, self.test_pred_metrics_c del self.train_prob_metrics_c, self.val_prob_metrics_c, self.test_prob_metrics_c def shared_step(self, batch): """Override: regression loss instead of classification.""" context = torch.no_grad() if not self.finetune_backbone else torch.enable_grad() with context: psg = batch['psg'] if self.use_which_backbone == 'ecg': x = psg[:, 0:1, :] feats = self._get_features(self.backbone, x) elif self.use_which_backbone == 'resp': x = psg[:, 1:5, :] feats = self._get_features(self.backbone, x) elif self.use_which_backbone == 'elect': x = psg[:, 5:, :] feats = self._get_features(self.backbone, x) elif self.use_which_backbone == 'all': x = psg feats = self._get_features(self.backbone, x) elif self.use_which_backbone == 'fusion': feats_list = [] if 'ecg' in self.backbones: f_ecg = self._get_features(self.backbones['ecg'], psg[:, 0:1, :]) feats_list.append(f_ecg) if 'resp' in self.backbones: f_resp = self._get_features(self.backbones['resp'], psg[:, 1:5, :]) feats_list.append(f_resp) if 'elect' in self.backbones: f_elect = self._get_features(self.backbones['elect'], psg[:, 5:, :]) feats_list.append(f_elect) feats = torch.stack(feats_list, dim=0).mean(dim=0) else: raise ValueError(f"Unknown use_which_backbone: {self.use_which_backbone}") y = batch["label"].float() # [B, num_targets] feats = feats.view(feats.size(0), -1) preds = self.linear_layer(feats) # [B, num_targets] loss = self.criterion(preds, y) return loss, preds, y def training_step(self, batch, batch_idx): """Override: regression metrics.""" loss, preds, y = self.shared_step(batch) with torch.no_grad(): for i, name in enumerate(self.target_names): mae = F.l1_loss(preds[:, i], y[:, i]) self.log(f"train_{name}_mae", mae, on_step=False, on_epoch=True, sync_dist=True) self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=False, sync_dist=True) return loss def on_train_epoch_end(self): """Override: no classification metrics to compute.""" pass def validation_step(self, batch, batch_idx): """Override: regression metrics.""" loss, preds, y = self.shared_step(batch) for i, name in enumerate(self.target_names): mae = F.l1_loss(preds[:, i], y[:, i]) self.log(f"val_{name}_mae", mae, on_step=False, on_epoch=True, sync_dist=True) overall_mae = F.l1_loss(preds, y) self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True) self.log("val_mae", overall_mae, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True) return loss def on_validation_epoch_end(self): """Override: no classification metrics to compute.""" pass def test_step(self, batch, batch_idx): """Override: regression metrics.""" loss, preds, y = self.shared_step(batch) for i, name in enumerate(self.target_names): p, t = preds[:, i], y[:, i] mae = F.l1_loss(p, t) mse = F.mse_loss(p, t) rmse = torch.sqrt(mse) self.log(f"test_{name}_mae", mae, on_step=False, on_epoch=True, sync_dist=True) self.log(f"test_{name}_mse", mse, on_step=False, on_epoch=True, sync_dist=True) self.log(f"test_{name}_rmse", rmse, on_step=False, on_epoch=True, sync_dist=True) overall_mae = F.l1_loss(preds, y) overall_mse = F.mse_loss(preds, y) self.log("test_loss", loss, on_step=False, on_epoch=True, sync_dist=True) self.log("test_mae", overall_mae, on_step=False, on_epoch=True, sync_dist=True) self.log("test_mse", overall_mse, on_step=False, on_epoch=True, sync_dist=True) return loss def on_test_epoch_end(self): """Override: no classification metrics to compute.""" pass class SupervisedVitalSignsRegressor(SSLVitalSignsRegressor): """Supervised from-scratch regression. Equivalent to SSLVitalSignsRegressor with finetune_backbone=True.""" def __init__(self, backbones, use_which_backbone, epochs: int = 100, **kwargs ): kwargs['finetune_backbone'] = True super().__init__( backbones=backbones, use_which_backbone=use_which_backbone, epochs=epochs, **kwargs )