Upload folder using huggingface_hub
Browse files- LICENSE +21 -0
- README.md +256 -3
- config.py +231 -0
- demo.ipynb +0 -0
- environment.yml +187 -0
- finetune.bash +35 -0
- main_pipelines/main_finetune.py +261 -0
- main_pipelines/main_pretrain.py +196 -0
- osf/__init__.py +0 -0
- osf/backbone/__init__.py +0 -0
- osf/backbone/pos_embed.py +76 -0
- osf/backbone/vit1d.py +209 -0
- osf/backbone/vit1d_cls.py +363 -0
- osf/datasets/__init__.py +0 -0
- osf/datasets/augmentations.py +81 -0
- osf/datasets/pretrain_datamodule.py +303 -0
- osf/datasets/pretrain_dataset.py +381 -0
- osf/datasets/simclr_aug_registry.py +258 -0
- osf/models/__init__.py +0 -0
- osf/models/balanced_losses.py +89 -0
- osf/models/base_pretrain_model.py +144 -0
- osf/models/base_pretrain_model_cls.py +56 -0
- osf/models/dino_model_cls.py +311 -0
- osf/models/dino_utils/dino_clstoken_loss.py +96 -0
- osf/models/dino_utils/ibot_patch_loss.py +134 -0
- osf/models/dino_utils/koleo_loss.py +46 -0
- osf/models/ssl_finetuner.py +568 -0
- osf/utils/openclip_loss.py +472 -0
- osf/utils/results_utils.py +289 -0
- osf_backbone.pth +3 -0
- pretrained_weights/readme.md +1 -0
- requirements.txt +193 -0
- train_config.py +32 -0
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2026 Health Intelligence Lab @ UCLA (https://github.com/yang-ai-lab)
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -1,3 +1,256 @@
|
|
| 1 |
-
---
|
| 2 |
-
license: mit
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
tags:
|
| 4 |
+
- sleep
|
| 5 |
+
- eeg
|
| 6 |
+
- polysomnography
|
| 7 |
+
- foundation-model
|
| 8 |
+
- self-supervised
|
| 9 |
+
- vit
|
| 10 |
+
- biosignals
|
| 11 |
+
pipeline_tag: feature-extraction
|
| 12 |
+
library_name: pytorch
|
| 13 |
+
language:
|
| 14 |
+
- en
|
| 15 |
+
---
|
| 16 |
+
|
| 17 |
+
# OSF: On Pre-training and Scaling of Sleep Foundation Models
|
| 18 |
+
|
| 19 |
+
[](#citation)
|
| 20 |
+
[](https://yang-ai-lab.github.io/osf/)
|
| 21 |
+
[](LICENSE)
|
| 22 |
+
[](#installation)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
## 🔥 News
|
| 26 |
+
|
| 27 |
+
- [2026-2-24] Our codebase and checkpoint is released. Full codebase for benchmarking will be public available after acceptance.
|
| 28 |
+
- [2026-2-22] Our paper is out.
|
| 29 |
+
|
| 30 |
+
## 📖 Introduction
|
| 31 |
+
|
| 32 |
+
Polysomnography (PSG) provides the gold standard for sleep assessment but suffers from substantial heterogeneity across recording devices and cohorts.
|
| 33 |
+
There have been growing efforts to build general-purpose foundation models (FMs) for sleep physiology, but lack an in-depth understanding of the pre-training process and scaling patterns that lead to more generalizable sleep FMs.
|
| 34 |
+
To fill this gap, we curate a massive corpus of 166,500 hours of sleep recordings from nine public sources and establish SleepBench, a comprehensive, fully open-source benchmark.
|
| 35 |
+
Leveraging SleepBench, we systematically evaluate four families of self-supervised pre-training objectives and uncover three critical findings:
|
| 36 |
+
(1) existing FMs fail to generalize to missing channels at inference;
|
| 37 |
+
(2) channel-invariant feature learning is essential for pre-training;
|
| 38 |
+
and (3) scaling sample size, model capacity, and multi-source data mixture consistently improves downstream performance.
|
| 39 |
+
With an enhanced pre-training and scaling recipe, we introduce OSF, a family of sleep FMs that achieves state-of-the-art performance across nine datasets on diverse sleep and disease prediction tasks.
|
| 40 |
+
Further analysis of OSF also reveals intriguing properties in sample efficiency, hierarchical aggregation, and cross-dataset scaling.
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
## 📖 Table of Contents
|
| 44 |
+
|
| 45 |
+
1. [Installation](#-installation)
|
| 46 |
+
2. [Quick Start](#-quick-start)
|
| 47 |
+
3. [Pretrained Weights](#-pretrained-weights)
|
| 48 |
+
4. [Usage](#-usage)
|
| 49 |
+
5. [Benchmark Evaluations](#-benchmark-evaluations)
|
| 50 |
+
6. [Supported Datasets](#-supported-datasets)
|
| 51 |
+
7. [Citation](#-citation)
|
| 52 |
+
|
| 53 |
+
## 💿 Installation
|
| 54 |
+
|
| 55 |
+
```bash
|
| 56 |
+
git clone https://huggingface.co/yang-ai-lab/OSF-Base
|
| 57 |
+
cd OSF-Base
|
| 58 |
+
conda env create -f environment.yml
|
| 59 |
+
conda activate myenv
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
### Dependencies
|
| 64 |
+
|
| 65 |
+
- Python >= 3.10
|
| 66 |
+
- PyTorch >= 2.9.0
|
| 67 |
+
- PyTorch Lightning >= 2.5.5
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
## 🚀 Quick Start
|
| 71 |
+
|
| 72 |
+
We provide a demo notebook (`demo.ipynb`) demonstrating how to extract embeddings from PSG signals using the pretrained model.
|
| 73 |
+
|
| 74 |
+
```python
|
| 75 |
+
import torch
|
| 76 |
+
from osf.backbone.vit1d_cls import vit_base
|
| 77 |
+
|
| 78 |
+
# Load pretrained weights (included in this repo)
|
| 79 |
+
payload = torch.load("osf_backbone.pth", map_location="cpu", weights_only=False)
|
| 80 |
+
meta = payload["metadata"]
|
| 81 |
+
|
| 82 |
+
# Initialize model
|
| 83 |
+
backbone = vit_base(
|
| 84 |
+
num_leads=meta["num_leads"], # 12 channels
|
| 85 |
+
seq_len=meta["seq_len"], # 1920 (64 Hz × 30 s)
|
| 86 |
+
patch_size=meta["patch_size_time"],
|
| 87 |
+
lead_wise=meta["lead_wise"],
|
| 88 |
+
patch_size_ch=meta["patch_size_ch"],
|
| 89 |
+
)
|
| 90 |
+
backbone.load_state_dict(payload["state_dict"])
|
| 91 |
+
backbone.eval()
|
| 92 |
+
|
| 93 |
+
# Extract embeddings
|
| 94 |
+
# x: [B, 12, 1920] - 12-channel PSG, 64 Hz × 30 seconds
|
| 95 |
+
with torch.no_grad():
|
| 96 |
+
cls_embs, patch_embs = backbone.forward_encoding(x, return_sequence=False)
|
| 97 |
+
# cls_embs: [B, 768] - Global epoch-level representation
|
| 98 |
+
# patch_embs: [B, 90, 768] - Local patch representations
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
## 📦 Pretrained Weights
|
| 102 |
+
|
| 103 |
+
| Model | Backbone | Channels |
|
| 104 |
+
|-------|----------|----------|
|
| 105 |
+
| OSF | ViT-Base | 12-ch |
|
| 106 |
+
|
| 107 |
+
The pretrained weights are included in this repository. You can download them via the Hugging Face Hub:
|
| 108 |
+
|
| 109 |
+
```python
|
| 110 |
+
from huggingface_hub import hf_hub_download
|
| 111 |
+
checkpoint_path = hf_hub_download(repo_id="yang-ai-lab/OSF-Base", filename="osf_backbone.pth")
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
Or via the CLI:
|
| 115 |
+
|
| 116 |
+
```bash
|
| 117 |
+
huggingface-cli download yang-ai-lab/OSF-Base osf_backbone.pth
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
## 👩💻 Usage
|
| 121 |
+
|
| 122 |
+
### Input Format
|
| 123 |
+
|
| 124 |
+
Expected input format:
|
| 125 |
+
- **12 PSG Channels**: ECG, EMG_Chin, EMG_LLeg, EMG_RLeg, ABD, THX, NP, SN, EOG_E1_A2, EOG_E2_A1, EEG_C3_A2, EEG_C4_A1
|
| 126 |
+
- **Sample Rate**: 64 Hz
|
| 127 |
+
- **Epoch Length**: 30 seconds
|
| 128 |
+
- **Input Shape**: `[B, 12, 1920]`
|
| 129 |
+
|
| 130 |
+
### Pretraining
|
| 131 |
+
|
| 132 |
+
We support multiple self-supervised pretraining methods, for example, to launch pre-training of our OSF method, run pretraining:
|
| 133 |
+
|
| 134 |
+
```bash
|
| 135 |
+
python main_pretrain.py \
|
| 136 |
+
--model_name "dino_ours" \
|
| 137 |
+
--psg_encoder_name "vit_base" \
|
| 138 |
+
--batch_size 256 \
|
| 139 |
+
--lr 5e-5 \
|
| 140 |
+
--max_epochs 30 \
|
| 141 |
+
--num_devices 4 \
|
| 142 |
+
--patch_size_time 64 \
|
| 143 |
+
--patch_size_ch 4 \
|
| 144 |
+
--precision "bf16-mixed"
|
| 145 |
+
```
|
| 146 |
+
|
| 147 |
+
See `main_pipleines/main_pretrain.py` for more detailed settings.
|
| 148 |
+
|
| 149 |
+
### Fine-tuning
|
| 150 |
+
|
| 151 |
+
Fine-tune the pretrained model on downstream tasks:
|
| 152 |
+
|
| 153 |
+
```bash
|
| 154 |
+
python main_finetune.py \
|
| 155 |
+
--model_name "dino_ours" \
|
| 156 |
+
--ckpt_path "/path/to/pretrained/checkpoint.ckpt" \
|
| 157 |
+
--downstream_dataset_name "shhs" \
|
| 158 |
+
--eval_label "Stage" \
|
| 159 |
+
--train_data_pct 1.0 \
|
| 160 |
+
--max_steps 500 \
|
| 161 |
+
--lr 0.1 \
|
| 162 |
+
--num_devices 4
|
| 163 |
+
```
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
## 📊 Benchmark Evaluations
|
| 167 |
+
|
| 168 |
+
### Benchmarked SSL Methods
|
| 169 |
+
|
| 170 |
+
| Method | Type | Original Paper |
|
| 171 |
+
|--------|------|-------------|
|
| 172 |
+
| SleepFM | Contrastive | [Leave-one-out multi-modal contrastive learning](https://www.nature.com/articles/s41591-025-04133-4.pdf) |
|
| 173 |
+
| SimCLR | Contrastive | [Simple Constrastive Learning](https://proceedings.mlr.press/v119/chen20j/chen20j.pdf) |
|
| 174 |
+
| DINO | Self-distillation | [DINO](https://arxiv.org/pdf/2304.07193) |
|
| 175 |
+
| VQ-VAE | Reconstruction | [Vector-quantized variational autoencoder](https://proceedings.neurips.cc/paper/2017/file/7a98af17e63a0ac09ce2e96d03992fbc-Paper.pdf) |
|
| 176 |
+
| MAE | Reconstruction | [Masked Autoencoding](https://openaccess.thecvf.com/content/CVPR2022/papers/He_Masked_Autoencoders_Are_Scalable_Vision_Learners_CVPR_2022_paper.pdf) |
|
| 177 |
+
| AR | Autoregressive | [Autoregressive Next-Token prediction](https://storage.prod.researchhub.com/uploads/papers/2020/06/01/language-models.pdf) |
|
| 178 |
+
| OSF | Self-distillation | ours |
|
| 179 |
+
|
| 180 |
+
### Downstream Tasks
|
| 181 |
+
|
| 182 |
+
**Epoch-level Classification Tasks:**
|
| 183 |
+
|
| 184 |
+
| Task | Classes | Description |
|
| 185 |
+
|------|---------|-------------|
|
| 186 |
+
| Sleep Stage | 4 | Awake, Light Sleep, Deep Sleep, REM classification |
|
| 187 |
+
| Arousal | 2 | Arousal event detection |
|
| 188 |
+
| Hypopnea | 2 | Hypopnea event detection |
|
| 189 |
+
| Oxygen Desaturation | 2 | Oxygen desaturation detection |
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
### Evaluation Settings
|
| 193 |
+
|
| 194 |
+
| Setting | Description |
|
| 195 |
+
|---------|-------------|
|
| 196 |
+
| Linear Probing | Freeze backbone, train linear classifier |
|
| 197 |
+
| Full Fine-tuning | Fine-tune entire model end-to-end |
|
| 198 |
+
| Few-shot (k-shot) | Train with limited labeled samples |
|
| 199 |
+
|
| 200 |
+
For example scripts, see `main_pipelines` and `bash_scripts` folders.
|
| 201 |
+
|
| 202 |
+
## 📊 Supported Datasets
|
| 203 |
+
|
| 204 |
+
We aggregated nine large-scale datasets from the National Sleep Research Resource platform.
|
| 205 |
+
|
| 206 |
+
| Dataset | Full Name | Source |
|
| 207 |
+
|---------|-----------|--------|
|
| 208 |
+
| SHHS | Sleep Heart Health Study | NSRR |
|
| 209 |
+
| CHAT | Childhood Adenotonsillectomy Trial | NSRR |
|
| 210 |
+
| MROS | MrOS Sleep Study | NSRR |
|
| 211 |
+
| CCSHS | Cleveland Children's Sleep and Health Study | NSRR |
|
| 212 |
+
| CFS | Cleveland Family Study | NSRR |
|
| 213 |
+
| MESA | Multi-Ethnic Study of Atherosclerosis | NSRR |
|
| 214 |
+
| SOF | Study of Osteoporotic Fractures | NSRR |
|
| 215 |
+
| WSC | Wisconsin Sleep Cohort | NSRR |
|
| 216 |
+
| STAGES | Stanford Technology Analytics and Genomics in Sleep | NSRR |
|
| 217 |
+
| NCHSDB | NCH Sleep DataBank | NSRR |
|
| 218 |
+
|
| 219 |
+
For new users, please apply for an account and access to each of these datasets following instructions here [NSRR Registration](https://sleepdata.org/join)
|
| 220 |
+
|
| 221 |
+
## 📁 Project Structure
|
| 222 |
+
|
| 223 |
+
```
|
| 224 |
+
OSF-Open-Sleep-Foundation-Model/
|
| 225 |
+
├── osf/
|
| 226 |
+
│ ├── backbone/ # ViT backbone implementations
|
| 227 |
+
│ │ └── vit1d_cls.py
|
| 228 |
+
│ ├── models/ # SSL model implementations
|
| 229 |
+
│ │ └── dino_model_cls.py
|
| 230 |
+
│ │
|
| 231 |
+
│ ├── datasets/ # Data loading utilities
|
| 232 |
+
│ └── utils/ # Helper functions
|
| 233 |
+
├── main_pipelines/ # Training scripts
|
| 234 |
+
│ ├── main_pretrain.py
|
| 235 |
+
│ └── ...
|
| 236 |
+
├── bash_scripts/ # Example bash scripts
|
| 237 |
+
├── osf_backbone.pth # Pretrained model weights
|
| 238 |
+
├── demo.ipynb # Quick start demo
|
| 239 |
+
├── config.py # Dataset and channel configurations
|
| 240 |
+
└── train_config.py # Training configurations
|
| 241 |
+
```
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
## 📝 Citation
|
| 245 |
+
|
| 246 |
+
If you use this code or models in your research, please cite our paper:
|
| 247 |
+
|
| 248 |
+
```bibtex
|
| 249 |
+
@article{shuai2026osf,
|
| 250 |
+
title={OSF: On Pre-training and Scaling of Sleep Foundation Models},
|
| 251 |
+
author={Shuai, Zitao and Xu, Zongzhe and Yang, David and Wang, Wei and Yang, Yuzhe},
|
| 252 |
+
journal={arXiv preprint},
|
| 253 |
+
year={2026}
|
| 254 |
+
}
|
| 255 |
+
```
|
| 256 |
+
|
config.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration constants for sleep data processing.
|
| 3 |
+
Contains dataset names, paths, channel definitions, and event labels.
|
| 4 |
+
"""
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
# =============================================================================
|
| 9 |
+
# Dataset name constants
|
| 10 |
+
# =============================================================================
|
| 11 |
+
SHHS = 'shhs'
|
| 12 |
+
CHAT = 'chat'
|
| 13 |
+
MROS = 'mros'
|
| 14 |
+
CCSHS = 'ccshs'
|
| 15 |
+
CFS = 'cfs'
|
| 16 |
+
MESA = 'mesa'
|
| 17 |
+
SOF = 'sof'
|
| 18 |
+
WSC = 'wsc'
|
| 19 |
+
HSP = 'hsp'
|
| 20 |
+
NCHSDB = 'nchsdb'
|
| 21 |
+
STAGES = 'stages'
|
| 22 |
+
PATS = 'pats'
|
| 23 |
+
SHHS2 = 'shhs2'
|
| 24 |
+
NUMOM2B = 'numom2b'
|
| 25 |
+
|
| 26 |
+
# =============================================================================
|
| 27 |
+
# Data paths
|
| 28 |
+
# =============================================================================
|
| 29 |
+
META_PATH = '/path/to/your/nsrr/data'
|
| 30 |
+
|
| 31 |
+
MASTER_SHHS = [META_PATH + "/" + SHHS + "/datasets/shhs-harmonized-dataset-0.21.0.csv"]
|
| 32 |
+
MASTER_CHAT = [META_PATH + "/" + CHAT + "/datasets/chat-harmonized-dataset-0.14.0.csv"]
|
| 33 |
+
MASTER_MROS = [META_PATH + "/" + MROS + "/datasets/mros-visit1-harmonized-0.6.0.csv"]
|
| 34 |
+
MASTER_CCSHS = [META_PATH + "/" + CCSHS + "/datasets/ccshs-trec-harmonized-0.8.0.csv"]
|
| 35 |
+
MASTER_CFS = [META_PATH + "/" + CFS + "/datasets/cfs-visit5-harmonized-dataset-0.7.0.csv"]
|
| 36 |
+
MASTER_MESA = [META_PATH + "/" + MESA + "/datasets/mesa-sleep-harmonized-dataset-0.7.0.csv"]
|
| 37 |
+
MASTER_SOF = [META_PATH + "/" + SOF + "/datasets/sof-visit-8-harmonized-dataset-0.8.0.csv"]
|
| 38 |
+
MASTER_WSC = [META_PATH + "/" + WSC + "/datasets/wsc-harmonized-dataset-0.7.0.csv"]
|
| 39 |
+
MASTER_HSP = [
|
| 40 |
+
META_PATH + "/" + HSP + "/psg-metadata/I0001_psg_metadata_2025-05-06.csv",
|
| 41 |
+
META_PATH + "/" + HSP + "/psg-metadata/I0002_psg_metadata_2025-05-06.csv",
|
| 42 |
+
META_PATH + "/" + HSP + "/psg-metadata/I0003_psg_metadata_2025-05-06.csv",
|
| 43 |
+
META_PATH + "/" + HSP + "/psg-metadata/I0004_psg_metadata_2025-05-06.csv",
|
| 44 |
+
META_PATH + "/" + HSP + "/psg-metadata/I0006_psg_metadata_2025-05-06.csv",
|
| 45 |
+
]
|
| 46 |
+
MASTER_STAGES = [META_PATH + "/" + STAGES + "/metadata/stages-harmonized-dataset-0.3.0.csv"]
|
| 47 |
+
MASTER_NCHSDB = [META_PATH + "/" + NCHSDB + "/datasets/nchsdb-dataset-harmonized-0.3.0.csv"]
|
| 48 |
+
MASTER_PATS = [META_PATH + "/" + PATS + "/datasets/pats-harmonized-dataset-0.1.0.csv"]
|
| 49 |
+
|
| 50 |
+
MASTER_CSV_LIST = {
|
| 51 |
+
'shhs': MASTER_SHHS,
|
| 52 |
+
'chat': MASTER_CHAT,
|
| 53 |
+
'mros': MASTER_MROS,
|
| 54 |
+
'ccshs': MASTER_CCSHS,
|
| 55 |
+
'cfs': MASTER_CFS,
|
| 56 |
+
'mesa': MASTER_MESA,
|
| 57 |
+
'sof': MASTER_SOF,
|
| 58 |
+
'wsc': MASTER_WSC,
|
| 59 |
+
'hsp': MASTER_HSP,
|
| 60 |
+
'stages': MASTER_STAGES,
|
| 61 |
+
'pats': MASTER_PATS,
|
| 62 |
+
'nchsdb': MASTER_NCHSDB,
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
# =============================================================================
|
| 66 |
+
# Channel name constants
|
| 67 |
+
# =============================================================================
|
| 68 |
+
# ECG channels
|
| 69 |
+
ECG = 'ECG'
|
| 70 |
+
ECG1 = 'ECG1'
|
| 71 |
+
ECG2 = 'ECG2'
|
| 72 |
+
ECG3 = 'ECG3'
|
| 73 |
+
HR = 'HR'
|
| 74 |
+
PPG = 'PPG'
|
| 75 |
+
|
| 76 |
+
# Respiratory channels
|
| 77 |
+
SPO2 = 'SPO2'
|
| 78 |
+
OX = 'OX'
|
| 79 |
+
ABD = 'ABD'
|
| 80 |
+
THX = 'THX'
|
| 81 |
+
AF = 'AF'
|
| 82 |
+
NP = 'NP'
|
| 83 |
+
SN = 'SN'
|
| 84 |
+
|
| 85 |
+
# EOG channels
|
| 86 |
+
EOG_L = 'EOG_L'
|
| 87 |
+
EOG_R = 'EOG_R'
|
| 88 |
+
EOG_E1_A2 = 'EOG_E1_A2'
|
| 89 |
+
EOG_E2_A1 = 'EOG_E2_A1'
|
| 90 |
+
|
| 91 |
+
# EMG Leg channels
|
| 92 |
+
EMG_LLeg = 'EMG_LLeg'
|
| 93 |
+
EMG_RLeg = 'EMG_RLeg'
|
| 94 |
+
EMG_LLeg1 = 'EMG_LLeg1'
|
| 95 |
+
EMG_LLeg2 = 'EMG_LLeg2'
|
| 96 |
+
EMG_RLeg1 = 'EMG_RLeg1'
|
| 97 |
+
EMG_RLeg2 = 'EMG_RLeg2'
|
| 98 |
+
EMG_Leg = 'EMG_Leg'
|
| 99 |
+
|
| 100 |
+
# Sensor Leg channels
|
| 101 |
+
SENSOR_Leg = 'SENSOR_Leg'
|
| 102 |
+
SENSOR_LLeg = 'SENSOR_LLeg'
|
| 103 |
+
SENSOR_LLeg1 = 'SENSOR_LLeg1'
|
| 104 |
+
SENSOR_LLeg2 = 'SENSOR_LLeg2'
|
| 105 |
+
SENSOR_RLeg = 'SENSOR_RLeg'
|
| 106 |
+
SENSOR_RLeg1 = 'SENSOR_RLeg1'
|
| 107 |
+
SENSOR_RLeg2 = 'SENSOR_RLeg2'
|
| 108 |
+
|
| 109 |
+
# EMG Chin channels
|
| 110 |
+
EMG_Chin = 'EMG_Chin'
|
| 111 |
+
EMG_RChin = 'EMG_RChin'
|
| 112 |
+
EMG_LChin = 'EMG_LChin'
|
| 113 |
+
EMG_CChin = 'EMG_CChin'
|
| 114 |
+
|
| 115 |
+
# EEG channels (unipolar)
|
| 116 |
+
EEG_C3 = 'EEG_C3'
|
| 117 |
+
EEG_C4 = 'EEG_C4'
|
| 118 |
+
EEG_A1 = 'EEG_A1'
|
| 119 |
+
EEG_A2 = 'EEG_A2'
|
| 120 |
+
EEG_O1 = 'EEG_O1'
|
| 121 |
+
EEG_O2 = 'EEG_O2'
|
| 122 |
+
EEG_F3 = 'EEG_F3'
|
| 123 |
+
EEG_F4 = 'EEG_F4'
|
| 124 |
+
|
| 125 |
+
# EEG channels (bipolar/referenced)
|
| 126 |
+
EEG_C3_A2 = 'EEG_C3_A2'
|
| 127 |
+
EEG_C4_A1 = 'EEG_C4_A1'
|
| 128 |
+
EEG_F3_A2 = 'EEG_F3_A2'
|
| 129 |
+
EEG_F4_A1 = 'EEG_F4_A1'
|
| 130 |
+
EEG_O1_A2 = 'EEG_O1_A2'
|
| 131 |
+
EEG_O2_A1 = 'EEG_O2_A1'
|
| 132 |
+
|
| 133 |
+
# Other channels
|
| 134 |
+
FPZ = 'FPZ'
|
| 135 |
+
GROUND = 'GROUND'
|
| 136 |
+
POS = 'POS'
|
| 137 |
+
|
| 138 |
+
# =============================================================================
|
| 139 |
+
# Sampling frequencies (Hz)
|
| 140 |
+
# =============================================================================
|
| 141 |
+
FREQ_ECG = 128
|
| 142 |
+
FREQ_ECG1 = 128
|
| 143 |
+
FREQ_ECG2 = 128
|
| 144 |
+
FREQ_ECG3 = 128
|
| 145 |
+
FREQ_HR = 1
|
| 146 |
+
FREQ_PPG = 128
|
| 147 |
+
|
| 148 |
+
FREQ_SPO2 = 1
|
| 149 |
+
FREQ_OX = 1
|
| 150 |
+
FREQ_ABD = 8
|
| 151 |
+
FREQ_THX = 8
|
| 152 |
+
FREQ_AF = 8
|
| 153 |
+
FREQ_NP = 8
|
| 154 |
+
FREQ_SN = 32
|
| 155 |
+
|
| 156 |
+
FREQ_EOG_L = 64
|
| 157 |
+
FREQ_EOG_R = 64
|
| 158 |
+
FREQ_EOG_E1_A2 = 64
|
| 159 |
+
FREQ_EOG_E2_A1 = 64
|
| 160 |
+
|
| 161 |
+
FREQ_EMG_Leg = 64
|
| 162 |
+
FREQ_EMG_LLeg = 64
|
| 163 |
+
FREQ_EMG_RLeg = 64
|
| 164 |
+
FREQ_EMG_LLeg1 = 64
|
| 165 |
+
FREQ_EMG_LLeg2 = 64
|
| 166 |
+
FREQ_EMG_RLeg1 = 64
|
| 167 |
+
FREQ_EMG_RLeg2 = 64
|
| 168 |
+
|
| 169 |
+
FREQ_SENSOR_Leg = 64
|
| 170 |
+
FREQ_SENSOR_LLeg = 64
|
| 171 |
+
FREQ_SENSOR_LLeg1 = 64
|
| 172 |
+
FREQ_SENSOR_LLeg2 = 64
|
| 173 |
+
FREQ_SENSOR_RLeg = 64
|
| 174 |
+
FREQ_SENSOR_RLeg1 = 64
|
| 175 |
+
FREQ_SENSOR_RLeg2 = 64
|
| 176 |
+
|
| 177 |
+
FREQ_EMG_Chin = 64
|
| 178 |
+
FREQ_EMG_LChin = 64
|
| 179 |
+
FREQ_EMG_RChin = 64
|
| 180 |
+
FREQ_EMG_CChin = 64
|
| 181 |
+
|
| 182 |
+
FREQ_EEG_C3 = 64
|
| 183 |
+
FREQ_EEG_C4 = 64
|
| 184 |
+
FREQ_EEG_A1 = 64
|
| 185 |
+
FREQ_EEG_A2 = 64
|
| 186 |
+
FREQ_EEG_O1 = 64
|
| 187 |
+
FREQ_EEG_O2 = 64
|
| 188 |
+
FREQ_EEG_F3 = 64
|
| 189 |
+
FREQ_EEG_F4 = 64
|
| 190 |
+
|
| 191 |
+
FREQ_EEG_C3_A2 = 64
|
| 192 |
+
FREQ_EEG_C4_A1 = 64
|
| 193 |
+
FREQ_EEG_F3_A2 = 64
|
| 194 |
+
FREQ_EEG_F4_A1 = 64
|
| 195 |
+
FREQ_EEG_O1_A2 = 64
|
| 196 |
+
FREQ_EEG_O2_A1 = 64
|
| 197 |
+
|
| 198 |
+
FREQ_POS = 1
|
| 199 |
+
|
| 200 |
+
# =============================================================================
|
| 201 |
+
# Event annotation column names
|
| 202 |
+
# =============================================================================
|
| 203 |
+
EVENT_NAME_COLUMN = 'EVENT'
|
| 204 |
+
START_TIME_COLUMN = 'START_SEC'
|
| 205 |
+
END_TIME_COLUMN = 'END_SEC'
|
| 206 |
+
|
| 207 |
+
# =============================================================================
|
| 208 |
+
# Respiratory event names
|
| 209 |
+
# =============================================================================
|
| 210 |
+
RESPIRATORY_EVENT_CENTRAL_APNEA = 'Central Apnea'
|
| 211 |
+
RESPIRATORY_EVENT_OBSTRUCTIVE_APNEA = 'Obstructive Apnea'
|
| 212 |
+
RESPIRATORY_EVENT_MIXED_APNEA = 'Mixed Apnea'
|
| 213 |
+
RESPIRATORY_EVENT_HYPOPNEA = 'Hypopnea'
|
| 214 |
+
RESPIRATORY_EVENT_DESATURATION = 'Oxygen Desaturation'
|
| 215 |
+
|
| 216 |
+
# =============================================================================
|
| 217 |
+
# Limb movement event names
|
| 218 |
+
# =============================================================================
|
| 219 |
+
LIMB_MOVEMENT_ISOLATED = 'Limb Movement Isolated'
|
| 220 |
+
LIMB_MOVEMENT_PERIODIC = 'Limb Movement Periodic'
|
| 221 |
+
LIMB_MOVEMENT_ISOLATED_LEFT = 'Left Limb Movement Isolated'
|
| 222 |
+
LIMB_MOVEMENT_ISOLATED_RIGHT = 'Right Limb Movement Isolated'
|
| 223 |
+
LIMB_MOVEMENT_PERIODIC_LEFT = 'Left Limb Movement Periodic'
|
| 224 |
+
LIMB_MOVEMENT_PERIODIC_RIGHT = 'Right Limb Movement Periodic'
|
| 225 |
+
|
| 226 |
+
# =============================================================================
|
| 227 |
+
# Arousal event names
|
| 228 |
+
# =============================================================================
|
| 229 |
+
AROUSAL_EVENT_CLASSIC = 'Arousal'
|
| 230 |
+
AROUSAL_EVENT_RESPIRATORY = 'RERA'
|
| 231 |
+
AROUSAL_EVENT_EMG = 'EMG-Related Arousal'
|
demo.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
environment.yml
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: myenv
|
| 2 |
+
channels:
|
| 3 |
+
- conda-forge
|
| 4 |
+
dependencies:
|
| 5 |
+
- _openmp_mutex=4.5=2_gnu
|
| 6 |
+
- bzip2=1.0.8=h4777abc_8
|
| 7 |
+
- ca-certificates=2025.8.3=hbd8a1cb_0
|
| 8 |
+
- ld_impl_linux-aarch64=2.44=h5e2c951_1
|
| 9 |
+
- libexpat=2.7.1=hfae3067_0
|
| 10 |
+
- libffi=3.4.6=he21f813_1
|
| 11 |
+
- libgcc=15.1.0=he277a41_5
|
| 12 |
+
- libgcc-ng=15.1.0=he9431aa_5
|
| 13 |
+
- libgomp=15.1.0=he277a41_5
|
| 14 |
+
- liblzma=5.8.1=h86ecc28_2
|
| 15 |
+
- libnsl=2.0.1=h86ecc28_1
|
| 16 |
+
- libsqlite=3.50.4=h022381a_0
|
| 17 |
+
- libuuid=2.41.1=h3e4203c_0
|
| 18 |
+
- libxcrypt=4.4.36=h31becfc_1
|
| 19 |
+
- libzlib=1.3.1=h86ecc28_2
|
| 20 |
+
- ncurses=6.5=ha32ae93_3
|
| 21 |
+
- openssl=3.5.2=h8e36d6e_0
|
| 22 |
+
- pip=25.2=pyh8b19718_0
|
| 23 |
+
- python=3.10.18=h256493d_0_cpython
|
| 24 |
+
- readline=8.2=h8382b9d_2
|
| 25 |
+
- setuptools=80.9.0=pyhff2d567_0
|
| 26 |
+
- tk=8.6.13=noxft_h5688188_102
|
| 27 |
+
- wheel=0.45.1=pyhd8ed1ab_1
|
| 28 |
+
- pip:
|
| 29 |
+
- absl-py==2.3.1
|
| 30 |
+
- accelerate==1.2.1
|
| 31 |
+
- aiohappyeyeballs==2.6.1
|
| 32 |
+
- aiohttp==3.12.15
|
| 33 |
+
- aiosignal==1.4.0
|
| 34 |
+
- albucore==0.0.24
|
| 35 |
+
- albumentations==2.0.8
|
| 36 |
+
- annotated-types==0.7.0
|
| 37 |
+
- asttokens==3.0.0
|
| 38 |
+
- async-timeout==5.0.1
|
| 39 |
+
- attrs==25.3.0
|
| 40 |
+
- beartype==0.22.2
|
| 41 |
+
- braceexpand==0.1.7
|
| 42 |
+
- certifi==2025.10.5
|
| 43 |
+
- cffi==2.0.0
|
| 44 |
+
- charset-normalizer==3.4.4
|
| 45 |
+
- click==8.2.1
|
| 46 |
+
- coloredlogs==15.0.1
|
| 47 |
+
- comm==0.2.3
|
| 48 |
+
- contourpy==1.3.2
|
| 49 |
+
- cosine-annealing-warmup==2.0
|
| 50 |
+
- cycler==0.12.1
|
| 51 |
+
- debugpy==1.8.17
|
| 52 |
+
- decorator==5.2.1
|
| 53 |
+
- easydict==1.13
|
| 54 |
+
- einops==0.8.1
|
| 55 |
+
- ema-pytorch==0.7.7
|
| 56 |
+
- et-xmlfile==2.0.0
|
| 57 |
+
- exceptiongroup==1.3.0
|
| 58 |
+
- executing==2.2.1
|
| 59 |
+
- filelock==3.13.1
|
| 60 |
+
- flatbuffers==25.9.23
|
| 61 |
+
- fonttools==4.59.2
|
| 62 |
+
- frozenlist==1.7.0
|
| 63 |
+
- fsspec==2024.6.1
|
| 64 |
+
- gitdb==4.0.12
|
| 65 |
+
- gitpython==3.1.45
|
| 66 |
+
- grpcio==1.75.1
|
| 67 |
+
- h5py==3.14.0
|
| 68 |
+
- hf-xet==1.1.10
|
| 69 |
+
- huggingface-hub==0.35.3
|
| 70 |
+
- humanfriendly==10.0
|
| 71 |
+
- idna==3.11
|
| 72 |
+
- imageio==2.37.0
|
| 73 |
+
- importlib-metadata==8.7.0
|
| 74 |
+
- insightface==0.7.3
|
| 75 |
+
- ipdb==0.13.13
|
| 76 |
+
- ipykernel==6.30.1
|
| 77 |
+
- ipython==8.37.0
|
| 78 |
+
- jedi==0.19.2
|
| 79 |
+
- jinja2==3.1.6
|
| 80 |
+
- joblib==1.5.2
|
| 81 |
+
- jupyter-client==8.6.3
|
| 82 |
+
- jupyter-core==5.8.1
|
| 83 |
+
- kiwisolver==1.4.9
|
| 84 |
+
- kornia==0.8.1
|
| 85 |
+
- kornia-rs==0.1.9
|
| 86 |
+
- lazy-loader==0.4
|
| 87 |
+
- lightning-utilities==0.15.2
|
| 88 |
+
- llvmlite==0.46.0
|
| 89 |
+
- loguru==0.7.3
|
| 90 |
+
- markdown==3.9
|
| 91 |
+
- markupsafe==2.1.5
|
| 92 |
+
- matplotlib==3.10.6
|
| 93 |
+
- matplotlib-inline==0.1.7
|
| 94 |
+
- ml-dtypes==0.5.3
|
| 95 |
+
- mne==1.10.1
|
| 96 |
+
- mpmath==1.3.0
|
| 97 |
+
- multidict==6.6.4
|
| 98 |
+
- munch==4.0.0
|
| 99 |
+
- nest-asyncio==1.6.0
|
| 100 |
+
- networkx==3.4.2
|
| 101 |
+
- neurokit2==0.2.12
|
| 102 |
+
- ninja==1.13.0
|
| 103 |
+
- numba==0.63.1
|
| 104 |
+
- numpy==2.2.6
|
| 105 |
+
- onnx==1.19.1
|
| 106 |
+
- onnx2torch==1.5.15
|
| 107 |
+
- onnxruntime==1.23.1
|
| 108 |
+
- opencv-python==4.12.0.88
|
| 109 |
+
- opencv-python-headless==4.12.0.88
|
| 110 |
+
- openpyxl==3.1.5
|
| 111 |
+
- packaging==24.2
|
| 112 |
+
- pandas==2.3.2
|
| 113 |
+
- parso==0.8.5
|
| 114 |
+
- pexpect==4.9.0
|
| 115 |
+
- pillow==11.0.0
|
| 116 |
+
- platformdirs==4.5.0
|
| 117 |
+
- pooch==1.8.2
|
| 118 |
+
- prettytable==3.16.0
|
| 119 |
+
- prompt-toolkit==3.0.52
|
| 120 |
+
- propcache==0.3.2
|
| 121 |
+
- protobuf==6.32.1
|
| 122 |
+
- psutil==7.1.0
|
| 123 |
+
- ptyprocess==0.7.0
|
| 124 |
+
- pure-eval==0.2.3
|
| 125 |
+
- pyarrow==21.0.0
|
| 126 |
+
- pycparser==2.23
|
| 127 |
+
- pydantic==2.11.7
|
| 128 |
+
- pydantic-core==2.33.2
|
| 129 |
+
- pygments==2.19.2
|
| 130 |
+
- pynndescent==0.5.13
|
| 131 |
+
- pyparsing==3.2.3
|
| 132 |
+
- pysam==0.23.3
|
| 133 |
+
- python-dateutil==2.9.0.post0
|
| 134 |
+
- pytorch-lightning==2.5.5
|
| 135 |
+
- pytorch-warmup==0.2.0
|
| 136 |
+
- pytz==2025.2
|
| 137 |
+
- pyyaml==6.0.3
|
| 138 |
+
- pyzmq==27.1.0
|
| 139 |
+
- regex==2025.9.1
|
| 140 |
+
- requests==2.32.5
|
| 141 |
+
- safetensors==0.6.2
|
| 142 |
+
- scikit-image==0.25.2
|
| 143 |
+
- scikit-learn==1.7.2
|
| 144 |
+
- scipy==1.15.3
|
| 145 |
+
- seaborn==0.13.2
|
| 146 |
+
- sentencepiece==0.2.1
|
| 147 |
+
- sentry-sdk==2.37.1
|
| 148 |
+
- simsimd==6.5.3
|
| 149 |
+
- six==1.17.0
|
| 150 |
+
- smmap==5.0.2
|
| 151 |
+
- soundfile==0.13.1
|
| 152 |
+
- stack-data==0.6.3
|
| 153 |
+
- stringzilla==4.2.1
|
| 154 |
+
- sympy==1.13.1
|
| 155 |
+
- tabulate==0.9.0
|
| 156 |
+
- tensorboard==2.20.0
|
| 157 |
+
- tensorboard-data-server==0.7.2
|
| 158 |
+
- tensorboardx==2.6.4
|
| 159 |
+
- threadpoolctl==3.6.0
|
| 160 |
+
- tifffile==2025.5.10
|
| 161 |
+
- timm==1.0.19
|
| 162 |
+
- tokenizers==0.22.0
|
| 163 |
+
- tomli==2.2.1
|
| 164 |
+
- torch==2.5.1
|
| 165 |
+
- torchdiffeq==0.2.5
|
| 166 |
+
- torchmetrics==1.8.2
|
| 167 |
+
- torchtools==0.3.5
|
| 168 |
+
- torchvision==0.20.1
|
| 169 |
+
- tornado==6.5.2
|
| 170 |
+
- tqdm==4.67.1
|
| 171 |
+
- traitlets==5.14.3
|
| 172 |
+
- transformers==4.56.1
|
| 173 |
+
- typing-extensions==4.15.0
|
| 174 |
+
- typing-inspection==0.4.1
|
| 175 |
+
- tzdata==2025.2
|
| 176 |
+
- umap-learn==0.5.9.post2
|
| 177 |
+
- urllib3==2.5.0
|
| 178 |
+
- vitaldb==1.5.8
|
| 179 |
+
- wandb==0.22.1
|
| 180 |
+
- warmup-scheduler==0.3
|
| 181 |
+
- wcwidth==0.2.13
|
| 182 |
+
- webdataset==1.0.2
|
| 183 |
+
- werkzeug==3.1.3
|
| 184 |
+
- wfdb==4.3.0
|
| 185 |
+
- xxhash==3.5.0
|
| 186 |
+
- yarl==1.20.1
|
| 187 |
+
- zipp==3.23.0
|
finetune.bash
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
DATASETS=("shhs" "mros")
|
| 3 |
+
LABELS=("Stage" "Arousal" "Hypopnea" "Oxygen Desaturation")
|
| 4 |
+
|
| 5 |
+
TRAIN_PCTS=(1.0)
|
| 6 |
+
|
| 7 |
+
declare -A MODELS
|
| 8 |
+
|
| 9 |
+
MODELS["dino_ours"]="osf_vit_base.ckpt|all"
|
| 10 |
+
|
| 11 |
+
for model_name in "${!MODELS[@]}"; do
|
| 12 |
+
|
| 13 |
+
IFS='|' read -r ckpt_path use_backbone <<< "${MODELS[$model_name]}"
|
| 14 |
+
|
| 15 |
+
for dataset in "${DATASETS[@]}"; do
|
| 16 |
+
for label in "${LABELS[@]}"; do
|
| 17 |
+
for pct in "${TRAIN_PCTS[@]}"; do
|
| 18 |
+
echo "===== Model: ${model_name}, Dataset: ${dataset}, Label: ${label}, Pct: ${pct} ====="
|
| 19 |
+
|
| 20 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 python main_finetune.py \
|
| 21 |
+
--train_data_pct ${pct} \
|
| 22 |
+
--max_steps 500 \
|
| 23 |
+
--use_which_backbone "${use_backbone}" \
|
| 24 |
+
--model_name "${model_name}" \
|
| 25 |
+
--ckpt_path "${ckpt_path}" \
|
| 26 |
+
--lr 0.1 \
|
| 27 |
+
--eval_label "${label}" \
|
| 28 |
+
--num_devices 4 \
|
| 29 |
+
--data_source both \
|
| 30 |
+
--include_datasets "${dataset}" \
|
| 31 |
+
--downstream_dataset_name "${dataset}"
|
| 32 |
+
done
|
| 33 |
+
done
|
| 34 |
+
done
|
| 35 |
+
done
|
main_pipelines/main_finetune.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pprint import pprint
|
| 2 |
+
import os
|
| 3 |
+
from argparse import ArgumentParser, Namespace
|
| 4 |
+
import datetime
|
| 5 |
+
from dateutil import tz
|
| 6 |
+
import random
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import warnings
|
| 10 |
+
from pytorch_lightning import seed_everything, Trainer
|
| 11 |
+
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, EarlyStopping
|
| 12 |
+
from pytorch_lightning.loggers import WandbLogger
|
| 13 |
+
|
| 14 |
+
from osf.datasets.pretrain_datamodule import SleepDataModule
|
| 15 |
+
from osf.models.dino_model_cls import DINOCLSModel
|
| 16 |
+
from config import *
|
| 17 |
+
from train_config import *
|
| 18 |
+
from osf.models.ssl_finetuner import SSLFineTuner, SSLVitalSignsRegressor
|
| 19 |
+
from osf.utils.results_utils import save_results_to_json
|
| 20 |
+
|
| 21 |
+
warnings.filterwarnings("ignore")
|
| 22 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 23 |
+
torch.backends.cudnn.deterministic = True
|
| 24 |
+
torch.backends.cudnn.benchmark = True
|
| 25 |
+
torch.set_float32_matmul_precision('high')
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def main(hparams: Namespace):
|
| 29 |
+
now = datetime.datetime.now(tz.tzlocal())
|
| 30 |
+
timestamp = now.strftime("%Y_%m_%d_%H_%M_%S") + f"_{now.microsecond // 1000:03d}"
|
| 31 |
+
|
| 32 |
+
if hparams.monitor_type == "main":
|
| 33 |
+
exp_name = "finetune_12ch"
|
| 34 |
+
else:
|
| 35 |
+
exp_name = f"finetune_{hparams.monitor_type}"
|
| 36 |
+
|
| 37 |
+
if hparams.finetune_backbone:
|
| 38 |
+
exp_name = f"{exp_name}_full"
|
| 39 |
+
|
| 40 |
+
if hasattr(hparams, 'n_train_samples') and hparams.n_train_samples is not None and hparams.n_train_samples > 0:
|
| 41 |
+
pct_str = f"k{hparams.n_train_samples}"
|
| 42 |
+
elif hparams.train_data_pct < 1:
|
| 43 |
+
pct_str = f"{int(hparams.train_data_pct * 100)}pct"
|
| 44 |
+
else:
|
| 45 |
+
pct_str = "full"
|
| 46 |
+
if hparams.task_type == "classification":
|
| 47 |
+
task_label = hparams.eval_label
|
| 48 |
+
elif hparams.task_type == "regression":
|
| 49 |
+
task_label = "_".join(hparams.regression_targets)
|
| 50 |
+
else:
|
| 51 |
+
raise NotImplementedError(f"Unknown task_type: {hparams.task_type}")
|
| 52 |
+
run_name = f"{task_label}_{hparams.downstream_dataset_name}_{hparams.model_name}_{pct_str}_{timestamp}"
|
| 53 |
+
|
| 54 |
+
ckpt_dir = os.path.join(
|
| 55 |
+
CKPT_PATH, f"logs/{exp_name}/ckpts/{run_name}")
|
| 56 |
+
os.makedirs(ckpt_dir, exist_ok=True)
|
| 57 |
+
|
| 58 |
+
if hparams.task_type == "regression":
|
| 59 |
+
ckpt_monitor = "val_mae"
|
| 60 |
+
ckpt_mode = "min"
|
| 61 |
+
else:
|
| 62 |
+
ckpt_monitor = "val_auc"
|
| 63 |
+
ckpt_mode = "max"
|
| 64 |
+
|
| 65 |
+
callbacks = [
|
| 66 |
+
LearningRateMonitor(logging_interval="step"),
|
| 67 |
+
ModelCheckpoint(monitor=ckpt_monitor, dirpath=ckpt_dir,
|
| 68 |
+
save_last=False, mode=ckpt_mode, save_top_k=1,
|
| 69 |
+
auto_insert_metric_name=True),
|
| 70 |
+
]
|
| 71 |
+
if getattr(hparams, 'early_stopping', False):
|
| 72 |
+
early_stop_callback = EarlyStopping(
|
| 73 |
+
monitor=ckpt_monitor,
|
| 74 |
+
patience=getattr(hparams, 'early_stopping_patience', 10),
|
| 75 |
+
mode=ckpt_mode,
|
| 76 |
+
verbose=True,
|
| 77 |
+
)
|
| 78 |
+
callbacks.append(early_stop_callback)
|
| 79 |
+
print(f"[INFO] Early stopping enabled: monitor={ckpt_monitor}, patience={hparams.early_stopping_patience}")
|
| 80 |
+
logger_dir = os.path.join(CKPT_PATH, f"logs/{exp_name}")
|
| 81 |
+
os.makedirs(logger_dir, exist_ok=True)
|
| 82 |
+
wandb_logger = WandbLogger(
|
| 83 |
+
project=f"{exp_name}_sleepuni", save_dir=logger_dir, name=run_name)
|
| 84 |
+
trainer = Trainer(
|
| 85 |
+
max_steps=hparams.max_steps,
|
| 86 |
+
accelerator="gpu",
|
| 87 |
+
accumulate_grad_batches=hparams.accumulate_grad_batches,
|
| 88 |
+
deterministic=True,
|
| 89 |
+
devices=hparams.num_devices,
|
| 90 |
+
strategy="ddp_find_unused_parameters_true",
|
| 91 |
+
precision=hparams.precision,
|
| 92 |
+
callbacks=callbacks,
|
| 93 |
+
logger=wandb_logger
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
hparams.exp_log_dir = os.path.join(
|
| 97 |
+
CKPT_PATH, f"data/{run_name}/exp_logs")
|
| 98 |
+
train_edf_cols = MONITOR_TYPE_MAP.get(hparams.monitor_type, TRAIN_EDF_COLS_UNI_ENC)
|
| 99 |
+
|
| 100 |
+
if hparams.task_type == "regression":
|
| 101 |
+
event_cols = None
|
| 102 |
+
regression_targets = hparams.regression_targets
|
| 103 |
+
print(f"[INFO] Regression task with targets: {regression_targets}")
|
| 104 |
+
else: # classification
|
| 105 |
+
event_cols = hparams.eval_label
|
| 106 |
+
regression_targets = None
|
| 107 |
+
|
| 108 |
+
regression_filter_config = None
|
| 109 |
+
if hparams.task_type == "regression" and "SPO2" in hparams.regression_targets:
|
| 110 |
+
if hparams.filter_spo2_min is not None or hparams.filter_spo2_max is not None:
|
| 111 |
+
spo2_filter = {}
|
| 112 |
+
if hparams.filter_spo2_min is not None:
|
| 113 |
+
spo2_filter["min"] = hparams.filter_spo2_min
|
| 114 |
+
if hparams.filter_spo2_max is not None:
|
| 115 |
+
spo2_filter["max"] = hparams.filter_spo2_max
|
| 116 |
+
regression_filter_config = {"SPO2_mean": spo2_filter}
|
| 117 |
+
print(f"[INFO] Will filter SPO2_mean with: {spo2_filter}")
|
| 118 |
+
|
| 119 |
+
datamodule = SleepDataModule(
|
| 120 |
+
is_pretrain = 0,
|
| 121 |
+
data_pct = hparams.train_data_pct,
|
| 122 |
+
downstream_dataset_name = hparams.downstream_dataset_name,
|
| 123 |
+
csv_dir = SPLIT_DATA_FOLDER,
|
| 124 |
+
train_edf_cols = train_edf_cols,
|
| 125 |
+
event_cols = event_cols,
|
| 126 |
+
batch_size = hparams.batch_size,
|
| 127 |
+
num_workers = hparams.num_workers,
|
| 128 |
+
sample_rate = hparams.sample_rate,
|
| 129 |
+
window_size = 30,
|
| 130 |
+
data_source = hparams.data_source,
|
| 131 |
+
include_datasets = hparams.include_datasets,
|
| 132 |
+
regression_targets = regression_targets,
|
| 133 |
+
regression_filter_config = regression_filter_config,
|
| 134 |
+
n_train_samples = getattr(hparams, 'n_train_samples', None),
|
| 135 |
+
val_batch_size = getattr(hparams, 'val_batch_size', None),
|
| 136 |
+
val_data_pct = getattr(hparams, 'val_data_pct', None),
|
| 137 |
+
random_seed = hparams.seed,
|
| 138 |
+
)
|
| 139 |
+
if hparams.task_type == "regression":
|
| 140 |
+
hparams.num_classes = len(hparams.regression_targets) # output dim
|
| 141 |
+
hparams.target_names = hparams.regression_targets
|
| 142 |
+
print(f"[INFO] Regression targets: {hparams.target_names}, num_classes={hparams.num_classes}")
|
| 143 |
+
else: # classification
|
| 144 |
+
train_dataset = datamodule.train_dataloader().dataset
|
| 145 |
+
if hasattr(train_dataset, 'dataset'): # It's a Subset
|
| 146 |
+
hparams.num_classes = train_dataset.dataset.num_classes
|
| 147 |
+
else:
|
| 148 |
+
hparams.num_classes = train_dataset.num_classes
|
| 149 |
+
print(f"[INFO] Classification num_classes: {hparams.num_classes}")
|
| 150 |
+
hparams.training_steps_per_epoch = len(datamodule.train_dataloader()) // hparams.accumulate_grad_batches // hparams.num_devices
|
| 151 |
+
|
| 152 |
+
if hparams.max_steps > 0:
|
| 153 |
+
hparams.total_training_steps = hparams.max_steps
|
| 154 |
+
else:
|
| 155 |
+
hparams.total_training_steps = hparams.training_steps_per_epoch * hparams.max_epochs
|
| 156 |
+
|
| 157 |
+
print(f"Total training steps: {hparams.total_training_steps}")
|
| 158 |
+
print(f"Steps per epoch: {hparams.training_steps_per_epoch}")
|
| 159 |
+
|
| 160 |
+
class_distribution = datamodule.get_class_distribution()
|
| 161 |
+
if class_distribution is not None:
|
| 162 |
+
print(f"Class distribution: {class_distribution}")
|
| 163 |
+
hparams.class_distribution = class_distribution
|
| 164 |
+
|
| 165 |
+
# Load pretrained DINO model
|
| 166 |
+
pretrain_model = DINOCLSModel.load_from_checkpoint(hparams.ckpt_path)
|
| 167 |
+
pprint(vars(hparams))
|
| 168 |
+
|
| 169 |
+
hparams.epochs = hparams.max_epochs
|
| 170 |
+
|
| 171 |
+
def create_finetuner(backbones, hparams, train_edf_cols=None):
|
| 172 |
+
exclude_keys = {'train_edf_cols', 'regression_targets'}
|
| 173 |
+
hparams_dict = {k: v for k, v in vars(hparams).items() if k not in exclude_keys}
|
| 174 |
+
|
| 175 |
+
if hparams.task_type == "regression":
|
| 176 |
+
return SSLVitalSignsRegressor(backbones=backbones, **hparams_dict)
|
| 177 |
+
else:
|
| 178 |
+
return SSLFineTuner(backbones=backbones, **hparams_dict)
|
| 179 |
+
|
| 180 |
+
# Extract ViT backbone from DINO model
|
| 181 |
+
vit = pretrain_model.encoders["all"].backbone
|
| 182 |
+
hparams.in_features = vit.width
|
| 183 |
+
print(f"[INFO] Extracted ViT backbone for dino_ours, in_features={hparams.in_features}")
|
| 184 |
+
model = create_finetuner(backbones={"all": vit}, hparams=hparams, train_edf_cols=train_edf_cols)
|
| 185 |
+
|
| 186 |
+
trainer.fit(model, datamodule=datamodule)
|
| 187 |
+
trainer.test(model, datamodule=datamodule, ckpt_path="last")
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
if __name__ == '__main__':
|
| 191 |
+
parser = ArgumentParser(description="Fine-tune pretrained model for downstream tasks.")
|
| 192 |
+
parser.add_argument("--model_name", type=str, default="dino_ours")
|
| 193 |
+
parser.add_argument("--eval_label", type=str, default="Stage",
|
| 194 |
+
)
|
| 195 |
+
parser.add_argument("--downstream_dataset_name", type=str, default="mros",
|
| 196 |
+
)
|
| 197 |
+
parser.add_argument("--use_which_backbone", type=str, default="all",
|
| 198 |
+
)
|
| 199 |
+
parser.add_argument("--monitor_type", type=str, default="main",
|
| 200 |
+
choices=["main", "type3", "type4"],
|
| 201 |
+
help="Channel configuration: main (12ch), type3 (5ch), type4 (3ch)")
|
| 202 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 203 |
+
parser.add_argument("--train_data_pct", type=float, default=1.)
|
| 204 |
+
parser.add_argument("--n_train_samples", type=int, default=None,
|
| 205 |
+
help="If set, use exactly this many training samples (overrides train_data_pct for few-shot)")
|
| 206 |
+
parser.add_argument("--data_source", type=str, default="auto",
|
| 207 |
+
choices=["auto", "pretrain", "downstream", "both"],
|
| 208 |
+
help="Which CSV source to use: auto (default), pretrain, downstream, or both")
|
| 209 |
+
parser.add_argument("--include_datasets", type=str, nargs="*", default=None,
|
| 210 |
+
help="Filter by dataset names, e.g., --include_datasets shhs mros")
|
| 211 |
+
parser.add_argument("--batch_size", type=int, default=800)
|
| 212 |
+
parser.add_argument("--val_batch_size", type=int, default=None,
|
| 213 |
+
help="Batch size for val/test (defaults to batch_size if not set, useful for few-shot)")
|
| 214 |
+
parser.add_argument("--val_data_pct", type=float, default=None,
|
| 215 |
+
help="Percentage of val data to use (0-1, useful for few-shot to speed up validation)")
|
| 216 |
+
parser.add_argument("--patch_size_time", type=int, default=64)
|
| 217 |
+
parser.add_argument("--patch_size_ch", type=int, default=4,
|
| 218 |
+
help="Channel patch size for 2D patchify (default: 4)")
|
| 219 |
+
parser.add_argument("--num_workers", type=int, default=32)
|
| 220 |
+
parser.add_argument("--num_devices", type=int, default=1)
|
| 221 |
+
parser.add_argument("--max_epochs", type=int, default=10)
|
| 222 |
+
parser.add_argument("--max_steps", type=int, default=2500)
|
| 223 |
+
parser.add_argument("--early_stopping", action="store_true",
|
| 224 |
+
help="Enable early stopping based on val metric (useful for few-shot)")
|
| 225 |
+
parser.add_argument("--early_stopping_patience", type=int, default=10,
|
| 226 |
+
help="Patience for early stopping (number of val checks without improvement)")
|
| 227 |
+
parser.add_argument("--accumulate_grad_batches", type=int, default=1)
|
| 228 |
+
parser.add_argument("--ckpt_path", type=str, default="")
|
| 229 |
+
parser.add_argument("--lr", type=float, default=1e-2)
|
| 230 |
+
parser.add_argument("--num_classes", type=int, default=2)
|
| 231 |
+
parser.add_argument("--in_features", type=int, default=256)
|
| 232 |
+
parser.add_argument("--loss_type", type=str, default="ce", choices=["ce", "focal", "balanced_softmax"],
|
| 233 |
+
help="Loss type: 'ce' (cross-entropy), 'focal' (Focal Loss), or 'balanced_softmax' (Balanced Softmax)")
|
| 234 |
+
parser.add_argument("--focal_gamma", type=float, default=1.0,
|
| 235 |
+
help="Gamma parameter for Focal Loss (focusing parameter)")
|
| 236 |
+
parser.add_argument("--focal_alpha", type=float, default=None,
|
| 237 |
+
help="Alpha parameter for Focal Loss (class weighting). If None, computed from class distribution.")
|
| 238 |
+
parser.add_argument("--final_lr", type=float, default=0,
|
| 239 |
+
help="Final learning rate for cosine annealing scheduler")
|
| 240 |
+
parser.add_argument("--use_mean_pool", action="store_true",
|
| 241 |
+
help="Use mean pooling of all patches instead of CLS token for feature extraction")
|
| 242 |
+
parser.add_argument("--task_type", type=str, default="classification",
|
| 243 |
+
choices=["classification", "regression"],
|
| 244 |
+
help="Task type: classification or regression")
|
| 245 |
+
parser.add_argument("--regression_targets", type=str, nargs="*", default=["HR", "SPO2"],
|
| 246 |
+
help="Regression targets, e.g., --regression_targets HR SPO2")
|
| 247 |
+
parser.add_argument("--filter_spo2_min", type=float, default=None,
|
| 248 |
+
help="Filter out SPO2 values below this threshold (e.g., 70). Only applies when SPO2 is a regression target.")
|
| 249 |
+
parser.add_argument("--filter_spo2_max", type=float, default=None,
|
| 250 |
+
help="Filter out SPO2 values above this threshold (e.g., 100). Only applies when SPO2 is a regression target.")
|
| 251 |
+
parser.add_argument("--finetune_backbone", action="store_true",
|
| 252 |
+
help="If set, finetune the entire backbone (full finetuning); otherwise linear probing only")
|
| 253 |
+
parser.add_argument("--precision", type=str, default="32-true",
|
| 254 |
+
choices=["32-true", "16-mixed", "bf16-mixed"],
|
| 255 |
+
help="Training precision: 32-true (full), 16-mixed (FP16), bf16-mixed (BF16)")
|
| 256 |
+
parser.add_argument("--sample_rate", type=int, default=64,
|
| 257 |
+
help="Input sample rate in Hz (default: 64). Use 32 for half resolution.")
|
| 258 |
+
hparams = parser.parse_args()
|
| 259 |
+
|
| 260 |
+
seed_everything(hparams.seed)
|
| 261 |
+
main(hparams)
|
main_pipelines/main_pretrain.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pprint import pprint
|
| 2 |
+
import os
|
| 3 |
+
from argparse import ArgumentParser, Namespace
|
| 4 |
+
import datetime
|
| 5 |
+
from dateutil import tz
|
| 6 |
+
import random
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import warnings
|
| 10 |
+
from datetime import timedelta
|
| 11 |
+
from pytorch_lightning import seed_everything, Trainer
|
| 12 |
+
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, Callback
|
| 13 |
+
from pytorch_lightning.loggers import WandbLogger
|
| 14 |
+
from pytorch_lightning.strategies import DDPStrategy
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class DenseStepCheckpoint(Callback):
|
| 18 |
+
"""Save checkpoints at specific training steps."""
|
| 19 |
+
|
| 20 |
+
def __init__(self, dirpath: str, save_steps: list = None):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.dirpath = dirpath
|
| 23 |
+
self.save_steps = set(save_steps) if save_steps else {1, 10, 100, 1000, 10000, 100000}
|
| 24 |
+
self.saved_steps = set()
|
| 25 |
+
|
| 26 |
+
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
| 27 |
+
global_step = trainer.global_step
|
| 28 |
+
if global_step in self.save_steps and global_step not in self.saved_steps:
|
| 29 |
+
ckpt_path = os.path.join(self.dirpath, f"step={global_step}.ckpt")
|
| 30 |
+
trainer.save_checkpoint(ckpt_path)
|
| 31 |
+
self.saved_steps.add(global_step)
|
| 32 |
+
if trainer.is_global_zero:
|
| 33 |
+
print(f"[DenseStepCheckpoint] Saved checkpoint at step {global_step}: {ckpt_path}")
|
| 34 |
+
|
| 35 |
+
from osf.datasets.pretrain_datamodule import SleepDataModule
|
| 36 |
+
from osf.models.dino_model_cls import DINOCLSModel
|
| 37 |
+
from config import *
|
| 38 |
+
from train_config import *
|
| 39 |
+
|
| 40 |
+
warnings.filterwarnings("ignore")
|
| 41 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 42 |
+
torch.backends.cudnn.deterministic = True
|
| 43 |
+
torch.backends.cudnn.benchmark = True
|
| 44 |
+
torch.set_float32_matmul_precision('high')
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
torch._dynamo.config.cache_size_limit = 128
|
| 48 |
+
torch._dynamo.config.optimize_ddp = False
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def param_stats(model: torch.nn.Module, verbose: bool = False):
|
| 53 |
+
total = sum(p.numel() for p in model.parameters())
|
| 54 |
+
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 55 |
+
if verbose:
|
| 56 |
+
print(f"{'Name':40s} {'Shape':20s} {'#Params':>10s} {'Train?':>6s}")
|
| 57 |
+
print("-" * 80)
|
| 58 |
+
for name, p in model.named_parameters():
|
| 59 |
+
print(f"{name:40s} {str(list(p.shape)):20s} {p.numel():10d} {str(p.requires_grad):>6s}")
|
| 60 |
+
print("-" * 80)
|
| 61 |
+
print(f"Total parameters: {total / 1e6:.3f} M ({total})")
|
| 62 |
+
print(f" Trainable params: {trainable / 1e6:.3f} M ({trainable})")
|
| 63 |
+
print(f" Frozen params: {(total-trainable) / 1e6:.3f} M ({total-trainable})")
|
| 64 |
+
def main(hparams: Namespace):
|
| 65 |
+
|
| 66 |
+
now = datetime.datetime.now(tz.tzlocal())
|
| 67 |
+
extension = now.strftime("%Y_%m_%d_%H_%M_%S")
|
| 68 |
+
extension = f"final_sleep_unimodal_{hparams.model_name}_{hparams.psg_encoder_name}_bz{hparams.batch_size}_{extension}"
|
| 69 |
+
ckpt_dir = os.path.join(
|
| 70 |
+
CKPT_PATH, f"logs/sleepuni/ckpts/{extension}")
|
| 71 |
+
os.makedirs(ckpt_dir, exist_ok=True)
|
| 72 |
+
if hparams.model_name in MODEL_LIST:
|
| 73 |
+
callbacks = [
|
| 74 |
+
LearningRateMonitor(logging_interval="step"),
|
| 75 |
+
ModelCheckpoint(monitor="val/loss", dirpath=ckpt_dir,
|
| 76 |
+
save_last=True, every_n_epochs=2, mode="min", save_top_k=-1,
|
| 77 |
+
save_on_train_epoch_end=False, auto_insert_metric_name=True),
|
| 78 |
+
]
|
| 79 |
+
if hparams.dense_ckpt:
|
| 80 |
+
dense_ckpt_dir = os.path.join(ckpt_dir, "dense_steps")
|
| 81 |
+
os.makedirs(dense_ckpt_dir, exist_ok=True)
|
| 82 |
+
callbacks.append(DenseStepCheckpoint(
|
| 83 |
+
dirpath=dense_ckpt_dir,
|
| 84 |
+
save_steps=hparams.dense_ckpt_steps
|
| 85 |
+
))
|
| 86 |
+
else:
|
| 87 |
+
raise NotImplementedError
|
| 88 |
+
logger_dir = os.path.join(CKPT_PATH, "logs/sleepuni")
|
| 89 |
+
os.makedirs(logger_dir, exist_ok=True)
|
| 90 |
+
print("wandb logger dir: ", logger_dir)
|
| 91 |
+
wandb_logger = WandbLogger(
|
| 92 |
+
project=hparams.wandb_proj_name + f'final_{hparams.model_name}_{hparams.psg_encoder_name}_bz{hparams.batch_size}', save_dir=logger_dir, name=extension)
|
| 93 |
+
|
| 94 |
+
strategy = DDPStrategy(
|
| 95 |
+
find_unused_parameters=True,
|
| 96 |
+
static_graph=False,
|
| 97 |
+
timeout=timedelta(minutes=15),
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
trainer = Trainer(
|
| 101 |
+
max_epochs=hparams.max_epochs,
|
| 102 |
+
accelerator="gpu",
|
| 103 |
+
accumulate_grad_batches=hparams.accumulate_grad_batches,
|
| 104 |
+
devices=hparams.num_devices,
|
| 105 |
+
num_nodes=hparams.num_nodes,
|
| 106 |
+
precision=hparams.precision,
|
| 107 |
+
gradient_clip_val=3.0,
|
| 108 |
+
gradient_clip_algorithm="norm",
|
| 109 |
+
strategy=strategy,
|
| 110 |
+
callbacks=callbacks,
|
| 111 |
+
logger=wandb_logger,
|
| 112 |
+
log_every_n_steps=10,
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
hparams.exp_log_dir = os.path.join(
|
| 116 |
+
CKPT_PATH, f"data/{extension}/exp_logs")
|
| 117 |
+
train_edf_cols = MONITOR_TYPE_MAP.get(hparams.monitor_type, TRAIN_EDF_COLS_UNI_ENC)
|
| 118 |
+
hparams.num_leads = len(train_edf_cols)
|
| 119 |
+
|
| 120 |
+
dm = SleepDataModule(
|
| 121 |
+
is_pretrain = 1,
|
| 122 |
+
csv_dir = SPLIT_DATA_FOLDER,
|
| 123 |
+
train_edf_cols = train_edf_cols,
|
| 124 |
+
batch_size = hparams.batch_size,
|
| 125 |
+
num_workers = hparams.num_workers,
|
| 126 |
+
data_pct = hparams.train_data_pct,
|
| 127 |
+
window_size = 30,
|
| 128 |
+
sample_rate = 64,
|
| 129 |
+
val_dataset_list = hparams.val_dataset_list,
|
| 130 |
+
data_source = hparams.data_source,
|
| 131 |
+
include_datasets = hparams.include_datasets,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
hparams.simclr_augmentation = AUGMENTATION_MAP.get(hparams.model_name, "none")
|
| 135 |
+
|
| 136 |
+
# Create DINO model
|
| 137 |
+
model = DINOCLSModel(**vars(hparams))
|
| 138 |
+
model.training_steps_per_epoch = len(dm.train_dataloader()) // hparams.accumulate_grad_batches // hparams.num_devices
|
| 139 |
+
model.teacher_temp_warmup_iters = model.training_steps_per_epoch * 0.1 * hparams.max_epochs
|
| 140 |
+
print(f"[INFO] DINO teacher warmup steps: {model.teacher_temp_warmup_iters}")
|
| 141 |
+
pprint(vars(hparams))
|
| 142 |
+
|
| 143 |
+
if hparams.ckpt_path:
|
| 144 |
+
trainer.fit(model, datamodule = dm, ckpt_path=hparams.ckpt_path)
|
| 145 |
+
else:
|
| 146 |
+
trainer.fit(model, datamodule = dm)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
if __name__ == '__main__':
|
| 150 |
+
parser = ArgumentParser(description="Pretraining DINO model for sleep PSG data.")
|
| 151 |
+
parser.add_argument("--model_name", type=str, default="dino_ours",
|
| 152 |
+
choices=MODEL_LIST)
|
| 153 |
+
|
| 154 |
+
parser.add_argument("--psg_encoder_name", type=str, default="vit_base")
|
| 155 |
+
parser.add_argument("--val_dataset_list", default=PRETRAIN_VAL_DATASET_LIST)
|
| 156 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 157 |
+
parser.add_argument("--train_data_pct", type=float, default=1.)
|
| 158 |
+
parser.add_argument("--data_source", type=str, default="auto",
|
| 159 |
+
choices=["auto", "pretrain", "downstream", "both"])
|
| 160 |
+
parser.add_argument("--include_datasets", type=str, nargs="*", default=None)
|
| 161 |
+
parser.add_argument("--monitor_type", type=str, default="main",
|
| 162 |
+
choices=["main", "type3", "type4"],
|
| 163 |
+
help="Channel configuration: main (12ch), type3 (5ch), type4 (3ch)")
|
| 164 |
+
parser.add_argument("--batch_size", type=int, default=32)
|
| 165 |
+
parser.add_argument("--patch_size_time", type=int, default=4)
|
| 166 |
+
parser.add_argument("--patch_size_ch", type=int, default=4)
|
| 167 |
+
parser.add_argument("--use_2d_pos_embed", type=bool, default=True)
|
| 168 |
+
parser.add_argument("--sample_rate", type=int, default=64)
|
| 169 |
+
parser.add_argument("--num_workers", type=int, default=64)
|
| 170 |
+
parser.add_argument("--num_devices", type=int, default=4)
|
| 171 |
+
parser.add_argument("--num_nodes", type=int, default=1)
|
| 172 |
+
parser.add_argument("--max_epochs", type=int, default=30)
|
| 173 |
+
parser.add_argument("--accumulate_grad_batches", type=int, default=1)
|
| 174 |
+
parser.add_argument("--precision", type=str, default="32-true")
|
| 175 |
+
parser.add_argument("--lr", type=float, default=1e-4)
|
| 176 |
+
parser.add_argument("--text_encoder_name", type=str, default="google/flan-t5-base")
|
| 177 |
+
parser.add_argument("--lead_wise", type=int, default=0)
|
| 178 |
+
parser.add_argument("--use_lead_embedding", type=int, default=1)
|
| 179 |
+
# DINO-specific args
|
| 180 |
+
parser.add_argument("--koleo_lambda", type=float, default=0.0)
|
| 181 |
+
parser.add_argument("--ibot_lambda", type=float, default=0.0)
|
| 182 |
+
parser.add_argument("--dino_out_dim", type=int, default=2048)
|
| 183 |
+
parser.add_argument("--dino_patch_out_dim", type=int, default=2048)
|
| 184 |
+
parser.add_argument("--dino_hidden_dim", type=int, default=2048)
|
| 185 |
+
parser.add_argument("--dino_bottleneck_dim", type=int, default=256)
|
| 186 |
+
parser.add_argument("--wandb_proj_name", type=str, default="sleepuni")
|
| 187 |
+
parser.add_argument("--ckpt_path", type=str, default=None)
|
| 188 |
+
parser.add_argument("--dense_ckpt", action="store_true")
|
| 189 |
+
parser.add_argument("--dense_ckpt_steps", type=int, nargs="+", default=[10, 100, 200, 400, 500, 800, 1000, 1600, 2500, 3200, 6400, 10000, 12500, 12800, 25600, 51200, 62500, 100000])
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
hparams = parser.parse_args()
|
| 194 |
+
|
| 195 |
+
seed_everything(hparams.seed)
|
| 196 |
+
main(hparams)
|
osf/__init__.py
ADDED
|
File without changes
|
osf/backbone/__init__.py
ADDED
|
File without changes
|
osf/backbone/pos_embed.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
|
| 11 |
+
"""
|
| 12 |
+
grid_size: int of the grid height and width
|
| 13 |
+
return:
|
| 14 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
| 15 |
+
"""
|
| 16 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
| 17 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
| 18 |
+
grid = np.meshgrid(grid_w, grid_h)
|
| 19 |
+
grid = np.stack(grid, axis=0)
|
| 20 |
+
|
| 21 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
| 22 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 23 |
+
if cls_token:
|
| 24 |
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
| 25 |
+
return pos_embed
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 29 |
+
assert embed_dim % 2 == 0
|
| 30 |
+
|
| 31 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
| 32 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
| 33 |
+
|
| 34 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
| 35 |
+
return emb
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 39 |
+
"""
|
| 40 |
+
embed_dim: output dimension for each position
|
| 41 |
+
pos: a list of positions to be encoded: size (M,)
|
| 42 |
+
out: (M, D)
|
| 43 |
+
"""
|
| 44 |
+
assert embed_dim % 2 == 0
|
| 45 |
+
omega = np.arange(embed_dim // 2, dtype=float)
|
| 46 |
+
omega /= embed_dim / 2.
|
| 47 |
+
omega = 1. / 10000**omega
|
| 48 |
+
|
| 49 |
+
pos = pos.reshape(-1) # (M,)
|
| 50 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2)
|
| 51 |
+
|
| 52 |
+
emb_sin = np.sin(out) # (M, D/2)
|
| 53 |
+
emb_cos = np.cos(out) # (M, D/2)
|
| 54 |
+
|
| 55 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
| 56 |
+
return emb
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def interpolate_pos_embed(model, checkpoint_model):
|
| 60 |
+
if 'pos_embed' in checkpoint_model:
|
| 61 |
+
pos_embed_checkpoint = checkpoint_model['pos_embed']
|
| 62 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
| 63 |
+
num_patches = model.patch_embed.num_patches
|
| 64 |
+
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
| 65 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
| 66 |
+
new_size = int(num_patches ** 0.5)
|
| 67 |
+
if orig_size != new_size:
|
| 68 |
+
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
| 69 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
| 70 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
| 71 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
| 72 |
+
pos_tokens = torch.nn.functional.interpolate(
|
| 73 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
| 74 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
| 75 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
| 76 |
+
checkpoint_model['pos_embed'] = new_pos_embed
|
osf/backbone/vit1d.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
1D Vision Transformer for time-series signals.
|
| 3 |
+
|
| 4 |
+
Patchify modes:
|
| 5 |
+
- lead_wise=0: 1D patchify (all channels in one patch), no lead embedding
|
| 6 |
+
- lead_wise=1: 2D patchify (channel groups), with lead embedding by default
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class DropPath(nn.Module):
|
| 15 |
+
def __init__(self, drop_prob: float, scale_by_keep: bool = True):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.drop_prob = drop_prob
|
| 18 |
+
self.scale_by_keep = scale_by_keep
|
| 19 |
+
|
| 20 |
+
def forward(self, x):
|
| 21 |
+
if self.drop_prob <= 0. or not self.training:
|
| 22 |
+
return x
|
| 23 |
+
keep_prob = 1 - self.drop_prob
|
| 24 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
| 25 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 26 |
+
if keep_prob > 0.0 and self.scale_by_keep:
|
| 27 |
+
random_tensor.div_(keep_prob)
|
| 28 |
+
return x * random_tensor
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class PreNorm(nn.Module):
|
| 32 |
+
def __init__(self, dim: int, fn: nn.Module):
|
| 33 |
+
super().__init__()
|
| 34 |
+
self.norm = nn.LayerNorm(dim)
|
| 35 |
+
self.fn = fn
|
| 36 |
+
|
| 37 |
+
def forward(self, x, **kwargs):
|
| 38 |
+
return self.fn(self.norm(x), **kwargs)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class FeedForward(nn.Module):
|
| 42 |
+
def __init__(self, input_dim: int, output_dim: int, hidden_dim: int, drop_out_rate=0.):
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.net = nn.Sequential(
|
| 45 |
+
nn.Linear(input_dim, hidden_dim),
|
| 46 |
+
nn.GELU(),
|
| 47 |
+
nn.Dropout(drop_out_rate),
|
| 48 |
+
nn.Linear(hidden_dim, output_dim),
|
| 49 |
+
nn.Dropout(drop_out_rate)
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
def forward(self, x):
|
| 53 |
+
return self.net(x)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class Attention(nn.Module):
|
| 57 |
+
def __init__(self, input_dim: int, output_dim: int, heads: int = 8, dim_head: int = 64,
|
| 58 |
+
qkv_bias: bool = True, drop_out_rate: float = 0., attn_drop_out_rate: float = 0.):
|
| 59 |
+
super().__init__()
|
| 60 |
+
inner_dim = dim_head * heads
|
| 61 |
+
project_out = not (heads == 1 and dim_head == input_dim)
|
| 62 |
+
|
| 63 |
+
self.heads = heads
|
| 64 |
+
self.scale = dim_head ** -0.5
|
| 65 |
+
self.attend = nn.Softmax(dim=-1)
|
| 66 |
+
self.dropout = nn.Dropout(attn_drop_out_rate)
|
| 67 |
+
self.to_qkv = nn.Linear(input_dim, inner_dim * 3, bias=qkv_bias)
|
| 68 |
+
|
| 69 |
+
if project_out:
|
| 70 |
+
self.to_out = nn.Sequential(nn.Linear(inner_dim, output_dim), nn.Dropout(drop_out_rate))
|
| 71 |
+
else:
|
| 72 |
+
self.to_out = nn.Identity()
|
| 73 |
+
|
| 74 |
+
def forward(self, x):
|
| 75 |
+
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
| 76 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv)
|
| 77 |
+
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
| 78 |
+
attn = self.attend(dots)
|
| 79 |
+
attn = self.dropout(attn)
|
| 80 |
+
out = torch.matmul(attn, v)
|
| 81 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
| 82 |
+
return self.to_out(out)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class TransformerBlock(nn.Module):
|
| 86 |
+
def __init__(self, input_dim: int, output_dim: int, hidden_dim: int, heads: int = 8,
|
| 87 |
+
dim_head: int = 32, qkv_bias: bool = True, drop_out_rate: float = 0.,
|
| 88 |
+
attn_drop_out_rate: float = 0., drop_path_rate: float = 0.):
|
| 89 |
+
super().__init__()
|
| 90 |
+
attn = Attention(input_dim, output_dim, heads, dim_head, qkv_bias, drop_out_rate, attn_drop_out_rate)
|
| 91 |
+
self.attn = PreNorm(input_dim, attn)
|
| 92 |
+
self.droppath1 = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
|
| 93 |
+
|
| 94 |
+
ff = FeedForward(output_dim, output_dim, hidden_dim, drop_out_rate)
|
| 95 |
+
self.ff = PreNorm(output_dim, ff)
|
| 96 |
+
self.droppath2 = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
|
| 97 |
+
|
| 98 |
+
def forward(self, x):
|
| 99 |
+
x = self.droppath1(self.attn(x)) + x
|
| 100 |
+
x = self.droppath2(self.ff(x)) + x
|
| 101 |
+
return x
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class ViT(nn.Module):
|
| 105 |
+
def __init__(self,
|
| 106 |
+
num_leads: int,
|
| 107 |
+
seq_len: int,
|
| 108 |
+
patch_size: int,
|
| 109 |
+
lead_wise=0,
|
| 110 |
+
patch_size_ch=4,
|
| 111 |
+
use_lead_embedding: bool = True,
|
| 112 |
+
width: int = 768,
|
| 113 |
+
depth: int = 12,
|
| 114 |
+
mlp_dim: int = 3072,
|
| 115 |
+
heads: int = 12,
|
| 116 |
+
dim_head: int = 64,
|
| 117 |
+
qkv_bias: bool = True,
|
| 118 |
+
drop_out_rate: float = 0.,
|
| 119 |
+
attn_drop_out_rate: float = 0.,
|
| 120 |
+
drop_path_rate: float = 0.,
|
| 121 |
+
**kwargs):
|
| 122 |
+
super().__init__()
|
| 123 |
+
assert seq_len % patch_size == 0
|
| 124 |
+
num_patches = seq_len // patch_size
|
| 125 |
+
self.lead_wise = lead_wise
|
| 126 |
+
self.use_lead_embedding = use_lead_embedding
|
| 127 |
+
|
| 128 |
+
if lead_wise == 0:
|
| 129 |
+
self.to_patch_embedding = nn.Conv1d(num_leads, width, kernel_size=patch_size, stride=patch_size, bias=False)
|
| 130 |
+
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, width))
|
| 131 |
+
else:
|
| 132 |
+
self.to_patch_embedding = nn.Conv2d(1, width, kernel_size=(patch_size_ch, patch_size),
|
| 133 |
+
stride=(patch_size_ch, patch_size), bias=False)
|
| 134 |
+
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches * num_leads // patch_size_ch, width))
|
| 135 |
+
if use_lead_embedding:
|
| 136 |
+
self.lead_emb = nn.Embedding(num_leads // patch_size_ch, width)
|
| 137 |
+
else:
|
| 138 |
+
self.lead_emb = None
|
| 139 |
+
|
| 140 |
+
self.dropout = nn.Dropout(drop_out_rate)
|
| 141 |
+
self.depth = depth
|
| 142 |
+
self.width = width
|
| 143 |
+
|
| 144 |
+
drop_path_rate_list = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
|
| 145 |
+
for i in range(depth):
|
| 146 |
+
block = TransformerBlock(width, width, mlp_dim, heads, dim_head, qkv_bias,
|
| 147 |
+
drop_out_rate, attn_drop_out_rate, drop_path_rate_list[i])
|
| 148 |
+
self.add_module(f'block{i}', block)
|
| 149 |
+
|
| 150 |
+
self.norm = nn.LayerNorm(width)
|
| 151 |
+
self.head = nn.Identity()
|
| 152 |
+
|
| 153 |
+
def _patchify_and_embed(self, series: torch.Tensor) -> torch.Tensor:
|
| 154 |
+
"""Patchify input and add positional/lead embeddings. [B,C,T] -> [B,N,D]"""
|
| 155 |
+
if self.lead_wise == 0:
|
| 156 |
+
x = self.to_patch_embedding(series) # [B, D, N]
|
| 157 |
+
x = rearrange(x, 'b c n -> b n c') # [B, N, D]
|
| 158 |
+
x = x + self.pos_embedding[:, :x.size(1), :].to(x.device)
|
| 159 |
+
else:
|
| 160 |
+
x = self.to_patch_embedding(series.unsqueeze(1)) # [B, D, Lr, Nt]
|
| 161 |
+
Lr, Nt = x.shape[-2], x.shape[-1]
|
| 162 |
+
x = rearrange(x, 'b c lr nt -> b (lr nt) c') # [B, N, D]
|
| 163 |
+
x = x + self.pos_embedding[:, :x.size(1), :].to(x.device)
|
| 164 |
+
if self.use_lead_embedding and self.lead_emb is not None:
|
| 165 |
+
row_ids = torch.arange(Lr, device=x.device).repeat_interleave(Nt)
|
| 166 |
+
x = x + self.lead_emb(row_ids)[None, :, :]
|
| 167 |
+
return x
|
| 168 |
+
|
| 169 |
+
def forward_encoding(self, series: torch.Tensor) -> torch.Tensor:
|
| 170 |
+
"""Encode series. Returns [B,D] (mean pooled)."""
|
| 171 |
+
x = self._patchify_and_embed(series)
|
| 172 |
+
x = self.dropout(x)
|
| 173 |
+
for i in range(self.depth):
|
| 174 |
+
x = getattr(self, f'block{i}')(x)
|
| 175 |
+
x = x.mean(dim=1)
|
| 176 |
+
return self.norm(x)
|
| 177 |
+
|
| 178 |
+
def forward(self, series):
|
| 179 |
+
x = self.forward_encoding(series)
|
| 180 |
+
return self.head(x)
|
| 181 |
+
|
| 182 |
+
def reset_head(self, num_classes=1):
|
| 183 |
+
del self.head
|
| 184 |
+
self.head = nn.Linear(self.width, num_classes)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def vit_nano(num_leads, num_classes=1, seq_len=5000, patch_size=50, **kwargs):
|
| 188 |
+
return ViT(num_leads=num_leads, num_classes=num_classes, seq_len=seq_len, patch_size=patch_size,
|
| 189 |
+
width=128, depth=6, heads=4, mlp_dim=512, **kwargs)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def vit_tiny(num_leads, num_classes=1, seq_len=5000, patch_size=50, **kwargs):
|
| 193 |
+
return ViT(num_leads=num_leads, num_classes=num_classes, seq_len=seq_len, patch_size=patch_size,
|
| 194 |
+
width=192, depth=12, heads=3, mlp_dim=768, **kwargs)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def vit_small(num_leads, num_classes=1, seq_len=5000, patch_size=50, **kwargs):
|
| 198 |
+
return ViT(num_leads=num_leads, num_classes=num_classes, seq_len=seq_len, patch_size=patch_size,
|
| 199 |
+
width=384, depth=12, heads=6, mlp_dim=1536, **kwargs)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def vit_middle(num_leads, num_classes=1, seq_len=5000, patch_size=50, **kwargs):
|
| 203 |
+
return ViT(num_leads=num_leads, num_classes=num_classes, seq_len=seq_len, patch_size=patch_size,
|
| 204 |
+
width=512, depth=12, heads=8, mlp_dim=2048, **kwargs)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def vit_base(num_leads, num_classes=1, seq_len=5000, patch_size=50, **kwargs):
|
| 208 |
+
return ViT(num_leads=num_leads, num_classes=num_classes, seq_len=seq_len, patch_size=patch_size,
|
| 209 |
+
width=768, depth=12, heads=12, mlp_dim=3072, **kwargs)
|
osf/backbone/vit1d_cls.py
ADDED
|
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
1D Vision Transformer with CLS token support.
|
| 3 |
+
|
| 4 |
+
Patchify modes:
|
| 5 |
+
- lead_wise=0: 1D patchify (all channels in one patch)
|
| 6 |
+
- lead_wise=1: 2D patchify (channel groups)
|
| 7 |
+
|
| 8 |
+
Note: lead_emb is DEPRECATED and not used in data flow. It is kept only for
|
| 9 |
+
checkpoint compatibility. Do NOT add lead_emb usage without careful consideration.
|
| 10 |
+
"""
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
from einops import rearrange
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class DropPath(nn.Module):
|
| 17 |
+
'''
|
| 18 |
+
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 19 |
+
'''
|
| 20 |
+
def __init__(self, drop_prob: float, scale_by_keep: bool = True):
|
| 21 |
+
super(DropPath, self).__init__()
|
| 22 |
+
self.drop_prob = drop_prob
|
| 23 |
+
self.scale_by_keep = scale_by_keep
|
| 24 |
+
|
| 25 |
+
def forward(self, x):
|
| 26 |
+
if self.drop_prob <= 0. or not self.training:
|
| 27 |
+
return x
|
| 28 |
+
keep_prob = 1 - self.drop_prob
|
| 29 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
| 30 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 31 |
+
if keep_prob > 0.0 and self.scale_by_keep:
|
| 32 |
+
random_tensor.div_(keep_prob)
|
| 33 |
+
return x * random_tensor
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class PreNorm(nn.Module):
|
| 37 |
+
def __init__(self,
|
| 38 |
+
dim: int,
|
| 39 |
+
fn: nn.Module):
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.norm = nn.LayerNorm(dim)
|
| 42 |
+
self.fn = fn
|
| 43 |
+
|
| 44 |
+
def forward(self, x, **kwargs):
|
| 45 |
+
return self.fn(self.norm(x), **kwargs)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class FeedForward(nn.Module):
|
| 49 |
+
"""
|
| 50 |
+
MLP Module with GELU activation fn + dropout.
|
| 51 |
+
"""
|
| 52 |
+
def __init__(self,
|
| 53 |
+
input_dim: int,
|
| 54 |
+
output_dim: int,
|
| 55 |
+
hidden_dim: int,
|
| 56 |
+
drop_out_rate=0.):
|
| 57 |
+
super().__init__()
|
| 58 |
+
self.net = nn.Sequential(nn.Linear(input_dim, hidden_dim),
|
| 59 |
+
nn.GELU(),
|
| 60 |
+
nn.Dropout(drop_out_rate),
|
| 61 |
+
nn.Linear(hidden_dim, output_dim),
|
| 62 |
+
nn.Dropout(drop_out_rate))
|
| 63 |
+
|
| 64 |
+
def forward(self, x):
|
| 65 |
+
return self.net(x)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class Attention(nn.Module):
|
| 69 |
+
def __init__(self,
|
| 70 |
+
input_dim: int,
|
| 71 |
+
output_dim: int,
|
| 72 |
+
heads: int = 8,
|
| 73 |
+
dim_head: int = 64,
|
| 74 |
+
qkv_bias: bool = True,
|
| 75 |
+
drop_out_rate: float = 0.,
|
| 76 |
+
attn_drop_out_rate: float = 0.):
|
| 77 |
+
super().__init__()
|
| 78 |
+
inner_dim = dim_head * heads
|
| 79 |
+
project_out = not (heads == 1 and dim_head == input_dim)
|
| 80 |
+
|
| 81 |
+
self.heads = heads
|
| 82 |
+
self.scale = dim_head ** -0.5
|
| 83 |
+
|
| 84 |
+
self.attend = nn.Softmax(dim=-1)
|
| 85 |
+
self.dropout = nn.Dropout(attn_drop_out_rate)
|
| 86 |
+
self.to_qkv = nn.Linear(input_dim, inner_dim * 3, bias=qkv_bias)
|
| 87 |
+
|
| 88 |
+
if project_out:
|
| 89 |
+
self.to_out = nn.Sequential(nn.Linear(inner_dim, output_dim),
|
| 90 |
+
nn.Dropout(drop_out_rate))
|
| 91 |
+
else:
|
| 92 |
+
self.to_out = nn.Identity()
|
| 93 |
+
|
| 94 |
+
def forward(self, x):
|
| 95 |
+
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
| 96 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv)
|
| 97 |
+
|
| 98 |
+
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
| 99 |
+
|
| 100 |
+
attn = self.attend(dots)
|
| 101 |
+
attn = self.dropout(attn)
|
| 102 |
+
out = torch.matmul(attn, v)
|
| 103 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
| 104 |
+
out = self.to_out(out)
|
| 105 |
+
return out
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class TransformerBlock(nn.Module):
|
| 109 |
+
def __init__(self,
|
| 110 |
+
input_dim: int,
|
| 111 |
+
output_dim: int,
|
| 112 |
+
hidden_dim: int,
|
| 113 |
+
heads: int = 8,
|
| 114 |
+
dim_head: int = 32,
|
| 115 |
+
qkv_bias: bool = True,
|
| 116 |
+
drop_out_rate: float = 0.,
|
| 117 |
+
attn_drop_out_rate: float = 0.,
|
| 118 |
+
drop_path_rate: float = 0.):
|
| 119 |
+
super().__init__()
|
| 120 |
+
attn = Attention(input_dim=input_dim,
|
| 121 |
+
output_dim=output_dim,
|
| 122 |
+
heads=heads,
|
| 123 |
+
dim_head=dim_head,
|
| 124 |
+
qkv_bias=qkv_bias,
|
| 125 |
+
drop_out_rate=drop_out_rate,
|
| 126 |
+
attn_drop_out_rate=attn_drop_out_rate)
|
| 127 |
+
self.attn = PreNorm(dim=input_dim,
|
| 128 |
+
fn=attn)
|
| 129 |
+
self.droppath1 = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
|
| 130 |
+
|
| 131 |
+
ff = FeedForward(input_dim=output_dim,
|
| 132 |
+
output_dim=output_dim,
|
| 133 |
+
hidden_dim=hidden_dim,
|
| 134 |
+
drop_out_rate=drop_out_rate)
|
| 135 |
+
self.ff = PreNorm(dim=output_dim,
|
| 136 |
+
fn=ff)
|
| 137 |
+
self.droppath2 = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
|
| 138 |
+
|
| 139 |
+
def forward(self, x):
|
| 140 |
+
x = self.droppath1(self.attn(x)) + x
|
| 141 |
+
x = self.droppath2(self.ff(x)) + x
|
| 142 |
+
return x
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class ViT(nn.Module):
|
| 146 |
+
|
| 147 |
+
def __init__(self,
|
| 148 |
+
num_leads: int,
|
| 149 |
+
seq_len: int,
|
| 150 |
+
patch_size: int,
|
| 151 |
+
lead_wise: int = 0,
|
| 152 |
+
patch_size_ch: int = 4,
|
| 153 |
+
width: int = 768,
|
| 154 |
+
depth: int = 12,
|
| 155 |
+
mlp_dim: int = 3072,
|
| 156 |
+
heads: int = 12,
|
| 157 |
+
dim_head: int = 64,
|
| 158 |
+
qkv_bias: bool = True,
|
| 159 |
+
drop_out_rate: float = 0.,
|
| 160 |
+
attn_drop_out_rate: float = 0.,
|
| 161 |
+
drop_path_rate: float = 0.,
|
| 162 |
+
**kwargs):
|
| 163 |
+
super().__init__()
|
| 164 |
+
assert seq_len % patch_size == 0
|
| 165 |
+
num_patches_time = seq_len // patch_size
|
| 166 |
+
|
| 167 |
+
self.lead_wise = lead_wise
|
| 168 |
+
self.width = width
|
| 169 |
+
self.depth = depth
|
| 170 |
+
|
| 171 |
+
if lead_wise == 0:
|
| 172 |
+
self.to_patch_embedding = nn.Conv1d(num_leads, width, kernel_size=patch_size,
|
| 173 |
+
stride=patch_size, bias=False)
|
| 174 |
+
N_max = num_patches_time
|
| 175 |
+
self.lead_emb = None
|
| 176 |
+
else:
|
| 177 |
+
self.to_patch_embedding = nn.Conv2d(1, width,
|
| 178 |
+
kernel_size=(patch_size_ch, patch_size),
|
| 179 |
+
stride=(patch_size_ch, patch_size),
|
| 180 |
+
bias=False)
|
| 181 |
+
Lr = num_leads // patch_size_ch
|
| 182 |
+
N_max = Lr * num_patches_time
|
| 183 |
+
self.lead_emb = nn.Embedding(Lr, width)
|
| 184 |
+
|
| 185 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, width))
|
| 186 |
+
nn.init.trunc_normal_(self.cls_token, std=0.02)
|
| 187 |
+
self.pos_embedding = nn.Parameter(torch.zeros(1, N_max + 1, width))
|
| 188 |
+
nn.init.trunc_normal_(self.pos_embedding, std=0.02)
|
| 189 |
+
|
| 190 |
+
self.dropout = nn.Dropout(drop_out_rate)
|
| 191 |
+
|
| 192 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
|
| 193 |
+
for i in range(depth):
|
| 194 |
+
block = TransformerBlock(input_dim=width, output_dim=width,
|
| 195 |
+
hidden_dim=mlp_dim, heads=heads, dim_head=dim_head,
|
| 196 |
+
qkv_bias=qkv_bias, drop_out_rate=drop_out_rate,
|
| 197 |
+
attn_drop_out_rate=attn_drop_out_rate,
|
| 198 |
+
drop_path_rate=dpr[i])
|
| 199 |
+
self.add_module(f'block{i}', block)
|
| 200 |
+
|
| 201 |
+
self.norm = nn.LayerNorm(width)
|
| 202 |
+
self.head = nn.Identity()
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def to_tokens_2d(self, series: torch.Tensor,
|
| 206 |
+
patch_size_ch: int | None = None,
|
| 207 |
+
patch_size_time: int | None = None):
|
| 208 |
+
"""Patchify only (no pos embedding). Returns (tokens, meta)."""
|
| 209 |
+
B, L, T = series.shape
|
| 210 |
+
|
| 211 |
+
if self.lead_wise == 0:
|
| 212 |
+
x = self.to_patch_embedding(series) # [B,C,Nt]
|
| 213 |
+
Nt = x.shape[-1]
|
| 214 |
+
x = rearrange(x, 'b c n -> b n c') # [B,Nt,C]
|
| 215 |
+
meta = dict(lead_wise=0, L=L, Nt=Nt, pz_ch=1)
|
| 216 |
+
return x, meta
|
| 217 |
+
|
| 218 |
+
# lead_wise == 1
|
| 219 |
+
if patch_size_ch is None or patch_size_time is None:
|
| 220 |
+
kch, kt = self.to_patch_embedding.kernel_size
|
| 221 |
+
patch_size_ch = patch_size_ch or kch
|
| 222 |
+
patch_size_time = patch_size_time or kt
|
| 223 |
+
assert L % patch_size_ch == 0 and T % patch_size_time == 0
|
| 224 |
+
|
| 225 |
+
x = series.unsqueeze(1) # [B,1,L,T]
|
| 226 |
+
x = self.to_patch_embedding(x) # [B,C,Lr,Nt]
|
| 227 |
+
Lr, Nt = x.shape[-2], x.shape[-1]
|
| 228 |
+
x = rearrange(x, 'b c lr nt -> b (lr nt) c') # [B,Lr*Nt,C]
|
| 229 |
+
meta = dict(lead_wise=1, L=L, Nt=Nt, pz_ch=patch_size_ch)
|
| 230 |
+
return x, meta
|
| 231 |
+
|
| 232 |
+
def forward_encoding(self, series: torch.Tensor,
|
| 233 |
+
return_sequence: bool = False):
|
| 234 |
+
"""Encode with CLS token. Returns (cls, patches) or full sequence if return_sequence=True."""
|
| 235 |
+
tokens, meta = self.to_tokens_2d(series)
|
| 236 |
+
B = tokens.size(0)
|
| 237 |
+
cls_tok = self.cls_token.expand(B, -1, -1)
|
| 238 |
+
x = torch.cat([cls_tok, tokens], dim=1) # [B,N+1,C]
|
| 239 |
+
|
| 240 |
+
pe = self.pos_embedding[:, :x.size(1), :].to(x.device)
|
| 241 |
+
|
| 242 |
+
x = x + pe
|
| 243 |
+
|
| 244 |
+
x = self._run_blocks(x)
|
| 245 |
+
if return_sequence:
|
| 246 |
+
return x
|
| 247 |
+
cls, patches = x[:, 0], x[:, 1:]
|
| 248 |
+
|
| 249 |
+
return cls, patches
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def _run_blocks(self, x: torch.Tensor):
|
| 253 |
+
x = self.dropout(x)
|
| 254 |
+
for i in range(self.depth):
|
| 255 |
+
x = getattr(self, f'block{i}')(x)
|
| 256 |
+
x = self.norm(x)
|
| 257 |
+
return self.head(x)
|
| 258 |
+
|
| 259 |
+
def forward(self, series: torch.Tensor):
|
| 260 |
+
cls, _ = self.forward_encoding(series, return_sequence=False)
|
| 261 |
+
return cls
|
| 262 |
+
|
| 263 |
+
def forward_avg_pool(self, series: torch.Tensor):
|
| 264 |
+
"""Returns avg-pooled patch embeddings. series: [B,C,T] -> [B,D]"""
|
| 265 |
+
_, patches = self.forward_encoding(series, return_sequence=False) # [B,N,D]
|
| 266 |
+
return patches.mean(dim=1) # [B,D]
|
| 267 |
+
|
| 268 |
+
def reset_head(self, num_classes=1):
|
| 269 |
+
del self.head
|
| 270 |
+
self.head = nn.Linear(self.width, num_classes)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def vit_nano(num_leads, num_classes=1, seq_len=5000, patch_size=50, **kwargs):
|
| 275 |
+
model_args = dict(num_leads=num_leads,
|
| 276 |
+
num_classes=num_classes,
|
| 277 |
+
seq_len=seq_len,
|
| 278 |
+
patch_size=patch_size,
|
| 279 |
+
width=128,
|
| 280 |
+
depth=6,
|
| 281 |
+
heads=4,
|
| 282 |
+
mlp_dim=512,
|
| 283 |
+
**kwargs)
|
| 284 |
+
return ViT(**model_args)
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def vit_tiny(num_leads, num_classes=1, seq_len=5000, patch_size=50, **kwargs):
|
| 288 |
+
model_args = dict(num_leads=num_leads,
|
| 289 |
+
num_classes=num_classes,
|
| 290 |
+
seq_len=seq_len,
|
| 291 |
+
patch_size=patch_size,
|
| 292 |
+
width=192,
|
| 293 |
+
depth=12,
|
| 294 |
+
heads=3,
|
| 295 |
+
mlp_dim=768,
|
| 296 |
+
**kwargs)
|
| 297 |
+
return ViT(**model_args)
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def vit_small(num_leads, num_classes=1, seq_len=5000, patch_size=50, **kwargs):
|
| 301 |
+
model_args = dict(num_leads=num_leads,
|
| 302 |
+
num_classes=num_classes,
|
| 303 |
+
seq_len=seq_len,
|
| 304 |
+
patch_size=patch_size,
|
| 305 |
+
width=384,
|
| 306 |
+
depth=12,
|
| 307 |
+
heads=6,
|
| 308 |
+
mlp_dim=1536,
|
| 309 |
+
**kwargs)
|
| 310 |
+
return ViT(**model_args)
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def vit_middle(num_leads, num_classes=1, seq_len=5000, patch_size=50, **kwargs):
|
| 314 |
+
model_args = dict(num_leads=num_leads,
|
| 315 |
+
num_classes=num_classes,
|
| 316 |
+
seq_len=seq_len,
|
| 317 |
+
patch_size=patch_size,
|
| 318 |
+
width=512,
|
| 319 |
+
depth=12,
|
| 320 |
+
heads=8,
|
| 321 |
+
mlp_dim=2048,
|
| 322 |
+
**kwargs)
|
| 323 |
+
return ViT(**model_args)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def vit_base(num_leads, num_classes=1, seq_len=5000, patch_size=50, **kwargs):
|
| 327 |
+
model_args = dict(num_leads=num_leads,
|
| 328 |
+
num_classes=num_classes,
|
| 329 |
+
seq_len=seq_len,
|
| 330 |
+
patch_size=patch_size,
|
| 331 |
+
width=768,
|
| 332 |
+
depth=12,
|
| 333 |
+
heads=12,
|
| 334 |
+
mlp_dim=3072,
|
| 335 |
+
**kwargs)
|
| 336 |
+
return ViT(**model_args)
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
def vit_large(num_leads, num_classes=1, seq_len=5000, patch_size=50, **kwargs):
|
| 340 |
+
return ViT(
|
| 341 |
+
num_leads=num_leads,
|
| 342 |
+
num_classes=num_classes,
|
| 343 |
+
seq_len=seq_len,
|
| 344 |
+
patch_size=patch_size,
|
| 345 |
+
width=1024,
|
| 346 |
+
depth=24,
|
| 347 |
+
heads=16,
|
| 348 |
+
mlp_dim=4096,
|
| 349 |
+
**kwargs
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
def vit_xl(num_leads, num_classes=1, seq_len=5000, patch_size=50, **kwargs):
|
| 353 |
+
return ViT(
|
| 354 |
+
num_leads=num_leads,
|
| 355 |
+
num_classes=num_classes,
|
| 356 |
+
seq_len=seq_len,
|
| 357 |
+
patch_size=patch_size,
|
| 358 |
+
width=1536,
|
| 359 |
+
depth=24,
|
| 360 |
+
heads=24,
|
| 361 |
+
mlp_dim=6144,
|
| 362 |
+
**kwargs
|
| 363 |
+
)
|
osf/datasets/__init__.py
ADDED
|
File without changes
|
osf/datasets/augmentations.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data augmentations for SSL pretraining (SimCLR, DINO).
|
| 3 |
+
"""
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from typing import Tuple
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@torch.no_grad()
|
| 10 |
+
def random_time_crop(
|
| 11 |
+
x: torch.Tensor,
|
| 12 |
+
ratio: Tuple[float, float] | float = (0.6, 0.9),
|
| 13 |
+
*,
|
| 14 |
+
resize_back: bool = True,
|
| 15 |
+
align_to: int | None = 40
|
| 16 |
+
) -> torch.Tensor:
|
| 17 |
+
"""
|
| 18 |
+
Randomly crop a contiguous sub-sequence per sample, optionally resize back to original T.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
x: (B, C, T)
|
| 22 |
+
ratio: crop length ratio in [low, high] or a float
|
| 23 |
+
resize_back: if True, linearly interpolate the cropped view back to length T
|
| 24 |
+
align_to: if not None, crop length is rounded to a multiple of align_to (>= align_to)
|
| 25 |
+
"""
|
| 26 |
+
assert x.dim() == 3, f"expected (B,C,T), got {tuple(x.shape)}"
|
| 27 |
+
B, C, T = x.shape
|
| 28 |
+
dev = x.device
|
| 29 |
+
|
| 30 |
+
def _sample_L() -> int:
|
| 31 |
+
if isinstance(ratio, (tuple, list)):
|
| 32 |
+
a, b = float(ratio[0]), float(ratio[1])
|
| 33 |
+
r = torch.empty((), device=dev).uniform_(a, b).item()
|
| 34 |
+
else:
|
| 35 |
+
r = float(ratio)
|
| 36 |
+
L = max(2, int(round(T * r)))
|
| 37 |
+
if align_to and align_to > 1:
|
| 38 |
+
L = max(align_to, int(round(L / align_to)) * align_to)
|
| 39 |
+
return min(L, T)
|
| 40 |
+
|
| 41 |
+
Ls = [_sample_L() for _ in range(B)]
|
| 42 |
+
outs = []
|
| 43 |
+
for b in range(B):
|
| 44 |
+
L = Ls[b]
|
| 45 |
+
max_start = max(0, T - L)
|
| 46 |
+
s = int(torch.randint(0, max_start + 1, (1,), device=dev).item())
|
| 47 |
+
v = x[b, :, s:s+L] # (C, L)
|
| 48 |
+
if resize_back and v.shape[-1] != T:
|
| 49 |
+
v = F.interpolate(v[None], size=T, mode="linear", align_corners=False)[0]
|
| 50 |
+
outs.append(v)
|
| 51 |
+
return torch.stack(outs, dim=0)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@torch.no_grad()
|
| 55 |
+
def channel_dropout(
|
| 56 |
+
x: torch.Tensor,
|
| 57 |
+
drop_prob: float = 0.2,
|
| 58 |
+
min_keep: int = 1
|
| 59 |
+
) -> torch.Tensor:
|
| 60 |
+
"""
|
| 61 |
+
Drop entire channels to zero with probability drop_prob (per sample, per channel).
|
| 62 |
+
Ensures at least `min_keep` channels remain active in each sample.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
x: (B, C, T)
|
| 66 |
+
drop_prob: probability to drop each channel
|
| 67 |
+
min_keep: minimum number of channels to keep per sample
|
| 68 |
+
"""
|
| 69 |
+
assert x.dim() == 3
|
| 70 |
+
B, C, T = x.shape
|
| 71 |
+
mask = (torch.rand(B, C, 1, device=x.device, dtype=x.dtype) > drop_prob).to(x.dtype)
|
| 72 |
+
|
| 73 |
+
# Ensure at least min_keep channels kept
|
| 74 |
+
keep = mask.sum(dim=1, keepdim=True) # (B, 1, 1)
|
| 75 |
+
need = (keep < min_keep).squeeze(-1).squeeze(-1) # (B,)
|
| 76 |
+
if need.any():
|
| 77 |
+
for b in torch.where(need)[0]:
|
| 78 |
+
idx = torch.randperm(C, device=x.device)[:min_keep]
|
| 79 |
+
mask[b, idx, 0] = 1.0
|
| 80 |
+
|
| 81 |
+
return x * mask
|
osf/datasets/pretrain_datamodule.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import List, Sequence, Optional, Dict, Union
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import torch
|
| 8 |
+
import pytorch_lightning as pl
|
| 9 |
+
from pytorch_lightning import LightningDataModule
|
| 10 |
+
from pytorch_lightning.utilities.types import EVAL_DATALOADERS
|
| 11 |
+
from torch.utils.data import DataLoader, Subset
|
| 12 |
+
|
| 13 |
+
from osf.datasets.pretrain_dataset import SleepEpochDataset
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class SleepDataModule(LightningDataModule):
|
| 18 |
+
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
csv_dir: str | Path,
|
| 22 |
+
*,
|
| 23 |
+
is_pretrain,
|
| 24 |
+
data_pct = 1,
|
| 25 |
+
val_dataset_list: Optional[List[str]] = None,
|
| 26 |
+
downstream_dataset_name = None,
|
| 27 |
+
batch_size: int = 128,
|
| 28 |
+
num_workers: int = 4,
|
| 29 |
+
patient_cols: Optional[Union[str, Sequence[str]]] = None,
|
| 30 |
+
event_cols: Optional[Union[str, Sequence[str]]] = None,
|
| 31 |
+
train_edf_cols: Sequence[str] | None,
|
| 32 |
+
transforms=None,
|
| 33 |
+
n_views: int = 1,
|
| 34 |
+
cache_size: int = 8,
|
| 35 |
+
sample_rate: int = 128,
|
| 36 |
+
window_size: int = 30,
|
| 37 |
+
pin_memory: bool = False,
|
| 38 |
+
persistent_workers: bool = False,
|
| 39 |
+
data_source: str = "auto",
|
| 40 |
+
include_datasets: Optional[List[str]] = None,
|
| 41 |
+
regression_targets: Optional[List[str]] = None,
|
| 42 |
+
regression_filter_config: Optional[Dict] = None,
|
| 43 |
+
n_train_samples: Optional[int] = None,
|
| 44 |
+
val_batch_size: Optional[int] = None,
|
| 45 |
+
val_data_pct: Optional[float] = None,
|
| 46 |
+
return_all_event_cols: bool = False,
|
| 47 |
+
return_nsrrid: bool = False,
|
| 48 |
+
random_seed: int = 42,
|
| 49 |
+
):
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.save_hyperparameters(ignore=["transforms"])
|
| 52 |
+
self.downstream_dataset_name = downstream_dataset_name
|
| 53 |
+
self.csv_dir = csv_dir
|
| 54 |
+
self.transforms = transforms
|
| 55 |
+
self.n_views = n_views
|
| 56 |
+
self.pin_memory = pin_memory
|
| 57 |
+
self.persistent_workers = persistent_workers
|
| 58 |
+
self.is_pretrain = is_pretrain
|
| 59 |
+
self.patient_cols = patient_cols
|
| 60 |
+
self.event_cols = event_cols
|
| 61 |
+
self.data_pct = data_pct
|
| 62 |
+
self.data_source = data_source
|
| 63 |
+
self.include_datasets = include_datasets
|
| 64 |
+
self.regression_targets = regression_targets
|
| 65 |
+
self.regression_filter_config = regression_filter_config
|
| 66 |
+
self.n_train_samples = n_train_samples
|
| 67 |
+
self.val_batch_size = val_batch_size
|
| 68 |
+
self.val_data_pct = val_data_pct
|
| 69 |
+
self.return_all_event_cols = return_all_event_cols
|
| 70 |
+
self.return_nsrrid = return_nsrrid
|
| 71 |
+
self.random_seed = random_seed
|
| 72 |
+
|
| 73 |
+
def train_dataloader(self):
|
| 74 |
+
if self.is_pretrain == 1:
|
| 75 |
+
train_set = SleepEpochDataset(
|
| 76 |
+
csv_dir = self.csv_dir,
|
| 77 |
+
split = "pretrain",
|
| 78 |
+
data_pct = self.data_pct,
|
| 79 |
+
train_edf_cols= self.hparams.train_edf_cols,
|
| 80 |
+
transform = self.transforms,
|
| 81 |
+
sample_rate = self.hparams.sample_rate,
|
| 82 |
+
window_size = self.hparams.window_size,
|
| 83 |
+
cache_size = self.hparams.cache_size,
|
| 84 |
+
data_source = self.data_source,
|
| 85 |
+
include_datasets = self.include_datasets,
|
| 86 |
+
)
|
| 87 |
+
persistent_workers = self.persistent_workers
|
| 88 |
+
else:
|
| 89 |
+
train_set = SleepEpochDataset(
|
| 90 |
+
csv_dir = self.csv_dir,
|
| 91 |
+
split = "train",
|
| 92 |
+
data_pct = self.data_pct,
|
| 93 |
+
patient_cols = self.patient_cols,
|
| 94 |
+
event_cols = self.event_cols,
|
| 95 |
+
train_edf_cols= self.hparams.train_edf_cols,
|
| 96 |
+
transform = self.transforms,
|
| 97 |
+
sample_rate = self.hparams.sample_rate,
|
| 98 |
+
window_size = self.hparams.window_size,
|
| 99 |
+
cache_size = self.hparams.cache_size,
|
| 100 |
+
downstream_dataset_name = self.downstream_dataset_name,
|
| 101 |
+
data_source = self.data_source,
|
| 102 |
+
include_datasets = self.include_datasets,
|
| 103 |
+
regression_targets = self.regression_targets,
|
| 104 |
+
regression_filter_config = self.regression_filter_config,
|
| 105 |
+
return_all_event_cols = self.return_all_event_cols,
|
| 106 |
+
return_nsrrid = self.return_nsrrid,
|
| 107 |
+
)
|
| 108 |
+
self._train_dataset = train_set
|
| 109 |
+
persistent_workers = True
|
| 110 |
+
|
| 111 |
+
if self.n_train_samples is not None and self.n_train_samples > 0:
|
| 112 |
+
n_total = len(train_set)
|
| 113 |
+
rng = np.random.default_rng(seed=self.random_seed)
|
| 114 |
+
|
| 115 |
+
if hasattr(train_set, 'event_cols') and train_set.event_cols and hasattr(train_set, 'all_epoch_df'):
|
| 116 |
+
label_col = train_set.event_cols[0]
|
| 117 |
+
if label_col in train_set.all_epoch_df.columns:
|
| 118 |
+
labels = train_set.all_epoch_df[label_col].values
|
| 119 |
+
num_classes = getattr(train_set, 'num_classes', None)
|
| 120 |
+
|
| 121 |
+
if num_classes is not None:
|
| 122 |
+
all_indices = []
|
| 123 |
+
for c in range(num_classes):
|
| 124 |
+
class_indices = np.where(labels == c)[0]
|
| 125 |
+
n_per_class = min(self.n_train_samples, len(class_indices))
|
| 126 |
+
if n_per_class > 0:
|
| 127 |
+
sampled = rng.choice(class_indices, size=n_per_class, replace=False)
|
| 128 |
+
all_indices.extend(sampled.tolist())
|
| 129 |
+
print(f"[Few-shot] Class {c}: sampled {n_per_class}/{len(class_indices)} samples")
|
| 130 |
+
|
| 131 |
+
indices = all_indices
|
| 132 |
+
train_set = Subset(train_set, indices)
|
| 133 |
+
print(f"[Few-shot] Total: {len(indices)}/{n_total} samples ({self.n_train_samples}-shot per class)")
|
| 134 |
+
else:
|
| 135 |
+
n_keep = min(self.n_train_samples, n_total)
|
| 136 |
+
indices = rng.choice(n_total, size=n_keep, replace=False).tolist()
|
| 137 |
+
train_set = Subset(train_set, indices)
|
| 138 |
+
print(f"[Few-shot] Using {n_keep}/{n_total} training samples (random, n_train_samples={self.n_train_samples})")
|
| 139 |
+
else:
|
| 140 |
+
n_keep = min(self.n_train_samples, n_total)
|
| 141 |
+
indices = rng.choice(n_total, size=n_keep, replace=False).tolist()
|
| 142 |
+
train_set = Subset(train_set, indices)
|
| 143 |
+
print(f"[Few-shot] Using {n_keep}/{n_total} training samples (random, n_train_samples={self.n_train_samples})")
|
| 144 |
+
else:
|
| 145 |
+
n_keep = min(self.n_train_samples, n_total)
|
| 146 |
+
indices = rng.choice(n_total, size=n_keep, replace=False).tolist()
|
| 147 |
+
train_set = Subset(train_set, indices)
|
| 148 |
+
print(f"[Few-shot] Using {n_keep}/{n_total} training samples (random, n_train_samples={self.n_train_samples})")
|
| 149 |
+
|
| 150 |
+
return DataLoader(
|
| 151 |
+
train_set,
|
| 152 |
+
batch_size = self.hparams.batch_size,
|
| 153 |
+
shuffle = True,
|
| 154 |
+
num_workers = self.hparams.num_workers,
|
| 155 |
+
pin_memory = self.pin_memory,
|
| 156 |
+
persistent_workers = persistent_workers,
|
| 157 |
+
drop_last = True,
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
def get_class_distribution(self) -> Optional[torch.Tensor]:
|
| 161 |
+
"""
|
| 162 |
+
Get class distribution from training dataset.
|
| 163 |
+
Returns [num_classes] tensor of class counts, or None if not available.
|
| 164 |
+
"""
|
| 165 |
+
if hasattr(self, '_train_dataset'):
|
| 166 |
+
counts = self._train_dataset.get_class_counts()
|
| 167 |
+
if counts is not None:
|
| 168 |
+
return torch.from_numpy(counts).float()
|
| 169 |
+
return None
|
| 170 |
+
|
| 171 |
+
def val_dataloader(self):
|
| 172 |
+
if self.hparams.val_dataset_list:
|
| 173 |
+
if self.is_pretrain == 1:
|
| 174 |
+
val_sets = [
|
| 175 |
+
SleepEpochDataset(
|
| 176 |
+
csv_dir = self.csv_dir,
|
| 177 |
+
split = "pretrain-val",
|
| 178 |
+
data_pct = self.data_pct,
|
| 179 |
+
patient_cols = self.patient_cols,
|
| 180 |
+
event_cols = self.event_cols,
|
| 181 |
+
train_edf_cols= self.hparams.train_edf_cols,
|
| 182 |
+
transform = None,
|
| 183 |
+
sample_rate = self.hparams.sample_rate,
|
| 184 |
+
window_size = self.hparams.window_size,
|
| 185 |
+
cache_size = self.hparams.cache_size,
|
| 186 |
+
downstream_dataset_name = ds_name,
|
| 187 |
+
data_source = self.data_source,
|
| 188 |
+
include_datasets = self.include_datasets,
|
| 189 |
+
)
|
| 190 |
+
for ds_name in self.hparams.val_dataset_list
|
| 191 |
+
]
|
| 192 |
+
persistent_workers = self.persistent_workers
|
| 193 |
+
else:
|
| 194 |
+
if self.is_pretrain == 1:
|
| 195 |
+
val_sets = [
|
| 196 |
+
SleepEpochDataset(
|
| 197 |
+
csv_dir = self.csv_dir,
|
| 198 |
+
split = "pretrain-val",
|
| 199 |
+
data_pct = self.data_pct,
|
| 200 |
+
patient_cols = self.patient_cols,
|
| 201 |
+
event_cols = self.event_cols,
|
| 202 |
+
train_edf_cols= self.hparams.train_edf_cols,
|
| 203 |
+
transform = None,
|
| 204 |
+
sample_rate = self.hparams.sample_rate,
|
| 205 |
+
window_size = self.hparams.window_size,
|
| 206 |
+
cache_size = self.hparams.cache_size,
|
| 207 |
+
data_source = self.data_source,
|
| 208 |
+
include_datasets = self.include_datasets,
|
| 209 |
+
)
|
| 210 |
+
]
|
| 211 |
+
persistent_workers = self.persistent_workers
|
| 212 |
+
else:
|
| 213 |
+
val_sets = [
|
| 214 |
+
SleepEpochDataset(
|
| 215 |
+
csv_dir = self.csv_dir,
|
| 216 |
+
split = "val",
|
| 217 |
+
data_pct = self.data_pct,
|
| 218 |
+
patient_cols = self.patient_cols,
|
| 219 |
+
event_cols = self.event_cols,
|
| 220 |
+
train_edf_cols= self.hparams.train_edf_cols,
|
| 221 |
+
transform = None,
|
| 222 |
+
sample_rate = self.hparams.sample_rate,
|
| 223 |
+
window_size = self.hparams.window_size,
|
| 224 |
+
cache_size = self.hparams.cache_size,
|
| 225 |
+
downstream_dataset_name = self.downstream_dataset_name,
|
| 226 |
+
data_source = self.data_source,
|
| 227 |
+
include_datasets = self.include_datasets,
|
| 228 |
+
regression_targets = self.regression_targets,
|
| 229 |
+
regression_filter_config = self.regression_filter_config,
|
| 230 |
+
)
|
| 231 |
+
]
|
| 232 |
+
persistent_workers = True
|
| 233 |
+
|
| 234 |
+
if self.val_data_pct is not None and 0 < self.val_data_pct < 1.0:
|
| 235 |
+
subsampled_val_sets = []
|
| 236 |
+
for ds in val_sets:
|
| 237 |
+
n_total = len(ds)
|
| 238 |
+
n_keep = max(1, int(n_total * self.val_data_pct))
|
| 239 |
+
rng = np.random.default_rng(seed=self.random_seed)
|
| 240 |
+
indices = rng.choice(n_total, size=n_keep, replace=False).tolist()
|
| 241 |
+
subsampled_val_sets.append(Subset(ds, indices))
|
| 242 |
+
print(f"[Val subsample] Using {n_keep}/{n_total} val samples ({self.val_data_pct*100:.1f}%)")
|
| 243 |
+
val_sets = subsampled_val_sets
|
| 244 |
+
|
| 245 |
+
val_bs = self.val_batch_size if self.val_batch_size is not None else self.hparams.batch_size
|
| 246 |
+
return [
|
| 247 |
+
DataLoader(
|
| 248 |
+
ds,
|
| 249 |
+
batch_size = val_bs,
|
| 250 |
+
shuffle = False,
|
| 251 |
+
num_workers = self.hparams.num_workers,
|
| 252 |
+
pin_memory = self.pin_memory,
|
| 253 |
+
persistent_workers = persistent_workers,
|
| 254 |
+
drop_last = True,
|
| 255 |
+
)
|
| 256 |
+
for ds in val_sets
|
| 257 |
+
]
|
| 258 |
+
|
| 259 |
+
def test_dataloader(self):
|
| 260 |
+
if self.is_pretrain == 1:
|
| 261 |
+
test_set = SleepEpochDataset(
|
| 262 |
+
csv_dir = self.csv_dir,
|
| 263 |
+
split = "pretrain-test",
|
| 264 |
+
patient_cols = self.patient_cols,
|
| 265 |
+
event_cols = self.event_cols,
|
| 266 |
+
train_edf_cols= self.hparams.train_edf_cols,
|
| 267 |
+
transform = None,
|
| 268 |
+
sample_rate = self.hparams.sample_rate,
|
| 269 |
+
window_size = self.hparams.window_size,
|
| 270 |
+
cache_size = self.hparams.cache_size,
|
| 271 |
+
data_source = self.data_source,
|
| 272 |
+
include_datasets = self.include_datasets,
|
| 273 |
+
)
|
| 274 |
+
persistent_workers = self.persistent_workers
|
| 275 |
+
else:
|
| 276 |
+
test_set = SleepEpochDataset(
|
| 277 |
+
csv_dir = self.csv_dir,
|
| 278 |
+
split = "test",
|
| 279 |
+
patient_cols = self.patient_cols,
|
| 280 |
+
event_cols = self.event_cols,
|
| 281 |
+
train_edf_cols= self.hparams.train_edf_cols,
|
| 282 |
+
transform = None,
|
| 283 |
+
sample_rate = self.hparams.sample_rate,
|
| 284 |
+
window_size = self.hparams.window_size,
|
| 285 |
+
cache_size = self.hparams.cache_size,
|
| 286 |
+
downstream_dataset_name = self.downstream_dataset_name,
|
| 287 |
+
data_source = self.data_source,
|
| 288 |
+
include_datasets = self.include_datasets,
|
| 289 |
+
regression_targets = self.regression_targets,
|
| 290 |
+
regression_filter_config = self.regression_filter_config,
|
| 291 |
+
)
|
| 292 |
+
persistent_workers = True
|
| 293 |
+
test_bs = self.val_batch_size if self.val_batch_size is not None else self.hparams.batch_size
|
| 294 |
+
return DataLoader(
|
| 295 |
+
test_set,
|
| 296 |
+
batch_size = test_bs,
|
| 297 |
+
shuffle = False,
|
| 298 |
+
num_workers = self.hparams.num_workers,
|
| 299 |
+
pin_memory = self.pin_memory,
|
| 300 |
+
drop_last = True,
|
| 301 |
+
persistent_workers = persistent_workers,
|
| 302 |
+
)
|
| 303 |
+
|
osf/datasets/pretrain_dataset.py
ADDED
|
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Sleep Epoch Dataset for pretraining and downstream tasks
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import torch
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from contextlib import suppress
|
| 9 |
+
from typing import Sequence, Optional, Dict, Union, List
|
| 10 |
+
from torch.utils.data import Dataset
|
| 11 |
+
from train_config import NEED_NORM_COL
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def to_pm1(s: pd.Series) -> pd.Series:
|
| 15 |
+
s = pd.to_numeric(s, errors="coerce")
|
| 16 |
+
vmin, vmax = s.min(skipna=True), s.max(skipna=True)
|
| 17 |
+
if pd.isna(vmin) or pd.isna(vmax) or vmax <= vmin:
|
| 18 |
+
return pd.Series(0.0, index=s.index)
|
| 19 |
+
return (2 * (s - vmin) / (vmax - vmin) - 1).fillna(0.0)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class SleepEpochDataset(Dataset):
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
csv_dir='/path/to/your/postprocessed/data',
|
| 26 |
+
split: str = "train",
|
| 27 |
+
*,
|
| 28 |
+
data_pct=1,
|
| 29 |
+
patient_cols: Optional[Union[str, Sequence[str]]] = None,
|
| 30 |
+
event_cols: Optional[Union[str, Sequence[str]]] = None,
|
| 31 |
+
train_edf_cols=None,
|
| 32 |
+
test_size: float = 0.15,
|
| 33 |
+
val_size: float = 0.15,
|
| 34 |
+
random_state: int = 1337,
|
| 35 |
+
sample_rate: int = 128,
|
| 36 |
+
window_size: int = 300,
|
| 37 |
+
epoch_length: int = 30,
|
| 38 |
+
cache_size: int = 8,
|
| 39 |
+
transform=None,
|
| 40 |
+
downstream_dataset_name=None,
|
| 41 |
+
data_source: str = "auto",
|
| 42 |
+
include_datasets: Optional[List[str]] = None,
|
| 43 |
+
regression_targets: Optional[List[str]] = None,
|
| 44 |
+
regression_filter_config: Optional[Dict] = None,
|
| 45 |
+
return_all_event_cols: bool = False,
|
| 46 |
+
return_nsrrid: bool = False,
|
| 47 |
+
):
|
| 48 |
+
assert split in {"pretrain", "pretrain-val", "pretrain-test", "train", "val", "test"}
|
| 49 |
+
assert data_source in {"auto", "pretrain", "downstream", "both"}
|
| 50 |
+
|
| 51 |
+
self.transform = transform
|
| 52 |
+
self.sample_rate = sample_rate
|
| 53 |
+
self.window_size = window_size
|
| 54 |
+
self.epoch_length = epoch_length
|
| 55 |
+
self.patient_cols = [patient_cols] if isinstance(patient_cols, str) else patient_cols
|
| 56 |
+
self.event_cols = [event_cols] if isinstance(event_cols, str) else event_cols
|
| 57 |
+
self.train_edf_cols = train_edf_cols
|
| 58 |
+
self.split = split
|
| 59 |
+
self.data_pct = float(data_pct)
|
| 60 |
+
self.data_source = data_source
|
| 61 |
+
self.regression_targets = regression_targets
|
| 62 |
+
self.regression_filter_config = regression_filter_config
|
| 63 |
+
self.return_all_event_cols = return_all_event_cols
|
| 64 |
+
self.return_nsrrid = return_nsrrid
|
| 65 |
+
|
| 66 |
+
patient_df, epoch_df = self._load_csvs(
|
| 67 |
+
csv_dir, split, data_source, include_datasets, self.event_cols,
|
| 68 |
+
regression_targets=self.regression_targets,
|
| 69 |
+
regression_filter_config=self.regression_filter_config,
|
| 70 |
+
return_all_event_cols=self.return_all_event_cols,
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
if downstream_dataset_name and include_datasets is None:
|
| 74 |
+
if downstream_dataset_name != "all":
|
| 75 |
+
mask = epoch_df['dataset_name'].astype(str).str.lower().str.startswith(downstream_dataset_name)
|
| 76 |
+
epoch_df = epoch_df.loc[mask].copy()
|
| 77 |
+
ids = epoch_df["nsrrid"].astype(str).unique()
|
| 78 |
+
patient_df = patient_df[patient_df["nsrrid"].astype(str).isin(ids)].copy()
|
| 79 |
+
|
| 80 |
+
# Determine num_classes
|
| 81 |
+
if self.event_cols:
|
| 82 |
+
if self.event_cols[0] in ['Hypopnea', 'Arousal', 'Oxygen Desaturation']:
|
| 83 |
+
self.num_classes = 2
|
| 84 |
+
elif self.event_cols[0] == 'Stage':
|
| 85 |
+
self.num_classes = 4
|
| 86 |
+
mapping = {0: 0, 1: 1, 2: 1, 3: 2, 4: 3}
|
| 87 |
+
epoch_df['Stage'] = epoch_df['Stage'].replace(mapping)
|
| 88 |
+
else:
|
| 89 |
+
self.num_classes = 2
|
| 90 |
+
else:
|
| 91 |
+
self.num_classes = 2
|
| 92 |
+
|
| 93 |
+
# Drop Stage == -1
|
| 94 |
+
if self.event_cols and ('Stage' in self.event_cols) and ('Stage' in epoch_df.columns):
|
| 95 |
+
epoch_df = epoch_df.loc[epoch_df['Stage'] != -1].copy()
|
| 96 |
+
|
| 97 |
+
# Build tables
|
| 98 |
+
if split in ("pretrain", "pretrain-val"):
|
| 99 |
+
sort_cols = [c for c in ['nsrrid', 'seg_id', 'epoch_id'] if c in epoch_df.columns]
|
| 100 |
+
self.all_epoch_df = epoch_df.sort_values(sort_cols).reset_index(drop=True)
|
| 101 |
+
|
| 102 |
+
idx_keep_cols = [c for c in ['nsrrid', 'seg_id', 'path_head'] if c in self.all_epoch_df.columns]
|
| 103 |
+
if self.regression_targets:
|
| 104 |
+
for t in self.regression_targets:
|
| 105 |
+
col = f"{t}_mean"
|
| 106 |
+
if col in self.all_epoch_df.columns:
|
| 107 |
+
idx_keep_cols.append(col)
|
| 108 |
+
self.epoch_df = (
|
| 109 |
+
self.all_epoch_df[idx_keep_cols]
|
| 110 |
+
.drop_duplicates(['nsrrid', 'seg_id'], keep='first')
|
| 111 |
+
.reset_index(drop=True)
|
| 112 |
+
)
|
| 113 |
+
else:
|
| 114 |
+
expected_len = self.window_size // self.epoch_length
|
| 115 |
+
grp = epoch_df.groupby(['nsrrid', 'seg_id']).size().rename('n').reset_index()
|
| 116 |
+
valid_keys = grp.loc[grp['n'] == expected_len, ['nsrrid', 'seg_id']]
|
| 117 |
+
epoch_df_valid = epoch_df.merge(valid_keys, on=['nsrrid', 'seg_id'], how='inner')
|
| 118 |
+
|
| 119 |
+
sort_cols = [c for c in ['nsrrid', 'seg_id', 'epoch_id'] if c in epoch_df_valid.columns]
|
| 120 |
+
self.all_epoch_df = epoch_df_valid.sort_values(sort_cols).reset_index(drop=True)
|
| 121 |
+
|
| 122 |
+
idx_keep_cols = [c for c in ['nsrrid', 'seg_id', 'path_head'] if c in self.all_epoch_df.columns]
|
| 123 |
+
if self.regression_targets:
|
| 124 |
+
for t in self.regression_targets:
|
| 125 |
+
col = f"{t}_mean"
|
| 126 |
+
if col in self.all_epoch_df.columns:
|
| 127 |
+
idx_keep_cols.append(col)
|
| 128 |
+
self.epoch_df = (
|
| 129 |
+
self.all_epoch_df[idx_keep_cols]
|
| 130 |
+
.drop_duplicates(['nsrrid', 'seg_id'], keep='first')
|
| 131 |
+
.reset_index(drop=True)
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
# Patient-level sampling
|
| 135 |
+
if not (0 < self.data_pct <= 1.0):
|
| 136 |
+
raise ValueError(f"data_pct must be in (0,1], got {self.data_pct}")
|
| 137 |
+
|
| 138 |
+
if self.data_pct < 1.0:
|
| 139 |
+
eligible_patients = pd.Index(self.epoch_df['nsrrid'].unique())
|
| 140 |
+
n_keep = max(1, int(len(eligible_patients) * self.data_pct))
|
| 141 |
+
sampled_nsrrids = pd.Series(eligible_patients).sample(n=n_keep, random_state=random_state).to_list()
|
| 142 |
+
self.epoch_df = self.epoch_df.loc[self.epoch_df['nsrrid'].isin(sampled_nsrrids)].reset_index(drop=True)
|
| 143 |
+
self.all_epoch_df = self.all_epoch_df.loc[self.all_epoch_df['nsrrid'].isin(sampled_nsrrids)].reset_index(drop=True)
|
| 144 |
+
patient_df = patient_df.loc[patient_df['nsrrid'].isin(sampled_nsrrids)].copy()
|
| 145 |
+
|
| 146 |
+
self.patient_df = patient_df.set_index("nsrrid")
|
| 147 |
+
|
| 148 |
+
# Build segment indices
|
| 149 |
+
self._seg_indices = None
|
| 150 |
+
if hasattr(self, "all_epoch_df") and {'nsrrid', 'seg_id'}.issubset(self.all_epoch_df.columns):
|
| 151 |
+
grp_indices = self.all_epoch_df.groupby(['nsrrid', 'seg_id'], sort=False).indices
|
| 152 |
+
self._seg_indices = {}
|
| 153 |
+
has_epoch_id = 'epoch_id' in self.all_epoch_df.columns
|
| 154 |
+
epoch_id_values = self.all_epoch_df['epoch_id'].to_numpy() if has_epoch_id else None
|
| 155 |
+
for key, idx_list in grp_indices.items():
|
| 156 |
+
idx_arr = np.fromiter(idx_list, dtype=np.int64)
|
| 157 |
+
if has_epoch_id:
|
| 158 |
+
order = np.argsort(epoch_id_values[idx_arr])
|
| 159 |
+
idx_arr = idx_arr[order]
|
| 160 |
+
self._seg_indices[key] = idx_arr
|
| 161 |
+
|
| 162 |
+
# Compute class distribution
|
| 163 |
+
self._class_counts = None
|
| 164 |
+
if self.event_cols and self.event_cols[0] in self.all_epoch_df.columns:
|
| 165 |
+
label_col = self.event_cols[0]
|
| 166 |
+
value_counts = self.all_epoch_df[label_col].value_counts().sort_index()
|
| 167 |
+
class_counts = np.zeros(self.num_classes, dtype=np.int64)
|
| 168 |
+
for cls_idx, count in value_counts.items():
|
| 169 |
+
if 0 <= int(cls_idx) < self.num_classes:
|
| 170 |
+
class_counts[int(cls_idx)] = int(count)
|
| 171 |
+
self._class_counts = class_counts
|
| 172 |
+
|
| 173 |
+
def _load_csvs(self, csv_dir, split, data_source, include_datasets, event_cols,
|
| 174 |
+
regression_targets=None, regression_filter_config=None, return_all_event_cols=False):
|
| 175 |
+
split_suffix_map = {
|
| 176 |
+
"pretrain": "train", "pretrain-val": "valid", "pretrain-test": "test",
|
| 177 |
+
"train": "train", "val": "valid", "test": "test"
|
| 178 |
+
}
|
| 179 |
+
split_suffix = split_suffix_map[split]
|
| 180 |
+
|
| 181 |
+
if data_source == "auto":
|
| 182 |
+
sources = ["pretrain"] if split.startswith("pretrain") else ["downstream"]
|
| 183 |
+
elif data_source == "both":
|
| 184 |
+
sources = ["pretrain", "downstream"]
|
| 185 |
+
else:
|
| 186 |
+
sources = [data_source]
|
| 187 |
+
|
| 188 |
+
patient_dfs = []
|
| 189 |
+
epoch_dfs = []
|
| 190 |
+
csv_prefix = "epoch_regression" if regression_targets else "epoch"
|
| 191 |
+
|
| 192 |
+
for source in sources:
|
| 193 |
+
patient_csv = f"{csv_dir}/patient_{source}_{split_suffix}.csv"
|
| 194 |
+
epoch_csv = f"{csv_dir}/{csv_prefix}_{source}_{split_suffix}.csv"
|
| 195 |
+
|
| 196 |
+
if Path(patient_csv).is_file() and Path(epoch_csv).is_file():
|
| 197 |
+
patient_dfs.append(pd.read_csv(patient_csv))
|
| 198 |
+
epoch_dfs.append(pd.read_csv(epoch_csv))
|
| 199 |
+
|
| 200 |
+
patient_df = pd.concat(patient_dfs, ignore_index=True).drop_duplicates(subset=['nsrrid'])
|
| 201 |
+
epoch_df = pd.concat(epoch_dfs, ignore_index=True)
|
| 202 |
+
|
| 203 |
+
base_cols = ['nsrrid', 'seg_id', 'dataset_name', 'epoch_id', 'path_head']
|
| 204 |
+
if event_cols:
|
| 205 |
+
if return_all_event_cols:
|
| 206 |
+
for col in event_cols:
|
| 207 |
+
if col and col not in base_cols:
|
| 208 |
+
base_cols.append(col)
|
| 209 |
+
elif event_cols[0]:
|
| 210 |
+
base_cols.append(event_cols[0])
|
| 211 |
+
|
| 212 |
+
if regression_targets:
|
| 213 |
+
for t in regression_targets:
|
| 214 |
+
col_name = f"{t}_mean"
|
| 215 |
+
if col_name in epoch_df.columns:
|
| 216 |
+
base_cols.append(col_name)
|
| 217 |
+
|
| 218 |
+
keep_cols = [c for c in base_cols if c in epoch_df.columns]
|
| 219 |
+
epoch_df = epoch_df[keep_cols].copy()
|
| 220 |
+
|
| 221 |
+
if regression_targets:
|
| 222 |
+
label_cols = [f"{t}_mean" for t in regression_targets]
|
| 223 |
+
existing = [c for c in label_cols if c in epoch_df.columns]
|
| 224 |
+
if existing:
|
| 225 |
+
epoch_df = epoch_df.dropna(subset=existing).reset_index(drop=True)
|
| 226 |
+
|
| 227 |
+
if regression_filter_config:
|
| 228 |
+
for col_name, filter_rules in regression_filter_config.items():
|
| 229 |
+
if col_name in epoch_df.columns:
|
| 230 |
+
mask = pd.Series([True] * len(epoch_df))
|
| 231 |
+
if "min" in filter_rules:
|
| 232 |
+
mask = mask & (epoch_df[col_name] >= filter_rules["min"])
|
| 233 |
+
if "max" in filter_rules:
|
| 234 |
+
mask = mask & (epoch_df[col_name] <= filter_rules["max"])
|
| 235 |
+
epoch_df = epoch_df[mask].reset_index(drop=True)
|
| 236 |
+
|
| 237 |
+
if include_datasets is not None and 'dataset_name' in epoch_df.columns:
|
| 238 |
+
include_lower = [d.lower() for d in include_datasets]
|
| 239 |
+
mask = epoch_df['dataset_name'].astype(str).str.lower().isin(include_lower)
|
| 240 |
+
epoch_df = epoch_df[mask].copy()
|
| 241 |
+
patient_df = patient_df[patient_df['nsrrid'].isin(epoch_df['nsrrid'].unique())].copy()
|
| 242 |
+
|
| 243 |
+
return patient_df, epoch_df
|
| 244 |
+
|
| 245 |
+
def __len__(self) -> int:
|
| 246 |
+
return len(self.epoch_df)
|
| 247 |
+
|
| 248 |
+
def get_class_counts(self) -> Optional[np.ndarray]:
|
| 249 |
+
return self._class_counts
|
| 250 |
+
|
| 251 |
+
def _resample_df(self, df: pd.DataFrame, target_hz: int) -> pd.DataFrame:
|
| 252 |
+
if not np.issubdtype(df.index.dtype, np.number):
|
| 253 |
+
t = np.arange(len(df)) / float(target_hz)
|
| 254 |
+
df = df.copy()
|
| 255 |
+
df.index = t
|
| 256 |
+
|
| 257 |
+
t0 = float(df.index.min())
|
| 258 |
+
t1 = float(df.index.max())
|
| 259 |
+
t_target = np.arange(t0, t0 + self.window_size, 1.0 / target_hz)
|
| 260 |
+
if t_target[-1] > t1:
|
| 261 |
+
t_target = t_target[t_target <= t1 + 1e-9]
|
| 262 |
+
out = df.reindex(t_target).interpolate(method="linear", limit_direction="both")
|
| 263 |
+
return out.fillna(0.0)
|
| 264 |
+
|
| 265 |
+
def __getitem__(self, idx: int):
|
| 266 |
+
row = self.epoch_df.iloc[idx]
|
| 267 |
+
nsrrid = row["nsrrid"]
|
| 268 |
+
seg_id = int(row["seg_id"])
|
| 269 |
+
cols = list(self.train_edf_cols) if self.train_edf_cols is not None else None
|
| 270 |
+
|
| 271 |
+
if self.split == "pretrain":
|
| 272 |
+
df_epoch = self._load_epoch_all_df(row["path_head"], seg_id, columns=cols)
|
| 273 |
+
df_epoch = self._resample_df(df_epoch, self.sample_rate)
|
| 274 |
+
|
| 275 |
+
if cols is not None:
|
| 276 |
+
for ch in cols:
|
| 277 |
+
if ch not in df_epoch.columns:
|
| 278 |
+
df_epoch[ch] = 0.0
|
| 279 |
+
elif ch in NEED_NORM_COL:
|
| 280 |
+
df_epoch[ch] = to_pm1(df_epoch[ch])
|
| 281 |
+
df_epoch = df_epoch[cols]
|
| 282 |
+
|
| 283 |
+
samples_per_epoch = int(self.window_size * self.sample_rate)
|
| 284 |
+
if len(df_epoch) < samples_per_epoch:
|
| 285 |
+
pad = samples_per_epoch - len(df_epoch)
|
| 286 |
+
tail = pd.DataFrame({c: 0.0 for c in df_epoch.columns},
|
| 287 |
+
index=df_epoch.index[-1] + (np.arange(1, pad + 1) / self.sample_rate))
|
| 288 |
+
df_epoch = pd.concat([df_epoch, tail], axis=0)
|
| 289 |
+
elif len(df_epoch) > samples_per_epoch:
|
| 290 |
+
df_epoch = df_epoch.iloc[:samples_per_epoch]
|
| 291 |
+
|
| 292 |
+
x = torch.tensor(df_epoch.to_numpy(copy=False), dtype=torch.float32).t().contiguous()
|
| 293 |
+
x = torch.clamp(x, min=-6, max=6)
|
| 294 |
+
|
| 295 |
+
output = {"psg": x}
|
| 296 |
+
if self.return_nsrrid:
|
| 297 |
+
output["nsrrid"] = nsrrid
|
| 298 |
+
output["seg_id"] = seg_id
|
| 299 |
+
|
| 300 |
+
if self.patient_cols:
|
| 301 |
+
y = torch.tensor(self.patient_df.loc[nsrrid, self.patient_cols].values.astype(float), dtype=torch.float32)
|
| 302 |
+
output["label"] = y.long() if not self.return_nsrrid else y
|
| 303 |
+
elif self.event_cols:
|
| 304 |
+
if self.return_all_event_cols:
|
| 305 |
+
available_cols = [c for c in self.event_cols if c in row.index]
|
| 306 |
+
y = torch.tensor([row[c] for c in available_cols], dtype=torch.float32)
|
| 307 |
+
else:
|
| 308 |
+
y = torch.tensor([row[self.event_cols[0]]], dtype=torch.float32)
|
| 309 |
+
output["label"] = y
|
| 310 |
+
|
| 311 |
+
return output
|
| 312 |
+
else:
|
| 313 |
+
# Downstream split
|
| 314 |
+
if self._seg_indices is None:
|
| 315 |
+
seg_df = self.all_epoch_df[
|
| 316 |
+
(self.all_epoch_df['nsrrid'] == nsrrid) & (self.all_epoch_df['seg_id'] == seg_id)
|
| 317 |
+
].sort_values('epoch_id')
|
| 318 |
+
else:
|
| 319 |
+
idx_arr = self._seg_indices.get((nsrrid, seg_id))
|
| 320 |
+
seg_df = self.all_epoch_df.iloc[idx_arr] if idx_arr is not None else \
|
| 321 |
+
self.all_epoch_df[(self.all_epoch_df['nsrrid'] == nsrrid) & (self.all_epoch_df['seg_id'] == seg_id)].sort_values('epoch_id')
|
| 322 |
+
|
| 323 |
+
df_epoch = self._load_epoch_all_df(row["path_head"], seg_id, columns=cols)
|
| 324 |
+
df_epoch = self._resample_df(df_epoch, self.sample_rate)
|
| 325 |
+
|
| 326 |
+
if cols is not None:
|
| 327 |
+
for ch in cols:
|
| 328 |
+
if ch not in df_epoch.columns:
|
| 329 |
+
df_epoch[ch] = 0.0
|
| 330 |
+
elif ch in NEED_NORM_COL:
|
| 331 |
+
df_epoch[ch] = to_pm1(df_epoch[ch])
|
| 332 |
+
df_epoch = df_epoch[cols]
|
| 333 |
+
|
| 334 |
+
samples_per_epoch = int(self.window_size * self.sample_rate)
|
| 335 |
+
if len(df_epoch) < samples_per_epoch:
|
| 336 |
+
pad = samples_per_epoch - len(df_epoch)
|
| 337 |
+
tail = pd.DataFrame({c: 0.0 for c in df_epoch.columns},
|
| 338 |
+
index=df_epoch.index[-1] + (np.arange(1, pad + 1) / self.sample_rate))
|
| 339 |
+
df_epoch = pd.concat([df_epoch, tail], axis=0)
|
| 340 |
+
elif len(df_epoch) > samples_per_epoch:
|
| 341 |
+
df_epoch = df_epoch.iloc[:samples_per_epoch]
|
| 342 |
+
|
| 343 |
+
x = torch.tensor(df_epoch.to_numpy(copy=False), dtype=torch.float32).t().contiguous()
|
| 344 |
+
x = torch.clamp(x, min=-6, max=6)
|
| 345 |
+
|
| 346 |
+
output = {"psg": x}
|
| 347 |
+
if self.return_nsrrid:
|
| 348 |
+
output["nsrrid"] = nsrrid
|
| 349 |
+
output["seg_id"] = seg_id
|
| 350 |
+
|
| 351 |
+
if self.patient_cols:
|
| 352 |
+
y = torch.tensor(self.patient_df.loc[nsrrid, self.patient_cols].values.astype(float), dtype=torch.float32)
|
| 353 |
+
y = y.repeat(self.window_size // self.epoch_length)
|
| 354 |
+
output["label"] = y
|
| 355 |
+
elif self.event_cols:
|
| 356 |
+
if self.return_all_event_cols:
|
| 357 |
+
available_cols = [c for c in self.event_cols if c in seg_df.columns]
|
| 358 |
+
y = torch.tensor(seg_df[available_cols].values.astype(float), dtype=torch.float32).squeeze(0)
|
| 359 |
+
else:
|
| 360 |
+
y = torch.tensor(seg_df[self.event_cols].values.astype(float), dtype=torch.float32).squeeze(1)
|
| 361 |
+
output["label"] = y
|
| 362 |
+
elif self.regression_targets:
|
| 363 |
+
label_cols = [f"{t}_mean" for t in self.regression_targets]
|
| 364 |
+
y = torch.tensor([row[c] for c in label_cols], dtype=torch.float32)
|
| 365 |
+
output["label"] = y
|
| 366 |
+
|
| 367 |
+
return output
|
| 368 |
+
|
| 369 |
+
def _build_epoch_all_path(self, path_head: str, epoch_id: int) -> Path:
|
| 370 |
+
return Path(f"{path_head}/epoch-{epoch_id:05d}_all.parquet")
|
| 371 |
+
|
| 372 |
+
def _load_epoch_all_df(self, path_head: str, epoch_id: int, columns=None) -> pd.DataFrame:
|
| 373 |
+
fp = self._build_epoch_all_path(path_head, epoch_id)
|
| 374 |
+
if not fp.is_file():
|
| 375 |
+
raise FileNotFoundError(f"Parquet missing: {fp}")
|
| 376 |
+
df = pd.read_parquet(fp)
|
| 377 |
+
for c in df.columns:
|
| 378 |
+
if not np.issubdtype(df[c].dtype, np.floating):
|
| 379 |
+
with suppress(Exception):
|
| 380 |
+
df[c] = df[c].astype(np.float32)
|
| 381 |
+
return df
|
osf/datasets/simclr_aug_registry.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Two-view augmentation registry for SSL pretraining (SimCLR, DINO).
|
| 3 |
+
Provides multi-view generation pipelines for contrastive and self-distillation methods.
|
| 4 |
+
"""
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
from typing import Callable, Dict
|
| 7 |
+
import torch
|
| 8 |
+
from osf.datasets import augmentations as A
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _two_view(pipe1: Callable, pipe2: Callable | None = None) -> Callable:
|
| 12 |
+
"""Wrap one/two single-view pipelines into a two-view augmentation maker."""
|
| 13 |
+
if pipe2 is None:
|
| 14 |
+
pipe2 = pipe1
|
| 15 |
+
def make(x: torch.Tensor):
|
| 16 |
+
return pipe1(x), pipe2(x)
|
| 17 |
+
return make
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
SIMCLR_AUG_REGISTRY: Dict[str, Callable] = {
|
| 21 |
+
"none": _two_view(lambda x: x),
|
| 22 |
+
|
| 23 |
+
"channel_dropout": _two_view(lambda x: A.channel_dropout(x, drop_prob=0.2, min_keep=1)),
|
| 24 |
+
"channel_dropout_light": _two_view(lambda x: A.channel_dropout(x, drop_prob=0.25, min_keep=1)),
|
| 25 |
+
"channel_dropout_aligned": _two_view(lambda x: A.channel_dropout(x, drop_prob=0.5, min_keep=1)),
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
SIMCLR_AUG_FACTORIES: Dict[str, Callable[..., Callable]] = {}
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def build_simclr_augmentor(name: str, **kwargs) -> Callable:
|
| 32 |
+
key = (name or "none").lower()
|
| 33 |
+
if key in SIMCLR_AUG_REGISTRY:
|
| 34 |
+
return SIMCLR_AUG_REGISTRY[key]
|
| 35 |
+
if key in SIMCLR_AUG_FACTORIES:
|
| 36 |
+
return SIMCLR_AUG_FACTORIES[key](**kwargs)
|
| 37 |
+
raise ValueError(
|
| 38 |
+
f"Unknown simclr_augmentation '{name}'. "
|
| 39 |
+
f"Available presets: {list(SIMCLR_AUG_REGISTRY.keys())} | "
|
| 40 |
+
f"factories: {list(SIMCLR_AUG_FACTORIES.keys())}"
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _per_channel_span_mask_factory(
|
| 45 |
+
ratio: tuple[float, float] = (0.10, 0.30),
|
| 46 |
+
n_spans: int = 1,
|
| 47 |
+
fill: str | torch.Tensor = "zero",
|
| 48 |
+
noise_scale: float = 0.05,
|
| 49 |
+
same_mask_for_batch: bool = False,
|
| 50 |
+
):
|
| 51 |
+
assert 0.0 <= ratio[0] <= ratio[1] <= 1.0
|
| 52 |
+
|
| 53 |
+
def _single_view(x: torch.Tensor) -> torch.Tensor:
|
| 54 |
+
B, C, T = x.shape
|
| 55 |
+
device, dtype = x.device, x.dtype
|
| 56 |
+
|
| 57 |
+
min_len = max(1, int(round(ratio[0] * T)))
|
| 58 |
+
max_len = max(min_len, int(round(ratio[1] * T)))
|
| 59 |
+
arange_T = torch.arange(T, device=device)
|
| 60 |
+
mask = torch.zeros((B, C, T), device=device, dtype=torch.bool)
|
| 61 |
+
shape_bc = (1, C) if same_mask_for_batch else (B, C)
|
| 62 |
+
|
| 63 |
+
for _ in range(max(1, int(n_spans))):
|
| 64 |
+
if max_len == min_len:
|
| 65 |
+
lengths = torch.full(shape_bc, max_len, device=device, dtype=torch.long)
|
| 66 |
+
else:
|
| 67 |
+
lengths = torch.randint(min_len, max_len + 1, shape_bc, device=device)
|
| 68 |
+
max_start = (T - lengths).clamp_min(0)
|
| 69 |
+
if (max_start > 0).any():
|
| 70 |
+
rnd = torch.rand_like(max_start, dtype=torch.float32)
|
| 71 |
+
starts = torch.floor(rnd * (max_start.to(torch.float32) + 1)).to(torch.long)
|
| 72 |
+
else:
|
| 73 |
+
starts = torch.zeros_like(max_start)
|
| 74 |
+
if same_mask_for_batch and B > 1:
|
| 75 |
+
starts = starts.expand(B, C)
|
| 76 |
+
lengths = lengths.expand(B, C)
|
| 77 |
+
span_mask = (arange_T.view(1, 1, T) >= starts.unsqueeze(-1)) & \
|
| 78 |
+
(arange_T.view(1, 1, T) < (starts + lengths).unsqueeze(-1))
|
| 79 |
+
mask |= span_mask
|
| 80 |
+
|
| 81 |
+
y = x.clone()
|
| 82 |
+
if isinstance(fill, torch.Tensor):
|
| 83 |
+
fill_t = fill.to(device=device, dtype=dtype)
|
| 84 |
+
if fill_t.dim() == 0:
|
| 85 |
+
fill_t = fill_t.view(1, 1, 1)
|
| 86 |
+
if fill_t.shape[-1] == 1 and fill_t.dim() == 3 and fill_t.shape[0] in (1, B):
|
| 87 |
+
fill_t = fill_t if fill_t.shape[0] == B else fill_t.expand(B, -1, -1)
|
| 88 |
+
elif fill_t.dim() == 3 and fill_t.shape == (B, C, T):
|
| 89 |
+
pass
|
| 90 |
+
elif fill_t.dim() == 3 and fill_t.shape == (1, C, 1):
|
| 91 |
+
fill_t = fill_t.expand(B, -1, T)
|
| 92 |
+
y[mask] = fill_t[mask.expand_as(fill_t)]
|
| 93 |
+
elif fill == "zero":
|
| 94 |
+
y[mask] = 0.0
|
| 95 |
+
elif fill == "mean":
|
| 96 |
+
m = x.mean(dim=-1, keepdim=True)
|
| 97 |
+
y = torch.where(mask, m.expand_as(x), y)
|
| 98 |
+
elif fill == "noise":
|
| 99 |
+
m = x.mean(dim=-1, keepdim=True)
|
| 100 |
+
s = x.std(dim=-1, keepdim=True, unbiased=False).clamp_min(1e-8)
|
| 101 |
+
noise = torch.randn_like(x) * (s * noise_scale) + m
|
| 102 |
+
y = torch.where(mask, noise, y)
|
| 103 |
+
else:
|
| 104 |
+
raise ValueError(f"Unknown fill mode: {fill!r}")
|
| 105 |
+
return y
|
| 106 |
+
|
| 107 |
+
return _two_view(_single_view)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
SIMCLR_AUG_FACTORIES["pc_span_mask"] = _per_channel_span_mask_factory
|
| 111 |
+
|
| 112 |
+
SIMCLR_AUG_REGISTRY.update({
|
| 113 |
+
"pc_span_mask_light": _per_channel_span_mask_factory(
|
| 114 |
+
ratio=(0.1, 0.3), n_spans=1, fill="zero", noise_scale=0.05, same_mask_for_batch=False
|
| 115 |
+
),
|
| 116 |
+
"pc_span_mask_heavy": _per_channel_span_mask_factory(
|
| 117 |
+
ratio=(0.20, 0.6), n_spans=2, fill="zero", noise_scale=0.05, same_mask_for_batch=False
|
| 118 |
+
),
|
| 119 |
+
"pc_span_mask_aligned": _per_channel_span_mask_factory(
|
| 120 |
+
ratio=(0.3, 0.6), n_spans=1, fill="zero", noise_scale=0, same_mask_for_batch=False
|
| 121 |
+
),
|
| 122 |
+
})
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def _channel_then_pcspan_factory(
|
| 126 |
+
drop_prob: float = 0.3,
|
| 127 |
+
min_keep: int = 1,
|
| 128 |
+
ratio: tuple[float, float] = (0.10, 0.30),
|
| 129 |
+
n_spans: int = 1,
|
| 130 |
+
fill: str = "zero",
|
| 131 |
+
noise_scale: float = 0.05,
|
| 132 |
+
same_mask_for_batch: bool = False,
|
| 133 |
+
):
|
| 134 |
+
def single_view(x: torch.Tensor) -> torch.Tensor:
|
| 135 |
+
y = A.channel_dropout(x, drop_prob=drop_prob, min_keep=min_keep)
|
| 136 |
+
B, C, T = y.shape
|
| 137 |
+
device = y.device
|
| 138 |
+
|
| 139 |
+
min_len = max(1, int(round(ratio[0] * T)))
|
| 140 |
+
max_len = max(min_len, int(round(ratio[1] * T)))
|
| 141 |
+
arange_T = torch.arange(T, device=device)
|
| 142 |
+
mask = torch.zeros((B, C, T), device=device, dtype=torch.bool)
|
| 143 |
+
shape_bc = (1, C) if same_mask_for_batch else (B, C)
|
| 144 |
+
|
| 145 |
+
for _ in range(max(1, int(n_spans))):
|
| 146 |
+
lengths = torch.full(shape_bc, max_len, device=device, dtype=torch.long) \
|
| 147 |
+
if max_len == min_len else torch.randint(min_len, max_len + 1, shape_bc, device=device)
|
| 148 |
+
max_start = (T - lengths).clamp_min(0)
|
| 149 |
+
if (max_start > 0).any():
|
| 150 |
+
rnd = torch.rand_like(max_start, dtype=torch.float32)
|
| 151 |
+
starts = torch.floor(rnd * (max_start.to(torch.float32) + 1)).to(torch.long)
|
| 152 |
+
else:
|
| 153 |
+
starts = torch.zeros_like(max_start)
|
| 154 |
+
if same_mask_for_batch and B > 1:
|
| 155 |
+
starts = starts.expand(B, C)
|
| 156 |
+
lengths = lengths.expand(B, C)
|
| 157 |
+
span_mask = (arange_T.view(1, 1, T) >= starts.unsqueeze(-1)) & \
|
| 158 |
+
(arange_T.view(1, 1, T) < (starts + lengths).unsqueeze(-1))
|
| 159 |
+
mask |= span_mask
|
| 160 |
+
|
| 161 |
+
out = y.clone()
|
| 162 |
+
if fill == "zero":
|
| 163 |
+
out[mask] = 0.0
|
| 164 |
+
elif fill == "mean":
|
| 165 |
+
m = y.mean(dim=-1, keepdim=True)
|
| 166 |
+
out = torch.where(mask, m.expand_as(y), out)
|
| 167 |
+
elif fill == "noise":
|
| 168 |
+
m = y.mean(dim=-1, keepdim=True)
|
| 169 |
+
s = y.std(dim=-1, keepdim=True, unbiased=False).clamp_min(1e-8)
|
| 170 |
+
noise = torch.randn_like(y) * (s * noise_scale) + m
|
| 171 |
+
out = torch.where(mask, noise, out)
|
| 172 |
+
else:
|
| 173 |
+
raise ValueError(f"Unknown fill: {fill!r}")
|
| 174 |
+
return out
|
| 175 |
+
|
| 176 |
+
return _two_view(single_view)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
SIMCLR_AUG_FACTORIES["chan_then_pcspan"] = _channel_then_pcspan_factory
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def _crop_then_chan_pcspan_factory(
|
| 183 |
+
crop_ratio: tuple[float, float] = (0.25, 0.75),
|
| 184 |
+
align_to: int = 40,
|
| 185 |
+
drop_prob: float = 0.5,
|
| 186 |
+
min_keep: int = 1,
|
| 187 |
+
span_ratio: tuple[float, float] = (0.3, 0.6),
|
| 188 |
+
n_spans: int = 1,
|
| 189 |
+
fill: str = "zero",
|
| 190 |
+
noise_scale: float = 0.0,
|
| 191 |
+
same_mask_for_batch: bool = False,
|
| 192 |
+
):
|
| 193 |
+
def single_view(x: torch.Tensor) -> torch.Tensor:
|
| 194 |
+
y = A.random_time_crop(x, ratio=crop_ratio, resize_back=True, align_to=align_to)
|
| 195 |
+
y = A.channel_dropout(y, drop_prob=drop_prob, min_keep=min_keep)
|
| 196 |
+
|
| 197 |
+
B, C, T = y.shape
|
| 198 |
+
device = y.device
|
| 199 |
+
min_len = max(1, int(round(span_ratio[0] * T)))
|
| 200 |
+
max_len = max(min_len, int(round(span_ratio[1] * T)))
|
| 201 |
+
arange_T = torch.arange(T, device=device)
|
| 202 |
+
mask = torch.zeros((B, C, T), device=device, dtype=torch.bool)
|
| 203 |
+
shape_bc = (1, C) if same_mask_for_batch else (B, C)
|
| 204 |
+
|
| 205 |
+
for _ in range(max(1, int(n_spans))):
|
| 206 |
+
lengths = torch.full(shape_bc, max_len, device=device, dtype=torch.long) \
|
| 207 |
+
if max_len == min_len else torch.randint(min_len, max_len + 1, shape_bc, device=device)
|
| 208 |
+
max_start = (T - lengths).clamp_min(0)
|
| 209 |
+
if (max_start > 0).any():
|
| 210 |
+
rnd = torch.rand_like(max_start, dtype=torch.float32)
|
| 211 |
+
starts = torch.floor(rnd * (max_start.to(torch.float32) + 1)).to(torch.long)
|
| 212 |
+
else:
|
| 213 |
+
starts = torch.zeros_like(max_start)
|
| 214 |
+
if same_mask_for_batch and B > 1:
|
| 215 |
+
starts = starts.expand(B, C)
|
| 216 |
+
lengths = lengths.expand(B, C)
|
| 217 |
+
span_mask = (arange_T.view(1, 1, T) >= starts.unsqueeze(-1)) & \
|
| 218 |
+
(arange_T.view(1, 1, T) < (starts + lengths).unsqueeze(-1))
|
| 219 |
+
mask |= span_mask
|
| 220 |
+
|
| 221 |
+
out = y.clone()
|
| 222 |
+
if fill == "zero":
|
| 223 |
+
out[mask] = 0.0
|
| 224 |
+
elif fill == "mean":
|
| 225 |
+
m = y.mean(dim=-1, keepdim=True)
|
| 226 |
+
out = torch.where(mask, m.expand_as(y), out)
|
| 227 |
+
elif fill == "noise":
|
| 228 |
+
m = y.mean(dim=-1, keepdim=True)
|
| 229 |
+
s = y.std(dim=-1, keepdim=True, unbiased=False).clamp_min(1e-8)
|
| 230 |
+
noise = torch.randn_like(y) * (s * noise_scale) + m
|
| 231 |
+
out = torch.where(mask, noise, out)
|
| 232 |
+
else:
|
| 233 |
+
raise ValueError(f"Unknown fill: {fill!r}")
|
| 234 |
+
return out
|
| 235 |
+
|
| 236 |
+
return _two_view(single_view)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
SIMCLR_AUG_FACTORIES["crop_then_chan_pcspan"] = _crop_then_chan_pcspan_factory
|
| 240 |
+
|
| 241 |
+
SIMCLR_AUG_REGISTRY.update({
|
| 242 |
+
"chan_then_pcspan": _channel_then_pcspan_factory(
|
| 243 |
+
drop_prob=0.5, min_keep=1, ratio=(0.3, 0.6), n_spans=1, fill="zero",
|
| 244 |
+
noise_scale=0, same_mask_for_batch=False
|
| 245 |
+
),
|
| 246 |
+
"chan_then_pcspan_light": _channel_then_pcspan_factory(
|
| 247 |
+
drop_prob=0.25, min_keep=1, ratio=(0.3, 0.6), n_spans=1, fill="zero",
|
| 248 |
+
noise_scale=0, same_mask_for_batch=False
|
| 249 |
+
),
|
| 250 |
+
"crop_then_chan_pcspan": _crop_then_chan_pcspan_factory(
|
| 251 |
+
crop_ratio=(0.25, 0.75), align_to=40, drop_prob=0.5, min_keep=1,
|
| 252 |
+
span_ratio=(0.3, 0.6), n_spans=1, fill="zero", noise_scale=0, same_mask_for_batch=False
|
| 253 |
+
),
|
| 254 |
+
"crop_then_chan_pcspan_light": _crop_then_chan_pcspan_factory(
|
| 255 |
+
crop_ratio=(0.25, 0.75), align_to=40, drop_prob=0.25, min_keep=1,
|
| 256 |
+
span_ratio=(0.3, 0.6), n_spans=1, fill="zero", noise_scale=0, same_mask_for_batch=False
|
| 257 |
+
),
|
| 258 |
+
})
|
osf/models/__init__.py
ADDED
|
File without changes
|
osf/models/balanced_losses.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Balanced/imbalanced learning losses.
|
| 3 |
+
Reference: https://github.com/YyzHarry/SubpopBench
|
| 4 |
+
"""
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from typing import Optional
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class FocalLoss(nn.Module):
|
| 12 |
+
"""
|
| 13 |
+
Focal Loss: FL(p_t) = -α_t * (1 - p_t)^γ * log(p_t)
|
| 14 |
+
Paper: https://arxiv.org/abs/1708.02002
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
alpha: Weighting factor (float or [num_classes] tensor)
|
| 18 |
+
gamma: Focusing parameter (higher = more focus on hard examples)
|
| 19 |
+
reduction: 'mean' or 'none'
|
| 20 |
+
"""
|
| 21 |
+
def __init__(self, alpha: Optional[float | torch.Tensor] = None, gamma: float = 2.0, reduction: str = "mean"):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.gamma = gamma
|
| 24 |
+
self.reduction = reduction
|
| 25 |
+
|
| 26 |
+
if isinstance(alpha, (float, int)):
|
| 27 |
+
self.register_buffer("alpha", torch.tensor([alpha], dtype=torch.float32))
|
| 28 |
+
elif isinstance(alpha, torch.Tensor):
|
| 29 |
+
self.register_buffer("alpha", alpha.float())
|
| 30 |
+
elif alpha is None:
|
| 31 |
+
self.alpha = None
|
| 32 |
+
else:
|
| 33 |
+
raise ValueError(f"alpha must be float, Tensor, or None, got {type(alpha)}")
|
| 34 |
+
|
| 35 |
+
def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
| 36 |
+
"""
|
| 37 |
+
Args:
|
| 38 |
+
logits: [B, C] unnormalized logits
|
| 39 |
+
targets: [B] class indices
|
| 40 |
+
"""
|
| 41 |
+
ce_loss = F.cross_entropy(logits, targets, reduction="none")
|
| 42 |
+
pt = torch.exp(-ce_loss) # p_t
|
| 43 |
+
focal_loss = ((1 - pt) ** self.gamma) * ce_loss
|
| 44 |
+
|
| 45 |
+
if self.alpha is not None:
|
| 46 |
+
if self.alpha.dim() == 0 or len(self.alpha) == 1:
|
| 47 |
+
alpha_t = self.alpha.squeeze()
|
| 48 |
+
else:
|
| 49 |
+
alpha_t = self.alpha[targets] # [B]
|
| 50 |
+
focal_loss = alpha_t * focal_loss
|
| 51 |
+
|
| 52 |
+
if self.reduction == "mean":
|
| 53 |
+
return focal_loss.mean()
|
| 54 |
+
elif self.reduction == "none":
|
| 55 |
+
return focal_loss
|
| 56 |
+
else:
|
| 57 |
+
raise ValueError(f"reduction must be 'mean' or 'none', got {self.reduction}")
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class BalancedSoftmax(nn.Module):
|
| 61 |
+
"""
|
| 62 |
+
Balanced Softmax: adjusted_logits = logits + log(class_counts)
|
| 63 |
+
Paper: https://arxiv.org/abs/2007.10740
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
class_counts: [C] tensor of sample counts per class
|
| 67 |
+
reduction: 'mean' or 'none'
|
| 68 |
+
"""
|
| 69 |
+
def __init__(self, class_counts: torch.Tensor, reduction: str = "mean"):
|
| 70 |
+
super().__init__()
|
| 71 |
+
if not isinstance(class_counts, torch.Tensor):
|
| 72 |
+
class_counts = torch.tensor(class_counts, dtype=torch.float32)
|
| 73 |
+
|
| 74 |
+
class_counts = class_counts.float()
|
| 75 |
+
if (class_counts == 0).any():
|
| 76 |
+
zero_classes = (class_counts == 0).nonzero(as_tuple=True)[0].tolist()
|
| 77 |
+
raise ValueError(f"BalancedSoftmax requires non-zero class counts. Zero counts: {zero_classes}")
|
| 78 |
+
|
| 79 |
+
self.register_buffer("log_class_counts", torch.log(class_counts))
|
| 80 |
+
self.reduction = reduction
|
| 81 |
+
|
| 82 |
+
def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
| 83 |
+
"""
|
| 84 |
+
Args:
|
| 85 |
+
logits: [B, C] unnormalized logits
|
| 86 |
+
targets: [B] class indices
|
| 87 |
+
"""
|
| 88 |
+
adjusted_logits = logits + self.log_class_counts.unsqueeze(0)
|
| 89 |
+
return F.cross_entropy(adjusted_logits, targets, reduction=self.reduction)
|
osf/models/base_pretrain_model.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from pytorch_lightning import LightningModule
|
| 5 |
+
|
| 6 |
+
from osf.backbone.vit1d import vit_nano, vit_tiny, vit_small, vit_middle, vit_base
|
| 7 |
+
|
| 8 |
+
VIT_FACTORIES = {
|
| 9 |
+
"vit_nano": vit_nano,
|
| 10 |
+
"vit_tiny": vit_tiny,
|
| 11 |
+
"vit_small": vit_small,
|
| 12 |
+
"vit_middle": vit_middle,
|
| 13 |
+
"vit_base": vit_base,
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class PSGModalityEncoder(nn.Module):
|
| 18 |
+
"""ViT encoder for PSG signals: backbone -> optional projection -> L2-norm"""
|
| 19 |
+
|
| 20 |
+
def __init__(self, *,
|
| 21 |
+
encoder_name: str,
|
| 22 |
+
proj_out: int = 256,
|
| 23 |
+
proj_hidden: int = 512,
|
| 24 |
+
freq: int = 64,
|
| 25 |
+
win_sec: int = 30,
|
| 26 |
+
channel: int = 11,
|
| 27 |
+
lead_wise=0,
|
| 28 |
+
patch_size=40,
|
| 29 |
+
patch_size_ch=4,
|
| 30 |
+
use_lead_embedding: bool = True,
|
| 31 |
+
is_proj_head=1):
|
| 32 |
+
super().__init__()
|
| 33 |
+
token_len = freq * win_sec
|
| 34 |
+
self.token_len = token_len
|
| 35 |
+
self.patch_size = patch_size
|
| 36 |
+
|
| 37 |
+
if encoder_name not in VIT_FACTORIES:
|
| 38 |
+
raise ValueError(f"Unknown encoder_name: {encoder_name}. Choose from {list(VIT_FACTORIES.keys())}")
|
| 39 |
+
|
| 40 |
+
self.backbone = VIT_FACTORIES[encoder_name](
|
| 41 |
+
num_leads=channel, seq_len=token_len, patch_size=patch_size,
|
| 42 |
+
lead_wise=lead_wise, patch_size_ch=patch_size_ch,
|
| 43 |
+
use_lead_embedding=use_lead_embedding,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
d_model = self.backbone.width
|
| 47 |
+
if is_proj_head == 1:
|
| 48 |
+
self.proj_head = nn.Sequential(
|
| 49 |
+
nn.Linear(d_model, proj_hidden),
|
| 50 |
+
nn.LayerNorm(proj_hidden),
|
| 51 |
+
nn.ReLU(inplace=True),
|
| 52 |
+
nn.Linear(proj_hidden, proj_out),
|
| 53 |
+
nn.LayerNorm(proj_out),
|
| 54 |
+
)
|
| 55 |
+
else:
|
| 56 |
+
self.proj_head = None
|
| 57 |
+
|
| 58 |
+
def forward(self, x, normalize=True):
|
| 59 |
+
# x: [B, C, T]
|
| 60 |
+
h = self.backbone(x) # [B, D]
|
| 61 |
+
if self.proj_head is not None:
|
| 62 |
+
h = self.proj_head(h) # [B, proj_out]
|
| 63 |
+
if normalize:
|
| 64 |
+
return F.normalize(h, dim=-1)
|
| 65 |
+
return h
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class BasePretrainModel(LightningModule):
|
| 69 |
+
def __init__(self,
|
| 70 |
+
psg_encoder_name: str = "vit_base",
|
| 71 |
+
text_encoder_name: str = "google/flan-t5-base",
|
| 72 |
+
fusion_decoder_name: str = 'cross-attn',
|
| 73 |
+
shared_emb_dim: int = 256,
|
| 74 |
+
lr: float = 2e-4,
|
| 75 |
+
weight_decay: float = 0.2,
|
| 76 |
+
training_steps_per_epoch: int = 7000,
|
| 77 |
+
max_epochs: int = 100,
|
| 78 |
+
*args, **kwargs):
|
| 79 |
+
super().__init__()
|
| 80 |
+
self.save_hyperparameters()
|
| 81 |
+
self.psg_encoder_name = psg_encoder_name
|
| 82 |
+
self.text_encoder_name = text_encoder_name
|
| 83 |
+
self.fusion_decoder_name = fusion_decoder_name
|
| 84 |
+
self.shared_emb_dim = shared_emb_dim
|
| 85 |
+
self.lr = lr
|
| 86 |
+
self.weight_decay = weight_decay
|
| 87 |
+
self.training_steps_per_epoch = training_steps_per_epoch
|
| 88 |
+
self.max_epochs = max_epochs
|
| 89 |
+
self.warmup_epochs = 0.1 * self.max_epochs
|
| 90 |
+
self.proj_out = shared_emb_dim
|
| 91 |
+
self.proj_hidden = 256
|
| 92 |
+
|
| 93 |
+
assert self.training_steps_per_epoch > 1
|
| 94 |
+
|
| 95 |
+
def configure_optimizers(self):
|
| 96 |
+
optimizer = torch.optim.AdamW(
|
| 97 |
+
self.parameters(),
|
| 98 |
+
lr=self.lr,
|
| 99 |
+
weight_decay=self.weight_decay,
|
| 100 |
+
betas=(0.9, 0.95),
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
total_steps = int(self.training_steps_per_epoch * self.max_epochs)
|
| 104 |
+
warmup_steps = int(round(self.training_steps_per_epoch * self.warmup_epochs))
|
| 105 |
+
warmup_steps = max(0, warmup_steps)
|
| 106 |
+
decay_steps = max(1, total_steps - warmup_steps)
|
| 107 |
+
|
| 108 |
+
if warmup_steps > 0:
|
| 109 |
+
warmup = torch.optim.lr_scheduler.LinearLR(
|
| 110 |
+
optimizer, start_factor=0.01, end_factor=1.0, total_iters=warmup_steps)
|
| 111 |
+
cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 112 |
+
optimizer, T_max=decay_steps, eta_min=1e-8)
|
| 113 |
+
sched = torch.optim.lr_scheduler.SequentialLR(
|
| 114 |
+
optimizer, schedulers=[warmup, cosine], milestones=[warmup_steps])
|
| 115 |
+
else:
|
| 116 |
+
sched = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 117 |
+
optimizer, T_max=decay_steps, eta_min=1e-8)
|
| 118 |
+
|
| 119 |
+
return [optimizer], [{"scheduler": sched, "interval": "step", "frequency": 1}]
|
| 120 |
+
|
| 121 |
+
def training_step(self, batch, batch_idx):
|
| 122 |
+
loss_dict, metrics_dict = self.shared_step(batch, batch_idx)
|
| 123 |
+
for k, v in loss_dict.items():
|
| 124 |
+
self.log(f"train/{k}", v, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
|
| 125 |
+
for k, v in metrics_dict.items():
|
| 126 |
+
self.log(f"train/{k}", v, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
|
| 127 |
+
return loss_dict['loss']
|
| 128 |
+
|
| 129 |
+
def validation_step(self, batch, batch_idx):
|
| 130 |
+
with torch.no_grad():
|
| 131 |
+
loss_dict, metrics_dict = self.shared_step(batch, batch_idx)
|
| 132 |
+
for k, v in loss_dict.items():
|
| 133 |
+
self.log(f"val/{k}", v, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
|
| 134 |
+
for k, v in metrics_dict.items():
|
| 135 |
+
self.log(f"val/{k}", v, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
|
| 136 |
+
return loss_dict
|
| 137 |
+
|
| 138 |
+
def test_step(self, batch, batch_idx):
|
| 139 |
+
loss_dict, metrics_dict = self.shared_step(batch, batch_idx)
|
| 140 |
+
for k, v in loss_dict.items():
|
| 141 |
+
self.log(f"test/{k}", v, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
|
| 142 |
+
for k, v in metrics_dict.items():
|
| 143 |
+
self.log(f"test/{k}", v, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
|
| 144 |
+
return loss_dict
|
osf/models/base_pretrain_model_cls.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
from osf.backbone.vit1d_cls import vit_nano, vit_tiny, vit_small, vit_middle, vit_base, vit_large, vit_xl
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class PSGModalityEncoderCLS(nn.Module):
|
| 6 |
+
"""
|
| 7 |
+
Init helper for ViT with CLS token. No forward() - access .backbone directly.
|
| 8 |
+
|
| 9 |
+
Used by DINO to initialize encoder, then DINO accesses self.encoders["all"].backbone.
|
| 10 |
+
"""
|
| 11 |
+
def __init__(self, *,
|
| 12 |
+
encoder_name: str,
|
| 13 |
+
proj_out: int = 256,
|
| 14 |
+
proj_hidden: int = 512,
|
| 15 |
+
freq: int = 64,
|
| 16 |
+
win_sec: int = 30,
|
| 17 |
+
channel: int = 12,
|
| 18 |
+
lead_wise = 0,
|
| 19 |
+
patch_size = 40,
|
| 20 |
+
patch_size_ch = 4,
|
| 21 |
+
is_proj_head = 1,
|
| 22 |
+
):
|
| 23 |
+
super().__init__()
|
| 24 |
+
token_len = freq * win_sec
|
| 25 |
+
|
| 26 |
+
self.token_len = token_len
|
| 27 |
+
self.patch_size = patch_size
|
| 28 |
+
|
| 29 |
+
if encoder_name == "vit_nano":
|
| 30 |
+
self.backbone = vit_nano(num_leads=channel, seq_len=token_len, patch_size=patch_size, lead_wise=lead_wise, patch_size_ch=patch_size_ch)
|
| 31 |
+
elif encoder_name == "vit_tiny":
|
| 32 |
+
self.backbone = vit_tiny(num_leads=channel, seq_len=token_len, patch_size=patch_size, lead_wise=lead_wise, patch_size_ch=patch_size_ch)
|
| 33 |
+
elif encoder_name == "vit_small":
|
| 34 |
+
self.backbone = vit_small(num_leads=channel, seq_len=token_len, patch_size=patch_size, lead_wise=lead_wise, patch_size_ch=patch_size_ch)
|
| 35 |
+
elif encoder_name == "vit_middle":
|
| 36 |
+
self.backbone = vit_middle(num_leads=channel, seq_len=token_len, patch_size=patch_size, lead_wise=lead_wise, patch_size_ch=patch_size_ch)
|
| 37 |
+
elif encoder_name == "vit_base":
|
| 38 |
+
self.backbone = vit_base(num_leads=channel, seq_len=token_len, patch_size=patch_size, lead_wise=lead_wise, patch_size_ch=patch_size_ch)
|
| 39 |
+
elif encoder_name == "vit_large":
|
| 40 |
+
self.backbone = vit_large(num_leads=channel, seq_len=token_len, patch_size=patch_size, lead_wise=lead_wise, patch_size_ch=patch_size_ch)
|
| 41 |
+
elif encoder_name == "vit_xl":
|
| 42 |
+
self.backbone = vit_xl(num_leads=channel, seq_len=token_len, patch_size=patch_size, lead_wise=lead_wise, patch_size_ch=patch_size_ch)
|
| 43 |
+
else:
|
| 44 |
+
raise ValueError(f"Unknown encoder_name for CLS variant: {encoder_name}")
|
| 45 |
+
|
| 46 |
+
d_model = self.backbone.width
|
| 47 |
+
if is_proj_head == 1:
|
| 48 |
+
self.proj_head = nn.Sequential(
|
| 49 |
+
nn.Linear(d_model, proj_hidden),
|
| 50 |
+
nn.LayerNorm(proj_hidden),
|
| 51 |
+
nn.ReLU(inplace=True),
|
| 52 |
+
nn.Linear(proj_hidden, proj_out),
|
| 53 |
+
nn.LayerNorm(proj_out),
|
| 54 |
+
)
|
| 55 |
+
else:
|
| 56 |
+
self.proj_head = None
|
osf/models/dino_model_cls.py
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import math
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
from osf.models.dino_utils.dino_clstoken_loss import DINOLoss
|
| 9 |
+
from osf.models.dino_utils.ibot_patch_loss import iBOTPatchLoss
|
| 10 |
+
from osf.models.dino_utils.koleo_loss import KoLeoLoss
|
| 11 |
+
from osf.models.base_pretrain_model import BasePretrainModel
|
| 12 |
+
from osf.models.base_pretrain_model_cls import PSGModalityEncoderCLS
|
| 13 |
+
from osf.datasets.simclr_aug_registry import build_simclr_augmentor
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class DINOHead(nn.Module):
|
| 17 |
+
def __init__(self, in_dim, out_dim, hidden_dim=2048, bottleneck_dim=256, nlayers=3):
|
| 18 |
+
super().__init__()
|
| 19 |
+
num_layers = max(nlayers, 1)
|
| 20 |
+
if num_layers == 1:
|
| 21 |
+
self.mlp = nn.Sequential(nn.Linear(in_dim, bottleneck_dim))
|
| 22 |
+
else:
|
| 23 |
+
layers = [nn.Linear(in_dim, hidden_dim), nn.GELU()]
|
| 24 |
+
for _ in range(num_layers - 2):
|
| 25 |
+
layers += [nn.Linear(hidden_dim, hidden_dim), nn.GELU()]
|
| 26 |
+
layers += [nn.Linear(hidden_dim, bottleneck_dim)]
|
| 27 |
+
self.mlp = nn.Sequential(*layers)
|
| 28 |
+
|
| 29 |
+
self.apply(self._init_weights)
|
| 30 |
+
self.prototypes = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
|
| 31 |
+
self.prototypes.weight_g.data.fill_(1.0)
|
| 32 |
+
|
| 33 |
+
@staticmethod
|
| 34 |
+
def _init_weights(m):
|
| 35 |
+
if isinstance(m, nn.Linear):
|
| 36 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
| 37 |
+
if m.bias is not None:
|
| 38 |
+
nn.init.zeros_(m.bias)
|
| 39 |
+
|
| 40 |
+
def forward(self, x):
|
| 41 |
+
x = self.mlp(x)
|
| 42 |
+
x = F.normalize(x, dim=-1)
|
| 43 |
+
return self.prototypes(x)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class DINOCLSModel(BasePretrainModel):
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
psg_encoder_name: str = "vit_base",
|
| 50 |
+
text_encoder_name: Optional[str] = None,
|
| 51 |
+
shared_emb_dim: int = 768,
|
| 52 |
+
out_dim: int = 2048,
|
| 53 |
+
patch_out_dim: int = 2048,
|
| 54 |
+
dino_out_dim: int = None,
|
| 55 |
+
dino_patch_out_dim: int = None,
|
| 56 |
+
dino_hidden_dim: int = 2048,
|
| 57 |
+
dino_bottleneck_dim: int = 256,
|
| 58 |
+
student_temp: float = 0.1,
|
| 59 |
+
teacher_temp_warmup: float = 0.04,
|
| 60 |
+
teacher_temp_final: float = 0.07,
|
| 61 |
+
teacher_temp_warmup_iters: int = 10000,
|
| 62 |
+
base_momentum: float = 0.996,
|
| 63 |
+
use_koleo: bool = True,
|
| 64 |
+
koleo_lambda: float = 0.0,
|
| 65 |
+
ibot_lambda: float = 0.0,
|
| 66 |
+
lr: float = 2e-4,
|
| 67 |
+
weight_decay: float = 0.2,
|
| 68 |
+
num_freeze_layers: int = 6,
|
| 69 |
+
simclr_augmentation: dict | None = None,
|
| 70 |
+
n_local_crops: int = 2,
|
| 71 |
+
*args, **kwargs
|
| 72 |
+
):
|
| 73 |
+
super().__init__(
|
| 74 |
+
psg_encoder_name=psg_encoder_name,
|
| 75 |
+
text_encoder_name=None,
|
| 76 |
+
shared_emb_dim=shared_emb_dim,
|
| 77 |
+
lr=lr,
|
| 78 |
+
weight_decay=weight_decay,
|
| 79 |
+
*args, **kwargs
|
| 80 |
+
)
|
| 81 |
+
self.save_hyperparameters()
|
| 82 |
+
|
| 83 |
+
self.proj_out = shared_emb_dim
|
| 84 |
+
self.proj_hidden = 256
|
| 85 |
+
self.num_freeze_layers = num_freeze_layers
|
| 86 |
+
|
| 87 |
+
num_leads = kwargs.get('num_leads', 12)
|
| 88 |
+
self.num_leads = num_leads
|
| 89 |
+
|
| 90 |
+
self.cfg = [dict(name="all", freq=64, win_sec=30, in_ch=num_leads)]
|
| 91 |
+
self.encoders = nn.ModuleDict()
|
| 92 |
+
for mod in self.cfg:
|
| 93 |
+
self.encoders[mod["name"]] = PSGModalityEncoderCLS(
|
| 94 |
+
encoder_name=psg_encoder_name,
|
| 95 |
+
proj_out=shared_emb_dim,
|
| 96 |
+
proj_hidden=256,
|
| 97 |
+
freq=mod["freq"],
|
| 98 |
+
win_sec=mod["win_sec"],
|
| 99 |
+
channel=mod["in_ch"],
|
| 100 |
+
patch_size=kwargs['patch_size_time'],
|
| 101 |
+
lead_wise=kwargs['lead_wise'],
|
| 102 |
+
patch_size_ch=(num_leads if kwargs['lead_wise'] == 0 else kwargs['patch_size_ch']),
|
| 103 |
+
is_proj_head=0,
|
| 104 |
+
)
|
| 105 |
+
self.lead_wise = kwargs['lead_wise']
|
| 106 |
+
self.patch_size_time = kwargs['patch_size_time']
|
| 107 |
+
self.patch_size_ch = (num_leads if self.lead_wise == 0 else kwargs['patch_size_ch'])
|
| 108 |
+
trunk_dim = self.encoders['all'].backbone.width
|
| 109 |
+
out_dim = dino_out_dim if dino_out_dim is not None else out_dim
|
| 110 |
+
patch_out_dim = dino_patch_out_dim if dino_patch_out_dim is not None else patch_out_dim
|
| 111 |
+
self.out_dim = out_dim
|
| 112 |
+
self.patch_out_dim = patch_out_dim
|
| 113 |
+
|
| 114 |
+
self.student_global_head = DINOHead(trunk_dim, out_dim, dino_hidden_dim, dino_bottleneck_dim, 3)
|
| 115 |
+
self.student_patch_head = DINOHead(trunk_dim, patch_out_dim, dino_hidden_dim, dino_bottleneck_dim, 3)
|
| 116 |
+
self.teacher_encoder = copy.deepcopy(self.encoders["all"])
|
| 117 |
+
for p in self.teacher_encoder.parameters():
|
| 118 |
+
p.requires_grad = False
|
| 119 |
+
self.teacher_encoder.eval()
|
| 120 |
+
|
| 121 |
+
self.teacher_global_head = DINOHead(trunk_dim, out_dim, dino_hidden_dim, dino_bottleneck_dim, 3)
|
| 122 |
+
self.teacher_patch_head = DINOHead(trunk_dim, patch_out_dim, dino_hidden_dim, dino_bottleneck_dim, 3)
|
| 123 |
+
self.teacher_global_head.load_state_dict(self.student_global_head.state_dict(), strict=True)
|
| 124 |
+
self.teacher_patch_head.load_state_dict(self.student_patch_head.state_dict(), strict=True)
|
| 125 |
+
for p in self.teacher_global_head.parameters():
|
| 126 |
+
p.requires_grad = False
|
| 127 |
+
for p in self.teacher_patch_head.parameters():
|
| 128 |
+
p.requires_grad = False
|
| 129 |
+
self.teacher_global_head.eval()
|
| 130 |
+
self.teacher_patch_head.eval()
|
| 131 |
+
self.dino_loss = DINOLoss(out_dim=out_dim, student_temp=student_temp, center_momentum=0.9)
|
| 132 |
+
self.ibot_loss = iBOTPatchLoss(patch_out_dim=patch_out_dim, student_temp=student_temp, center_momentum=0.9)
|
| 133 |
+
self.koleo = KoLeoLoss() if use_koleo else None
|
| 134 |
+
self.koleo_lambda = float(koleo_lambda)
|
| 135 |
+
self.ibot_lambda = float(ibot_lambda)
|
| 136 |
+
self.teacher_temp_warmup = float(teacher_temp_warmup)
|
| 137 |
+
self.teacher_temp_final = float(teacher_temp_final)
|
| 138 |
+
self.teacher_temp_warmup_iters = int(teacher_temp_warmup_iters)
|
| 139 |
+
self.base_momentum = float(base_momentum)
|
| 140 |
+
|
| 141 |
+
self.register_buffer("seen_steps", torch.tensor(0, dtype=torch.long))
|
| 142 |
+
|
| 143 |
+
if simclr_augmentation is None:
|
| 144 |
+
simclr_augmentation = {}
|
| 145 |
+
self.simclr_augmentation = simclr_augmentation
|
| 146 |
+
self.augmentor = build_simclr_augmentor(self.simclr_augmentation)
|
| 147 |
+
self.n_local_crops = int(n_local_crops)
|
| 148 |
+
|
| 149 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, trunk_dim))
|
| 150 |
+
nn.init.trunc_normal_(self.mask_token, std=0.02)
|
| 151 |
+
|
| 152 |
+
def _teacher_temp(self, step: int) -> float:
|
| 153 |
+
if step < self.teacher_temp_warmup_iters:
|
| 154 |
+
alpha = step / float(max(1, self.teacher_temp_warmup_iters))
|
| 155 |
+
return self.teacher_temp_warmup * (1 - alpha) + self.teacher_temp_final * alpha
|
| 156 |
+
return self.teacher_temp_final
|
| 157 |
+
|
| 158 |
+
def _momentum(self, step: int, max_steps: int) -> float:
|
| 159 |
+
return 1.0 - (1.0 - self.base_momentum) * (math.cos(math.pi * step / max_steps) + 1) / 2
|
| 160 |
+
|
| 161 |
+
@torch.no_grad()
|
| 162 |
+
def _ema_update(self, m: float):
|
| 163 |
+
for param_q, param_k in zip(self.encoders['all'].parameters(), self.teacher_encoder.parameters()):
|
| 164 |
+
param_k.data.mul_(m).add_(param_q.data, alpha=1.0 - m)
|
| 165 |
+
for param_q, param_k in zip(self.student_global_head.parameters(), self.teacher_global_head.parameters()):
|
| 166 |
+
param_k.data.mul_(m).add_(param_q.data, alpha=1.0 - m)
|
| 167 |
+
for param_q, param_k in zip(self.student_patch_head.parameters(), self.teacher_patch_head.parameters()):
|
| 168 |
+
param_k.data.mul_(m).add_(param_q.data, alpha=1.0 - m)
|
| 169 |
+
self.teacher_encoder.eval()
|
| 170 |
+
self.teacher_global_head.eval()
|
| 171 |
+
self.teacher_patch_head.eval()
|
| 172 |
+
|
| 173 |
+
def _forward_encoder(self, encoder, x, return_tokens=True):
|
| 174 |
+
# x: [B, C, T]
|
| 175 |
+
if return_tokens:
|
| 176 |
+
cls, patches = encoder.backbone.forward_encoding(x, return_sequence=False)
|
| 177 |
+
return cls, patches # [B, D], [B, N, D]
|
| 178 |
+
else:
|
| 179 |
+
cls = encoder.backbone(x)
|
| 180 |
+
return cls, None # [B, D], None
|
| 181 |
+
|
| 182 |
+
def _make_views_aug(self, x: torch.Tensor):
|
| 183 |
+
v1, v2 = self.augmentor(x)
|
| 184 |
+
globals_x = [v1, v2]
|
| 185 |
+
locals_x = []
|
| 186 |
+
for _ in range(self.n_local_crops):
|
| 187 |
+
lv1, _ = self.augmentor(x)
|
| 188 |
+
locals_x.append(lv1)
|
| 189 |
+
return globals_x, locals_x
|
| 190 |
+
|
| 191 |
+
def shared_step(self, batch, batch_idx):
|
| 192 |
+
x = batch["psg"]
|
| 193 |
+
globals_x, locals_x = self._make_views_aug(x)
|
| 194 |
+
tt = self._teacher_temp(int(self.global_step))
|
| 195 |
+
|
| 196 |
+
with torch.no_grad():
|
| 197 |
+
teacher_out_soft_list = []
|
| 198 |
+
teacher_global_logits_cache = []
|
| 199 |
+
teacher_patch_logits_cache = []
|
| 200 |
+
|
| 201 |
+
if len(globals_x) > 0:
|
| 202 |
+
g_sizes = [gx.size(0) for gx in globals_x]
|
| 203 |
+
g_cat = torch.cat(globals_x, dim=0)
|
| 204 |
+
cls_t_cat, _ = self._forward_encoder(self.teacher_encoder, g_cat, return_tokens=True)
|
| 205 |
+
g_logits_cat = self.teacher_global_head(cls_t_cat)
|
| 206 |
+
g_logits_split = list(torch.split(g_logits_cat, g_sizes, dim=0))
|
| 207 |
+
teacher_out_soft_list = [self.dino_loss.softmax_center_teacher(gl, tt) for gl in g_logits_split]
|
| 208 |
+
teacher_global_logits_cache = g_logits_split
|
| 209 |
+
|
| 210 |
+
student_global_logits = []
|
| 211 |
+
student_cls_tokens = []
|
| 212 |
+
all_student_views = globals_x + locals_x
|
| 213 |
+
if len(all_student_views) > 0:
|
| 214 |
+
s_sizes = [sx.size(0) for sx in all_student_views]
|
| 215 |
+
s_cat = torch.cat(all_student_views, dim=0)
|
| 216 |
+
cls_s_cat, _ = self._forward_encoder(self.encoders["all"], s_cat, return_tokens=False)
|
| 217 |
+
sg_logits_cat = self.student_global_head(cls_s_cat)
|
| 218 |
+
student_global_logits = list(torch.split(sg_logits_cat, s_sizes, dim=0))
|
| 219 |
+
student_cls_tokens = list(torch.split(cls_s_cat, s_sizes, dim=0))
|
| 220 |
+
|
| 221 |
+
ibot_loss_val = torch.tensor(0.0, device=x.device)
|
| 222 |
+
if len(globals_x) > 0:
|
| 223 |
+
with torch.no_grad():
|
| 224 |
+
t_tokens, _ = self.teacher_encoder.backbone.to_tokens_2d(
|
| 225 |
+
globals_x[0], patch_size_ch=self.patch_size_ch, patch_size_time=self.patch_size_time)
|
| 226 |
+
B2 = t_tokens.size(0)
|
| 227 |
+
cls_tok = self.teacher_encoder.backbone.cls_token.expand(B2, -1, -1)
|
| 228 |
+
t_full = torch.cat([cls_tok, t_tokens], dim=1)
|
| 229 |
+
pe_full = self.teacher_encoder.backbone.pos_embedding[:, :t_full.size(1), :].to(t_full.device)
|
| 230 |
+
t_full = t_full + pe_full
|
| 231 |
+
t_full = self.teacher_encoder.backbone._run_blocks(t_full)
|
| 232 |
+
_, t_patches = t_full[:, 0], t_full[:, 1:]
|
| 233 |
+
t_logits_all = self.teacher_patch_head(t_patches)
|
| 234 |
+
t_soft = self.ibot_loss.softmax_center_teacher(t_logits_all, tt)
|
| 235 |
+
|
| 236 |
+
s_tokens, _ = self.encoders["all"].backbone.to_tokens_2d(
|
| 237 |
+
globals_x[0], patch_size_ch=self.patch_size_ch, patch_size_time=self.patch_size_time)
|
| 238 |
+
B2, N, Dtok = s_tokens.shape
|
| 239 |
+
|
| 240 |
+
mask_ratio = float(getattr(self, "ibot_mask_ratio", 0.3))
|
| 241 |
+
n_mask = max(1, int(round(N * mask_ratio)))
|
| 242 |
+
rand = torch.rand(B2, N, device=x.device)
|
| 243 |
+
topk_idx = rand.topk(k=n_mask, dim=1, largest=True).indices
|
| 244 |
+
masks = torch.zeros(B2, N, dtype=torch.bool, device=x.device)
|
| 245 |
+
masks.scatter_(1, topk_idx, True)
|
| 246 |
+
|
| 247 |
+
s_tokens_masked = torch.where(
|
| 248 |
+
masks.unsqueeze(-1),
|
| 249 |
+
self.mask_token.expand_as(s_tokens),
|
| 250 |
+
s_tokens
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
cls_tok_s = self.encoders["all"].backbone.cls_token.expand(B2, -1, -1)
|
| 254 |
+
s_full = torch.cat([cls_tok_s, s_tokens_masked], dim=1)
|
| 255 |
+
pe_full_s = self.encoders["all"].backbone.pos_embedding[:, :s_full.size(1), :].to(s_full.device)
|
| 256 |
+
s_full = s_full + pe_full_s
|
| 257 |
+
s_full = self.encoders["all"].backbone._run_blocks(s_full)
|
| 258 |
+
_, s_patches = s_full[:, 0], s_full[:, 1:]
|
| 259 |
+
s_logits_all = self.student_patch_head(s_patches)
|
| 260 |
+
|
| 261 |
+
ibot_loss_val = self.ibot_loss.forward_masked(
|
| 262 |
+
student_patch_tokens_masked=s_logits_all[masks],
|
| 263 |
+
teacher_patch_tokens_masked=t_soft[masks],
|
| 264 |
+
student_masks_flat=masks,
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
with torch.no_grad():
|
| 268 |
+
teacher_patch_logits_cache.append(t_logits_all)
|
| 269 |
+
|
| 270 |
+
dino_loss_val = self.dino_loss(student_global_logits, teacher_out_soft_list)
|
| 271 |
+
pair_norm = max(1, len(student_global_logits) * len(teacher_out_soft_list))
|
| 272 |
+
dino_loss_val = dino_loss_val / pair_norm
|
| 273 |
+
koleo_val = torch.tensor(0.0, device=x.device)
|
| 274 |
+
if self.koleo is not None and len(student_cls_tokens) > 0:
|
| 275 |
+
koleo_val = self.koleo(F.normalize(student_cls_tokens[0], dim=-1))
|
| 276 |
+
|
| 277 |
+
total_loss = dino_loss_val + self.ibot_lambda * ibot_loss_val + self.koleo_lambda * koleo_val
|
| 278 |
+
|
| 279 |
+
with torch.no_grad():
|
| 280 |
+
if self.training:
|
| 281 |
+
if len(teacher_global_logits_cache) > 0:
|
| 282 |
+
self.dino_loss.update_center(torch.cat(teacher_global_logits_cache, dim=0))
|
| 283 |
+
if len(teacher_patch_logits_cache) > 0:
|
| 284 |
+
self.ibot_loss.update_center(torch.cat(teacher_patch_logits_cache, dim=0))
|
| 285 |
+
|
| 286 |
+
metrics = {
|
| 287 |
+
"loss": total_loss,
|
| 288 |
+
"loss/dino": dino_loss_val,
|
| 289 |
+
"loss/ibot": ibot_loss_val,
|
| 290 |
+
"loss/koleo": koleo_val,
|
| 291 |
+
"sched/teacher_temp": torch.tensor(tt, device=x.device),
|
| 292 |
+
}
|
| 293 |
+
return {"loss": total_loss}, metrics
|
| 294 |
+
|
| 295 |
+
def training_step(self, batch, batch_idx):
|
| 296 |
+
loss_dict, metrics = self.shared_step(batch, batch_idx)
|
| 297 |
+
for k, v in metrics.items():
|
| 298 |
+
self.log(f"train/{k}", v, on_step=True, on_epoch=True, prog_bar=(k == "loss"), sync_dist=True)
|
| 299 |
+
return loss_dict["loss"]
|
| 300 |
+
|
| 301 |
+
def on_train_batch_end(self, outputs, batch, batch_idx):
|
| 302 |
+
max_steps = max(1, getattr(self.trainer, "max_steps", getattr(self.trainer, "estimated_stepping_batches", 100000)))
|
| 303 |
+
m = self._momentum(int(self.global_step), max_steps)
|
| 304 |
+
self._ema_update(m)
|
| 305 |
+
self.log("sched/momentum", torch.tensor(m, device=self.device), on_step=True, prog_bar=False)
|
| 306 |
+
|
| 307 |
+
def validation_step(self, batch, batch_idx):
|
| 308 |
+
loss_dict, metrics = self.shared_step(batch, batch_idx)
|
| 309 |
+
for k, v in metrics.items():
|
| 310 |
+
self.log(f"val/{k}", v, on_step=True, on_epoch=True, prog_bar=(k == "loss"), sync_dist=True)
|
| 311 |
+
return loss_dict["loss"]
|
osf/models/dino_utils/dino_clstoken_loss.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.distributed as dist
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch import nn
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DINOLoss(nn.Module):
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
out_dim,
|
| 16 |
+
student_temp=0.1,
|
| 17 |
+
center_momentum=0.9,
|
| 18 |
+
):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.student_temp = student_temp
|
| 21 |
+
self.center_momentum = center_momentum
|
| 22 |
+
self.register_buffer("center", torch.zeros(1, out_dim))
|
| 23 |
+
self.updated = True
|
| 24 |
+
self.reduce_handle = None
|
| 25 |
+
self.len_teacher_output = None
|
| 26 |
+
self.async_batch_center = None
|
| 27 |
+
|
| 28 |
+
@torch.no_grad()
|
| 29 |
+
def softmax_center_teacher(self, teacher_output, teacher_temp):
|
| 30 |
+
self.apply_center_update()
|
| 31 |
+
# teacher centering and sharpening
|
| 32 |
+
return F.softmax((teacher_output - self.center) / teacher_temp, dim=-1)
|
| 33 |
+
|
| 34 |
+
@torch.no_grad()
|
| 35 |
+
def sinkhorn_knopp_teacher(self, teacher_output, teacher_temp, n_iterations=3):
|
| 36 |
+
teacher_output = teacher_output.float()
|
| 37 |
+
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
| 38 |
+
Q = torch.exp(teacher_output / teacher_temp).t()
|
| 39 |
+
B = Q.shape[1] * world_size
|
| 40 |
+
K = Q.shape[0]
|
| 41 |
+
sum_Q = torch.sum(Q)
|
| 42 |
+
if dist.is_initialized():
|
| 43 |
+
dist.all_reduce(sum_Q)
|
| 44 |
+
Q /= sum_Q
|
| 45 |
+
|
| 46 |
+
for it in range(n_iterations):
|
| 47 |
+
sum_of_rows = torch.sum(Q, dim=1, keepdim=True)
|
| 48 |
+
if dist.is_initialized():
|
| 49 |
+
dist.all_reduce(sum_of_rows)
|
| 50 |
+
Q /= sum_of_rows
|
| 51 |
+
Q /= K
|
| 52 |
+
Q /= torch.sum(Q, dim=0, keepdim=True)
|
| 53 |
+
Q /= B
|
| 54 |
+
|
| 55 |
+
Q *= B
|
| 56 |
+
return Q.t()
|
| 57 |
+
|
| 58 |
+
def forward(self, student_output_list, teacher_out_softmaxed_centered_list):
|
| 59 |
+
"""
|
| 60 |
+
Cross-entropy between softmax outputs of the teacher and student networks.
|
| 61 |
+
"""
|
| 62 |
+
# TODO: Use cross_entropy_distribution here
|
| 63 |
+
total_loss = 0
|
| 64 |
+
for s in student_output_list:
|
| 65 |
+
lsm = F.log_softmax(s / self.student_temp, dim=-1)
|
| 66 |
+
for t in teacher_out_softmaxed_centered_list:
|
| 67 |
+
loss = torch.sum(t * lsm, dim=-1)
|
| 68 |
+
total_loss -= loss.mean()
|
| 69 |
+
return total_loss
|
| 70 |
+
|
| 71 |
+
@torch.no_grad()
|
| 72 |
+
def update_center(self, teacher_output):
|
| 73 |
+
self.reduce_center_update(teacher_output)
|
| 74 |
+
|
| 75 |
+
@torch.no_grad()
|
| 76 |
+
def reduce_center_update(self, teacher_output):
|
| 77 |
+
self.updated = False
|
| 78 |
+
self.len_teacher_output = len(teacher_output)
|
| 79 |
+
self.async_batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
|
| 80 |
+
if dist.is_initialized():
|
| 81 |
+
self.reduce_handle = dist.all_reduce(self.async_batch_center, async_op=True)
|
| 82 |
+
|
| 83 |
+
@torch.no_grad()
|
| 84 |
+
def apply_center_update(self):
|
| 85 |
+
if self.updated is False:
|
| 86 |
+
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
| 87 |
+
|
| 88 |
+
if self.reduce_handle is not None:
|
| 89 |
+
self.reduce_handle.wait()
|
| 90 |
+
_t = self.async_batch_center / (self.len_teacher_output * world_size)
|
| 91 |
+
|
| 92 |
+
self.center = self.center * self.center_momentum + _t * (1 - self.center_momentum)
|
| 93 |
+
|
| 94 |
+
self.updated = True
|
| 95 |
+
|
| 96 |
+
|
osf/models/dino_utils/ibot_patch_loss.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.distributed as dist
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch import nn
|
| 10 |
+
|
| 11 |
+
import logging
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger("dinov2")
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
from xformers.ops import cross_entropy
|
| 19 |
+
|
| 20 |
+
def lossfunc(t, s, temp):
|
| 21 |
+
s = s.float()
|
| 22 |
+
t = t.float()
|
| 23 |
+
if s.ndim == 2:
|
| 24 |
+
return -cross_entropy(s.unsqueeze(0), t.unsqueeze(0), temp, bw_inplace=True).squeeze(0)
|
| 25 |
+
elif s.ndim == 3:
|
| 26 |
+
return -cross_entropy(s, t, temp, bw_inplace=True)
|
| 27 |
+
|
| 28 |
+
except ImportError:
|
| 29 |
+
|
| 30 |
+
def lossfunc(t, s, temp):
|
| 31 |
+
return torch.sum(t * F.log_softmax(s / temp, dim=-1), dim=-1)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class iBOTPatchLoss(nn.Module):
|
| 35 |
+
def __init__(self, patch_out_dim, student_temp=0.1, center_momentum=0.9):
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.student_temp = student_temp
|
| 38 |
+
self.center_momentum = center_momentum
|
| 39 |
+
self.register_buffer("center", torch.zeros(1, 1, patch_out_dim))
|
| 40 |
+
self.updated = True
|
| 41 |
+
self.reduce_handle = None
|
| 42 |
+
self.len_teacher_patch_tokens = None
|
| 43 |
+
self.async_batch_center = None
|
| 44 |
+
|
| 45 |
+
@torch.no_grad()
|
| 46 |
+
def softmax_center_teacher(self, teacher_patch_tokens, teacher_temp):
|
| 47 |
+
self.apply_center_update()
|
| 48 |
+
return F.softmax((teacher_patch_tokens - self.center) / teacher_temp, dim=-1)
|
| 49 |
+
|
| 50 |
+
@torch.no_grad()
|
| 51 |
+
def sinkhorn_knopp_teacher(self, teacher_output, teacher_temp, n_masked_patches_tensor, n_iterations=3):
|
| 52 |
+
teacher_output = teacher_output.float()
|
| 53 |
+
# world_size = dist.get_world_size() if dist.is_initialized() else 1
|
| 54 |
+
Q = torch.exp(teacher_output / teacher_temp).t()
|
| 55 |
+
B = n_masked_patches_tensor
|
| 56 |
+
dist.all_reduce(B)
|
| 57 |
+
K = Q.shape[0]
|
| 58 |
+
sum_Q = torch.sum(Q)
|
| 59 |
+
if dist.is_initialized():
|
| 60 |
+
dist.all_reduce(sum_Q)
|
| 61 |
+
Q /= sum_Q
|
| 62 |
+
|
| 63 |
+
for it in range(n_iterations):
|
| 64 |
+
sum_of_rows = torch.sum(Q, dim=1, keepdim=True)
|
| 65 |
+
if dist.is_initialized():
|
| 66 |
+
dist.all_reduce(sum_of_rows)
|
| 67 |
+
Q /= sum_of_rows
|
| 68 |
+
Q /= K
|
| 69 |
+
Q /= torch.sum(Q, dim=0, keepdim=True)
|
| 70 |
+
Q /= B
|
| 71 |
+
|
| 72 |
+
Q *= B
|
| 73 |
+
return Q.t()
|
| 74 |
+
|
| 75 |
+
def forward(self, student_patch_tokens, teacher_patch_tokens, student_masks_flat):
|
| 76 |
+
"""
|
| 77 |
+
Cross-entropy between softmax outputs of the teacher and student networks.
|
| 78 |
+
student_patch_tokens: (B, N, D) tensor
|
| 79 |
+
teacher_patch_tokens: (B, N, D) tensor
|
| 80 |
+
student_masks_flat: (B, N) tensor
|
| 81 |
+
"""
|
| 82 |
+
t = teacher_patch_tokens
|
| 83 |
+
s = student_patch_tokens
|
| 84 |
+
loss = torch.sum(t * F.log_softmax(s / self.student_temp, dim=-1), dim=-1)
|
| 85 |
+
loss = torch.sum(loss * student_masks_flat.float(), dim=-1) / student_masks_flat.sum(dim=-1).clamp(min=1.0)
|
| 86 |
+
return -loss.mean()
|
| 87 |
+
|
| 88 |
+
def forward_masked(
|
| 89 |
+
self,
|
| 90 |
+
student_patch_tokens_masked,
|
| 91 |
+
teacher_patch_tokens_masked,
|
| 92 |
+
student_masks_flat,
|
| 93 |
+
n_masked_patches=None,
|
| 94 |
+
masks_weight=None,
|
| 95 |
+
):
|
| 96 |
+
t = teacher_patch_tokens_masked
|
| 97 |
+
s = student_patch_tokens_masked
|
| 98 |
+
# loss = torch.sum(t * F.log_softmax(s / self.student_temp, dim=-1), dim=-1)
|
| 99 |
+
loss = lossfunc(t, s, self.student_temp)
|
| 100 |
+
if masks_weight is None:
|
| 101 |
+
masks_weight = (
|
| 102 |
+
(1 / student_masks_flat.sum(-1).clamp(min=1.0))
|
| 103 |
+
.unsqueeze(-1)
|
| 104 |
+
.expand_as(student_masks_flat)[student_masks_flat]
|
| 105 |
+
)
|
| 106 |
+
if n_masked_patches is not None:
|
| 107 |
+
loss = loss[:n_masked_patches]
|
| 108 |
+
loss = loss * masks_weight
|
| 109 |
+
return -loss.sum() / student_masks_flat.shape[0]
|
| 110 |
+
|
| 111 |
+
@torch.no_grad()
|
| 112 |
+
def update_center(self, teacher_patch_tokens):
|
| 113 |
+
self.reduce_center_update(teacher_patch_tokens)
|
| 114 |
+
|
| 115 |
+
@torch.no_grad()
|
| 116 |
+
def reduce_center_update(self, teacher_patch_tokens):
|
| 117 |
+
self.updated = False
|
| 118 |
+
self.len_teacher_patch_tokens = len(teacher_patch_tokens)
|
| 119 |
+
self.async_batch_center = torch.sum(teacher_patch_tokens.mean(1), dim=0, keepdim=True)
|
| 120 |
+
if dist.is_initialized():
|
| 121 |
+
self.reduce_handle = dist.all_reduce(self.async_batch_center, async_op=True)
|
| 122 |
+
|
| 123 |
+
@torch.no_grad()
|
| 124 |
+
def apply_center_update(self):
|
| 125 |
+
if self.updated is False:
|
| 126 |
+
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
| 127 |
+
|
| 128 |
+
if self.reduce_handle is not None:
|
| 129 |
+
self.reduce_handle.wait()
|
| 130 |
+
_t = self.async_batch_center / (self.len_teacher_patch_tokens * world_size)
|
| 131 |
+
|
| 132 |
+
self.center = self.center * self.center_momentum + _t * (1 - self.center_momentum)
|
| 133 |
+
|
| 134 |
+
self.updated = True
|
osf/models/dino_utils/koleo_loss.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
|
| 12 |
+
# import torch.distributed as dist
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger("dinov2")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class KoLeoLoss(nn.Module):
|
| 19 |
+
"""Kozachenko-Leonenko entropic loss regularizer from Sablayrolles et al. - 2018 - Spreading vectors for similarity search"""
|
| 20 |
+
|
| 21 |
+
def __init__(self):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.pdist = nn.PairwiseDistance(2, eps=1e-8)
|
| 24 |
+
|
| 25 |
+
def pairwise_NNs_inner(self, x):
|
| 26 |
+
"""
|
| 27 |
+
Pairwise nearest neighbors for L2-normalized vectors.
|
| 28 |
+
Uses Torch rather than Faiss to remain on GPU.
|
| 29 |
+
"""
|
| 30 |
+
dots = torch.mm(x, x.t())
|
| 31 |
+
n = x.shape[0]
|
| 32 |
+
dots.view(-1)[:: (n + 1)].fill_(-1)
|
| 33 |
+
_, I = torch.max(dots, dim=1) # noqa: E741
|
| 34 |
+
return I
|
| 35 |
+
|
| 36 |
+
def forward(self, student_output, eps=1e-8):
|
| 37 |
+
"""
|
| 38 |
+
Args:
|
| 39 |
+
student_output (BxD): backbone output of student
|
| 40 |
+
"""
|
| 41 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 42 |
+
student_output = F.normalize(student_output, eps=eps, p=2, dim=-1)
|
| 43 |
+
I = self.pairwise_NNs_inner(student_output) # noqa: E741
|
| 44 |
+
distances = self.pdist(student_output, student_output[I]) # BxD, BxD -> B
|
| 45 |
+
loss = -torch.log(distances + eps).mean()
|
| 46 |
+
return loss
|
osf/models/ssl_finetuner.py
ADDED
|
@@ -0,0 +1,568 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple, Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from pytorch_lightning import LightningModule
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from einops import rearrange
|
| 8 |
+
from itertools import chain
|
| 9 |
+
from torchmetrics import Accuracy, Precision, Recall, F1Score, AUROC, ConfusionMatrix, CohenKappa, AveragePrecision, MetricCollection
|
| 10 |
+
from osf.models.balanced_losses import FocalLoss, BalancedSoftmax
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _create_pred_metrics(num_classes: int) -> MetricCollection:
|
| 14 |
+
"""Create metrics that take preds (class indices) as input."""
|
| 15 |
+
metrics = {
|
| 16 |
+
"acc": Accuracy(task="multiclass", num_classes=num_classes, average="micro"),
|
| 17 |
+
"f1": F1Score(task="multiclass", num_classes=num_classes, average="macro"),
|
| 18 |
+
"f1_w": F1Score(task="multiclass", num_classes=num_classes, average="weighted"),
|
| 19 |
+
"rec_m": Recall(task="multiclass", num_classes=num_classes, average="macro"),
|
| 20 |
+
"kappa": CohenKappa(task="multiclass", num_classes=num_classes, weights="quadratic"),
|
| 21 |
+
}
|
| 22 |
+
return MetricCollection(metrics)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _create_prob_metrics(num_classes: int) -> MetricCollection:
|
| 26 |
+
"""Create metrics that take probs (probabilities) as input."""
|
| 27 |
+
metrics = {
|
| 28 |
+
"auc": AUROC(task="multiclass", num_classes=num_classes, average="macro"),
|
| 29 |
+
"auprc": AveragePrecision(task="multiclass", num_classes=num_classes, average="macro"),
|
| 30 |
+
}
|
| 31 |
+
return MetricCollection(metrics)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _create_perclass_pred_metrics(num_classes: int) -> MetricCollection:
|
| 35 |
+
"""Create per-class metrics that take preds as input."""
|
| 36 |
+
metrics = {
|
| 37 |
+
"acc_c": Accuracy(task="multiclass", num_classes=num_classes, average=None),
|
| 38 |
+
"prec_c": Precision(task="multiclass", num_classes=num_classes, average=None),
|
| 39 |
+
"rec_c": Recall(task="multiclass", num_classes=num_classes, average=None),
|
| 40 |
+
"f1_c": F1Score(task="multiclass", num_classes=num_classes, average=None),
|
| 41 |
+
"cm": ConfusionMatrix(task="multiclass", num_classes=num_classes, normalize=None),
|
| 42 |
+
}
|
| 43 |
+
return MetricCollection(metrics)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _create_perclass_prob_metrics(num_classes: int) -> MetricCollection:
|
| 47 |
+
"""Create per-class metrics that take probs as input."""
|
| 48 |
+
metrics = {
|
| 49 |
+
"auc_c": AUROC(task="multiclass", num_classes=num_classes, average=None),
|
| 50 |
+
"auprc_c": AveragePrecision(task="multiclass", num_classes=num_classes, average=None),
|
| 51 |
+
}
|
| 52 |
+
return MetricCollection(metrics)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class SSLFineTuner(LightningModule):
|
| 57 |
+
def __init__(self,
|
| 58 |
+
backbones,
|
| 59 |
+
use_which_backbone,
|
| 60 |
+
config = None,
|
| 61 |
+
in_features: int = 256,
|
| 62 |
+
num_classes: int = 2,
|
| 63 |
+
epochs: int = 10,
|
| 64 |
+
dropout: float = 0.0,
|
| 65 |
+
lr: float = 1e-3,
|
| 66 |
+
weight_decay: float = 1e-4,
|
| 67 |
+
final_lr: float = 1e-5,
|
| 68 |
+
use_channel_bank: bool = True,
|
| 69 |
+
loss_type: str = "ce",
|
| 70 |
+
class_distribution: Optional[torch.Tensor] = None,
|
| 71 |
+
focal_gamma: float = 2.0,
|
| 72 |
+
focal_alpha: Optional[float | torch.Tensor] = None,
|
| 73 |
+
use_mean_pool: bool = False,
|
| 74 |
+
total_training_steps: int = None,
|
| 75 |
+
finetune_backbone: bool = False,
|
| 76 |
+
*args, **kwargs
|
| 77 |
+
) -> None:
|
| 78 |
+
super().__init__()
|
| 79 |
+
self.save_hyperparameters()
|
| 80 |
+
self.lr = lr
|
| 81 |
+
self.weight_decay = weight_decay
|
| 82 |
+
self.epochs = epochs
|
| 83 |
+
self.final_lr = final_lr
|
| 84 |
+
self.use_channel_bank = use_channel_bank
|
| 85 |
+
self.loss_type = loss_type
|
| 86 |
+
self.focal_gamma = focal_gamma
|
| 87 |
+
self.focal_alpha = focal_alpha
|
| 88 |
+
self.use_mean_pool = use_mean_pool
|
| 89 |
+
self.total_training_steps = total_training_steps
|
| 90 |
+
self.finetune_backbone = finetune_backbone
|
| 91 |
+
|
| 92 |
+
if loss_type == "ce":
|
| 93 |
+
self.criterion = None
|
| 94 |
+
elif loss_type == "focal":
|
| 95 |
+
alpha = focal_alpha
|
| 96 |
+
if alpha is None and class_distribution is not None:
|
| 97 |
+
class_dist = class_distribution.float()
|
| 98 |
+
total_samples = class_dist.sum()
|
| 99 |
+
alpha = total_samples / (num_classes * class_dist)
|
| 100 |
+
alpha = alpha / alpha.mean()
|
| 101 |
+
self.criterion = FocalLoss(alpha=alpha, gamma=focal_gamma, reduction="mean")
|
| 102 |
+
elif loss_type == "balanced_softmax":
|
| 103 |
+
self.criterion = BalancedSoftmax(class_distribution, reduction="mean")
|
| 104 |
+
else:
|
| 105 |
+
raise ValueError(f"Unknown loss_type: {loss_type}. Must be one of ['ce', 'focal', 'balanced_softmax']")
|
| 106 |
+
|
| 107 |
+
if isinstance(backbones, nn.ModuleDict):
|
| 108 |
+
self.backbones = backbones
|
| 109 |
+
else:
|
| 110 |
+
self.backbones = nn.ModuleDict(backbones)
|
| 111 |
+
self.config = config
|
| 112 |
+
self.use_which_backbone = use_which_backbone
|
| 113 |
+
self.backbone = self.backbones[self.use_which_backbone] if self.use_which_backbone != "fusion" else None
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
if self.use_which_backbone == "fusion":
|
| 117 |
+
for k in ("ecg", "resp", "elect"):
|
| 118 |
+
if k in self.backbones:
|
| 119 |
+
for p in self.backbones[k].parameters():
|
| 120 |
+
p.requires_grad = self.finetune_backbone
|
| 121 |
+
if not self.finetune_backbone:
|
| 122 |
+
self.backbones[k].eval()
|
| 123 |
+
else:
|
| 124 |
+
for p in self.backbone.parameters():
|
| 125 |
+
p.requires_grad = self.finetune_backbone
|
| 126 |
+
if not self.finetune_backbone:
|
| 127 |
+
self.backbone.eval()
|
| 128 |
+
|
| 129 |
+
if self.finetune_backbone:
|
| 130 |
+
print(f"[INFO] Full finetuning mode: backbone parameters are TRAINABLE")
|
| 131 |
+
|
| 132 |
+
if self.use_which_backbone == "fusion":
|
| 133 |
+
dims = [getattr(self.backbones[k], "out_dim", in_features)
|
| 134 |
+
for k in ("ecg", "resp", "elect") if k in self.backbones]
|
| 135 |
+
if len(dims) == 0:
|
| 136 |
+
raise ValueError("fusion requires at least one of {'ecg','resp','elect'} in backbones.")
|
| 137 |
+
if len(set(dims)) != 1:
|
| 138 |
+
raise ValueError(f"Mean fusion requires equal output dims, got {dims}")
|
| 139 |
+
final_in_features = dims[0]
|
| 140 |
+
else:
|
| 141 |
+
final_in_features = getattr(self.backbone, "out_dim", in_features)
|
| 142 |
+
|
| 143 |
+
self.linear_layer = nn.Sequential(
|
| 144 |
+
nn.Dropout(dropout),
|
| 145 |
+
nn.Linear(final_in_features, num_classes)
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
self.train_pred_metrics = _create_pred_metrics(num_classes)
|
| 149 |
+
self.val_pred_metrics = _create_pred_metrics(num_classes)
|
| 150 |
+
self.test_pred_metrics = _create_pred_metrics(num_classes)
|
| 151 |
+
|
| 152 |
+
self.train_prob_metrics = _create_prob_metrics(num_classes)
|
| 153 |
+
self.val_prob_metrics = _create_prob_metrics(num_classes)
|
| 154 |
+
self.test_prob_metrics = _create_prob_metrics(num_classes)
|
| 155 |
+
|
| 156 |
+
self.train_pred_metrics_c = _create_perclass_pred_metrics(num_classes)
|
| 157 |
+
self.val_pred_metrics_c = _create_perclass_pred_metrics(num_classes)
|
| 158 |
+
self.test_pred_metrics_c = _create_perclass_pred_metrics(num_classes)
|
| 159 |
+
|
| 160 |
+
self.train_prob_metrics_c = _create_perclass_prob_metrics(num_classes)
|
| 161 |
+
self.val_prob_metrics_c = _create_perclass_prob_metrics(num_classes)
|
| 162 |
+
self.test_prob_metrics_c = _create_perclass_prob_metrics(num_classes)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
self.class_names = getattr(self.config, "class_names", [str(i) for i in range(num_classes)])
|
| 166 |
+
|
| 167 |
+
def on_train_epoch_start(self) -> None:
|
| 168 |
+
if not self.finetune_backbone:
|
| 169 |
+
if self.use_which_backbone == "fusion":
|
| 170 |
+
for k in ("ecg", "resp", "elect"):
|
| 171 |
+
if k in self.backbones:
|
| 172 |
+
self.backbones[k].eval()
|
| 173 |
+
else:
|
| 174 |
+
self.backbone.eval()
|
| 175 |
+
|
| 176 |
+
def training_step(self, batch, batch_idx):
|
| 177 |
+
loss, logits, y = self.shared_step(batch)
|
| 178 |
+
probs = logits.softmax(-1)
|
| 179 |
+
preds = logits.argmax(-1)
|
| 180 |
+
|
| 181 |
+
self.train_pred_metrics.update(preds, y)
|
| 182 |
+
self.train_prob_metrics.update(probs, y)
|
| 183 |
+
self.train_pred_metrics_c.update(preds, y)
|
| 184 |
+
self.train_prob_metrics_c.update(probs, y)
|
| 185 |
+
|
| 186 |
+
self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=False, sync_dist=True)
|
| 187 |
+
return loss
|
| 188 |
+
|
| 189 |
+
def on_train_epoch_end(self):
|
| 190 |
+
pred_agg = self.train_pred_metrics.compute()
|
| 191 |
+
prob_agg = self.train_prob_metrics.compute()
|
| 192 |
+
|
| 193 |
+
self.log("train_acc", pred_agg["acc"], prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
|
| 194 |
+
self.log("train_f1", pred_agg["f1"], prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
|
| 195 |
+
self.log("train_auc", prob_agg["auc"], prog_bar=False, on_step=False, on_epoch=True, sync_dist=True)
|
| 196 |
+
self.log("train_auprc", prob_agg["auprc"], prog_bar=False, on_step=False, on_epoch=True, sync_dist=True)
|
| 197 |
+
|
| 198 |
+
pred_c = self.train_pred_metrics_c.compute()
|
| 199 |
+
prob_c = self.train_prob_metrics_c.compute()
|
| 200 |
+
cm = pred_c["cm"]
|
| 201 |
+
support = cm.sum(dim=1) if cm is not None else None
|
| 202 |
+
|
| 203 |
+
for i in range(len(pred_c["acc_c"])):
|
| 204 |
+
name = self.class_names[i] if i < len(self.class_names) else str(i)
|
| 205 |
+
self.log(f"train/acc_{name}", pred_c["acc_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
|
| 206 |
+
self.log(f"train/prec_{name}", pred_c["prec_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
|
| 207 |
+
self.log(f"train/rec_{name}", pred_c["rec_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
|
| 208 |
+
self.log(f"train/f1_{name}", pred_c["f1_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
|
| 209 |
+
self.log(f"train/auc_{name}", prob_c["auc_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
|
| 210 |
+
self.log(f"train/auprc_{name}", prob_c["auprc_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
|
| 211 |
+
if support is not None:
|
| 212 |
+
self.log(f"train/support_{name}", support[i].to(pred_c["acc_c"][i].dtype), on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
|
| 213 |
+
|
| 214 |
+
self.train_pred_metrics.reset()
|
| 215 |
+
self.train_prob_metrics.reset()
|
| 216 |
+
self.train_pred_metrics_c.reset()
|
| 217 |
+
self.train_prob_metrics_c.reset()
|
| 218 |
+
|
| 219 |
+
def validation_step(self, batch, batch_idx):
|
| 220 |
+
loss, logits, y = self.shared_step(batch)
|
| 221 |
+
probs = logits.softmax(-1)
|
| 222 |
+
preds = logits.argmax(-1)
|
| 223 |
+
|
| 224 |
+
self.val_pred_metrics.update(preds, y)
|
| 225 |
+
self.val_prob_metrics.update(probs, y)
|
| 226 |
+
self.val_pred_metrics_c.update(preds, y)
|
| 227 |
+
self.val_prob_metrics_c.update(probs, y)
|
| 228 |
+
|
| 229 |
+
self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
|
| 230 |
+
return loss
|
| 231 |
+
|
| 232 |
+
def on_validation_epoch_end(self):
|
| 233 |
+
pred_agg = self.val_pred_metrics.compute()
|
| 234 |
+
prob_agg = self.val_prob_metrics.compute()
|
| 235 |
+
|
| 236 |
+
self.log("val_acc", pred_agg["acc"], prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
|
| 237 |
+
self.log("val_f1", pred_agg["f1"], prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
|
| 238 |
+
self.log("val_f1_w", pred_agg["f1_w"], prog_bar=False, on_step=False, on_epoch=True, sync_dist=True)
|
| 239 |
+
self.log("val_rec_m", pred_agg["rec_m"], prog_bar=False, on_step=False, on_epoch=True, sync_dist=True)
|
| 240 |
+
self.log("val_auc", prob_agg["auc"], prog_bar=False, on_step=False, on_epoch=True, sync_dist=True)
|
| 241 |
+
self.log("val_auprc", prob_agg["auprc"], prog_bar=False, on_step=False, on_epoch=True, sync_dist=True)
|
| 242 |
+
self.log("val_kappa", pred_agg["kappa"], prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
|
| 243 |
+
|
| 244 |
+
pred_c = self.val_pred_metrics_c.compute()
|
| 245 |
+
prob_c = self.val_prob_metrics_c.compute()
|
| 246 |
+
cm = pred_c["cm"]
|
| 247 |
+
support = cm.sum(dim=1)
|
| 248 |
+
|
| 249 |
+
for i in range(len(pred_c["acc_c"])):
|
| 250 |
+
name = self.class_names[i] if i < len(self.class_names) else str(i)
|
| 251 |
+
self.log(f"val/acc_{name}", pred_c["acc_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
|
| 252 |
+
self.log(f"val/prec_{name}", pred_c["prec_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
|
| 253 |
+
self.log(f"val/rec_{name}", pred_c["rec_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
|
| 254 |
+
self.log(f"val/f1_{name}", pred_c["f1_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
|
| 255 |
+
self.log(f"val/auc_{name}", prob_c["auc_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
|
| 256 |
+
self.log(f"val/auprc_{name}", prob_c["auprc_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
|
| 257 |
+
self.log(f"val/support_{name}", support[i].to(pred_c["acc_c"][i].dtype), on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
|
| 258 |
+
|
| 259 |
+
self.val_pred_metrics.reset()
|
| 260 |
+
self.val_prob_metrics.reset()
|
| 261 |
+
self.val_pred_metrics_c.reset()
|
| 262 |
+
self.val_prob_metrics_c.reset()
|
| 263 |
+
|
| 264 |
+
def test_step(self, batch, batch_idx):
|
| 265 |
+
loss, logits, y = self.shared_step(batch)
|
| 266 |
+
probs = logits.softmax(-1)
|
| 267 |
+
preds = logits.argmax(-1)
|
| 268 |
+
|
| 269 |
+
self.test_pred_metrics.update(preds, y)
|
| 270 |
+
self.test_prob_metrics.update(probs, y)
|
| 271 |
+
self.test_pred_metrics_c.update(preds, y)
|
| 272 |
+
self.test_prob_metrics_c.update(probs, y)
|
| 273 |
+
|
| 274 |
+
self.log("test_loss", loss, on_step=False, on_epoch=True, sync_dist=True)
|
| 275 |
+
return loss
|
| 276 |
+
|
| 277 |
+
def on_test_epoch_end(self):
|
| 278 |
+
pred_agg = self.test_pred_metrics.compute()
|
| 279 |
+
prob_agg = self.test_prob_metrics.compute()
|
| 280 |
+
|
| 281 |
+
self.log("test_acc", pred_agg["acc"], on_step=False, on_epoch=True, sync_dist=True)
|
| 282 |
+
self.log("test_f1", pred_agg["f1"], on_step=False, on_epoch=True, sync_dist=True)
|
| 283 |
+
self.log("test_f1_w", pred_agg["f1_w"], on_step=False, on_epoch=True, sync_dist=True)
|
| 284 |
+
self.log("test_rec_m", pred_agg["rec_m"], on_step=False, on_epoch=True, sync_dist=True)
|
| 285 |
+
self.log("test_auc", prob_agg["auc"], on_step=False, on_epoch=True, sync_dist=True)
|
| 286 |
+
self.log("test_auprc", prob_agg["auprc"], on_step=False, on_epoch=True, sync_dist=True)
|
| 287 |
+
self.log("test_kappa", pred_agg["kappa"], on_step=False, on_epoch=True, sync_dist=True)
|
| 288 |
+
|
| 289 |
+
pred_c = self.test_pred_metrics_c.compute()
|
| 290 |
+
prob_c = self.test_prob_metrics_c.compute()
|
| 291 |
+
cm = pred_c["cm"]
|
| 292 |
+
support = cm.sum(dim=1) if cm is not None else None
|
| 293 |
+
|
| 294 |
+
for i in range(len(pred_c["acc_c"])):
|
| 295 |
+
name = self.class_names[i] if i < len(self.class_names) else str(i)
|
| 296 |
+
self.log(f"test/acc_{name}", pred_c["acc_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
|
| 297 |
+
self.log(f"test/prec_{name}", pred_c["prec_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
|
| 298 |
+
self.log(f"test/rec_{name}", pred_c["rec_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
|
| 299 |
+
self.log(f"test/f1_{name}", pred_c["f1_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
|
| 300 |
+
self.log(f"test/auc_{name}", prob_c["auc_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
|
| 301 |
+
self.log(f"test/auprc_{name}", prob_c["auprc_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
|
| 302 |
+
if support is not None:
|
| 303 |
+
self.log(f"test/support_{name}", support[i].to(pred_c["acc_c"][i].dtype),
|
| 304 |
+
on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
|
| 305 |
+
|
| 306 |
+
self.test_pred_metrics.reset()
|
| 307 |
+
self.test_prob_metrics.reset()
|
| 308 |
+
self.test_pred_metrics_c.reset()
|
| 309 |
+
self.test_prob_metrics_c.reset()
|
| 310 |
+
def shared_step(self, batch):
|
| 311 |
+
context = torch.no_grad() if not self.finetune_backbone else torch.enable_grad()
|
| 312 |
+
|
| 313 |
+
with context:
|
| 314 |
+
psg = batch['psg']
|
| 315 |
+
if self.use_which_backbone == 'ecg':
|
| 316 |
+
x = psg[:, 0:1, :]
|
| 317 |
+
feats = self._get_features(self.backbone, x)
|
| 318 |
+
|
| 319 |
+
elif self.use_which_backbone == 'resp':
|
| 320 |
+
x = psg[:, 1:5, :]
|
| 321 |
+
feats = self._get_features(self.backbone, x)
|
| 322 |
+
elif self.use_which_backbone == 'elect':
|
| 323 |
+
x = psg[:, 5:, :]
|
| 324 |
+
feats = self._get_features(self.backbone, x)
|
| 325 |
+
elif self.use_which_backbone == 'all':
|
| 326 |
+
|
| 327 |
+
x = psg
|
| 328 |
+
feats = self._get_features(self.backbone, x)
|
| 329 |
+
|
| 330 |
+
elif self.use_which_backbone == 'fusion':
|
| 331 |
+
feats_list = []
|
| 332 |
+
if 'ecg' in self.backbones:
|
| 333 |
+
x_ecg = psg[:, 0:1, :]
|
| 334 |
+
f_ecg = self._get_features(self.backbones['ecg'], x_ecg)
|
| 335 |
+
feats_list.append(f_ecg)
|
| 336 |
+
if 'resp' in self.backbones:
|
| 337 |
+
x_resp = psg[:, 1:5, :]
|
| 338 |
+
f_resp = self._get_features(self.backbones['resp'], x_resp)
|
| 339 |
+
feats_list.append(f_resp)
|
| 340 |
+
if 'elect' in self.backbones:
|
| 341 |
+
x_elect = psg[:, 5:, :]
|
| 342 |
+
f_elect = self._get_features(self.backbones['elect'], x_elect)
|
| 343 |
+
feats_list.append(f_elect)
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
feats = torch.stack(feats_list, dim=0).mean(dim=0)
|
| 347 |
+
else:
|
| 348 |
+
raise ValueError(f"Unknown use_which_backbone: {self.use_which_backbone}")
|
| 349 |
+
|
| 350 |
+
y = batch["label"]
|
| 351 |
+
feats = feats.view(feats.size(0), -1)
|
| 352 |
+
logits = self.linear_layer(feats)
|
| 353 |
+
y = y.squeeze(1).long()
|
| 354 |
+
|
| 355 |
+
if self.criterion is None:
|
| 356 |
+
loss = F.cross_entropy(logits, y)
|
| 357 |
+
else:
|
| 358 |
+
loss = self.criterion(logits, y)
|
| 359 |
+
|
| 360 |
+
return loss, logits, y
|
| 361 |
+
|
| 362 |
+
def _get_features(self, backbone, x):
|
| 363 |
+
"""Get features from backbone. Uses mean pooling if use_mean_pool=True."""
|
| 364 |
+
if self.use_mean_pool:
|
| 365 |
+
if hasattr(backbone, 'forward_encoding_mean_pool'):
|
| 366 |
+
return backbone.forward_encoding_mean_pool(x)
|
| 367 |
+
elif hasattr(backbone, 'forward_avg_pool'):
|
| 368 |
+
return backbone.forward_avg_pool(x)
|
| 369 |
+
return backbone(x)
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def configure_optimizers(self):
|
| 373 |
+
if self.finetune_backbone:
|
| 374 |
+
if self.use_which_backbone == "fusion":
|
| 375 |
+
backbone_params = chain(*[self.backbones[k].parameters()
|
| 376 |
+
for k in ("ecg", "resp", "elect") if k in self.backbones])
|
| 377 |
+
else:
|
| 378 |
+
backbone_params = self.backbone.parameters()
|
| 379 |
+
params = chain(backbone_params, self.linear_layer.parameters())
|
| 380 |
+
else:
|
| 381 |
+
params = self.linear_layer.parameters()
|
| 382 |
+
|
| 383 |
+
optimizer = torch.optim.AdamW(
|
| 384 |
+
params,
|
| 385 |
+
lr=self.lr,
|
| 386 |
+
weight_decay=self.weight_decay,
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
if self.total_training_steps is not None and self.total_training_steps > 0:
|
| 390 |
+
warmup_steps = int(0.1 * self.total_training_steps)
|
| 391 |
+
cosine_steps = self.total_training_steps - warmup_steps
|
| 392 |
+
|
| 393 |
+
warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
|
| 394 |
+
optimizer,
|
| 395 |
+
start_factor=0.1,
|
| 396 |
+
end_factor=1.0,
|
| 397 |
+
total_iters=warmup_steps
|
| 398 |
+
)
|
| 399 |
+
cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 400 |
+
optimizer,
|
| 401 |
+
T_max=cosine_steps,
|
| 402 |
+
eta_min=self.final_lr
|
| 403 |
+
)
|
| 404 |
+
scheduler = torch.optim.lr_scheduler.SequentialLR(
|
| 405 |
+
optimizer,
|
| 406 |
+
schedulers=[warmup_scheduler, cosine_scheduler],
|
| 407 |
+
milestones=[warmup_steps]
|
| 408 |
+
)
|
| 409 |
+
return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
|
| 410 |
+
else:
|
| 411 |
+
return [optimizer]
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
class SSLVitalSignsRegressor(SSLFineTuner):
|
| 416 |
+
"""SSL Finetuner for vital signs regression (HR, SPO2). Uses MSE loss."""
|
| 417 |
+
def __init__(self,
|
| 418 |
+
backbones,
|
| 419 |
+
use_which_backbone,
|
| 420 |
+
config = None,
|
| 421 |
+
in_features: int = 256,
|
| 422 |
+
num_classes: int = 1,
|
| 423 |
+
target_names: list = None,
|
| 424 |
+
dropout: float = 0.0,
|
| 425 |
+
**kwargs
|
| 426 |
+
) -> None:
|
| 427 |
+
kwargs['loss_type'] = 'ce'
|
| 428 |
+
|
| 429 |
+
super().__init__(
|
| 430 |
+
backbones=backbones,
|
| 431 |
+
use_which_backbone=use_which_backbone,
|
| 432 |
+
config=config,
|
| 433 |
+
in_features=in_features,
|
| 434 |
+
num_classes=2,
|
| 435 |
+
dropout=dropout,
|
| 436 |
+
**kwargs
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
self.num_targets = num_classes
|
| 440 |
+
self.target_names = target_names or [f"target_{i}" for i in range(num_classes)]
|
| 441 |
+
self.criterion = nn.MSELoss()
|
| 442 |
+
|
| 443 |
+
in_feat = self.linear_layer[1].in_features
|
| 444 |
+
self.linear_layer = nn.Sequential(
|
| 445 |
+
nn.Dropout(dropout),
|
| 446 |
+
nn.Linear(in_feat, num_classes)
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
del self.train_pred_metrics, self.val_pred_metrics, self.test_pred_metrics
|
| 450 |
+
del self.train_prob_metrics, self.val_prob_metrics, self.test_prob_metrics
|
| 451 |
+
del self.train_pred_metrics_c, self.val_pred_metrics_c, self.test_pred_metrics_c
|
| 452 |
+
del self.train_prob_metrics_c, self.val_prob_metrics_c, self.test_prob_metrics_c
|
| 453 |
+
|
| 454 |
+
def shared_step(self, batch):
|
| 455 |
+
"""Override: regression loss instead of classification."""
|
| 456 |
+
context = torch.no_grad() if not self.finetune_backbone else torch.enable_grad()
|
| 457 |
+
|
| 458 |
+
with context:
|
| 459 |
+
psg = batch['psg']
|
| 460 |
+
if self.use_which_backbone == 'ecg':
|
| 461 |
+
x = psg[:, 0:1, :]
|
| 462 |
+
feats = self._get_features(self.backbone, x)
|
| 463 |
+
elif self.use_which_backbone == 'resp':
|
| 464 |
+
x = psg[:, 1:5, :]
|
| 465 |
+
feats = self._get_features(self.backbone, x)
|
| 466 |
+
elif self.use_which_backbone == 'elect':
|
| 467 |
+
x = psg[:, 5:, :]
|
| 468 |
+
feats = self._get_features(self.backbone, x)
|
| 469 |
+
elif self.use_which_backbone == 'all':
|
| 470 |
+
x = psg
|
| 471 |
+
feats = self._get_features(self.backbone, x)
|
| 472 |
+
elif self.use_which_backbone == 'fusion':
|
| 473 |
+
feats_list = []
|
| 474 |
+
if 'ecg' in self.backbones:
|
| 475 |
+
f_ecg = self._get_features(self.backbones['ecg'], psg[:, 0:1, :])
|
| 476 |
+
feats_list.append(f_ecg)
|
| 477 |
+
if 'resp' in self.backbones:
|
| 478 |
+
f_resp = self._get_features(self.backbones['resp'], psg[:, 1:5, :])
|
| 479 |
+
feats_list.append(f_resp)
|
| 480 |
+
if 'elect' in self.backbones:
|
| 481 |
+
f_elect = self._get_features(self.backbones['elect'], psg[:, 5:, :])
|
| 482 |
+
feats_list.append(f_elect)
|
| 483 |
+
|
| 484 |
+
feats = torch.stack(feats_list, dim=0).mean(dim=0)
|
| 485 |
+
else:
|
| 486 |
+
raise ValueError(f"Unknown use_which_backbone: {self.use_which_backbone}")
|
| 487 |
+
|
| 488 |
+
y = batch["label"].float() # [B, num_targets]
|
| 489 |
+
feats = feats.view(feats.size(0), -1)
|
| 490 |
+
preds = self.linear_layer(feats) # [B, num_targets]
|
| 491 |
+
|
| 492 |
+
loss = self.criterion(preds, y)
|
| 493 |
+
return loss, preds, y
|
| 494 |
+
|
| 495 |
+
def training_step(self, batch, batch_idx):
|
| 496 |
+
"""Override: regression metrics."""
|
| 497 |
+
loss, preds, y = self.shared_step(batch)
|
| 498 |
+
|
| 499 |
+
with torch.no_grad():
|
| 500 |
+
for i, name in enumerate(self.target_names):
|
| 501 |
+
mae = F.l1_loss(preds[:, i], y[:, i])
|
| 502 |
+
self.log(f"train_{name}_mae", mae, on_step=False, on_epoch=True, sync_dist=True)
|
| 503 |
+
|
| 504 |
+
self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=False, sync_dist=True)
|
| 505 |
+
return loss
|
| 506 |
+
|
| 507 |
+
def on_train_epoch_end(self):
|
| 508 |
+
"""Override: no classification metrics to compute."""
|
| 509 |
+
pass
|
| 510 |
+
|
| 511 |
+
def validation_step(self, batch, batch_idx):
|
| 512 |
+
"""Override: regression metrics."""
|
| 513 |
+
loss, preds, y = self.shared_step(batch)
|
| 514 |
+
|
| 515 |
+
for i, name in enumerate(self.target_names):
|
| 516 |
+
mae = F.l1_loss(preds[:, i], y[:, i])
|
| 517 |
+
self.log(f"val_{name}_mae", mae, on_step=False, on_epoch=True, sync_dist=True)
|
| 518 |
+
|
| 519 |
+
overall_mae = F.l1_loss(preds, y)
|
| 520 |
+
self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
|
| 521 |
+
self.log("val_mae", overall_mae, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
|
| 522 |
+
return loss
|
| 523 |
+
|
| 524 |
+
def on_validation_epoch_end(self):
|
| 525 |
+
"""Override: no classification metrics to compute."""
|
| 526 |
+
pass
|
| 527 |
+
|
| 528 |
+
def test_step(self, batch, batch_idx):
|
| 529 |
+
"""Override: regression metrics."""
|
| 530 |
+
loss, preds, y = self.shared_step(batch)
|
| 531 |
+
|
| 532 |
+
for i, name in enumerate(self.target_names):
|
| 533 |
+
p, t = preds[:, i], y[:, i]
|
| 534 |
+
mae = F.l1_loss(p, t)
|
| 535 |
+
mse = F.mse_loss(p, t)
|
| 536 |
+
rmse = torch.sqrt(mse)
|
| 537 |
+
|
| 538 |
+
self.log(f"test_{name}_mae", mae, on_step=False, on_epoch=True, sync_dist=True)
|
| 539 |
+
self.log(f"test_{name}_mse", mse, on_step=False, on_epoch=True, sync_dist=True)
|
| 540 |
+
self.log(f"test_{name}_rmse", rmse, on_step=False, on_epoch=True, sync_dist=True)
|
| 541 |
+
|
| 542 |
+
overall_mae = F.l1_loss(preds, y)
|
| 543 |
+
overall_mse = F.mse_loss(preds, y)
|
| 544 |
+
self.log("test_loss", loss, on_step=False, on_epoch=True, sync_dist=True)
|
| 545 |
+
self.log("test_mae", overall_mae, on_step=False, on_epoch=True, sync_dist=True)
|
| 546 |
+
self.log("test_mse", overall_mse, on_step=False, on_epoch=True, sync_dist=True)
|
| 547 |
+
return loss
|
| 548 |
+
|
| 549 |
+
def on_test_epoch_end(self):
|
| 550 |
+
"""Override: no classification metrics to compute."""
|
| 551 |
+
pass
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
class SupervisedVitalSignsRegressor(SSLVitalSignsRegressor):
|
| 555 |
+
"""Supervised from-scratch regression. Equivalent to SSLVitalSignsRegressor with finetune_backbone=True."""
|
| 556 |
+
def __init__(self,
|
| 557 |
+
backbones,
|
| 558 |
+
use_which_backbone,
|
| 559 |
+
epochs: int = 100,
|
| 560 |
+
**kwargs
|
| 561 |
+
):
|
| 562 |
+
kwargs['finetune_backbone'] = True
|
| 563 |
+
super().__init__(
|
| 564 |
+
backbones=backbones,
|
| 565 |
+
use_which_backbone=use_which_backbone,
|
| 566 |
+
epochs=epochs,
|
| 567 |
+
**kwargs
|
| 568 |
+
)
|
osf/utils/openclip_loss.py
ADDED
|
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from torch.nn import functional as F
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
import torch.distributed.nn
|
| 10 |
+
from torch import distributed as dist
|
| 11 |
+
|
| 12 |
+
has_distributed = True
|
| 13 |
+
except ImportError:
|
| 14 |
+
has_distributed = False
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
import horovod.torch as hvd
|
| 18 |
+
except ImportError:
|
| 19 |
+
hvd = None
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_clip_metrics(image_features, text_features, logit_scale):
|
| 23 |
+
metrics = {}
|
| 24 |
+
logits_per_image = (logit_scale * image_features @ text_features.t()).detach().cpu()
|
| 25 |
+
logits_per_text = logits_per_image.t().detach().cpu()
|
| 26 |
+
|
| 27 |
+
logits = {"image_to_text": logits_per_image, "text_to_image": logits_per_text}
|
| 28 |
+
ground_truth = torch.arange(len(text_features)).view(-1, 1)
|
| 29 |
+
|
| 30 |
+
for name, logit in logits.items():
|
| 31 |
+
ranking = torch.argsort(logit, descending=True)
|
| 32 |
+
preds = torch.where(ranking == ground_truth)[1]
|
| 33 |
+
preds = preds.detach().cpu().numpy()
|
| 34 |
+
metrics[f"{name}_mean_rank"] = preds.mean() + 1
|
| 35 |
+
metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1
|
| 36 |
+
for k in [1, 5, 10]:
|
| 37 |
+
metrics[f"{name}_R@{k}"] = np.mean(preds < k)
|
| 38 |
+
|
| 39 |
+
return metrics
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def gather_features(
|
| 43 |
+
image_features,
|
| 44 |
+
text_features,
|
| 45 |
+
local_loss=False,
|
| 46 |
+
gather_with_grad=False,
|
| 47 |
+
rank=0,
|
| 48 |
+
world_size=1,
|
| 49 |
+
use_horovod=False
|
| 50 |
+
):
|
| 51 |
+
assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
|
| 52 |
+
if use_horovod:
|
| 53 |
+
assert hvd is not None, 'Please install horovod'
|
| 54 |
+
if gather_with_grad:
|
| 55 |
+
all_image_features = hvd.allgather(image_features)
|
| 56 |
+
all_text_features = hvd.allgather(text_features)
|
| 57 |
+
else:
|
| 58 |
+
with torch.no_grad():
|
| 59 |
+
all_image_features = hvd.allgather(image_features)
|
| 60 |
+
all_text_features = hvd.allgather(text_features)
|
| 61 |
+
if not local_loss:
|
| 62 |
+
# ensure grads for local rank when all_* features don't have a gradient
|
| 63 |
+
gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
|
| 64 |
+
gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
|
| 65 |
+
gathered_image_features[rank] = image_features
|
| 66 |
+
gathered_text_features[rank] = text_features
|
| 67 |
+
all_image_features = torch.cat(gathered_image_features, dim=0)
|
| 68 |
+
all_text_features = torch.cat(gathered_text_features, dim=0)
|
| 69 |
+
else:
|
| 70 |
+
if gather_with_grad:
|
| 71 |
+
all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
|
| 72 |
+
all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
|
| 73 |
+
else:
|
| 74 |
+
gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
|
| 75 |
+
gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
|
| 76 |
+
dist.all_gather(gathered_image_features, image_features)
|
| 77 |
+
dist.all_gather(gathered_text_features, text_features)
|
| 78 |
+
if not local_loss:
|
| 79 |
+
gathered_image_features[rank] = image_features
|
| 80 |
+
gathered_text_features[rank] = text_features
|
| 81 |
+
all_image_features = torch.cat(gathered_image_features, dim=0)
|
| 82 |
+
all_text_features = torch.cat(gathered_text_features, dim=0)
|
| 83 |
+
|
| 84 |
+
return all_image_features, all_text_features
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class ClipLoss(nn.Module):
|
| 88 |
+
|
| 89 |
+
def __init__(
|
| 90 |
+
self,
|
| 91 |
+
local_loss=True,
|
| 92 |
+
gather_with_grad=True,
|
| 93 |
+
cache_labels=True,
|
| 94 |
+
rank=0,
|
| 95 |
+
world_size=1,
|
| 96 |
+
use_horovod=False,
|
| 97 |
+
):
|
| 98 |
+
super().__init__()
|
| 99 |
+
self.local_loss = local_loss
|
| 100 |
+
self.gather_with_grad = gather_with_grad
|
| 101 |
+
self.cache_labels = cache_labels
|
| 102 |
+
self.rank = rank
|
| 103 |
+
self.world_size = world_size
|
| 104 |
+
self.use_horovod = use_horovod
|
| 105 |
+
|
| 106 |
+
# cache state
|
| 107 |
+
self.prev_num_logits = 0
|
| 108 |
+
self.labels = {}
|
| 109 |
+
|
| 110 |
+
def get_ground_truth(self, device, num_logits) -> torch.Tensor:
|
| 111 |
+
if self.prev_num_logits != num_logits or device not in self.labels:
|
| 112 |
+
labels = torch.arange(num_logits, device=device, dtype=torch.long)
|
| 113 |
+
if self.world_size > 1 and self.local_loss:
|
| 114 |
+
labels = labels + num_logits * self.rank
|
| 115 |
+
if self.cache_labels:
|
| 116 |
+
self.labels[device] = labels
|
| 117 |
+
self.prev_num_logits = num_logits
|
| 118 |
+
else:
|
| 119 |
+
labels = self.labels[device]
|
| 120 |
+
return labels
|
| 121 |
+
|
| 122 |
+
def get_logits(self, image_features, text_features, logit_scale, return_gather_features=False):
|
| 123 |
+
if self.world_size > 1:
|
| 124 |
+
all_image_features, all_text_features = gather_features(
|
| 125 |
+
image_features,
|
| 126 |
+
text_features,
|
| 127 |
+
local_loss=self.local_loss,
|
| 128 |
+
gather_with_grad=self.gather_with_grad,
|
| 129 |
+
rank=self.rank,
|
| 130 |
+
world_size=self.world_size,
|
| 131 |
+
use_horovod=self.use_horovod,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
if self.local_loss:
|
| 135 |
+
logits_per_image = logit_scale * image_features @ all_text_features.T
|
| 136 |
+
logits_per_text = logit_scale * text_features @ all_image_features.T
|
| 137 |
+
else:
|
| 138 |
+
logits_per_image = logit_scale * all_image_features @ all_text_features.T
|
| 139 |
+
logits_per_text = logits_per_image.T
|
| 140 |
+
|
| 141 |
+
if return_gather_features:
|
| 142 |
+
return logits_per_image, logits_per_text, all_image_features, all_text_features
|
| 143 |
+
else:
|
| 144 |
+
return logits_per_image, logits_per_text
|
| 145 |
+
|
| 146 |
+
else:
|
| 147 |
+
logits_per_image = logit_scale * image_features @ text_features.T
|
| 148 |
+
logits_per_text = logit_scale * text_features @ image_features.T
|
| 149 |
+
|
| 150 |
+
return logits_per_image, logits_per_text
|
| 151 |
+
|
| 152 |
+
def forward(self, image_features, text_features, logit_scale, output_dict=False):
|
| 153 |
+
device = image_features.device
|
| 154 |
+
logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)
|
| 155 |
+
|
| 156 |
+
labels = self.get_ground_truth(device, logits_per_image.shape[0])
|
| 157 |
+
|
| 158 |
+
total_loss = (
|
| 159 |
+
F.cross_entropy(logits_per_image, labels) +
|
| 160 |
+
F.cross_entropy(logits_per_text, labels)
|
| 161 |
+
) / 2
|
| 162 |
+
|
| 163 |
+
return {"contrastive_loss": total_loss} if output_dict else total_loss
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class CoCaLoss(ClipLoss):
|
| 167 |
+
def __init__(
|
| 168 |
+
self,
|
| 169 |
+
caption_loss_weight,
|
| 170 |
+
clip_loss_weight,
|
| 171 |
+
pad_id=0,
|
| 172 |
+
local_loss=False,
|
| 173 |
+
gather_with_grad=False,
|
| 174 |
+
cache_labels=False,
|
| 175 |
+
rank=0,
|
| 176 |
+
world_size=1,
|
| 177 |
+
use_horovod=False,
|
| 178 |
+
):
|
| 179 |
+
super().__init__(
|
| 180 |
+
local_loss=local_loss,
|
| 181 |
+
gather_with_grad=gather_with_grad,
|
| 182 |
+
cache_labels=cache_labels,
|
| 183 |
+
rank=rank,
|
| 184 |
+
world_size=world_size,
|
| 185 |
+
use_horovod=use_horovod
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
self.clip_loss_weight = clip_loss_weight
|
| 189 |
+
self.caption_loss_weight = caption_loss_weight
|
| 190 |
+
self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id)
|
| 191 |
+
|
| 192 |
+
def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False):
|
| 193 |
+
if self.clip_loss_weight:
|
| 194 |
+
clip_loss = super().forward(image_features, text_features, logit_scale)
|
| 195 |
+
clip_loss = self.clip_loss_weight * clip_loss
|
| 196 |
+
else:
|
| 197 |
+
clip_loss = torch.tensor(0, device=logits.device)
|
| 198 |
+
|
| 199 |
+
caption_loss = self.caption_loss(
|
| 200 |
+
logits.permute(0, 2, 1),
|
| 201 |
+
labels,
|
| 202 |
+
)
|
| 203 |
+
caption_loss = caption_loss * self.caption_loss_weight
|
| 204 |
+
|
| 205 |
+
if output_dict:
|
| 206 |
+
return {"contrastive_loss": clip_loss, "caption_loss": caption_loss}
|
| 207 |
+
|
| 208 |
+
return clip_loss, caption_loss
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class DistillClipLoss(ClipLoss):
|
| 212 |
+
|
| 213 |
+
def dist_loss(self, teacher_logits, student_logits):
|
| 214 |
+
return -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0)
|
| 215 |
+
|
| 216 |
+
def forward(
|
| 217 |
+
self,
|
| 218 |
+
image_features,
|
| 219 |
+
text_features,
|
| 220 |
+
logit_scale,
|
| 221 |
+
dist_image_features,
|
| 222 |
+
dist_text_features,
|
| 223 |
+
dist_logit_scale,
|
| 224 |
+
output_dict=False,
|
| 225 |
+
):
|
| 226 |
+
logits_per_image, logits_per_text = \
|
| 227 |
+
self.get_logits(image_features, text_features, logit_scale)
|
| 228 |
+
|
| 229 |
+
dist_logits_per_image, dist_logits_per_text = \
|
| 230 |
+
self.get_logits(dist_image_features, dist_text_features, dist_logit_scale)
|
| 231 |
+
|
| 232 |
+
labels = self.get_ground_truth(image_features.device, logits_per_image.shape[0])
|
| 233 |
+
|
| 234 |
+
contrastive_loss = (
|
| 235 |
+
F.cross_entropy(logits_per_image, labels) +
|
| 236 |
+
F.cross_entropy(logits_per_text, labels)
|
| 237 |
+
) / 2
|
| 238 |
+
|
| 239 |
+
distill_loss = (
|
| 240 |
+
self.dist_loss(dist_logits_per_image, logits_per_image) +
|
| 241 |
+
self.dist_loss(dist_logits_per_text, logits_per_text)
|
| 242 |
+
) / 2
|
| 243 |
+
|
| 244 |
+
if output_dict:
|
| 245 |
+
return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss}
|
| 246 |
+
|
| 247 |
+
return contrastive_loss, distill_loss
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def neighbour_exchange(from_rank, to_rank, tensor, group=None):
|
| 251 |
+
tensor_recv = torch.zeros_like(tensor)
|
| 252 |
+
send_op = torch.distributed.P2POp(
|
| 253 |
+
torch.distributed.isend,
|
| 254 |
+
tensor,
|
| 255 |
+
to_rank,
|
| 256 |
+
group=group,
|
| 257 |
+
)
|
| 258 |
+
recv_op = torch.distributed.P2POp(
|
| 259 |
+
torch.distributed.irecv,
|
| 260 |
+
tensor_recv,
|
| 261 |
+
from_rank,
|
| 262 |
+
group=group,
|
| 263 |
+
)
|
| 264 |
+
reqs = torch.distributed.batch_isend_irecv([send_op, recv_op])
|
| 265 |
+
for req in reqs:
|
| 266 |
+
req.wait()
|
| 267 |
+
return tensor_recv
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def neighbour_exchange_bidir(left_rank, right_rank, tensor_to_left, tensor_to_right, group=None):
|
| 271 |
+
tensor_from_left = torch.zeros_like(tensor_to_right)
|
| 272 |
+
tensor_from_right = torch.zeros_like(tensor_to_left)
|
| 273 |
+
send_op_left = torch.distributed.P2POp(
|
| 274 |
+
torch.distributed.isend,
|
| 275 |
+
tensor_to_left,
|
| 276 |
+
left_rank,
|
| 277 |
+
group=group,
|
| 278 |
+
)
|
| 279 |
+
send_op_right = torch.distributed.P2POp(
|
| 280 |
+
torch.distributed.isend,
|
| 281 |
+
tensor_to_right,
|
| 282 |
+
right_rank,
|
| 283 |
+
group=group,
|
| 284 |
+
)
|
| 285 |
+
recv_op_left = torch.distributed.P2POp(
|
| 286 |
+
torch.distributed.irecv,
|
| 287 |
+
tensor_from_left,
|
| 288 |
+
left_rank,
|
| 289 |
+
group=group,
|
| 290 |
+
)
|
| 291 |
+
recv_op_right = torch.distributed.P2POp(
|
| 292 |
+
torch.distributed.irecv,
|
| 293 |
+
tensor_from_right,
|
| 294 |
+
right_rank,
|
| 295 |
+
group=group,
|
| 296 |
+
)
|
| 297 |
+
reqs = torch.distributed.batch_isend_irecv([send_op_right, send_op_left, recv_op_right, recv_op_left])
|
| 298 |
+
for req in reqs:
|
| 299 |
+
req.wait()
|
| 300 |
+
return tensor_from_right, tensor_from_left
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
class NeighbourExchange(torch.autograd.Function):
|
| 304 |
+
@staticmethod
|
| 305 |
+
def forward(ctx, from_rank, to_rank, group, tensor):
|
| 306 |
+
ctx.group = group
|
| 307 |
+
ctx.from_rank = from_rank
|
| 308 |
+
ctx.to_rank = to_rank
|
| 309 |
+
return neighbour_exchange(from_rank, to_rank, tensor, group=group)
|
| 310 |
+
|
| 311 |
+
@staticmethod
|
| 312 |
+
def backward(ctx, grad_output):
|
| 313 |
+
return (None, None, None) + (NeighbourExchange.apply(ctx.to_rank, ctx.from_rank, ctx.group, grad_output),)
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def neighbour_exchange_with_grad(from_rank, to_rank, tensor, group=None):
|
| 317 |
+
return NeighbourExchange.apply(from_rank, to_rank, group, tensor)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
class NeighbourExchangeBidir(torch.autograd.Function):
|
| 321 |
+
@staticmethod
|
| 322 |
+
def forward(ctx, left_rank, right_rank, group, tensor_to_left, tensor_to_right):
|
| 323 |
+
ctx.group = group
|
| 324 |
+
ctx.left_rank = left_rank
|
| 325 |
+
ctx.right_rank = right_rank
|
| 326 |
+
return neighbour_exchange_bidir(left_rank, right_rank, tensor_to_left, tensor_to_right, group=group)
|
| 327 |
+
|
| 328 |
+
@staticmethod
|
| 329 |
+
def backward(ctx, *grad_outputs):
|
| 330 |
+
return (None, None, None) + \
|
| 331 |
+
NeighbourExchangeBidir.apply(ctx.right_rank, ctx.left_rank, ctx.group, *grad_outputs)
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def neighbour_exchange_bidir_with_grad(left_rank, right_rank, tensor_to_left, tensor_to_right, group=None):
|
| 335 |
+
return NeighbourExchangeBidir.apply(left_rank, right_rank, group, tensor_to_left, tensor_to_right)
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
class SigLipLoss(nn.Module):
|
| 339 |
+
""" Sigmoid Loss for Language Image Pre-Training (SigLIP) - https://arxiv.org/abs/2303.15343
|
| 340 |
+
|
| 341 |
+
@article{zhai2023sigmoid,
|
| 342 |
+
title={Sigmoid loss for language image pre-training},
|
| 343 |
+
author={Zhai, Xiaohua and Mustafa, Basil and Kolesnikov, Alexander and Beyer, Lucas},
|
| 344 |
+
journal={arXiv preprint arXiv:2303.15343},
|
| 345 |
+
year={2023}
|
| 346 |
+
}
|
| 347 |
+
"""
|
| 348 |
+
def __init__(
|
| 349 |
+
self,
|
| 350 |
+
cache_labels: bool = False,
|
| 351 |
+
rank: int = 0,
|
| 352 |
+
world_size: int = 1,
|
| 353 |
+
dist_impl: Optional[str] = None,
|
| 354 |
+
):
|
| 355 |
+
super().__init__()
|
| 356 |
+
self.cache_labels = cache_labels
|
| 357 |
+
self.rank = rank
|
| 358 |
+
self.world_size = world_size
|
| 359 |
+
self.dist_impl = dist_impl or 'bidir' # default to bidir exchange for now, this will likely change
|
| 360 |
+
assert self.dist_impl in ('bidir', 'shift', 'reduce', 'gather')
|
| 361 |
+
|
| 362 |
+
# FIXME: cache not currently used
|
| 363 |
+
self.prev_num_logits = 0
|
| 364 |
+
self.labels = {}
|
| 365 |
+
|
| 366 |
+
def get_ground_truth(self, device, dtype, num_logits, negative_only=False) -> torch.Tensor:
|
| 367 |
+
labels = -torch.ones((num_logits, num_logits), device=device, dtype=dtype)
|
| 368 |
+
if not negative_only:
|
| 369 |
+
labels = 2 * torch.eye(num_logits, device=device, dtype=dtype) + labels
|
| 370 |
+
return labels
|
| 371 |
+
|
| 372 |
+
def get_logits(self, image_features, text_features, logit_scale, logit_bias=None):
|
| 373 |
+
logits = logit_scale * image_features @ text_features.T
|
| 374 |
+
if logit_bias is not None:
|
| 375 |
+
logits += logit_bias
|
| 376 |
+
return logits
|
| 377 |
+
|
| 378 |
+
def _loss(self, image_features, text_features, logit_scale, logit_bias=None, negative_only=False):
|
| 379 |
+
logits = self.get_logits(image_features, text_features, logit_scale, logit_bias)
|
| 380 |
+
labels = self.get_ground_truth(
|
| 381 |
+
image_features.device,
|
| 382 |
+
image_features.dtype,
|
| 383 |
+
image_features.shape[0],
|
| 384 |
+
negative_only=negative_only,
|
| 385 |
+
)
|
| 386 |
+
loss = -F.logsigmoid(labels * logits).sum() / image_features.shape[0]
|
| 387 |
+
return loss
|
| 388 |
+
|
| 389 |
+
def forward(self, image_features, text_features, logit_scale, logit_bias, output_dict=False):
|
| 390 |
+
loss = self._loss(image_features, text_features, logit_scale, logit_bias)
|
| 391 |
+
|
| 392 |
+
if self.world_size > 1:
|
| 393 |
+
if self.dist_impl == 'bidir':
|
| 394 |
+
right_rank = (self.rank + 1) % self.world_size
|
| 395 |
+
left_rank = (self.rank - 1 + self.world_size) % self.world_size
|
| 396 |
+
text_features_to_right = text_features_to_left = text_features
|
| 397 |
+
num_bidir, remainder = divmod(self.world_size - 1, 2)
|
| 398 |
+
for i in range(num_bidir):
|
| 399 |
+
text_features_recv = neighbour_exchange_bidir_with_grad(
|
| 400 |
+
left_rank,
|
| 401 |
+
right_rank,
|
| 402 |
+
text_features_to_left,
|
| 403 |
+
text_features_to_right,
|
| 404 |
+
)
|
| 405 |
+
for f in text_features_recv:
|
| 406 |
+
loss += self._loss(
|
| 407 |
+
image_features,
|
| 408 |
+
f,
|
| 409 |
+
logit_scale,
|
| 410 |
+
logit_bias,
|
| 411 |
+
negative_only=True,
|
| 412 |
+
)
|
| 413 |
+
text_features_to_left, text_features_to_right = text_features_recv
|
| 414 |
+
|
| 415 |
+
if remainder:
|
| 416 |
+
text_features_recv = neighbour_exchange_with_grad(
|
| 417 |
+
left_rank,
|
| 418 |
+
right_rank,
|
| 419 |
+
text_features_to_right
|
| 420 |
+
)
|
| 421 |
+
loss += self._loss(
|
| 422 |
+
image_features,
|
| 423 |
+
text_features_recv,
|
| 424 |
+
logit_scale,
|
| 425 |
+
logit_bias,
|
| 426 |
+
negative_only=True,
|
| 427 |
+
)
|
| 428 |
+
elif self.dist_impl == "shift":
|
| 429 |
+
right_rank = (self.rank + 1) % self.world_size
|
| 430 |
+
left_rank = (self.rank - 1 + self.world_size) % self.world_size
|
| 431 |
+
text_features_to_right = text_features
|
| 432 |
+
for i in range(self.world_size - 1):
|
| 433 |
+
text_features_from_left = neighbour_exchange_with_grad(
|
| 434 |
+
left_rank,
|
| 435 |
+
right_rank,
|
| 436 |
+
text_features_to_right,
|
| 437 |
+
)
|
| 438 |
+
loss += self._loss(
|
| 439 |
+
image_features,
|
| 440 |
+
text_features_from_left,
|
| 441 |
+
logit_scale,
|
| 442 |
+
logit_bias,
|
| 443 |
+
negative_only=True,
|
| 444 |
+
)
|
| 445 |
+
text_features_to_right = text_features_from_left
|
| 446 |
+
elif self.dist_impl == "reduce":
|
| 447 |
+
for i in range(self.world_size):
|
| 448 |
+
text_from_other = torch.distributed.nn.all_reduce(
|
| 449 |
+
text_features * (self.rank == i),
|
| 450 |
+
torch.distributed.ReduceOp.SUM,
|
| 451 |
+
)
|
| 452 |
+
loss += float(i != self.rank) * self._loss(
|
| 453 |
+
image_features,
|
| 454 |
+
text_from_other,
|
| 455 |
+
logit_scale,
|
| 456 |
+
logit_bias,
|
| 457 |
+
negative_only=True,
|
| 458 |
+
)
|
| 459 |
+
elif self.dist_impl == "gather":
|
| 460 |
+
all_text = torch.distributed.nn.all_gather(text_features)
|
| 461 |
+
for i in range(self.world_size):
|
| 462 |
+
loss += float(i != self.rank) * self._loss(
|
| 463 |
+
image_features,
|
| 464 |
+
all_text[i],
|
| 465 |
+
logit_scale,
|
| 466 |
+
logit_bias,
|
| 467 |
+
negative_only=True,
|
| 468 |
+
)
|
| 469 |
+
else:
|
| 470 |
+
assert False
|
| 471 |
+
|
| 472 |
+
return {"contrastive_loss": loss} if output_dict else loss
|
osf/utils/results_utils.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utility functions for saving experiment results to JSON/CSV.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import json
|
| 7 |
+
import glob
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pandas as pd
|
| 10 |
+
from typing import Dict, Any, Optional, List
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def convert_to_serializable(value):
|
| 14 |
+
"""Convert tensor/numpy values to Python native types for JSON serialization."""
|
| 15 |
+
if hasattr(value, 'item'): # torch.Tensor
|
| 16 |
+
return float(value.item())
|
| 17 |
+
elif isinstance(value, (np.ndarray, np.generic)):
|
| 18 |
+
return float(value)
|
| 19 |
+
return value
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def extract_embedding_type(root_dir: str) -> str:
|
| 23 |
+
"""
|
| 24 |
+
Extract embedding type identifier from root_dir path.
|
| 25 |
+
|
| 26 |
+
Examples:
|
| 27 |
+
".../dino_stage1_emb_no_norm" -> "dino_no_norm"
|
| 28 |
+
".../dino_stage1_emb" -> "dino"
|
| 29 |
+
".../mae_emb_normalized" -> "mae_normalized"
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
root_dir: Path to embedding directory
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
Short embedding type identifier
|
| 36 |
+
"""
|
| 37 |
+
if not root_dir:
|
| 38 |
+
return "unknown"
|
| 39 |
+
|
| 40 |
+
basename = os.path.basename(root_dir.rstrip('/'))
|
| 41 |
+
|
| 42 |
+
# Remove common suffixes/patterns
|
| 43 |
+
emb_type = basename
|
| 44 |
+
emb_type = emb_type.replace("_stage1_emb", "")
|
| 45 |
+
emb_type = emb_type.replace("_stage1", "")
|
| 46 |
+
emb_type = emb_type.replace("_emb", "")
|
| 47 |
+
emb_type = emb_type.replace("final_", "")
|
| 48 |
+
|
| 49 |
+
# Keep it concise
|
| 50 |
+
if len(emb_type) > 30:
|
| 51 |
+
emb_type = emb_type[:30]
|
| 52 |
+
|
| 53 |
+
return emb_type if emb_type else "emb"
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def format_lr(lr: float) -> str:
|
| 57 |
+
"""Format learning rate for filenames (e.g., 0.001 -> 1e-3)."""
|
| 58 |
+
if lr >= 1:
|
| 59 |
+
return f"{lr:.0f}"
|
| 60 |
+
elif lr >= 0.1:
|
| 61 |
+
return f"{lr:.1f}"
|
| 62 |
+
else:
|
| 63 |
+
# Convert to scientific notation
|
| 64 |
+
exp = 0
|
| 65 |
+
val = lr
|
| 66 |
+
while val < 1:
|
| 67 |
+
val *= 10
|
| 68 |
+
exp += 1
|
| 69 |
+
return f"{val:.0f}e-{exp}"
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def save_results_to_json(
|
| 73 |
+
test_metrics: Dict[str, Any],
|
| 74 |
+
hparams: Any,
|
| 75 |
+
extension: str,
|
| 76 |
+
ckpt_dir: str,
|
| 77 |
+
timestamp: str,
|
| 78 |
+
results_dir: str = "./results",
|
| 79 |
+
extra_fields: Optional[Dict[str, Any]] = None,
|
| 80 |
+
filename_prefix: str = "",
|
| 81 |
+
) -> str:
|
| 82 |
+
"""
|
| 83 |
+
Save test results to a JSON file.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
test_metrics: Dictionary of test metrics from trainer.test()
|
| 87 |
+
hparams: Hyperparameters namespace/object
|
| 88 |
+
extension: Experiment extension string (now used as run_name)
|
| 89 |
+
ckpt_dir: Checkpoint directory path
|
| 90 |
+
timestamp: Timestamp string
|
| 91 |
+
results_dir: Directory to save results (default: ./results)
|
| 92 |
+
extra_fields: Additional fields to include in the result record
|
| 93 |
+
- Should include: exp_type, task, dataset, model, etc.
|
| 94 |
+
filename_prefix: Prefix for the filename
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
Path to the saved JSON file
|
| 98 |
+
"""
|
| 99 |
+
os.makedirs(results_dir, exist_ok=True)
|
| 100 |
+
|
| 101 |
+
# Base result record
|
| 102 |
+
result_record = {
|
| 103 |
+
"run_name": extension, # Renamed from "extension" for clarity
|
| 104 |
+
"ckpt_dir": ckpt_dir,
|
| 105 |
+
"timestamp": timestamp,
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
common_fields = [
|
| 109 |
+
"model_name", "downstream_dataset_name", "ckpt_path", "stage2_ckpt_path",
|
| 110 |
+
"eval_label", "patient_cols", "use_which_backbone", "variant",
|
| 111 |
+
"in_features", "train_data_pct", "lr", "batch_size",
|
| 112 |
+
"max_epochs", "max_steps", "loss_type", "use_mean_pool",
|
| 113 |
+
"root_dir", "is_pretrain", "pooling", "use_transformer", "use_mil",
|
| 114 |
+
"encoder_name", "encoder", "mask_channels", "encoder_size",
|
| 115 |
+
"num_classes", "seed",
|
| 116 |
+
]
|
| 117 |
+
for field in common_fields:
|
| 118 |
+
if hasattr(hparams, field):
|
| 119 |
+
result_record[field] = getattr(hparams, field)
|
| 120 |
+
|
| 121 |
+
standard_metrics = [
|
| 122 |
+
"test_acc", "test_f1", "test_f1_w", "test_auc", "test_auprc",
|
| 123 |
+
"test_kappa", "test_rec_m", "test_loss",
|
| 124 |
+
"test/acc", "test/f1_macro", "test/auc_macro", "test/auprc_macro",
|
| 125 |
+
]
|
| 126 |
+
for metric in standard_metrics:
|
| 127 |
+
if metric in test_metrics:
|
| 128 |
+
key = metric.replace("/", "_")
|
| 129 |
+
result_record[key] = test_metrics[metric]
|
| 130 |
+
|
| 131 |
+
for key, value in test_metrics.items():
|
| 132 |
+
if key.startswith("test/") or key.startswith("test_"):
|
| 133 |
+
normalized_key = key.replace("/", "_")
|
| 134 |
+
if normalized_key not in result_record:
|
| 135 |
+
result_record[normalized_key] = value
|
| 136 |
+
|
| 137 |
+
if extra_fields:
|
| 138 |
+
result_record.update(extra_fields)
|
| 139 |
+
|
| 140 |
+
for key, value in result_record.items():
|
| 141 |
+
result_record[key] = convert_to_serializable(value)
|
| 142 |
+
|
| 143 |
+
if filename_prefix:
|
| 144 |
+
result_filename = f"{filename_prefix}_{timestamp}.json"
|
| 145 |
+
else:
|
| 146 |
+
model_name = getattr(hparams, 'model_name', 'model')
|
| 147 |
+
dataset_name = getattr(hparams, 'downstream_dataset_name', 'dataset')
|
| 148 |
+
label = getattr(hparams, 'eval_label', None) or getattr(hparams, 'patient_cols', 'task')
|
| 149 |
+
result_filename = f"{model_name}_{dataset_name}_{label}_{timestamp}.json"
|
| 150 |
+
|
| 151 |
+
result_path = os.path.join(results_dir, result_filename)
|
| 152 |
+
|
| 153 |
+
# Save to JSON
|
| 154 |
+
with open(result_path, 'w') as f:
|
| 155 |
+
json.dump(result_record, f, indent=2)
|
| 156 |
+
|
| 157 |
+
print(f"\n{'='*80}")
|
| 158 |
+
print(f"Results saved to: {result_path}")
|
| 159 |
+
print(f"{'='*80}\n")
|
| 160 |
+
|
| 161 |
+
return result_path
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def aggregate_results_to_csv(
|
| 165 |
+
results_dirs: List[str],
|
| 166 |
+
output_path: str = "./results/aggregated_results.csv",
|
| 167 |
+
key_columns: Optional[List[str]] = None,
|
| 168 |
+
metric_columns: Optional[List[str]] = None,
|
| 169 |
+
) -> pd.DataFrame:
|
| 170 |
+
"""
|
| 171 |
+
Aggregate all JSON result files from multiple directories into a single CSV.
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
results_dirs: List of directories containing JSON result files
|
| 175 |
+
output_path: Path to save the aggregated CSV
|
| 176 |
+
key_columns: Columns to use as identifiers (default: common experiment params)
|
| 177 |
+
metric_columns: Metric columns to include (default: all test metrics)
|
| 178 |
+
|
| 179 |
+
Returns:
|
| 180 |
+
DataFrame with aggregated results
|
| 181 |
+
"""
|
| 182 |
+
if key_columns is None:
|
| 183 |
+
key_columns = [
|
| 184 |
+
"exp_type", "task", "dataset", "model", "encoder",
|
| 185 |
+
"train_data_pct", "lr", "embedding_type",
|
| 186 |
+
"pretrain_ckpt_path", "finetuned_ckpt_dir", "trained_ckpt_dir",
|
| 187 |
+
"stage2_pretrain_ckpt", "embedding_root_dir",
|
| 188 |
+
"model_name", "downstream_dataset_name", "eval_label", "patient_cols",
|
| 189 |
+
"use_which_backbone", "variant", "loss_type",
|
| 190 |
+
"use_mean_pool", "pooling", "use_transformer", "use_mil",
|
| 191 |
+
"mask_channels", "mask_channels_str",
|
| 192 |
+
"ckpt_path", "stage2_ckpt_path", "root_dir",
|
| 193 |
+
]
|
| 194 |
+
|
| 195 |
+
if metric_columns is None:
|
| 196 |
+
metric_columns = [
|
| 197 |
+
"test_acc", "test_f1", "test_f1_w", "test_auc", "test_auprc",
|
| 198 |
+
"test_kappa", "test_rec_m", "test_loss",
|
| 199 |
+
]
|
| 200 |
+
|
| 201 |
+
all_records = []
|
| 202 |
+
|
| 203 |
+
for results_dir in results_dirs:
|
| 204 |
+
if not os.path.exists(results_dir):
|
| 205 |
+
print(f"[WARN] Directory not found: {results_dir}")
|
| 206 |
+
continue
|
| 207 |
+
|
| 208 |
+
# Find all JSON files
|
| 209 |
+
json_files = glob.glob(os.path.join(results_dir, "*.json"))
|
| 210 |
+
print(f"[INFO] Found {len(json_files)} JSON files in {results_dir}")
|
| 211 |
+
|
| 212 |
+
for json_file in json_files:
|
| 213 |
+
try:
|
| 214 |
+
with open(json_file, 'r') as f:
|
| 215 |
+
record = json.load(f)
|
| 216 |
+
record['_source_file'] = os.path.basename(json_file)
|
| 217 |
+
record['_source_dir'] = results_dir
|
| 218 |
+
all_records.append(record)
|
| 219 |
+
except Exception as e:
|
| 220 |
+
print(f"[WARN] Failed to load {json_file}: {e}")
|
| 221 |
+
|
| 222 |
+
if not all_records:
|
| 223 |
+
print("[WARN] No records found!")
|
| 224 |
+
return pd.DataFrame()
|
| 225 |
+
|
| 226 |
+
# Convert to DataFrame
|
| 227 |
+
df = pd.DataFrame(all_records)
|
| 228 |
+
|
| 229 |
+
existing_key_cols = [c for c in key_columns if c in df.columns]
|
| 230 |
+
existing_metric_cols = [c for c in metric_columns if c in df.columns]
|
| 231 |
+
|
| 232 |
+
per_class_cols = [c for c in df.columns if c.startswith("test_") and c not in existing_metric_cols]
|
| 233 |
+
per_class_cols = sorted(per_class_cols)
|
| 234 |
+
|
| 235 |
+
other_cols = [c for c in df.columns if c not in existing_key_cols + existing_metric_cols + per_class_cols]
|
| 236 |
+
|
| 237 |
+
ordered_cols = existing_key_cols + existing_metric_cols + per_class_cols + other_cols
|
| 238 |
+
df = df[[c for c in ordered_cols if c in df.columns]]
|
| 239 |
+
|
| 240 |
+
os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else ".", exist_ok=True)
|
| 241 |
+
df.to_csv(output_path, index=False)
|
| 242 |
+
|
| 243 |
+
print(f"\n{'='*80}")
|
| 244 |
+
print(f"Aggregated {len(all_records)} results to: {output_path}")
|
| 245 |
+
print(f"Columns: {list(df.columns[:10])}... ({len(df.columns)} total)")
|
| 246 |
+
print(f"{'='*80}\n")
|
| 247 |
+
|
| 248 |
+
return df
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def load_results_from_json(json_path: str) -> Dict[str, Any]:
|
| 252 |
+
"""Load a single JSON result file."""
|
| 253 |
+
with open(json_path, 'r') as f:
|
| 254 |
+
return json.load(f)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def filter_results(
|
| 258 |
+
df: pd.DataFrame,
|
| 259 |
+
model_name: Optional[str] = None,
|
| 260 |
+
dataset_name: Optional[str] = None,
|
| 261 |
+
eval_label: Optional[str] = None,
|
| 262 |
+
patient_cols: Optional[str] = None,
|
| 263 |
+
) -> pd.DataFrame:
|
| 264 |
+
"""
|
| 265 |
+
Filter aggregated results DataFrame by common fields.
|
| 266 |
+
|
| 267 |
+
Args:
|
| 268 |
+
df: DataFrame from aggregate_results_to_csv()
|
| 269 |
+
model_name: Filter by model name
|
| 270 |
+
dataset_name: Filter by downstream dataset name
|
| 271 |
+
eval_label: Filter by eval label (stage 1)
|
| 272 |
+
patient_cols: Filter by patient columns (stage 2)
|
| 273 |
+
|
| 274 |
+
Returns:
|
| 275 |
+
Filtered DataFrame
|
| 276 |
+
"""
|
| 277 |
+
filtered = df.copy()
|
| 278 |
+
|
| 279 |
+
if model_name is not None and 'model_name' in filtered.columns:
|
| 280 |
+
filtered = filtered[filtered['model_name'] == model_name]
|
| 281 |
+
if dataset_name is not None and 'downstream_dataset_name' in filtered.columns:
|
| 282 |
+
filtered = filtered[filtered['downstream_dataset_name'] == dataset_name]
|
| 283 |
+
if eval_label is not None and 'eval_label' in filtered.columns:
|
| 284 |
+
filtered = filtered[filtered['eval_label'] == eval_label]
|
| 285 |
+
if patient_cols is not None and 'patient_cols' in filtered.columns:
|
| 286 |
+
filtered = filtered[filtered['patient_cols'] == patient_cols]
|
| 287 |
+
|
| 288 |
+
return filtered
|
| 289 |
+
|
osf_backbone.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c51190b1942556969af3c3d63c2e59430ddb1ea0377c50ea87df83712fc31857
|
| 3 |
+
size 341360652
|
pretrained_weights/readme.md
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Please download the checkpoint throught the link in the readme of the directory root.
|
requirements.txt
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
absl-py==2.3.1
|
| 2 |
+
accelerate==1.2.1
|
| 3 |
+
aiohappyeyeballs==2.4.4
|
| 4 |
+
aiohttp==3.11.10
|
| 5 |
+
aiosignal==1.3.1
|
| 6 |
+
albucore==0.0.24
|
| 7 |
+
albumentations==2.0.8
|
| 8 |
+
altair==5.5.0
|
| 9 |
+
annotated-types==0.7.0
|
| 10 |
+
antlr4-python3-runtime==4.9.3
|
| 11 |
+
asttokens==3.0.0
|
| 12 |
+
async-timeout==5.0.1
|
| 13 |
+
attrs==25.3.0
|
| 14 |
+
beartype==0.22.2
|
| 15 |
+
bitarray==3.0.0
|
| 16 |
+
blinker==1.9.0
|
| 17 |
+
braceexpand==0.1.7
|
| 18 |
+
certifi==2025.10.5
|
| 19 |
+
cffi==1.17.1
|
| 20 |
+
charset-normalizer==3.4.3
|
| 21 |
+
click==8.1.7
|
| 22 |
+
coloredlogs==15.0.1
|
| 23 |
+
comm==0.2.3
|
| 24 |
+
contourpy==1.3.1
|
| 25 |
+
cosine_annealing_warmup @ git+https://github.com/katsura-jp/pytorch-cosine-annealing-with-warmup@12d03c07553aedd3d9e9155e2b3e31ce8c64081a
|
| 26 |
+
cycler==0.12.1
|
| 27 |
+
Cython==3.0.11
|
| 28 |
+
datasets==3.2.0
|
| 29 |
+
debugpy==1.8.17
|
| 30 |
+
decorator==5.2.1
|
| 31 |
+
diffusers==0.32.1
|
| 32 |
+
dill==0.3.8
|
| 33 |
+
docker-pycreds==0.4.0
|
| 34 |
+
easydict==1.13
|
| 35 |
+
efficientnet_pytorch==0.7.1
|
| 36 |
+
einops==0.8.0
|
| 37 |
+
ema-pytorch==0.7.7
|
| 38 |
+
et_xmlfile==2.0.0
|
| 39 |
+
exceptiongroup==1.3.0
|
| 40 |
+
executing==2.2.1
|
| 41 |
+
fairseq_signals_backbone @ git+https://github.com/fuying-wang/fairseq-signals@27d94bab8a1040879c011609df1488aac21a586a
|
| 42 |
+
filelock==3.20.0
|
| 43 |
+
flatbuffers==25.9.23
|
| 44 |
+
fonttools==4.55.1
|
| 45 |
+
frozenlist==1.5.0
|
| 46 |
+
fsspec==2024.6.1
|
| 47 |
+
gitdb==4.0.11
|
| 48 |
+
GitPython==3.1.43
|
| 49 |
+
grpcio==1.75.1
|
| 50 |
+
h5py==3.14.0
|
| 51 |
+
hf-xet==1.1.10
|
| 52 |
+
huggingface-hub==0.34.4
|
| 53 |
+
humanfriendly==10.0
|
| 54 |
+
hydra-core==1.3.2
|
| 55 |
+
idna==3.10
|
| 56 |
+
imageio==2.37.0
|
| 57 |
+
importlib_metadata==8.7.0
|
| 58 |
+
insightface==0.7.3
|
| 59 |
+
ipdb==0.13.13
|
| 60 |
+
ipykernel==6.30.1
|
| 61 |
+
ipython==8.37.0
|
| 62 |
+
jedi==0.19.2
|
| 63 |
+
Jinja2==3.1.6
|
| 64 |
+
joblib==1.4.2
|
| 65 |
+
jupyter_client==8.6.3
|
| 66 |
+
jupyter_core==5.8.1
|
| 67 |
+
kiwisolver==1.4.7
|
| 68 |
+
kornia==0.8.1
|
| 69 |
+
kornia_rs==0.1.9
|
| 70 |
+
lazy_loader==0.4
|
| 71 |
+
lightning==2.4.0
|
| 72 |
+
lightning-utilities==0.11.9
|
| 73 |
+
llvmlite==0.46.0
|
| 74 |
+
loguru==0.7.3
|
| 75 |
+
lxml==5.3.0
|
| 76 |
+
Markdown==3.9
|
| 77 |
+
MarkupSafe==3.0.3
|
| 78 |
+
matplotlib==3.9.3
|
| 79 |
+
matplotlib-inline==0.1.7
|
| 80 |
+
ml_dtypes==0.5.3
|
| 81 |
+
mne==1.10.1
|
| 82 |
+
mpmath==1.3.0
|
| 83 |
+
multidict==6.1.0
|
| 84 |
+
multiprocess==0.70.16
|
| 85 |
+
munch==4.0.0
|
| 86 |
+
narwhals==2.6.0
|
| 87 |
+
nest-asyncio==1.6.0
|
| 88 |
+
networkx==3.4.2
|
| 89 |
+
neurokit2==0.2.12
|
| 90 |
+
ninja==1.13.0
|
| 91 |
+
nltk==3.9.1
|
| 92 |
+
numba==0.63.1
|
| 93 |
+
numpy==2.1.2
|
| 94 |
+
omegaconf==2.3.0
|
| 95 |
+
onnx==1.19.1
|
| 96 |
+
onnx2torch==1.5.15
|
| 97 |
+
onnxruntime==1.23.1
|
| 98 |
+
opencv-python==4.12.0.88
|
| 99 |
+
opencv-python-headless==4.12.0.88
|
| 100 |
+
openpyxl==3.1.5
|
| 101 |
+
packaging==24.2
|
| 102 |
+
pandas==2.2.3
|
| 103 |
+
parso==0.8.5
|
| 104 |
+
peft==0.14.0
|
| 105 |
+
pexpect==4.9.0
|
| 106 |
+
pillow==11.0.0
|
| 107 |
+
platformdirs==4.4.0
|
| 108 |
+
pooch==1.8.2
|
| 109 |
+
portalocker==3.0.0
|
| 110 |
+
POT==0.9.5
|
| 111 |
+
pretrainedmodels==0.7.4
|
| 112 |
+
prettytable==3.16.0
|
| 113 |
+
prompt_toolkit==3.0.52
|
| 114 |
+
propcache==0.2.1
|
| 115 |
+
protobuf==5.29.1
|
| 116 |
+
psutil==7.1.0
|
| 117 |
+
ptyprocess==0.7.0
|
| 118 |
+
pure_eval==0.2.3
|
| 119 |
+
pyarrow==18.1.0
|
| 120 |
+
pycparser==2.23
|
| 121 |
+
pydantic==2.10.3
|
| 122 |
+
pydantic_core==2.27.1
|
| 123 |
+
pydeck==0.9.1
|
| 124 |
+
Pygments==2.19.2
|
| 125 |
+
pynndescent==0.5.13
|
| 126 |
+
pyparsing==3.2.0
|
| 127 |
+
pysam==0.23.3
|
| 128 |
+
python-dateutil==2.9.0.post0
|
| 129 |
+
pytorch-lightning==2.4.0
|
| 130 |
+
pytorch-warmup==0.2.0
|
| 131 |
+
pytz==2024.2
|
| 132 |
+
PyWavelets==1.8.0
|
| 133 |
+
PyYAML==6.0.3
|
| 134 |
+
pyzmq==27.1.0
|
| 135 |
+
regex==2024.11.6
|
| 136 |
+
requests==2.32.5
|
| 137 |
+
sacrebleu==2.4.3
|
| 138 |
+
safetensors==0.6.2
|
| 139 |
+
scikit-image==0.25.2
|
| 140 |
+
scikit-learn==1.7.2
|
| 141 |
+
scipy==1.14.1
|
| 142 |
+
seaborn==0.13.2
|
| 143 |
+
segmentation_models_pytorch==0.4.0
|
| 144 |
+
sentencepiece==0.2.1
|
| 145 |
+
sentry-sdk==2.19.2
|
| 146 |
+
setproctitle==1.3.4
|
| 147 |
+
simsimd==6.5.3
|
| 148 |
+
six==1.17.0
|
| 149 |
+
smmap==5.0.1
|
| 150 |
+
soundfile==0.12.1
|
| 151 |
+
stack-data==0.6.3
|
| 152 |
+
streamlit==1.50.0
|
| 153 |
+
stringzilla==4.2.1
|
| 154 |
+
sympy==1.13.1
|
| 155 |
+
tabulate==0.9.0
|
| 156 |
+
tenacity==9.1.2
|
| 157 |
+
tensorboard==2.20.0
|
| 158 |
+
tensorboard-data-server==0.7.2
|
| 159 |
+
tensorboardX==2.6.4
|
| 160 |
+
threadpoolctl==3.5.0
|
| 161 |
+
tifffile==2025.5.10
|
| 162 |
+
timm==1.0.12
|
| 163 |
+
tokenizers==0.21.0
|
| 164 |
+
toml==0.10.2
|
| 165 |
+
tomli==2.2.1
|
| 166 |
+
torch==2.5.1
|
| 167 |
+
torchaudio==2.5.1
|
| 168 |
+
torchdiffeq==0.2.5
|
| 169 |
+
torchmetrics==1.6.0
|
| 170 |
+
torchtools @ git+https://github.com/pabloppp/pytorch-tools@610158d5016d6418aee27f956e7afd17ff35ba04
|
| 171 |
+
torchvision==0.20.1
|
| 172 |
+
tornado==6.5.2
|
| 173 |
+
tqdm==4.67.1
|
| 174 |
+
traitlets==5.14.3
|
| 175 |
+
transformers==4.47.0
|
| 176 |
+
typing-inspection==0.4.1
|
| 177 |
+
typing_extensions==4.15.0
|
| 178 |
+
tzdata==2024.2
|
| 179 |
+
umap-learn==0.5.9.post2
|
| 180 |
+
unet==0.8.1
|
| 181 |
+
urllib3==2.5.0
|
| 182 |
+
vitaldb==1.5.8
|
| 183 |
+
wandb==0.21.4
|
| 184 |
+
warmup_scheduler==0.3
|
| 185 |
+
watchdog==6.0.0
|
| 186 |
+
wcwidth==0.2.13
|
| 187 |
+
webdataset==1.0.2
|
| 188 |
+
Werkzeug==3.1.3
|
| 189 |
+
wfdb==4.1.2
|
| 190 |
+
xgboost==2.1.3
|
| 191 |
+
xxhash==3.5.0
|
| 192 |
+
yarl==1.18.3
|
| 193 |
+
zipp==3.23.0
|
train_config.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from config import *
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
# Uni-encoder models (simclr, dino, mae, vqvae, ar, etc.)
|
| 5 |
+
TRAIN_EDF_COLS_UNI_ENC = [ECG, EMG_Chin, EMG_LLeg, EMG_RLeg,
|
| 6 |
+
ABD, THX, NP, SN,
|
| 7 |
+
EOG_E1_A2, EOG_E2_A1,EEG_C3_A2, EEG_C4_A1,
|
| 8 |
+
]
|
| 9 |
+
TRAIN_EDF_COLS_MULTI_ENC = [ECG,
|
| 10 |
+
ABD, THX, NP, SN,
|
| 11 |
+
EMG_Chin, EMG_LLeg, EMG_RLeg,
|
| 12 |
+
EOG_E1_A2, EOG_E2_A1,EEG_C3_A2, EEG_C4_A1,
|
| 13 |
+
]
|
| 14 |
+
TRAIN_EDF_COLS_TYPE3 = [ECG, ABD, THX, NP, SN]
|
| 15 |
+
TRAIN_EDF_COLS_TYPE4 = [ECG, ABD, THX]
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
MONITOR_TYPE_MAP = {
|
| 19 |
+
"main": TRAIN_EDF_COLS_UNI_ENC,
|
| 20 |
+
"type3": TRAIN_EDF_COLS_TYPE3,
|
| 21 |
+
"type4": TRAIN_EDF_COLS_TYPE4,
|
| 22 |
+
}
|
| 23 |
+
STAGE2_LABEL_PATH_WITH_PATHHEAD = "/path/to/your/label/splits"
|
| 24 |
+
CKPT_PATH = "/path/to/your/checkpoints"
|
| 25 |
+
MODEL_LIST = ["dino_ours"]
|
| 26 |
+
|
| 27 |
+
AUGMENTATION_MAP = {
|
| 28 |
+
"dino_ours": "chan_then_pcspan",
|
| 29 |
+
}
|
| 30 |
+
SPLIT_DATA_FOLDER = "/path/to/your/postprocessed/data"
|
| 31 |
+
PRETRAIN_VAL_DATASET_LIST = ['shhs']
|
| 32 |
+
NEED_NORM_COL = [HR, SPO2, OX]
|