ztshuaiUCLA commited on
Commit
8f8716a
·
verified ·
1 Parent(s): ea50973

Upload folder using huggingface_hub

Browse files
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Health Intelligence Lab @ UCLA (https://github.com/yang-ai-lab)
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,3 +1,256 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - sleep
5
+ - eeg
6
+ - polysomnography
7
+ - foundation-model
8
+ - self-supervised
9
+ - vit
10
+ - biosignals
11
+ pipeline_tag: feature-extraction
12
+ library_name: pytorch
13
+ language:
14
+ - en
15
+ ---
16
+
17
+ # OSF: On Pre-training and Scaling of Sleep Foundation Models
18
+
19
+ [![Paper](https://img.shields.io/badge/paper-arXiv-red)](#citation)
20
+ [![Webpage](https://img.shields.io/badge/website-demo-blue)](https://yang-ai-lab.github.io/osf/)
21
+ [![License](https://img.shields.io/badge/license-MIT-green)](LICENSE)
22
+ [![Python](https://img.shields.io/badge/python-3.10%2B-brightgreen)](#installation)
23
+
24
+
25
+ ## 🔥 News
26
+
27
+ - [2026-2-24] Our codebase and checkpoint is released. Full codebase for benchmarking will be public available after acceptance.
28
+ - [2026-2-22] Our paper is out.
29
+
30
+ ## 📖 Introduction
31
+
32
+ Polysomnography (PSG) provides the gold standard for sleep assessment but suffers from substantial heterogeneity across recording devices and cohorts.
33
+ There have been growing efforts to build general-purpose foundation models (FMs) for sleep physiology, but lack an in-depth understanding of the pre-training process and scaling patterns that lead to more generalizable sleep FMs.
34
+ To fill this gap, we curate a massive corpus of 166,500 hours of sleep recordings from nine public sources and establish SleepBench, a comprehensive, fully open-source benchmark.
35
+ Leveraging SleepBench, we systematically evaluate four families of self-supervised pre-training objectives and uncover three critical findings:
36
+ (1) existing FMs fail to generalize to missing channels at inference;
37
+ (2) channel-invariant feature learning is essential for pre-training;
38
+ and (3) scaling sample size, model capacity, and multi-source data mixture consistently improves downstream performance.
39
+ With an enhanced pre-training and scaling recipe, we introduce OSF, a family of sleep FMs that achieves state-of-the-art performance across nine datasets on diverse sleep and disease prediction tasks.
40
+ Further analysis of OSF also reveals intriguing properties in sample efficiency, hierarchical aggregation, and cross-dataset scaling.
41
+
42
+
43
+ ## 📖 Table of Contents
44
+
45
+ 1. [Installation](#-installation)
46
+ 2. [Quick Start](#-quick-start)
47
+ 3. [Pretrained Weights](#-pretrained-weights)
48
+ 4. [Usage](#-usage)
49
+ 5. [Benchmark Evaluations](#-benchmark-evaluations)
50
+ 6. [Supported Datasets](#-supported-datasets)
51
+ 7. [Citation](#-citation)
52
+
53
+ ## 💿 Installation
54
+
55
+ ```bash
56
+ git clone https://huggingface.co/yang-ai-lab/OSF-Base
57
+ cd OSF-Base
58
+ conda env create -f environment.yml
59
+ conda activate myenv
60
+ ```
61
+
62
+
63
+ ### Dependencies
64
+
65
+ - Python >= 3.10
66
+ - PyTorch >= 2.9.0
67
+ - PyTorch Lightning >= 2.5.5
68
+
69
+
70
+ ## 🚀 Quick Start
71
+
72
+ We provide a demo notebook (`demo.ipynb`) demonstrating how to extract embeddings from PSG signals using the pretrained model.
73
+
74
+ ```python
75
+ import torch
76
+ from osf.backbone.vit1d_cls import vit_base
77
+
78
+ # Load pretrained weights (included in this repo)
79
+ payload = torch.load("osf_backbone.pth", map_location="cpu", weights_only=False)
80
+ meta = payload["metadata"]
81
+
82
+ # Initialize model
83
+ backbone = vit_base(
84
+ num_leads=meta["num_leads"], # 12 channels
85
+ seq_len=meta["seq_len"], # 1920 (64 Hz × 30 s)
86
+ patch_size=meta["patch_size_time"],
87
+ lead_wise=meta["lead_wise"],
88
+ patch_size_ch=meta["patch_size_ch"],
89
+ )
90
+ backbone.load_state_dict(payload["state_dict"])
91
+ backbone.eval()
92
+
93
+ # Extract embeddings
94
+ # x: [B, 12, 1920] - 12-channel PSG, 64 Hz × 30 seconds
95
+ with torch.no_grad():
96
+ cls_embs, patch_embs = backbone.forward_encoding(x, return_sequence=False)
97
+ # cls_embs: [B, 768] - Global epoch-level representation
98
+ # patch_embs: [B, 90, 768] - Local patch representations
99
+ ```
100
+
101
+ ## 📦 Pretrained Weights
102
+
103
+ | Model | Backbone | Channels |
104
+ |-------|----------|----------|
105
+ | OSF | ViT-Base | 12-ch |
106
+
107
+ The pretrained weights are included in this repository. You can download them via the Hugging Face Hub:
108
+
109
+ ```python
110
+ from huggingface_hub import hf_hub_download
111
+ checkpoint_path = hf_hub_download(repo_id="yang-ai-lab/OSF-Base", filename="osf_backbone.pth")
112
+ ```
113
+
114
+ Or via the CLI:
115
+
116
+ ```bash
117
+ huggingface-cli download yang-ai-lab/OSF-Base osf_backbone.pth
118
+ ```
119
+
120
+ ## 👩‍💻 Usage
121
+
122
+ ### Input Format
123
+
124
+ Expected input format:
125
+ - **12 PSG Channels**: ECG, EMG_Chin, EMG_LLeg, EMG_RLeg, ABD, THX, NP, SN, EOG_E1_A2, EOG_E2_A1, EEG_C3_A2, EEG_C4_A1
126
+ - **Sample Rate**: 64 Hz
127
+ - **Epoch Length**: 30 seconds
128
+ - **Input Shape**: `[B, 12, 1920]`
129
+
130
+ ### Pretraining
131
+
132
+ We support multiple self-supervised pretraining methods, for example, to launch pre-training of our OSF method, run pretraining:
133
+
134
+ ```bash
135
+ python main_pretrain.py \
136
+ --model_name "dino_ours" \
137
+ --psg_encoder_name "vit_base" \
138
+ --batch_size 256 \
139
+ --lr 5e-5 \
140
+ --max_epochs 30 \
141
+ --num_devices 4 \
142
+ --patch_size_time 64 \
143
+ --patch_size_ch 4 \
144
+ --precision "bf16-mixed"
145
+ ```
146
+
147
+ See `main_pipleines/main_pretrain.py` for more detailed settings.
148
+
149
+ ### Fine-tuning
150
+
151
+ Fine-tune the pretrained model on downstream tasks:
152
+
153
+ ```bash
154
+ python main_finetune.py \
155
+ --model_name "dino_ours" \
156
+ --ckpt_path "/path/to/pretrained/checkpoint.ckpt" \
157
+ --downstream_dataset_name "shhs" \
158
+ --eval_label "Stage" \
159
+ --train_data_pct 1.0 \
160
+ --max_steps 500 \
161
+ --lr 0.1 \
162
+ --num_devices 4
163
+ ```
164
+
165
+
166
+ ## 📊 Benchmark Evaluations
167
+
168
+ ### Benchmarked SSL Methods
169
+
170
+ | Method | Type | Original Paper |
171
+ |--------|------|-------------|
172
+ | SleepFM | Contrastive | [Leave-one-out multi-modal contrastive learning](https://www.nature.com/articles/s41591-025-04133-4.pdf) |
173
+ | SimCLR | Contrastive | [Simple Constrastive Learning](https://proceedings.mlr.press/v119/chen20j/chen20j.pdf) |
174
+ | DINO | Self-distillation | [DINO](https://arxiv.org/pdf/2304.07193) |
175
+ | VQ-VAE | Reconstruction | [Vector-quantized variational autoencoder](https://proceedings.neurips.cc/paper/2017/file/7a98af17e63a0ac09ce2e96d03992fbc-Paper.pdf) |
176
+ | MAE | Reconstruction | [Masked Autoencoding](https://openaccess.thecvf.com/content/CVPR2022/papers/He_Masked_Autoencoders_Are_Scalable_Vision_Learners_CVPR_2022_paper.pdf) |
177
+ | AR | Autoregressive | [Autoregressive Next-Token prediction](https://storage.prod.researchhub.com/uploads/papers/2020/06/01/language-models.pdf) |
178
+ | OSF | Self-distillation | ours |
179
+
180
+ ### Downstream Tasks
181
+
182
+ **Epoch-level Classification Tasks:**
183
+
184
+ | Task | Classes | Description |
185
+ |------|---------|-------------|
186
+ | Sleep Stage | 4 | Awake, Light Sleep, Deep Sleep, REM classification |
187
+ | Arousal | 2 | Arousal event detection |
188
+ | Hypopnea | 2 | Hypopnea event detection |
189
+ | Oxygen Desaturation | 2 | Oxygen desaturation detection |
190
+
191
+
192
+ ### Evaluation Settings
193
+
194
+ | Setting | Description |
195
+ |---------|-------------|
196
+ | Linear Probing | Freeze backbone, train linear classifier |
197
+ | Full Fine-tuning | Fine-tune entire model end-to-end |
198
+ | Few-shot (k-shot) | Train with limited labeled samples |
199
+
200
+ For example scripts, see `main_pipelines` and `bash_scripts` folders.
201
+
202
+ ## 📊 Supported Datasets
203
+
204
+ We aggregated nine large-scale datasets from the National Sleep Research Resource platform.
205
+
206
+ | Dataset | Full Name | Source |
207
+ |---------|-----------|--------|
208
+ | SHHS | Sleep Heart Health Study | NSRR |
209
+ | CHAT | Childhood Adenotonsillectomy Trial | NSRR |
210
+ | MROS | MrOS Sleep Study | NSRR |
211
+ | CCSHS | Cleveland Children's Sleep and Health Study | NSRR |
212
+ | CFS | Cleveland Family Study | NSRR |
213
+ | MESA | Multi-Ethnic Study of Atherosclerosis | NSRR |
214
+ | SOF | Study of Osteoporotic Fractures | NSRR |
215
+ | WSC | Wisconsin Sleep Cohort | NSRR |
216
+ | STAGES | Stanford Technology Analytics and Genomics in Sleep | NSRR |
217
+ | NCHSDB | NCH Sleep DataBank | NSRR |
218
+
219
+ For new users, please apply for an account and access to each of these datasets following instructions here [NSRR Registration](https://sleepdata.org/join)
220
+
221
+ ## 📁 Project Structure
222
+
223
+ ```
224
+ OSF-Open-Sleep-Foundation-Model/
225
+ ├── osf/
226
+ │ ├── backbone/ # ViT backbone implementations
227
+ │ │ └── vit1d_cls.py
228
+ │ ├── models/ # SSL model implementations
229
+ │ │ └── dino_model_cls.py
230
+ │ │
231
+ │ ├── datasets/ # Data loading utilities
232
+ │ └── utils/ # Helper functions
233
+ ├── main_pipelines/ # Training scripts
234
+ │ ├── main_pretrain.py
235
+ │ └── ...
236
+ ├── bash_scripts/ # Example bash scripts
237
+ ├── osf_backbone.pth # Pretrained model weights
238
+ ├── demo.ipynb # Quick start demo
239
+ ├── config.py # Dataset and channel configurations
240
+ └── train_config.py # Training configurations
241
+ ```
242
+
243
+
244
+ ## 📝 Citation
245
+
246
+ If you use this code or models in your research, please cite our paper:
247
+
248
+ ```bibtex
249
+ @article{shuai2026osf,
250
+ title={OSF: On Pre-training and Scaling of Sleep Foundation Models},
251
+ author={Shuai, Zitao and Xu, Zongzhe and Yang, David and Wang, Wei and Yang, Yuzhe},
252
+ journal={arXiv preprint},
253
+ year={2026}
254
+ }
255
+ ```
256
+
config.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration constants for sleep data processing.
3
+ Contains dataset names, paths, channel definitions, and event labels.
4
+ """
5
+ import pandas as pd
6
+ import numpy as np
7
+
8
+ # =============================================================================
9
+ # Dataset name constants
10
+ # =============================================================================
11
+ SHHS = 'shhs'
12
+ CHAT = 'chat'
13
+ MROS = 'mros'
14
+ CCSHS = 'ccshs'
15
+ CFS = 'cfs'
16
+ MESA = 'mesa'
17
+ SOF = 'sof'
18
+ WSC = 'wsc'
19
+ HSP = 'hsp'
20
+ NCHSDB = 'nchsdb'
21
+ STAGES = 'stages'
22
+ PATS = 'pats'
23
+ SHHS2 = 'shhs2'
24
+ NUMOM2B = 'numom2b'
25
+
26
+ # =============================================================================
27
+ # Data paths
28
+ # =============================================================================
29
+ META_PATH = '/path/to/your/nsrr/data'
30
+
31
+ MASTER_SHHS = [META_PATH + "/" + SHHS + "/datasets/shhs-harmonized-dataset-0.21.0.csv"]
32
+ MASTER_CHAT = [META_PATH + "/" + CHAT + "/datasets/chat-harmonized-dataset-0.14.0.csv"]
33
+ MASTER_MROS = [META_PATH + "/" + MROS + "/datasets/mros-visit1-harmonized-0.6.0.csv"]
34
+ MASTER_CCSHS = [META_PATH + "/" + CCSHS + "/datasets/ccshs-trec-harmonized-0.8.0.csv"]
35
+ MASTER_CFS = [META_PATH + "/" + CFS + "/datasets/cfs-visit5-harmonized-dataset-0.7.0.csv"]
36
+ MASTER_MESA = [META_PATH + "/" + MESA + "/datasets/mesa-sleep-harmonized-dataset-0.7.0.csv"]
37
+ MASTER_SOF = [META_PATH + "/" + SOF + "/datasets/sof-visit-8-harmonized-dataset-0.8.0.csv"]
38
+ MASTER_WSC = [META_PATH + "/" + WSC + "/datasets/wsc-harmonized-dataset-0.7.0.csv"]
39
+ MASTER_HSP = [
40
+ META_PATH + "/" + HSP + "/psg-metadata/I0001_psg_metadata_2025-05-06.csv",
41
+ META_PATH + "/" + HSP + "/psg-metadata/I0002_psg_metadata_2025-05-06.csv",
42
+ META_PATH + "/" + HSP + "/psg-metadata/I0003_psg_metadata_2025-05-06.csv",
43
+ META_PATH + "/" + HSP + "/psg-metadata/I0004_psg_metadata_2025-05-06.csv",
44
+ META_PATH + "/" + HSP + "/psg-metadata/I0006_psg_metadata_2025-05-06.csv",
45
+ ]
46
+ MASTER_STAGES = [META_PATH + "/" + STAGES + "/metadata/stages-harmonized-dataset-0.3.0.csv"]
47
+ MASTER_NCHSDB = [META_PATH + "/" + NCHSDB + "/datasets/nchsdb-dataset-harmonized-0.3.0.csv"]
48
+ MASTER_PATS = [META_PATH + "/" + PATS + "/datasets/pats-harmonized-dataset-0.1.0.csv"]
49
+
50
+ MASTER_CSV_LIST = {
51
+ 'shhs': MASTER_SHHS,
52
+ 'chat': MASTER_CHAT,
53
+ 'mros': MASTER_MROS,
54
+ 'ccshs': MASTER_CCSHS,
55
+ 'cfs': MASTER_CFS,
56
+ 'mesa': MASTER_MESA,
57
+ 'sof': MASTER_SOF,
58
+ 'wsc': MASTER_WSC,
59
+ 'hsp': MASTER_HSP,
60
+ 'stages': MASTER_STAGES,
61
+ 'pats': MASTER_PATS,
62
+ 'nchsdb': MASTER_NCHSDB,
63
+ }
64
+
65
+ # =============================================================================
66
+ # Channel name constants
67
+ # =============================================================================
68
+ # ECG channels
69
+ ECG = 'ECG'
70
+ ECG1 = 'ECG1'
71
+ ECG2 = 'ECG2'
72
+ ECG3 = 'ECG3'
73
+ HR = 'HR'
74
+ PPG = 'PPG'
75
+
76
+ # Respiratory channels
77
+ SPO2 = 'SPO2'
78
+ OX = 'OX'
79
+ ABD = 'ABD'
80
+ THX = 'THX'
81
+ AF = 'AF'
82
+ NP = 'NP'
83
+ SN = 'SN'
84
+
85
+ # EOG channels
86
+ EOG_L = 'EOG_L'
87
+ EOG_R = 'EOG_R'
88
+ EOG_E1_A2 = 'EOG_E1_A2'
89
+ EOG_E2_A1 = 'EOG_E2_A1'
90
+
91
+ # EMG Leg channels
92
+ EMG_LLeg = 'EMG_LLeg'
93
+ EMG_RLeg = 'EMG_RLeg'
94
+ EMG_LLeg1 = 'EMG_LLeg1'
95
+ EMG_LLeg2 = 'EMG_LLeg2'
96
+ EMG_RLeg1 = 'EMG_RLeg1'
97
+ EMG_RLeg2 = 'EMG_RLeg2'
98
+ EMG_Leg = 'EMG_Leg'
99
+
100
+ # Sensor Leg channels
101
+ SENSOR_Leg = 'SENSOR_Leg'
102
+ SENSOR_LLeg = 'SENSOR_LLeg'
103
+ SENSOR_LLeg1 = 'SENSOR_LLeg1'
104
+ SENSOR_LLeg2 = 'SENSOR_LLeg2'
105
+ SENSOR_RLeg = 'SENSOR_RLeg'
106
+ SENSOR_RLeg1 = 'SENSOR_RLeg1'
107
+ SENSOR_RLeg2 = 'SENSOR_RLeg2'
108
+
109
+ # EMG Chin channels
110
+ EMG_Chin = 'EMG_Chin'
111
+ EMG_RChin = 'EMG_RChin'
112
+ EMG_LChin = 'EMG_LChin'
113
+ EMG_CChin = 'EMG_CChin'
114
+
115
+ # EEG channels (unipolar)
116
+ EEG_C3 = 'EEG_C3'
117
+ EEG_C4 = 'EEG_C4'
118
+ EEG_A1 = 'EEG_A1'
119
+ EEG_A2 = 'EEG_A2'
120
+ EEG_O1 = 'EEG_O1'
121
+ EEG_O2 = 'EEG_O2'
122
+ EEG_F3 = 'EEG_F3'
123
+ EEG_F4 = 'EEG_F4'
124
+
125
+ # EEG channels (bipolar/referenced)
126
+ EEG_C3_A2 = 'EEG_C3_A2'
127
+ EEG_C4_A1 = 'EEG_C4_A1'
128
+ EEG_F3_A2 = 'EEG_F3_A2'
129
+ EEG_F4_A1 = 'EEG_F4_A1'
130
+ EEG_O1_A2 = 'EEG_O1_A2'
131
+ EEG_O2_A1 = 'EEG_O2_A1'
132
+
133
+ # Other channels
134
+ FPZ = 'FPZ'
135
+ GROUND = 'GROUND'
136
+ POS = 'POS'
137
+
138
+ # =============================================================================
139
+ # Sampling frequencies (Hz)
140
+ # =============================================================================
141
+ FREQ_ECG = 128
142
+ FREQ_ECG1 = 128
143
+ FREQ_ECG2 = 128
144
+ FREQ_ECG3 = 128
145
+ FREQ_HR = 1
146
+ FREQ_PPG = 128
147
+
148
+ FREQ_SPO2 = 1
149
+ FREQ_OX = 1
150
+ FREQ_ABD = 8
151
+ FREQ_THX = 8
152
+ FREQ_AF = 8
153
+ FREQ_NP = 8
154
+ FREQ_SN = 32
155
+
156
+ FREQ_EOG_L = 64
157
+ FREQ_EOG_R = 64
158
+ FREQ_EOG_E1_A2 = 64
159
+ FREQ_EOG_E2_A1 = 64
160
+
161
+ FREQ_EMG_Leg = 64
162
+ FREQ_EMG_LLeg = 64
163
+ FREQ_EMG_RLeg = 64
164
+ FREQ_EMG_LLeg1 = 64
165
+ FREQ_EMG_LLeg2 = 64
166
+ FREQ_EMG_RLeg1 = 64
167
+ FREQ_EMG_RLeg2 = 64
168
+
169
+ FREQ_SENSOR_Leg = 64
170
+ FREQ_SENSOR_LLeg = 64
171
+ FREQ_SENSOR_LLeg1 = 64
172
+ FREQ_SENSOR_LLeg2 = 64
173
+ FREQ_SENSOR_RLeg = 64
174
+ FREQ_SENSOR_RLeg1 = 64
175
+ FREQ_SENSOR_RLeg2 = 64
176
+
177
+ FREQ_EMG_Chin = 64
178
+ FREQ_EMG_LChin = 64
179
+ FREQ_EMG_RChin = 64
180
+ FREQ_EMG_CChin = 64
181
+
182
+ FREQ_EEG_C3 = 64
183
+ FREQ_EEG_C4 = 64
184
+ FREQ_EEG_A1 = 64
185
+ FREQ_EEG_A2 = 64
186
+ FREQ_EEG_O1 = 64
187
+ FREQ_EEG_O2 = 64
188
+ FREQ_EEG_F3 = 64
189
+ FREQ_EEG_F4 = 64
190
+
191
+ FREQ_EEG_C3_A2 = 64
192
+ FREQ_EEG_C4_A1 = 64
193
+ FREQ_EEG_F3_A2 = 64
194
+ FREQ_EEG_F4_A1 = 64
195
+ FREQ_EEG_O1_A2 = 64
196
+ FREQ_EEG_O2_A1 = 64
197
+
198
+ FREQ_POS = 1
199
+
200
+ # =============================================================================
201
+ # Event annotation column names
202
+ # =============================================================================
203
+ EVENT_NAME_COLUMN = 'EVENT'
204
+ START_TIME_COLUMN = 'START_SEC'
205
+ END_TIME_COLUMN = 'END_SEC'
206
+
207
+ # =============================================================================
208
+ # Respiratory event names
209
+ # =============================================================================
210
+ RESPIRATORY_EVENT_CENTRAL_APNEA = 'Central Apnea'
211
+ RESPIRATORY_EVENT_OBSTRUCTIVE_APNEA = 'Obstructive Apnea'
212
+ RESPIRATORY_EVENT_MIXED_APNEA = 'Mixed Apnea'
213
+ RESPIRATORY_EVENT_HYPOPNEA = 'Hypopnea'
214
+ RESPIRATORY_EVENT_DESATURATION = 'Oxygen Desaturation'
215
+
216
+ # =============================================================================
217
+ # Limb movement event names
218
+ # =============================================================================
219
+ LIMB_MOVEMENT_ISOLATED = 'Limb Movement Isolated'
220
+ LIMB_MOVEMENT_PERIODIC = 'Limb Movement Periodic'
221
+ LIMB_MOVEMENT_ISOLATED_LEFT = 'Left Limb Movement Isolated'
222
+ LIMB_MOVEMENT_ISOLATED_RIGHT = 'Right Limb Movement Isolated'
223
+ LIMB_MOVEMENT_PERIODIC_LEFT = 'Left Limb Movement Periodic'
224
+ LIMB_MOVEMENT_PERIODIC_RIGHT = 'Right Limb Movement Periodic'
225
+
226
+ # =============================================================================
227
+ # Arousal event names
228
+ # =============================================================================
229
+ AROUSAL_EVENT_CLASSIC = 'Arousal'
230
+ AROUSAL_EVENT_RESPIRATORY = 'RERA'
231
+ AROUSAL_EVENT_EMG = 'EMG-Related Arousal'
demo.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
environment.yml ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: myenv
2
+ channels:
3
+ - conda-forge
4
+ dependencies:
5
+ - _openmp_mutex=4.5=2_gnu
6
+ - bzip2=1.0.8=h4777abc_8
7
+ - ca-certificates=2025.8.3=hbd8a1cb_0
8
+ - ld_impl_linux-aarch64=2.44=h5e2c951_1
9
+ - libexpat=2.7.1=hfae3067_0
10
+ - libffi=3.4.6=he21f813_1
11
+ - libgcc=15.1.0=he277a41_5
12
+ - libgcc-ng=15.1.0=he9431aa_5
13
+ - libgomp=15.1.0=he277a41_5
14
+ - liblzma=5.8.1=h86ecc28_2
15
+ - libnsl=2.0.1=h86ecc28_1
16
+ - libsqlite=3.50.4=h022381a_0
17
+ - libuuid=2.41.1=h3e4203c_0
18
+ - libxcrypt=4.4.36=h31becfc_1
19
+ - libzlib=1.3.1=h86ecc28_2
20
+ - ncurses=6.5=ha32ae93_3
21
+ - openssl=3.5.2=h8e36d6e_0
22
+ - pip=25.2=pyh8b19718_0
23
+ - python=3.10.18=h256493d_0_cpython
24
+ - readline=8.2=h8382b9d_2
25
+ - setuptools=80.9.0=pyhff2d567_0
26
+ - tk=8.6.13=noxft_h5688188_102
27
+ - wheel=0.45.1=pyhd8ed1ab_1
28
+ - pip:
29
+ - absl-py==2.3.1
30
+ - accelerate==1.2.1
31
+ - aiohappyeyeballs==2.6.1
32
+ - aiohttp==3.12.15
33
+ - aiosignal==1.4.0
34
+ - albucore==0.0.24
35
+ - albumentations==2.0.8
36
+ - annotated-types==0.7.0
37
+ - asttokens==3.0.0
38
+ - async-timeout==5.0.1
39
+ - attrs==25.3.0
40
+ - beartype==0.22.2
41
+ - braceexpand==0.1.7
42
+ - certifi==2025.10.5
43
+ - cffi==2.0.0
44
+ - charset-normalizer==3.4.4
45
+ - click==8.2.1
46
+ - coloredlogs==15.0.1
47
+ - comm==0.2.3
48
+ - contourpy==1.3.2
49
+ - cosine-annealing-warmup==2.0
50
+ - cycler==0.12.1
51
+ - debugpy==1.8.17
52
+ - decorator==5.2.1
53
+ - easydict==1.13
54
+ - einops==0.8.1
55
+ - ema-pytorch==0.7.7
56
+ - et-xmlfile==2.0.0
57
+ - exceptiongroup==1.3.0
58
+ - executing==2.2.1
59
+ - filelock==3.13.1
60
+ - flatbuffers==25.9.23
61
+ - fonttools==4.59.2
62
+ - frozenlist==1.7.0
63
+ - fsspec==2024.6.1
64
+ - gitdb==4.0.12
65
+ - gitpython==3.1.45
66
+ - grpcio==1.75.1
67
+ - h5py==3.14.0
68
+ - hf-xet==1.1.10
69
+ - huggingface-hub==0.35.3
70
+ - humanfriendly==10.0
71
+ - idna==3.11
72
+ - imageio==2.37.0
73
+ - importlib-metadata==8.7.0
74
+ - insightface==0.7.3
75
+ - ipdb==0.13.13
76
+ - ipykernel==6.30.1
77
+ - ipython==8.37.0
78
+ - jedi==0.19.2
79
+ - jinja2==3.1.6
80
+ - joblib==1.5.2
81
+ - jupyter-client==8.6.3
82
+ - jupyter-core==5.8.1
83
+ - kiwisolver==1.4.9
84
+ - kornia==0.8.1
85
+ - kornia-rs==0.1.9
86
+ - lazy-loader==0.4
87
+ - lightning-utilities==0.15.2
88
+ - llvmlite==0.46.0
89
+ - loguru==0.7.3
90
+ - markdown==3.9
91
+ - markupsafe==2.1.5
92
+ - matplotlib==3.10.6
93
+ - matplotlib-inline==0.1.7
94
+ - ml-dtypes==0.5.3
95
+ - mne==1.10.1
96
+ - mpmath==1.3.0
97
+ - multidict==6.6.4
98
+ - munch==4.0.0
99
+ - nest-asyncio==1.6.0
100
+ - networkx==3.4.2
101
+ - neurokit2==0.2.12
102
+ - ninja==1.13.0
103
+ - numba==0.63.1
104
+ - numpy==2.2.6
105
+ - onnx==1.19.1
106
+ - onnx2torch==1.5.15
107
+ - onnxruntime==1.23.1
108
+ - opencv-python==4.12.0.88
109
+ - opencv-python-headless==4.12.0.88
110
+ - openpyxl==3.1.5
111
+ - packaging==24.2
112
+ - pandas==2.3.2
113
+ - parso==0.8.5
114
+ - pexpect==4.9.0
115
+ - pillow==11.0.0
116
+ - platformdirs==4.5.0
117
+ - pooch==1.8.2
118
+ - prettytable==3.16.0
119
+ - prompt-toolkit==3.0.52
120
+ - propcache==0.3.2
121
+ - protobuf==6.32.1
122
+ - psutil==7.1.0
123
+ - ptyprocess==0.7.0
124
+ - pure-eval==0.2.3
125
+ - pyarrow==21.0.0
126
+ - pycparser==2.23
127
+ - pydantic==2.11.7
128
+ - pydantic-core==2.33.2
129
+ - pygments==2.19.2
130
+ - pynndescent==0.5.13
131
+ - pyparsing==3.2.3
132
+ - pysam==0.23.3
133
+ - python-dateutil==2.9.0.post0
134
+ - pytorch-lightning==2.5.5
135
+ - pytorch-warmup==0.2.0
136
+ - pytz==2025.2
137
+ - pyyaml==6.0.3
138
+ - pyzmq==27.1.0
139
+ - regex==2025.9.1
140
+ - requests==2.32.5
141
+ - safetensors==0.6.2
142
+ - scikit-image==0.25.2
143
+ - scikit-learn==1.7.2
144
+ - scipy==1.15.3
145
+ - seaborn==0.13.2
146
+ - sentencepiece==0.2.1
147
+ - sentry-sdk==2.37.1
148
+ - simsimd==6.5.3
149
+ - six==1.17.0
150
+ - smmap==5.0.2
151
+ - soundfile==0.13.1
152
+ - stack-data==0.6.3
153
+ - stringzilla==4.2.1
154
+ - sympy==1.13.1
155
+ - tabulate==0.9.0
156
+ - tensorboard==2.20.0
157
+ - tensorboard-data-server==0.7.2
158
+ - tensorboardx==2.6.4
159
+ - threadpoolctl==3.6.0
160
+ - tifffile==2025.5.10
161
+ - timm==1.0.19
162
+ - tokenizers==0.22.0
163
+ - tomli==2.2.1
164
+ - torch==2.5.1
165
+ - torchdiffeq==0.2.5
166
+ - torchmetrics==1.8.2
167
+ - torchtools==0.3.5
168
+ - torchvision==0.20.1
169
+ - tornado==6.5.2
170
+ - tqdm==4.67.1
171
+ - traitlets==5.14.3
172
+ - transformers==4.56.1
173
+ - typing-extensions==4.15.0
174
+ - typing-inspection==0.4.1
175
+ - tzdata==2025.2
176
+ - umap-learn==0.5.9.post2
177
+ - urllib3==2.5.0
178
+ - vitaldb==1.5.8
179
+ - wandb==0.22.1
180
+ - warmup-scheduler==0.3
181
+ - wcwidth==0.2.13
182
+ - webdataset==1.0.2
183
+ - werkzeug==3.1.3
184
+ - wfdb==4.3.0
185
+ - xxhash==3.5.0
186
+ - yarl==1.20.1
187
+ - zipp==3.23.0
finetune.bash ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ DATASETS=("shhs" "mros")
3
+ LABELS=("Stage" "Arousal" "Hypopnea" "Oxygen Desaturation")
4
+
5
+ TRAIN_PCTS=(1.0)
6
+
7
+ declare -A MODELS
8
+
9
+ MODELS["dino_ours"]="osf_vit_base.ckpt|all"
10
+
11
+ for model_name in "${!MODELS[@]}"; do
12
+
13
+ IFS='|' read -r ckpt_path use_backbone <<< "${MODELS[$model_name]}"
14
+
15
+ for dataset in "${DATASETS[@]}"; do
16
+ for label in "${LABELS[@]}"; do
17
+ for pct in "${TRAIN_PCTS[@]}"; do
18
+ echo "===== Model: ${model_name}, Dataset: ${dataset}, Label: ${label}, Pct: ${pct} ====="
19
+
20
+ CUDA_VISIBLE_DEVICES=0,1,2,3 python main_finetune.py \
21
+ --train_data_pct ${pct} \
22
+ --max_steps 500 \
23
+ --use_which_backbone "${use_backbone}" \
24
+ --model_name "${model_name}" \
25
+ --ckpt_path "${ckpt_path}" \
26
+ --lr 0.1 \
27
+ --eval_label "${label}" \
28
+ --num_devices 4 \
29
+ --data_source both \
30
+ --include_datasets "${dataset}" \
31
+ --downstream_dataset_name "${dataset}"
32
+ done
33
+ done
34
+ done
35
+ done
main_pipelines/main_finetune.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pprint import pprint
2
+ import os
3
+ from argparse import ArgumentParser, Namespace
4
+ import datetime
5
+ from dateutil import tz
6
+ import random
7
+ import numpy as np
8
+ import torch
9
+ import warnings
10
+ from pytorch_lightning import seed_everything, Trainer
11
+ from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, EarlyStopping
12
+ from pytorch_lightning.loggers import WandbLogger
13
+
14
+ from osf.datasets.pretrain_datamodule import SleepDataModule
15
+ from osf.models.dino_model_cls import DINOCLSModel
16
+ from config import *
17
+ from train_config import *
18
+ from osf.models.ssl_finetuner import SSLFineTuner, SSLVitalSignsRegressor
19
+ from osf.utils.results_utils import save_results_to_json
20
+
21
+ warnings.filterwarnings("ignore")
22
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
23
+ torch.backends.cudnn.deterministic = True
24
+ torch.backends.cudnn.benchmark = True
25
+ torch.set_float32_matmul_precision('high')
26
+
27
+
28
+ def main(hparams: Namespace):
29
+ now = datetime.datetime.now(tz.tzlocal())
30
+ timestamp = now.strftime("%Y_%m_%d_%H_%M_%S") + f"_{now.microsecond // 1000:03d}"
31
+
32
+ if hparams.monitor_type == "main":
33
+ exp_name = "finetune_12ch"
34
+ else:
35
+ exp_name = f"finetune_{hparams.monitor_type}"
36
+
37
+ if hparams.finetune_backbone:
38
+ exp_name = f"{exp_name}_full"
39
+
40
+ if hasattr(hparams, 'n_train_samples') and hparams.n_train_samples is not None and hparams.n_train_samples > 0:
41
+ pct_str = f"k{hparams.n_train_samples}"
42
+ elif hparams.train_data_pct < 1:
43
+ pct_str = f"{int(hparams.train_data_pct * 100)}pct"
44
+ else:
45
+ pct_str = "full"
46
+ if hparams.task_type == "classification":
47
+ task_label = hparams.eval_label
48
+ elif hparams.task_type == "regression":
49
+ task_label = "_".join(hparams.regression_targets)
50
+ else:
51
+ raise NotImplementedError(f"Unknown task_type: {hparams.task_type}")
52
+ run_name = f"{task_label}_{hparams.downstream_dataset_name}_{hparams.model_name}_{pct_str}_{timestamp}"
53
+
54
+ ckpt_dir = os.path.join(
55
+ CKPT_PATH, f"logs/{exp_name}/ckpts/{run_name}")
56
+ os.makedirs(ckpt_dir, exist_ok=True)
57
+
58
+ if hparams.task_type == "regression":
59
+ ckpt_monitor = "val_mae"
60
+ ckpt_mode = "min"
61
+ else:
62
+ ckpt_monitor = "val_auc"
63
+ ckpt_mode = "max"
64
+
65
+ callbacks = [
66
+ LearningRateMonitor(logging_interval="step"),
67
+ ModelCheckpoint(monitor=ckpt_monitor, dirpath=ckpt_dir,
68
+ save_last=False, mode=ckpt_mode, save_top_k=1,
69
+ auto_insert_metric_name=True),
70
+ ]
71
+ if getattr(hparams, 'early_stopping', False):
72
+ early_stop_callback = EarlyStopping(
73
+ monitor=ckpt_monitor,
74
+ patience=getattr(hparams, 'early_stopping_patience', 10),
75
+ mode=ckpt_mode,
76
+ verbose=True,
77
+ )
78
+ callbacks.append(early_stop_callback)
79
+ print(f"[INFO] Early stopping enabled: monitor={ckpt_monitor}, patience={hparams.early_stopping_patience}")
80
+ logger_dir = os.path.join(CKPT_PATH, f"logs/{exp_name}")
81
+ os.makedirs(logger_dir, exist_ok=True)
82
+ wandb_logger = WandbLogger(
83
+ project=f"{exp_name}_sleepuni", save_dir=logger_dir, name=run_name)
84
+ trainer = Trainer(
85
+ max_steps=hparams.max_steps,
86
+ accelerator="gpu",
87
+ accumulate_grad_batches=hparams.accumulate_grad_batches,
88
+ deterministic=True,
89
+ devices=hparams.num_devices,
90
+ strategy="ddp_find_unused_parameters_true",
91
+ precision=hparams.precision,
92
+ callbacks=callbacks,
93
+ logger=wandb_logger
94
+ )
95
+
96
+ hparams.exp_log_dir = os.path.join(
97
+ CKPT_PATH, f"data/{run_name}/exp_logs")
98
+ train_edf_cols = MONITOR_TYPE_MAP.get(hparams.monitor_type, TRAIN_EDF_COLS_UNI_ENC)
99
+
100
+ if hparams.task_type == "regression":
101
+ event_cols = None
102
+ regression_targets = hparams.regression_targets
103
+ print(f"[INFO] Regression task with targets: {regression_targets}")
104
+ else: # classification
105
+ event_cols = hparams.eval_label
106
+ regression_targets = None
107
+
108
+ regression_filter_config = None
109
+ if hparams.task_type == "regression" and "SPO2" in hparams.regression_targets:
110
+ if hparams.filter_spo2_min is not None or hparams.filter_spo2_max is not None:
111
+ spo2_filter = {}
112
+ if hparams.filter_spo2_min is not None:
113
+ spo2_filter["min"] = hparams.filter_spo2_min
114
+ if hparams.filter_spo2_max is not None:
115
+ spo2_filter["max"] = hparams.filter_spo2_max
116
+ regression_filter_config = {"SPO2_mean": spo2_filter}
117
+ print(f"[INFO] Will filter SPO2_mean with: {spo2_filter}")
118
+
119
+ datamodule = SleepDataModule(
120
+ is_pretrain = 0,
121
+ data_pct = hparams.train_data_pct,
122
+ downstream_dataset_name = hparams.downstream_dataset_name,
123
+ csv_dir = SPLIT_DATA_FOLDER,
124
+ train_edf_cols = train_edf_cols,
125
+ event_cols = event_cols,
126
+ batch_size = hparams.batch_size,
127
+ num_workers = hparams.num_workers,
128
+ sample_rate = hparams.sample_rate,
129
+ window_size = 30,
130
+ data_source = hparams.data_source,
131
+ include_datasets = hparams.include_datasets,
132
+ regression_targets = regression_targets,
133
+ regression_filter_config = regression_filter_config,
134
+ n_train_samples = getattr(hparams, 'n_train_samples', None),
135
+ val_batch_size = getattr(hparams, 'val_batch_size', None),
136
+ val_data_pct = getattr(hparams, 'val_data_pct', None),
137
+ random_seed = hparams.seed,
138
+ )
139
+ if hparams.task_type == "regression":
140
+ hparams.num_classes = len(hparams.regression_targets) # output dim
141
+ hparams.target_names = hparams.regression_targets
142
+ print(f"[INFO] Regression targets: {hparams.target_names}, num_classes={hparams.num_classes}")
143
+ else: # classification
144
+ train_dataset = datamodule.train_dataloader().dataset
145
+ if hasattr(train_dataset, 'dataset'): # It's a Subset
146
+ hparams.num_classes = train_dataset.dataset.num_classes
147
+ else:
148
+ hparams.num_classes = train_dataset.num_classes
149
+ print(f"[INFO] Classification num_classes: {hparams.num_classes}")
150
+ hparams.training_steps_per_epoch = len(datamodule.train_dataloader()) // hparams.accumulate_grad_batches // hparams.num_devices
151
+
152
+ if hparams.max_steps > 0:
153
+ hparams.total_training_steps = hparams.max_steps
154
+ else:
155
+ hparams.total_training_steps = hparams.training_steps_per_epoch * hparams.max_epochs
156
+
157
+ print(f"Total training steps: {hparams.total_training_steps}")
158
+ print(f"Steps per epoch: {hparams.training_steps_per_epoch}")
159
+
160
+ class_distribution = datamodule.get_class_distribution()
161
+ if class_distribution is not None:
162
+ print(f"Class distribution: {class_distribution}")
163
+ hparams.class_distribution = class_distribution
164
+
165
+ # Load pretrained DINO model
166
+ pretrain_model = DINOCLSModel.load_from_checkpoint(hparams.ckpt_path)
167
+ pprint(vars(hparams))
168
+
169
+ hparams.epochs = hparams.max_epochs
170
+
171
+ def create_finetuner(backbones, hparams, train_edf_cols=None):
172
+ exclude_keys = {'train_edf_cols', 'regression_targets'}
173
+ hparams_dict = {k: v for k, v in vars(hparams).items() if k not in exclude_keys}
174
+
175
+ if hparams.task_type == "regression":
176
+ return SSLVitalSignsRegressor(backbones=backbones, **hparams_dict)
177
+ else:
178
+ return SSLFineTuner(backbones=backbones, **hparams_dict)
179
+
180
+ # Extract ViT backbone from DINO model
181
+ vit = pretrain_model.encoders["all"].backbone
182
+ hparams.in_features = vit.width
183
+ print(f"[INFO] Extracted ViT backbone for dino_ours, in_features={hparams.in_features}")
184
+ model = create_finetuner(backbones={"all": vit}, hparams=hparams, train_edf_cols=train_edf_cols)
185
+
186
+ trainer.fit(model, datamodule=datamodule)
187
+ trainer.test(model, datamodule=datamodule, ckpt_path="last")
188
+
189
+
190
+ if __name__ == '__main__':
191
+ parser = ArgumentParser(description="Fine-tune pretrained model for downstream tasks.")
192
+ parser.add_argument("--model_name", type=str, default="dino_ours")
193
+ parser.add_argument("--eval_label", type=str, default="Stage",
194
+ )
195
+ parser.add_argument("--downstream_dataset_name", type=str, default="mros",
196
+ )
197
+ parser.add_argument("--use_which_backbone", type=str, default="all",
198
+ )
199
+ parser.add_argument("--monitor_type", type=str, default="main",
200
+ choices=["main", "type3", "type4"],
201
+ help="Channel configuration: main (12ch), type3 (5ch), type4 (3ch)")
202
+ parser.add_argument("--seed", type=int, default=42)
203
+ parser.add_argument("--train_data_pct", type=float, default=1.)
204
+ parser.add_argument("--n_train_samples", type=int, default=None,
205
+ help="If set, use exactly this many training samples (overrides train_data_pct for few-shot)")
206
+ parser.add_argument("--data_source", type=str, default="auto",
207
+ choices=["auto", "pretrain", "downstream", "both"],
208
+ help="Which CSV source to use: auto (default), pretrain, downstream, or both")
209
+ parser.add_argument("--include_datasets", type=str, nargs="*", default=None,
210
+ help="Filter by dataset names, e.g., --include_datasets shhs mros")
211
+ parser.add_argument("--batch_size", type=int, default=800)
212
+ parser.add_argument("--val_batch_size", type=int, default=None,
213
+ help="Batch size for val/test (defaults to batch_size if not set, useful for few-shot)")
214
+ parser.add_argument("--val_data_pct", type=float, default=None,
215
+ help="Percentage of val data to use (0-1, useful for few-shot to speed up validation)")
216
+ parser.add_argument("--patch_size_time", type=int, default=64)
217
+ parser.add_argument("--patch_size_ch", type=int, default=4,
218
+ help="Channel patch size for 2D patchify (default: 4)")
219
+ parser.add_argument("--num_workers", type=int, default=32)
220
+ parser.add_argument("--num_devices", type=int, default=1)
221
+ parser.add_argument("--max_epochs", type=int, default=10)
222
+ parser.add_argument("--max_steps", type=int, default=2500)
223
+ parser.add_argument("--early_stopping", action="store_true",
224
+ help="Enable early stopping based on val metric (useful for few-shot)")
225
+ parser.add_argument("--early_stopping_patience", type=int, default=10,
226
+ help="Patience for early stopping (number of val checks without improvement)")
227
+ parser.add_argument("--accumulate_grad_batches", type=int, default=1)
228
+ parser.add_argument("--ckpt_path", type=str, default="")
229
+ parser.add_argument("--lr", type=float, default=1e-2)
230
+ parser.add_argument("--num_classes", type=int, default=2)
231
+ parser.add_argument("--in_features", type=int, default=256)
232
+ parser.add_argument("--loss_type", type=str, default="ce", choices=["ce", "focal", "balanced_softmax"],
233
+ help="Loss type: 'ce' (cross-entropy), 'focal' (Focal Loss), or 'balanced_softmax' (Balanced Softmax)")
234
+ parser.add_argument("--focal_gamma", type=float, default=1.0,
235
+ help="Gamma parameter for Focal Loss (focusing parameter)")
236
+ parser.add_argument("--focal_alpha", type=float, default=None,
237
+ help="Alpha parameter for Focal Loss (class weighting). If None, computed from class distribution.")
238
+ parser.add_argument("--final_lr", type=float, default=0,
239
+ help="Final learning rate for cosine annealing scheduler")
240
+ parser.add_argument("--use_mean_pool", action="store_true",
241
+ help="Use mean pooling of all patches instead of CLS token for feature extraction")
242
+ parser.add_argument("--task_type", type=str, default="classification",
243
+ choices=["classification", "regression"],
244
+ help="Task type: classification or regression")
245
+ parser.add_argument("--regression_targets", type=str, nargs="*", default=["HR", "SPO2"],
246
+ help="Regression targets, e.g., --regression_targets HR SPO2")
247
+ parser.add_argument("--filter_spo2_min", type=float, default=None,
248
+ help="Filter out SPO2 values below this threshold (e.g., 70). Only applies when SPO2 is a regression target.")
249
+ parser.add_argument("--filter_spo2_max", type=float, default=None,
250
+ help="Filter out SPO2 values above this threshold (e.g., 100). Only applies when SPO2 is a regression target.")
251
+ parser.add_argument("--finetune_backbone", action="store_true",
252
+ help="If set, finetune the entire backbone (full finetuning); otherwise linear probing only")
253
+ parser.add_argument("--precision", type=str, default="32-true",
254
+ choices=["32-true", "16-mixed", "bf16-mixed"],
255
+ help="Training precision: 32-true (full), 16-mixed (FP16), bf16-mixed (BF16)")
256
+ parser.add_argument("--sample_rate", type=int, default=64,
257
+ help="Input sample rate in Hz (default: 64). Use 32 for half resolution.")
258
+ hparams = parser.parse_args()
259
+
260
+ seed_everything(hparams.seed)
261
+ main(hparams)
main_pipelines/main_pretrain.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pprint import pprint
2
+ import os
3
+ from argparse import ArgumentParser, Namespace
4
+ import datetime
5
+ from dateutil import tz
6
+ import random
7
+ import numpy as np
8
+ import torch
9
+ import warnings
10
+ from datetime import timedelta
11
+ from pytorch_lightning import seed_everything, Trainer
12
+ from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, Callback
13
+ from pytorch_lightning.loggers import WandbLogger
14
+ from pytorch_lightning.strategies import DDPStrategy
15
+
16
+
17
+ class DenseStepCheckpoint(Callback):
18
+ """Save checkpoints at specific training steps."""
19
+
20
+ def __init__(self, dirpath: str, save_steps: list = None):
21
+ super().__init__()
22
+ self.dirpath = dirpath
23
+ self.save_steps = set(save_steps) if save_steps else {1, 10, 100, 1000, 10000, 100000}
24
+ self.saved_steps = set()
25
+
26
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
27
+ global_step = trainer.global_step
28
+ if global_step in self.save_steps and global_step not in self.saved_steps:
29
+ ckpt_path = os.path.join(self.dirpath, f"step={global_step}.ckpt")
30
+ trainer.save_checkpoint(ckpt_path)
31
+ self.saved_steps.add(global_step)
32
+ if trainer.is_global_zero:
33
+ print(f"[DenseStepCheckpoint] Saved checkpoint at step {global_step}: {ckpt_path}")
34
+
35
+ from osf.datasets.pretrain_datamodule import SleepDataModule
36
+ from osf.models.dino_model_cls import DINOCLSModel
37
+ from config import *
38
+ from train_config import *
39
+
40
+ warnings.filterwarnings("ignore")
41
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
42
+ torch.backends.cudnn.deterministic = True
43
+ torch.backends.cudnn.benchmark = True
44
+ torch.set_float32_matmul_precision('high')
45
+
46
+
47
+ torch._dynamo.config.cache_size_limit = 128
48
+ torch._dynamo.config.optimize_ddp = False
49
+
50
+
51
+
52
+ def param_stats(model: torch.nn.Module, verbose: bool = False):
53
+ total = sum(p.numel() for p in model.parameters())
54
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
55
+ if verbose:
56
+ print(f"{'Name':40s} {'Shape':20s} {'#Params':>10s} {'Train?':>6s}")
57
+ print("-" * 80)
58
+ for name, p in model.named_parameters():
59
+ print(f"{name:40s} {str(list(p.shape)):20s} {p.numel():10d} {str(p.requires_grad):>6s}")
60
+ print("-" * 80)
61
+ print(f"Total parameters: {total / 1e6:.3f} M ({total})")
62
+ print(f" Trainable params: {trainable / 1e6:.3f} M ({trainable})")
63
+ print(f" Frozen params: {(total-trainable) / 1e6:.3f} M ({total-trainable})")
64
+ def main(hparams: Namespace):
65
+
66
+ now = datetime.datetime.now(tz.tzlocal())
67
+ extension = now.strftime("%Y_%m_%d_%H_%M_%S")
68
+ extension = f"final_sleep_unimodal_{hparams.model_name}_{hparams.psg_encoder_name}_bz{hparams.batch_size}_{extension}"
69
+ ckpt_dir = os.path.join(
70
+ CKPT_PATH, f"logs/sleepuni/ckpts/{extension}")
71
+ os.makedirs(ckpt_dir, exist_ok=True)
72
+ if hparams.model_name in MODEL_LIST:
73
+ callbacks = [
74
+ LearningRateMonitor(logging_interval="step"),
75
+ ModelCheckpoint(monitor="val/loss", dirpath=ckpt_dir,
76
+ save_last=True, every_n_epochs=2, mode="min", save_top_k=-1,
77
+ save_on_train_epoch_end=False, auto_insert_metric_name=True),
78
+ ]
79
+ if hparams.dense_ckpt:
80
+ dense_ckpt_dir = os.path.join(ckpt_dir, "dense_steps")
81
+ os.makedirs(dense_ckpt_dir, exist_ok=True)
82
+ callbacks.append(DenseStepCheckpoint(
83
+ dirpath=dense_ckpt_dir,
84
+ save_steps=hparams.dense_ckpt_steps
85
+ ))
86
+ else:
87
+ raise NotImplementedError
88
+ logger_dir = os.path.join(CKPT_PATH, "logs/sleepuni")
89
+ os.makedirs(logger_dir, exist_ok=True)
90
+ print("wandb logger dir: ", logger_dir)
91
+ wandb_logger = WandbLogger(
92
+ project=hparams.wandb_proj_name + f'final_{hparams.model_name}_{hparams.psg_encoder_name}_bz{hparams.batch_size}', save_dir=logger_dir, name=extension)
93
+
94
+ strategy = DDPStrategy(
95
+ find_unused_parameters=True,
96
+ static_graph=False,
97
+ timeout=timedelta(minutes=15),
98
+ )
99
+
100
+ trainer = Trainer(
101
+ max_epochs=hparams.max_epochs,
102
+ accelerator="gpu",
103
+ accumulate_grad_batches=hparams.accumulate_grad_batches,
104
+ devices=hparams.num_devices,
105
+ num_nodes=hparams.num_nodes,
106
+ precision=hparams.precision,
107
+ gradient_clip_val=3.0,
108
+ gradient_clip_algorithm="norm",
109
+ strategy=strategy,
110
+ callbacks=callbacks,
111
+ logger=wandb_logger,
112
+ log_every_n_steps=10,
113
+ )
114
+
115
+ hparams.exp_log_dir = os.path.join(
116
+ CKPT_PATH, f"data/{extension}/exp_logs")
117
+ train_edf_cols = MONITOR_TYPE_MAP.get(hparams.monitor_type, TRAIN_EDF_COLS_UNI_ENC)
118
+ hparams.num_leads = len(train_edf_cols)
119
+
120
+ dm = SleepDataModule(
121
+ is_pretrain = 1,
122
+ csv_dir = SPLIT_DATA_FOLDER,
123
+ train_edf_cols = train_edf_cols,
124
+ batch_size = hparams.batch_size,
125
+ num_workers = hparams.num_workers,
126
+ data_pct = hparams.train_data_pct,
127
+ window_size = 30,
128
+ sample_rate = 64,
129
+ val_dataset_list = hparams.val_dataset_list,
130
+ data_source = hparams.data_source,
131
+ include_datasets = hparams.include_datasets,
132
+ )
133
+
134
+ hparams.simclr_augmentation = AUGMENTATION_MAP.get(hparams.model_name, "none")
135
+
136
+ # Create DINO model
137
+ model = DINOCLSModel(**vars(hparams))
138
+ model.training_steps_per_epoch = len(dm.train_dataloader()) // hparams.accumulate_grad_batches // hparams.num_devices
139
+ model.teacher_temp_warmup_iters = model.training_steps_per_epoch * 0.1 * hparams.max_epochs
140
+ print(f"[INFO] DINO teacher warmup steps: {model.teacher_temp_warmup_iters}")
141
+ pprint(vars(hparams))
142
+
143
+ if hparams.ckpt_path:
144
+ trainer.fit(model, datamodule = dm, ckpt_path=hparams.ckpt_path)
145
+ else:
146
+ trainer.fit(model, datamodule = dm)
147
+
148
+
149
+ if __name__ == '__main__':
150
+ parser = ArgumentParser(description="Pretraining DINO model for sleep PSG data.")
151
+ parser.add_argument("--model_name", type=str, default="dino_ours",
152
+ choices=MODEL_LIST)
153
+
154
+ parser.add_argument("--psg_encoder_name", type=str, default="vit_base")
155
+ parser.add_argument("--val_dataset_list", default=PRETRAIN_VAL_DATASET_LIST)
156
+ parser.add_argument("--seed", type=int, default=42)
157
+ parser.add_argument("--train_data_pct", type=float, default=1.)
158
+ parser.add_argument("--data_source", type=str, default="auto",
159
+ choices=["auto", "pretrain", "downstream", "both"])
160
+ parser.add_argument("--include_datasets", type=str, nargs="*", default=None)
161
+ parser.add_argument("--monitor_type", type=str, default="main",
162
+ choices=["main", "type3", "type4"],
163
+ help="Channel configuration: main (12ch), type3 (5ch), type4 (3ch)")
164
+ parser.add_argument("--batch_size", type=int, default=32)
165
+ parser.add_argument("--patch_size_time", type=int, default=4)
166
+ parser.add_argument("--patch_size_ch", type=int, default=4)
167
+ parser.add_argument("--use_2d_pos_embed", type=bool, default=True)
168
+ parser.add_argument("--sample_rate", type=int, default=64)
169
+ parser.add_argument("--num_workers", type=int, default=64)
170
+ parser.add_argument("--num_devices", type=int, default=4)
171
+ parser.add_argument("--num_nodes", type=int, default=1)
172
+ parser.add_argument("--max_epochs", type=int, default=30)
173
+ parser.add_argument("--accumulate_grad_batches", type=int, default=1)
174
+ parser.add_argument("--precision", type=str, default="32-true")
175
+ parser.add_argument("--lr", type=float, default=1e-4)
176
+ parser.add_argument("--text_encoder_name", type=str, default="google/flan-t5-base")
177
+ parser.add_argument("--lead_wise", type=int, default=0)
178
+ parser.add_argument("--use_lead_embedding", type=int, default=1)
179
+ # DINO-specific args
180
+ parser.add_argument("--koleo_lambda", type=float, default=0.0)
181
+ parser.add_argument("--ibot_lambda", type=float, default=0.0)
182
+ parser.add_argument("--dino_out_dim", type=int, default=2048)
183
+ parser.add_argument("--dino_patch_out_dim", type=int, default=2048)
184
+ parser.add_argument("--dino_hidden_dim", type=int, default=2048)
185
+ parser.add_argument("--dino_bottleneck_dim", type=int, default=256)
186
+ parser.add_argument("--wandb_proj_name", type=str, default="sleepuni")
187
+ parser.add_argument("--ckpt_path", type=str, default=None)
188
+ parser.add_argument("--dense_ckpt", action="store_true")
189
+ parser.add_argument("--dense_ckpt_steps", type=int, nargs="+", default=[10, 100, 200, 400, 500, 800, 1000, 1600, 2500, 3200, 6400, 10000, 12500, 12800, 25600, 51200, 62500, 100000])
190
+
191
+
192
+
193
+ hparams = parser.parse_args()
194
+
195
+ seed_everything(hparams.seed)
196
+ main(hparams)
osf/__init__.py ADDED
File without changes
osf/backbone/__init__.py ADDED
File without changes
osf/backbone/pos_embed.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ import numpy as np
7
+
8
+ import torch
9
+
10
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
11
+ """
12
+ grid_size: int of the grid height and width
13
+ return:
14
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
15
+ """
16
+ grid_h = np.arange(grid_size, dtype=np.float32)
17
+ grid_w = np.arange(grid_size, dtype=np.float32)
18
+ grid = np.meshgrid(grid_w, grid_h)
19
+ grid = np.stack(grid, axis=0)
20
+
21
+ grid = grid.reshape([2, 1, grid_size, grid_size])
22
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
23
+ if cls_token:
24
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
25
+ return pos_embed
26
+
27
+
28
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
29
+ assert embed_dim % 2 == 0
30
+
31
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
32
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
33
+
34
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
35
+ return emb
36
+
37
+
38
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
39
+ """
40
+ embed_dim: output dimension for each position
41
+ pos: a list of positions to be encoded: size (M,)
42
+ out: (M, D)
43
+ """
44
+ assert embed_dim % 2 == 0
45
+ omega = np.arange(embed_dim // 2, dtype=float)
46
+ omega /= embed_dim / 2.
47
+ omega = 1. / 10000**omega
48
+
49
+ pos = pos.reshape(-1) # (M,)
50
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2)
51
+
52
+ emb_sin = np.sin(out) # (M, D/2)
53
+ emb_cos = np.cos(out) # (M, D/2)
54
+
55
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
56
+ return emb
57
+
58
+
59
+ def interpolate_pos_embed(model, checkpoint_model):
60
+ if 'pos_embed' in checkpoint_model:
61
+ pos_embed_checkpoint = checkpoint_model['pos_embed']
62
+ embedding_size = pos_embed_checkpoint.shape[-1]
63
+ num_patches = model.patch_embed.num_patches
64
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
65
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
66
+ new_size = int(num_patches ** 0.5)
67
+ if orig_size != new_size:
68
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
69
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
70
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
71
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
72
+ pos_tokens = torch.nn.functional.interpolate(
73
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
74
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
75
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
76
+ checkpoint_model['pos_embed'] = new_pos_embed
osf/backbone/vit1d.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 1D Vision Transformer for time-series signals.
3
+
4
+ Patchify modes:
5
+ - lead_wise=0: 1D patchify (all channels in one patch), no lead embedding
6
+ - lead_wise=1: 2D patchify (channel groups), with lead embedding by default
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from einops import rearrange
12
+
13
+
14
+ class DropPath(nn.Module):
15
+ def __init__(self, drop_prob: float, scale_by_keep: bool = True):
16
+ super().__init__()
17
+ self.drop_prob = drop_prob
18
+ self.scale_by_keep = scale_by_keep
19
+
20
+ def forward(self, x):
21
+ if self.drop_prob <= 0. or not self.training:
22
+ return x
23
+ keep_prob = 1 - self.drop_prob
24
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
25
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
26
+ if keep_prob > 0.0 and self.scale_by_keep:
27
+ random_tensor.div_(keep_prob)
28
+ return x * random_tensor
29
+
30
+
31
+ class PreNorm(nn.Module):
32
+ def __init__(self, dim: int, fn: nn.Module):
33
+ super().__init__()
34
+ self.norm = nn.LayerNorm(dim)
35
+ self.fn = fn
36
+
37
+ def forward(self, x, **kwargs):
38
+ return self.fn(self.norm(x), **kwargs)
39
+
40
+
41
+ class FeedForward(nn.Module):
42
+ def __init__(self, input_dim: int, output_dim: int, hidden_dim: int, drop_out_rate=0.):
43
+ super().__init__()
44
+ self.net = nn.Sequential(
45
+ nn.Linear(input_dim, hidden_dim),
46
+ nn.GELU(),
47
+ nn.Dropout(drop_out_rate),
48
+ nn.Linear(hidden_dim, output_dim),
49
+ nn.Dropout(drop_out_rate)
50
+ )
51
+
52
+ def forward(self, x):
53
+ return self.net(x)
54
+
55
+
56
+ class Attention(nn.Module):
57
+ def __init__(self, input_dim: int, output_dim: int, heads: int = 8, dim_head: int = 64,
58
+ qkv_bias: bool = True, drop_out_rate: float = 0., attn_drop_out_rate: float = 0.):
59
+ super().__init__()
60
+ inner_dim = dim_head * heads
61
+ project_out = not (heads == 1 and dim_head == input_dim)
62
+
63
+ self.heads = heads
64
+ self.scale = dim_head ** -0.5
65
+ self.attend = nn.Softmax(dim=-1)
66
+ self.dropout = nn.Dropout(attn_drop_out_rate)
67
+ self.to_qkv = nn.Linear(input_dim, inner_dim * 3, bias=qkv_bias)
68
+
69
+ if project_out:
70
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, output_dim), nn.Dropout(drop_out_rate))
71
+ else:
72
+ self.to_out = nn.Identity()
73
+
74
+ def forward(self, x):
75
+ qkv = self.to_qkv(x).chunk(3, dim=-1)
76
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv)
77
+ dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
78
+ attn = self.attend(dots)
79
+ attn = self.dropout(attn)
80
+ out = torch.matmul(attn, v)
81
+ out = rearrange(out, 'b h n d -> b n (h d)')
82
+ return self.to_out(out)
83
+
84
+
85
+ class TransformerBlock(nn.Module):
86
+ def __init__(self, input_dim: int, output_dim: int, hidden_dim: int, heads: int = 8,
87
+ dim_head: int = 32, qkv_bias: bool = True, drop_out_rate: float = 0.,
88
+ attn_drop_out_rate: float = 0., drop_path_rate: float = 0.):
89
+ super().__init__()
90
+ attn = Attention(input_dim, output_dim, heads, dim_head, qkv_bias, drop_out_rate, attn_drop_out_rate)
91
+ self.attn = PreNorm(input_dim, attn)
92
+ self.droppath1 = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
93
+
94
+ ff = FeedForward(output_dim, output_dim, hidden_dim, drop_out_rate)
95
+ self.ff = PreNorm(output_dim, ff)
96
+ self.droppath2 = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
97
+
98
+ def forward(self, x):
99
+ x = self.droppath1(self.attn(x)) + x
100
+ x = self.droppath2(self.ff(x)) + x
101
+ return x
102
+
103
+
104
+ class ViT(nn.Module):
105
+ def __init__(self,
106
+ num_leads: int,
107
+ seq_len: int,
108
+ patch_size: int,
109
+ lead_wise=0,
110
+ patch_size_ch=4,
111
+ use_lead_embedding: bool = True,
112
+ width: int = 768,
113
+ depth: int = 12,
114
+ mlp_dim: int = 3072,
115
+ heads: int = 12,
116
+ dim_head: int = 64,
117
+ qkv_bias: bool = True,
118
+ drop_out_rate: float = 0.,
119
+ attn_drop_out_rate: float = 0.,
120
+ drop_path_rate: float = 0.,
121
+ **kwargs):
122
+ super().__init__()
123
+ assert seq_len % patch_size == 0
124
+ num_patches = seq_len // patch_size
125
+ self.lead_wise = lead_wise
126
+ self.use_lead_embedding = use_lead_embedding
127
+
128
+ if lead_wise == 0:
129
+ self.to_patch_embedding = nn.Conv1d(num_leads, width, kernel_size=patch_size, stride=patch_size, bias=False)
130
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, width))
131
+ else:
132
+ self.to_patch_embedding = nn.Conv2d(1, width, kernel_size=(patch_size_ch, patch_size),
133
+ stride=(patch_size_ch, patch_size), bias=False)
134
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_patches * num_leads // patch_size_ch, width))
135
+ if use_lead_embedding:
136
+ self.lead_emb = nn.Embedding(num_leads // patch_size_ch, width)
137
+ else:
138
+ self.lead_emb = None
139
+
140
+ self.dropout = nn.Dropout(drop_out_rate)
141
+ self.depth = depth
142
+ self.width = width
143
+
144
+ drop_path_rate_list = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
145
+ for i in range(depth):
146
+ block = TransformerBlock(width, width, mlp_dim, heads, dim_head, qkv_bias,
147
+ drop_out_rate, attn_drop_out_rate, drop_path_rate_list[i])
148
+ self.add_module(f'block{i}', block)
149
+
150
+ self.norm = nn.LayerNorm(width)
151
+ self.head = nn.Identity()
152
+
153
+ def _patchify_and_embed(self, series: torch.Tensor) -> torch.Tensor:
154
+ """Patchify input and add positional/lead embeddings. [B,C,T] -> [B,N,D]"""
155
+ if self.lead_wise == 0:
156
+ x = self.to_patch_embedding(series) # [B, D, N]
157
+ x = rearrange(x, 'b c n -> b n c') # [B, N, D]
158
+ x = x + self.pos_embedding[:, :x.size(1), :].to(x.device)
159
+ else:
160
+ x = self.to_patch_embedding(series.unsqueeze(1)) # [B, D, Lr, Nt]
161
+ Lr, Nt = x.shape[-2], x.shape[-1]
162
+ x = rearrange(x, 'b c lr nt -> b (lr nt) c') # [B, N, D]
163
+ x = x + self.pos_embedding[:, :x.size(1), :].to(x.device)
164
+ if self.use_lead_embedding and self.lead_emb is not None:
165
+ row_ids = torch.arange(Lr, device=x.device).repeat_interleave(Nt)
166
+ x = x + self.lead_emb(row_ids)[None, :, :]
167
+ return x
168
+
169
+ def forward_encoding(self, series: torch.Tensor) -> torch.Tensor:
170
+ """Encode series. Returns [B,D] (mean pooled)."""
171
+ x = self._patchify_and_embed(series)
172
+ x = self.dropout(x)
173
+ for i in range(self.depth):
174
+ x = getattr(self, f'block{i}')(x)
175
+ x = x.mean(dim=1)
176
+ return self.norm(x)
177
+
178
+ def forward(self, series):
179
+ x = self.forward_encoding(series)
180
+ return self.head(x)
181
+
182
+ def reset_head(self, num_classes=1):
183
+ del self.head
184
+ self.head = nn.Linear(self.width, num_classes)
185
+
186
+
187
+ def vit_nano(num_leads, num_classes=1, seq_len=5000, patch_size=50, **kwargs):
188
+ return ViT(num_leads=num_leads, num_classes=num_classes, seq_len=seq_len, patch_size=patch_size,
189
+ width=128, depth=6, heads=4, mlp_dim=512, **kwargs)
190
+
191
+
192
+ def vit_tiny(num_leads, num_classes=1, seq_len=5000, patch_size=50, **kwargs):
193
+ return ViT(num_leads=num_leads, num_classes=num_classes, seq_len=seq_len, patch_size=patch_size,
194
+ width=192, depth=12, heads=3, mlp_dim=768, **kwargs)
195
+
196
+
197
+ def vit_small(num_leads, num_classes=1, seq_len=5000, patch_size=50, **kwargs):
198
+ return ViT(num_leads=num_leads, num_classes=num_classes, seq_len=seq_len, patch_size=patch_size,
199
+ width=384, depth=12, heads=6, mlp_dim=1536, **kwargs)
200
+
201
+
202
+ def vit_middle(num_leads, num_classes=1, seq_len=5000, patch_size=50, **kwargs):
203
+ return ViT(num_leads=num_leads, num_classes=num_classes, seq_len=seq_len, patch_size=patch_size,
204
+ width=512, depth=12, heads=8, mlp_dim=2048, **kwargs)
205
+
206
+
207
+ def vit_base(num_leads, num_classes=1, seq_len=5000, patch_size=50, **kwargs):
208
+ return ViT(num_leads=num_leads, num_classes=num_classes, seq_len=seq_len, patch_size=patch_size,
209
+ width=768, depth=12, heads=12, mlp_dim=3072, **kwargs)
osf/backbone/vit1d_cls.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 1D Vision Transformer with CLS token support.
3
+
4
+ Patchify modes:
5
+ - lead_wise=0: 1D patchify (all channels in one patch)
6
+ - lead_wise=1: 2D patchify (channel groups)
7
+
8
+ Note: lead_emb is DEPRECATED and not used in data flow. It is kept only for
9
+ checkpoint compatibility. Do NOT add lead_emb usage without careful consideration.
10
+ """
11
+ import torch
12
+ import torch.nn as nn
13
+ from einops import rearrange
14
+
15
+
16
+ class DropPath(nn.Module):
17
+ '''
18
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
19
+ '''
20
+ def __init__(self, drop_prob: float, scale_by_keep: bool = True):
21
+ super(DropPath, self).__init__()
22
+ self.drop_prob = drop_prob
23
+ self.scale_by_keep = scale_by_keep
24
+
25
+ def forward(self, x):
26
+ if self.drop_prob <= 0. or not self.training:
27
+ return x
28
+ keep_prob = 1 - self.drop_prob
29
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
30
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
31
+ if keep_prob > 0.0 and self.scale_by_keep:
32
+ random_tensor.div_(keep_prob)
33
+ return x * random_tensor
34
+
35
+
36
+ class PreNorm(nn.Module):
37
+ def __init__(self,
38
+ dim: int,
39
+ fn: nn.Module):
40
+ super().__init__()
41
+ self.norm = nn.LayerNorm(dim)
42
+ self.fn = fn
43
+
44
+ def forward(self, x, **kwargs):
45
+ return self.fn(self.norm(x), **kwargs)
46
+
47
+
48
+ class FeedForward(nn.Module):
49
+ """
50
+ MLP Module with GELU activation fn + dropout.
51
+ """
52
+ def __init__(self,
53
+ input_dim: int,
54
+ output_dim: int,
55
+ hidden_dim: int,
56
+ drop_out_rate=0.):
57
+ super().__init__()
58
+ self.net = nn.Sequential(nn.Linear(input_dim, hidden_dim),
59
+ nn.GELU(),
60
+ nn.Dropout(drop_out_rate),
61
+ nn.Linear(hidden_dim, output_dim),
62
+ nn.Dropout(drop_out_rate))
63
+
64
+ def forward(self, x):
65
+ return self.net(x)
66
+
67
+
68
+ class Attention(nn.Module):
69
+ def __init__(self,
70
+ input_dim: int,
71
+ output_dim: int,
72
+ heads: int = 8,
73
+ dim_head: int = 64,
74
+ qkv_bias: bool = True,
75
+ drop_out_rate: float = 0.,
76
+ attn_drop_out_rate: float = 0.):
77
+ super().__init__()
78
+ inner_dim = dim_head * heads
79
+ project_out = not (heads == 1 and dim_head == input_dim)
80
+
81
+ self.heads = heads
82
+ self.scale = dim_head ** -0.5
83
+
84
+ self.attend = nn.Softmax(dim=-1)
85
+ self.dropout = nn.Dropout(attn_drop_out_rate)
86
+ self.to_qkv = nn.Linear(input_dim, inner_dim * 3, bias=qkv_bias)
87
+
88
+ if project_out:
89
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, output_dim),
90
+ nn.Dropout(drop_out_rate))
91
+ else:
92
+ self.to_out = nn.Identity()
93
+
94
+ def forward(self, x):
95
+ qkv = self.to_qkv(x).chunk(3, dim=-1)
96
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv)
97
+
98
+ dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
99
+
100
+ attn = self.attend(dots)
101
+ attn = self.dropout(attn)
102
+ out = torch.matmul(attn, v)
103
+ out = rearrange(out, 'b h n d -> b n (h d)')
104
+ out = self.to_out(out)
105
+ return out
106
+
107
+
108
+ class TransformerBlock(nn.Module):
109
+ def __init__(self,
110
+ input_dim: int,
111
+ output_dim: int,
112
+ hidden_dim: int,
113
+ heads: int = 8,
114
+ dim_head: int = 32,
115
+ qkv_bias: bool = True,
116
+ drop_out_rate: float = 0.,
117
+ attn_drop_out_rate: float = 0.,
118
+ drop_path_rate: float = 0.):
119
+ super().__init__()
120
+ attn = Attention(input_dim=input_dim,
121
+ output_dim=output_dim,
122
+ heads=heads,
123
+ dim_head=dim_head,
124
+ qkv_bias=qkv_bias,
125
+ drop_out_rate=drop_out_rate,
126
+ attn_drop_out_rate=attn_drop_out_rate)
127
+ self.attn = PreNorm(dim=input_dim,
128
+ fn=attn)
129
+ self.droppath1 = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
130
+
131
+ ff = FeedForward(input_dim=output_dim,
132
+ output_dim=output_dim,
133
+ hidden_dim=hidden_dim,
134
+ drop_out_rate=drop_out_rate)
135
+ self.ff = PreNorm(dim=output_dim,
136
+ fn=ff)
137
+ self.droppath2 = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
138
+
139
+ def forward(self, x):
140
+ x = self.droppath1(self.attn(x)) + x
141
+ x = self.droppath2(self.ff(x)) + x
142
+ return x
143
+
144
+
145
+ class ViT(nn.Module):
146
+
147
+ def __init__(self,
148
+ num_leads: int,
149
+ seq_len: int,
150
+ patch_size: int,
151
+ lead_wise: int = 0,
152
+ patch_size_ch: int = 4,
153
+ width: int = 768,
154
+ depth: int = 12,
155
+ mlp_dim: int = 3072,
156
+ heads: int = 12,
157
+ dim_head: int = 64,
158
+ qkv_bias: bool = True,
159
+ drop_out_rate: float = 0.,
160
+ attn_drop_out_rate: float = 0.,
161
+ drop_path_rate: float = 0.,
162
+ **kwargs):
163
+ super().__init__()
164
+ assert seq_len % patch_size == 0
165
+ num_patches_time = seq_len // patch_size
166
+
167
+ self.lead_wise = lead_wise
168
+ self.width = width
169
+ self.depth = depth
170
+
171
+ if lead_wise == 0:
172
+ self.to_patch_embedding = nn.Conv1d(num_leads, width, kernel_size=patch_size,
173
+ stride=patch_size, bias=False)
174
+ N_max = num_patches_time
175
+ self.lead_emb = None
176
+ else:
177
+ self.to_patch_embedding = nn.Conv2d(1, width,
178
+ kernel_size=(patch_size_ch, patch_size),
179
+ stride=(patch_size_ch, patch_size),
180
+ bias=False)
181
+ Lr = num_leads // patch_size_ch
182
+ N_max = Lr * num_patches_time
183
+ self.lead_emb = nn.Embedding(Lr, width)
184
+
185
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, width))
186
+ nn.init.trunc_normal_(self.cls_token, std=0.02)
187
+ self.pos_embedding = nn.Parameter(torch.zeros(1, N_max + 1, width))
188
+ nn.init.trunc_normal_(self.pos_embedding, std=0.02)
189
+
190
+ self.dropout = nn.Dropout(drop_out_rate)
191
+
192
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
193
+ for i in range(depth):
194
+ block = TransformerBlock(input_dim=width, output_dim=width,
195
+ hidden_dim=mlp_dim, heads=heads, dim_head=dim_head,
196
+ qkv_bias=qkv_bias, drop_out_rate=drop_out_rate,
197
+ attn_drop_out_rate=attn_drop_out_rate,
198
+ drop_path_rate=dpr[i])
199
+ self.add_module(f'block{i}', block)
200
+
201
+ self.norm = nn.LayerNorm(width)
202
+ self.head = nn.Identity()
203
+
204
+
205
+ def to_tokens_2d(self, series: torch.Tensor,
206
+ patch_size_ch: int | None = None,
207
+ patch_size_time: int | None = None):
208
+ """Patchify only (no pos embedding). Returns (tokens, meta)."""
209
+ B, L, T = series.shape
210
+
211
+ if self.lead_wise == 0:
212
+ x = self.to_patch_embedding(series) # [B,C,Nt]
213
+ Nt = x.shape[-1]
214
+ x = rearrange(x, 'b c n -> b n c') # [B,Nt,C]
215
+ meta = dict(lead_wise=0, L=L, Nt=Nt, pz_ch=1)
216
+ return x, meta
217
+
218
+ # lead_wise == 1
219
+ if patch_size_ch is None or patch_size_time is None:
220
+ kch, kt = self.to_patch_embedding.kernel_size
221
+ patch_size_ch = patch_size_ch or kch
222
+ patch_size_time = patch_size_time or kt
223
+ assert L % patch_size_ch == 0 and T % patch_size_time == 0
224
+
225
+ x = series.unsqueeze(1) # [B,1,L,T]
226
+ x = self.to_patch_embedding(x) # [B,C,Lr,Nt]
227
+ Lr, Nt = x.shape[-2], x.shape[-1]
228
+ x = rearrange(x, 'b c lr nt -> b (lr nt) c') # [B,Lr*Nt,C]
229
+ meta = dict(lead_wise=1, L=L, Nt=Nt, pz_ch=patch_size_ch)
230
+ return x, meta
231
+
232
+ def forward_encoding(self, series: torch.Tensor,
233
+ return_sequence: bool = False):
234
+ """Encode with CLS token. Returns (cls, patches) or full sequence if return_sequence=True."""
235
+ tokens, meta = self.to_tokens_2d(series)
236
+ B = tokens.size(0)
237
+ cls_tok = self.cls_token.expand(B, -1, -1)
238
+ x = torch.cat([cls_tok, tokens], dim=1) # [B,N+1,C]
239
+
240
+ pe = self.pos_embedding[:, :x.size(1), :].to(x.device)
241
+
242
+ x = x + pe
243
+
244
+ x = self._run_blocks(x)
245
+ if return_sequence:
246
+ return x
247
+ cls, patches = x[:, 0], x[:, 1:]
248
+
249
+ return cls, patches
250
+
251
+
252
+ def _run_blocks(self, x: torch.Tensor):
253
+ x = self.dropout(x)
254
+ for i in range(self.depth):
255
+ x = getattr(self, f'block{i}')(x)
256
+ x = self.norm(x)
257
+ return self.head(x)
258
+
259
+ def forward(self, series: torch.Tensor):
260
+ cls, _ = self.forward_encoding(series, return_sequence=False)
261
+ return cls
262
+
263
+ def forward_avg_pool(self, series: torch.Tensor):
264
+ """Returns avg-pooled patch embeddings. series: [B,C,T] -> [B,D]"""
265
+ _, patches = self.forward_encoding(series, return_sequence=False) # [B,N,D]
266
+ return patches.mean(dim=1) # [B,D]
267
+
268
+ def reset_head(self, num_classes=1):
269
+ del self.head
270
+ self.head = nn.Linear(self.width, num_classes)
271
+
272
+
273
+
274
+ def vit_nano(num_leads, num_classes=1, seq_len=5000, patch_size=50, **kwargs):
275
+ model_args = dict(num_leads=num_leads,
276
+ num_classes=num_classes,
277
+ seq_len=seq_len,
278
+ patch_size=patch_size,
279
+ width=128,
280
+ depth=6,
281
+ heads=4,
282
+ mlp_dim=512,
283
+ **kwargs)
284
+ return ViT(**model_args)
285
+
286
+
287
+ def vit_tiny(num_leads, num_classes=1, seq_len=5000, patch_size=50, **kwargs):
288
+ model_args = dict(num_leads=num_leads,
289
+ num_classes=num_classes,
290
+ seq_len=seq_len,
291
+ patch_size=patch_size,
292
+ width=192,
293
+ depth=12,
294
+ heads=3,
295
+ mlp_dim=768,
296
+ **kwargs)
297
+ return ViT(**model_args)
298
+
299
+
300
+ def vit_small(num_leads, num_classes=1, seq_len=5000, patch_size=50, **kwargs):
301
+ model_args = dict(num_leads=num_leads,
302
+ num_classes=num_classes,
303
+ seq_len=seq_len,
304
+ patch_size=patch_size,
305
+ width=384,
306
+ depth=12,
307
+ heads=6,
308
+ mlp_dim=1536,
309
+ **kwargs)
310
+ return ViT(**model_args)
311
+
312
+
313
+ def vit_middle(num_leads, num_classes=1, seq_len=5000, patch_size=50, **kwargs):
314
+ model_args = dict(num_leads=num_leads,
315
+ num_classes=num_classes,
316
+ seq_len=seq_len,
317
+ patch_size=patch_size,
318
+ width=512,
319
+ depth=12,
320
+ heads=8,
321
+ mlp_dim=2048,
322
+ **kwargs)
323
+ return ViT(**model_args)
324
+
325
+
326
+ def vit_base(num_leads, num_classes=1, seq_len=5000, patch_size=50, **kwargs):
327
+ model_args = dict(num_leads=num_leads,
328
+ num_classes=num_classes,
329
+ seq_len=seq_len,
330
+ patch_size=patch_size,
331
+ width=768,
332
+ depth=12,
333
+ heads=12,
334
+ mlp_dim=3072,
335
+ **kwargs)
336
+ return ViT(**model_args)
337
+
338
+
339
+ def vit_large(num_leads, num_classes=1, seq_len=5000, patch_size=50, **kwargs):
340
+ return ViT(
341
+ num_leads=num_leads,
342
+ num_classes=num_classes,
343
+ seq_len=seq_len,
344
+ patch_size=patch_size,
345
+ width=1024,
346
+ depth=24,
347
+ heads=16,
348
+ mlp_dim=4096,
349
+ **kwargs
350
+ )
351
+
352
+ def vit_xl(num_leads, num_classes=1, seq_len=5000, patch_size=50, **kwargs):
353
+ return ViT(
354
+ num_leads=num_leads,
355
+ num_classes=num_classes,
356
+ seq_len=seq_len,
357
+ patch_size=patch_size,
358
+ width=1536,
359
+ depth=24,
360
+ heads=24,
361
+ mlp_dim=6144,
362
+ **kwargs
363
+ )
osf/datasets/__init__.py ADDED
File without changes
osf/datasets/augmentations.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data augmentations for SSL pretraining (SimCLR, DINO).
3
+ """
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from typing import Tuple
7
+
8
+
9
+ @torch.no_grad()
10
+ def random_time_crop(
11
+ x: torch.Tensor,
12
+ ratio: Tuple[float, float] | float = (0.6, 0.9),
13
+ *,
14
+ resize_back: bool = True,
15
+ align_to: int | None = 40
16
+ ) -> torch.Tensor:
17
+ """
18
+ Randomly crop a contiguous sub-sequence per sample, optionally resize back to original T.
19
+
20
+ Args:
21
+ x: (B, C, T)
22
+ ratio: crop length ratio in [low, high] or a float
23
+ resize_back: if True, linearly interpolate the cropped view back to length T
24
+ align_to: if not None, crop length is rounded to a multiple of align_to (>= align_to)
25
+ """
26
+ assert x.dim() == 3, f"expected (B,C,T), got {tuple(x.shape)}"
27
+ B, C, T = x.shape
28
+ dev = x.device
29
+
30
+ def _sample_L() -> int:
31
+ if isinstance(ratio, (tuple, list)):
32
+ a, b = float(ratio[0]), float(ratio[1])
33
+ r = torch.empty((), device=dev).uniform_(a, b).item()
34
+ else:
35
+ r = float(ratio)
36
+ L = max(2, int(round(T * r)))
37
+ if align_to and align_to > 1:
38
+ L = max(align_to, int(round(L / align_to)) * align_to)
39
+ return min(L, T)
40
+
41
+ Ls = [_sample_L() for _ in range(B)]
42
+ outs = []
43
+ for b in range(B):
44
+ L = Ls[b]
45
+ max_start = max(0, T - L)
46
+ s = int(torch.randint(0, max_start + 1, (1,), device=dev).item())
47
+ v = x[b, :, s:s+L] # (C, L)
48
+ if resize_back and v.shape[-1] != T:
49
+ v = F.interpolate(v[None], size=T, mode="linear", align_corners=False)[0]
50
+ outs.append(v)
51
+ return torch.stack(outs, dim=0)
52
+
53
+
54
+ @torch.no_grad()
55
+ def channel_dropout(
56
+ x: torch.Tensor,
57
+ drop_prob: float = 0.2,
58
+ min_keep: int = 1
59
+ ) -> torch.Tensor:
60
+ """
61
+ Drop entire channels to zero with probability drop_prob (per sample, per channel).
62
+ Ensures at least `min_keep` channels remain active in each sample.
63
+
64
+ Args:
65
+ x: (B, C, T)
66
+ drop_prob: probability to drop each channel
67
+ min_keep: minimum number of channels to keep per sample
68
+ """
69
+ assert x.dim() == 3
70
+ B, C, T = x.shape
71
+ mask = (torch.rand(B, C, 1, device=x.device, dtype=x.dtype) > drop_prob).to(x.dtype)
72
+
73
+ # Ensure at least min_keep channels kept
74
+ keep = mask.sum(dim=1, keepdim=True) # (B, 1, 1)
75
+ need = (keep < min_keep).squeeze(-1).squeeze(-1) # (B,)
76
+ if need.any():
77
+ for b in torch.where(need)[0]:
78
+ idx = torch.randperm(C, device=x.device)[:min_keep]
79
+ mask[b, idx, 0] = 1.0
80
+
81
+ return x * mask
osf/datasets/pretrain_datamodule.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Sequence, Optional, Dict, Union
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import torch
8
+ import pytorch_lightning as pl
9
+ from pytorch_lightning import LightningDataModule
10
+ from pytorch_lightning.utilities.types import EVAL_DATALOADERS
11
+ from torch.utils.data import DataLoader, Subset
12
+
13
+ from osf.datasets.pretrain_dataset import SleepEpochDataset
14
+
15
+
16
+
17
+ class SleepDataModule(LightningDataModule):
18
+
19
+ def __init__(
20
+ self,
21
+ csv_dir: str | Path,
22
+ *,
23
+ is_pretrain,
24
+ data_pct = 1,
25
+ val_dataset_list: Optional[List[str]] = None,
26
+ downstream_dataset_name = None,
27
+ batch_size: int = 128,
28
+ num_workers: int = 4,
29
+ patient_cols: Optional[Union[str, Sequence[str]]] = None,
30
+ event_cols: Optional[Union[str, Sequence[str]]] = None,
31
+ train_edf_cols: Sequence[str] | None,
32
+ transforms=None,
33
+ n_views: int = 1,
34
+ cache_size: int = 8,
35
+ sample_rate: int = 128,
36
+ window_size: int = 30,
37
+ pin_memory: bool = False,
38
+ persistent_workers: bool = False,
39
+ data_source: str = "auto",
40
+ include_datasets: Optional[List[str]] = None,
41
+ regression_targets: Optional[List[str]] = None,
42
+ regression_filter_config: Optional[Dict] = None,
43
+ n_train_samples: Optional[int] = None,
44
+ val_batch_size: Optional[int] = None,
45
+ val_data_pct: Optional[float] = None,
46
+ return_all_event_cols: bool = False,
47
+ return_nsrrid: bool = False,
48
+ random_seed: int = 42,
49
+ ):
50
+ super().__init__()
51
+ self.save_hyperparameters(ignore=["transforms"])
52
+ self.downstream_dataset_name = downstream_dataset_name
53
+ self.csv_dir = csv_dir
54
+ self.transforms = transforms
55
+ self.n_views = n_views
56
+ self.pin_memory = pin_memory
57
+ self.persistent_workers = persistent_workers
58
+ self.is_pretrain = is_pretrain
59
+ self.patient_cols = patient_cols
60
+ self.event_cols = event_cols
61
+ self.data_pct = data_pct
62
+ self.data_source = data_source
63
+ self.include_datasets = include_datasets
64
+ self.regression_targets = regression_targets
65
+ self.regression_filter_config = regression_filter_config
66
+ self.n_train_samples = n_train_samples
67
+ self.val_batch_size = val_batch_size
68
+ self.val_data_pct = val_data_pct
69
+ self.return_all_event_cols = return_all_event_cols
70
+ self.return_nsrrid = return_nsrrid
71
+ self.random_seed = random_seed
72
+
73
+ def train_dataloader(self):
74
+ if self.is_pretrain == 1:
75
+ train_set = SleepEpochDataset(
76
+ csv_dir = self.csv_dir,
77
+ split = "pretrain",
78
+ data_pct = self.data_pct,
79
+ train_edf_cols= self.hparams.train_edf_cols,
80
+ transform = self.transforms,
81
+ sample_rate = self.hparams.sample_rate,
82
+ window_size = self.hparams.window_size,
83
+ cache_size = self.hparams.cache_size,
84
+ data_source = self.data_source,
85
+ include_datasets = self.include_datasets,
86
+ )
87
+ persistent_workers = self.persistent_workers
88
+ else:
89
+ train_set = SleepEpochDataset(
90
+ csv_dir = self.csv_dir,
91
+ split = "train",
92
+ data_pct = self.data_pct,
93
+ patient_cols = self.patient_cols,
94
+ event_cols = self.event_cols,
95
+ train_edf_cols= self.hparams.train_edf_cols,
96
+ transform = self.transforms,
97
+ sample_rate = self.hparams.sample_rate,
98
+ window_size = self.hparams.window_size,
99
+ cache_size = self.hparams.cache_size,
100
+ downstream_dataset_name = self.downstream_dataset_name,
101
+ data_source = self.data_source,
102
+ include_datasets = self.include_datasets,
103
+ regression_targets = self.regression_targets,
104
+ regression_filter_config = self.regression_filter_config,
105
+ return_all_event_cols = self.return_all_event_cols,
106
+ return_nsrrid = self.return_nsrrid,
107
+ )
108
+ self._train_dataset = train_set
109
+ persistent_workers = True
110
+
111
+ if self.n_train_samples is not None and self.n_train_samples > 0:
112
+ n_total = len(train_set)
113
+ rng = np.random.default_rng(seed=self.random_seed)
114
+
115
+ if hasattr(train_set, 'event_cols') and train_set.event_cols and hasattr(train_set, 'all_epoch_df'):
116
+ label_col = train_set.event_cols[0]
117
+ if label_col in train_set.all_epoch_df.columns:
118
+ labels = train_set.all_epoch_df[label_col].values
119
+ num_classes = getattr(train_set, 'num_classes', None)
120
+
121
+ if num_classes is not None:
122
+ all_indices = []
123
+ for c in range(num_classes):
124
+ class_indices = np.where(labels == c)[0]
125
+ n_per_class = min(self.n_train_samples, len(class_indices))
126
+ if n_per_class > 0:
127
+ sampled = rng.choice(class_indices, size=n_per_class, replace=False)
128
+ all_indices.extend(sampled.tolist())
129
+ print(f"[Few-shot] Class {c}: sampled {n_per_class}/{len(class_indices)} samples")
130
+
131
+ indices = all_indices
132
+ train_set = Subset(train_set, indices)
133
+ print(f"[Few-shot] Total: {len(indices)}/{n_total} samples ({self.n_train_samples}-shot per class)")
134
+ else:
135
+ n_keep = min(self.n_train_samples, n_total)
136
+ indices = rng.choice(n_total, size=n_keep, replace=False).tolist()
137
+ train_set = Subset(train_set, indices)
138
+ print(f"[Few-shot] Using {n_keep}/{n_total} training samples (random, n_train_samples={self.n_train_samples})")
139
+ else:
140
+ n_keep = min(self.n_train_samples, n_total)
141
+ indices = rng.choice(n_total, size=n_keep, replace=False).tolist()
142
+ train_set = Subset(train_set, indices)
143
+ print(f"[Few-shot] Using {n_keep}/{n_total} training samples (random, n_train_samples={self.n_train_samples})")
144
+ else:
145
+ n_keep = min(self.n_train_samples, n_total)
146
+ indices = rng.choice(n_total, size=n_keep, replace=False).tolist()
147
+ train_set = Subset(train_set, indices)
148
+ print(f"[Few-shot] Using {n_keep}/{n_total} training samples (random, n_train_samples={self.n_train_samples})")
149
+
150
+ return DataLoader(
151
+ train_set,
152
+ batch_size = self.hparams.batch_size,
153
+ shuffle = True,
154
+ num_workers = self.hparams.num_workers,
155
+ pin_memory = self.pin_memory,
156
+ persistent_workers = persistent_workers,
157
+ drop_last = True,
158
+ )
159
+
160
+ def get_class_distribution(self) -> Optional[torch.Tensor]:
161
+ """
162
+ Get class distribution from training dataset.
163
+ Returns [num_classes] tensor of class counts, or None if not available.
164
+ """
165
+ if hasattr(self, '_train_dataset'):
166
+ counts = self._train_dataset.get_class_counts()
167
+ if counts is not None:
168
+ return torch.from_numpy(counts).float()
169
+ return None
170
+
171
+ def val_dataloader(self):
172
+ if self.hparams.val_dataset_list:
173
+ if self.is_pretrain == 1:
174
+ val_sets = [
175
+ SleepEpochDataset(
176
+ csv_dir = self.csv_dir,
177
+ split = "pretrain-val",
178
+ data_pct = self.data_pct,
179
+ patient_cols = self.patient_cols,
180
+ event_cols = self.event_cols,
181
+ train_edf_cols= self.hparams.train_edf_cols,
182
+ transform = None,
183
+ sample_rate = self.hparams.sample_rate,
184
+ window_size = self.hparams.window_size,
185
+ cache_size = self.hparams.cache_size,
186
+ downstream_dataset_name = ds_name,
187
+ data_source = self.data_source,
188
+ include_datasets = self.include_datasets,
189
+ )
190
+ for ds_name in self.hparams.val_dataset_list
191
+ ]
192
+ persistent_workers = self.persistent_workers
193
+ else:
194
+ if self.is_pretrain == 1:
195
+ val_sets = [
196
+ SleepEpochDataset(
197
+ csv_dir = self.csv_dir,
198
+ split = "pretrain-val",
199
+ data_pct = self.data_pct,
200
+ patient_cols = self.patient_cols,
201
+ event_cols = self.event_cols,
202
+ train_edf_cols= self.hparams.train_edf_cols,
203
+ transform = None,
204
+ sample_rate = self.hparams.sample_rate,
205
+ window_size = self.hparams.window_size,
206
+ cache_size = self.hparams.cache_size,
207
+ data_source = self.data_source,
208
+ include_datasets = self.include_datasets,
209
+ )
210
+ ]
211
+ persistent_workers = self.persistent_workers
212
+ else:
213
+ val_sets = [
214
+ SleepEpochDataset(
215
+ csv_dir = self.csv_dir,
216
+ split = "val",
217
+ data_pct = self.data_pct,
218
+ patient_cols = self.patient_cols,
219
+ event_cols = self.event_cols,
220
+ train_edf_cols= self.hparams.train_edf_cols,
221
+ transform = None,
222
+ sample_rate = self.hparams.sample_rate,
223
+ window_size = self.hparams.window_size,
224
+ cache_size = self.hparams.cache_size,
225
+ downstream_dataset_name = self.downstream_dataset_name,
226
+ data_source = self.data_source,
227
+ include_datasets = self.include_datasets,
228
+ regression_targets = self.regression_targets,
229
+ regression_filter_config = self.regression_filter_config,
230
+ )
231
+ ]
232
+ persistent_workers = True
233
+
234
+ if self.val_data_pct is not None and 0 < self.val_data_pct < 1.0:
235
+ subsampled_val_sets = []
236
+ for ds in val_sets:
237
+ n_total = len(ds)
238
+ n_keep = max(1, int(n_total * self.val_data_pct))
239
+ rng = np.random.default_rng(seed=self.random_seed)
240
+ indices = rng.choice(n_total, size=n_keep, replace=False).tolist()
241
+ subsampled_val_sets.append(Subset(ds, indices))
242
+ print(f"[Val subsample] Using {n_keep}/{n_total} val samples ({self.val_data_pct*100:.1f}%)")
243
+ val_sets = subsampled_val_sets
244
+
245
+ val_bs = self.val_batch_size if self.val_batch_size is not None else self.hparams.batch_size
246
+ return [
247
+ DataLoader(
248
+ ds,
249
+ batch_size = val_bs,
250
+ shuffle = False,
251
+ num_workers = self.hparams.num_workers,
252
+ pin_memory = self.pin_memory,
253
+ persistent_workers = persistent_workers,
254
+ drop_last = True,
255
+ )
256
+ for ds in val_sets
257
+ ]
258
+
259
+ def test_dataloader(self):
260
+ if self.is_pretrain == 1:
261
+ test_set = SleepEpochDataset(
262
+ csv_dir = self.csv_dir,
263
+ split = "pretrain-test",
264
+ patient_cols = self.patient_cols,
265
+ event_cols = self.event_cols,
266
+ train_edf_cols= self.hparams.train_edf_cols,
267
+ transform = None,
268
+ sample_rate = self.hparams.sample_rate,
269
+ window_size = self.hparams.window_size,
270
+ cache_size = self.hparams.cache_size,
271
+ data_source = self.data_source,
272
+ include_datasets = self.include_datasets,
273
+ )
274
+ persistent_workers = self.persistent_workers
275
+ else:
276
+ test_set = SleepEpochDataset(
277
+ csv_dir = self.csv_dir,
278
+ split = "test",
279
+ patient_cols = self.patient_cols,
280
+ event_cols = self.event_cols,
281
+ train_edf_cols= self.hparams.train_edf_cols,
282
+ transform = None,
283
+ sample_rate = self.hparams.sample_rate,
284
+ window_size = self.hparams.window_size,
285
+ cache_size = self.hparams.cache_size,
286
+ downstream_dataset_name = self.downstream_dataset_name,
287
+ data_source = self.data_source,
288
+ include_datasets = self.include_datasets,
289
+ regression_targets = self.regression_targets,
290
+ regression_filter_config = self.regression_filter_config,
291
+ )
292
+ persistent_workers = True
293
+ test_bs = self.val_batch_size if self.val_batch_size is not None else self.hparams.batch_size
294
+ return DataLoader(
295
+ test_set,
296
+ batch_size = test_bs,
297
+ shuffle = False,
298
+ num_workers = self.hparams.num_workers,
299
+ pin_memory = self.pin_memory,
300
+ drop_last = True,
301
+ persistent_workers = persistent_workers,
302
+ )
303
+
osf/datasets/pretrain_dataset.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Sleep Epoch Dataset for pretraining and downstream tasks
2
+
3
+ import os
4
+ import numpy as np
5
+ import pandas as pd
6
+ import torch
7
+ from pathlib import Path
8
+ from contextlib import suppress
9
+ from typing import Sequence, Optional, Dict, Union, List
10
+ from torch.utils.data import Dataset
11
+ from train_config import NEED_NORM_COL
12
+
13
+
14
+ def to_pm1(s: pd.Series) -> pd.Series:
15
+ s = pd.to_numeric(s, errors="coerce")
16
+ vmin, vmax = s.min(skipna=True), s.max(skipna=True)
17
+ if pd.isna(vmin) or pd.isna(vmax) or vmax <= vmin:
18
+ return pd.Series(0.0, index=s.index)
19
+ return (2 * (s - vmin) / (vmax - vmin) - 1).fillna(0.0)
20
+
21
+
22
+ class SleepEpochDataset(Dataset):
23
+ def __init__(
24
+ self,
25
+ csv_dir='/path/to/your/postprocessed/data',
26
+ split: str = "train",
27
+ *,
28
+ data_pct=1,
29
+ patient_cols: Optional[Union[str, Sequence[str]]] = None,
30
+ event_cols: Optional[Union[str, Sequence[str]]] = None,
31
+ train_edf_cols=None,
32
+ test_size: float = 0.15,
33
+ val_size: float = 0.15,
34
+ random_state: int = 1337,
35
+ sample_rate: int = 128,
36
+ window_size: int = 300,
37
+ epoch_length: int = 30,
38
+ cache_size: int = 8,
39
+ transform=None,
40
+ downstream_dataset_name=None,
41
+ data_source: str = "auto",
42
+ include_datasets: Optional[List[str]] = None,
43
+ regression_targets: Optional[List[str]] = None,
44
+ regression_filter_config: Optional[Dict] = None,
45
+ return_all_event_cols: bool = False,
46
+ return_nsrrid: bool = False,
47
+ ):
48
+ assert split in {"pretrain", "pretrain-val", "pretrain-test", "train", "val", "test"}
49
+ assert data_source in {"auto", "pretrain", "downstream", "both"}
50
+
51
+ self.transform = transform
52
+ self.sample_rate = sample_rate
53
+ self.window_size = window_size
54
+ self.epoch_length = epoch_length
55
+ self.patient_cols = [patient_cols] if isinstance(patient_cols, str) else patient_cols
56
+ self.event_cols = [event_cols] if isinstance(event_cols, str) else event_cols
57
+ self.train_edf_cols = train_edf_cols
58
+ self.split = split
59
+ self.data_pct = float(data_pct)
60
+ self.data_source = data_source
61
+ self.regression_targets = regression_targets
62
+ self.regression_filter_config = regression_filter_config
63
+ self.return_all_event_cols = return_all_event_cols
64
+ self.return_nsrrid = return_nsrrid
65
+
66
+ patient_df, epoch_df = self._load_csvs(
67
+ csv_dir, split, data_source, include_datasets, self.event_cols,
68
+ regression_targets=self.regression_targets,
69
+ regression_filter_config=self.regression_filter_config,
70
+ return_all_event_cols=self.return_all_event_cols,
71
+ )
72
+
73
+ if downstream_dataset_name and include_datasets is None:
74
+ if downstream_dataset_name != "all":
75
+ mask = epoch_df['dataset_name'].astype(str).str.lower().str.startswith(downstream_dataset_name)
76
+ epoch_df = epoch_df.loc[mask].copy()
77
+ ids = epoch_df["nsrrid"].astype(str).unique()
78
+ patient_df = patient_df[patient_df["nsrrid"].astype(str).isin(ids)].copy()
79
+
80
+ # Determine num_classes
81
+ if self.event_cols:
82
+ if self.event_cols[0] in ['Hypopnea', 'Arousal', 'Oxygen Desaturation']:
83
+ self.num_classes = 2
84
+ elif self.event_cols[0] == 'Stage':
85
+ self.num_classes = 4
86
+ mapping = {0: 0, 1: 1, 2: 1, 3: 2, 4: 3}
87
+ epoch_df['Stage'] = epoch_df['Stage'].replace(mapping)
88
+ else:
89
+ self.num_classes = 2
90
+ else:
91
+ self.num_classes = 2
92
+
93
+ # Drop Stage == -1
94
+ if self.event_cols and ('Stage' in self.event_cols) and ('Stage' in epoch_df.columns):
95
+ epoch_df = epoch_df.loc[epoch_df['Stage'] != -1].copy()
96
+
97
+ # Build tables
98
+ if split in ("pretrain", "pretrain-val"):
99
+ sort_cols = [c for c in ['nsrrid', 'seg_id', 'epoch_id'] if c in epoch_df.columns]
100
+ self.all_epoch_df = epoch_df.sort_values(sort_cols).reset_index(drop=True)
101
+
102
+ idx_keep_cols = [c for c in ['nsrrid', 'seg_id', 'path_head'] if c in self.all_epoch_df.columns]
103
+ if self.regression_targets:
104
+ for t in self.regression_targets:
105
+ col = f"{t}_mean"
106
+ if col in self.all_epoch_df.columns:
107
+ idx_keep_cols.append(col)
108
+ self.epoch_df = (
109
+ self.all_epoch_df[idx_keep_cols]
110
+ .drop_duplicates(['nsrrid', 'seg_id'], keep='first')
111
+ .reset_index(drop=True)
112
+ )
113
+ else:
114
+ expected_len = self.window_size // self.epoch_length
115
+ grp = epoch_df.groupby(['nsrrid', 'seg_id']).size().rename('n').reset_index()
116
+ valid_keys = grp.loc[grp['n'] == expected_len, ['nsrrid', 'seg_id']]
117
+ epoch_df_valid = epoch_df.merge(valid_keys, on=['nsrrid', 'seg_id'], how='inner')
118
+
119
+ sort_cols = [c for c in ['nsrrid', 'seg_id', 'epoch_id'] if c in epoch_df_valid.columns]
120
+ self.all_epoch_df = epoch_df_valid.sort_values(sort_cols).reset_index(drop=True)
121
+
122
+ idx_keep_cols = [c for c in ['nsrrid', 'seg_id', 'path_head'] if c in self.all_epoch_df.columns]
123
+ if self.regression_targets:
124
+ for t in self.regression_targets:
125
+ col = f"{t}_mean"
126
+ if col in self.all_epoch_df.columns:
127
+ idx_keep_cols.append(col)
128
+ self.epoch_df = (
129
+ self.all_epoch_df[idx_keep_cols]
130
+ .drop_duplicates(['nsrrid', 'seg_id'], keep='first')
131
+ .reset_index(drop=True)
132
+ )
133
+
134
+ # Patient-level sampling
135
+ if not (0 < self.data_pct <= 1.0):
136
+ raise ValueError(f"data_pct must be in (0,1], got {self.data_pct}")
137
+
138
+ if self.data_pct < 1.0:
139
+ eligible_patients = pd.Index(self.epoch_df['nsrrid'].unique())
140
+ n_keep = max(1, int(len(eligible_patients) * self.data_pct))
141
+ sampled_nsrrids = pd.Series(eligible_patients).sample(n=n_keep, random_state=random_state).to_list()
142
+ self.epoch_df = self.epoch_df.loc[self.epoch_df['nsrrid'].isin(sampled_nsrrids)].reset_index(drop=True)
143
+ self.all_epoch_df = self.all_epoch_df.loc[self.all_epoch_df['nsrrid'].isin(sampled_nsrrids)].reset_index(drop=True)
144
+ patient_df = patient_df.loc[patient_df['nsrrid'].isin(sampled_nsrrids)].copy()
145
+
146
+ self.patient_df = patient_df.set_index("nsrrid")
147
+
148
+ # Build segment indices
149
+ self._seg_indices = None
150
+ if hasattr(self, "all_epoch_df") and {'nsrrid', 'seg_id'}.issubset(self.all_epoch_df.columns):
151
+ grp_indices = self.all_epoch_df.groupby(['nsrrid', 'seg_id'], sort=False).indices
152
+ self._seg_indices = {}
153
+ has_epoch_id = 'epoch_id' in self.all_epoch_df.columns
154
+ epoch_id_values = self.all_epoch_df['epoch_id'].to_numpy() if has_epoch_id else None
155
+ for key, idx_list in grp_indices.items():
156
+ idx_arr = np.fromiter(idx_list, dtype=np.int64)
157
+ if has_epoch_id:
158
+ order = np.argsort(epoch_id_values[idx_arr])
159
+ idx_arr = idx_arr[order]
160
+ self._seg_indices[key] = idx_arr
161
+
162
+ # Compute class distribution
163
+ self._class_counts = None
164
+ if self.event_cols and self.event_cols[0] in self.all_epoch_df.columns:
165
+ label_col = self.event_cols[0]
166
+ value_counts = self.all_epoch_df[label_col].value_counts().sort_index()
167
+ class_counts = np.zeros(self.num_classes, dtype=np.int64)
168
+ for cls_idx, count in value_counts.items():
169
+ if 0 <= int(cls_idx) < self.num_classes:
170
+ class_counts[int(cls_idx)] = int(count)
171
+ self._class_counts = class_counts
172
+
173
+ def _load_csvs(self, csv_dir, split, data_source, include_datasets, event_cols,
174
+ regression_targets=None, regression_filter_config=None, return_all_event_cols=False):
175
+ split_suffix_map = {
176
+ "pretrain": "train", "pretrain-val": "valid", "pretrain-test": "test",
177
+ "train": "train", "val": "valid", "test": "test"
178
+ }
179
+ split_suffix = split_suffix_map[split]
180
+
181
+ if data_source == "auto":
182
+ sources = ["pretrain"] if split.startswith("pretrain") else ["downstream"]
183
+ elif data_source == "both":
184
+ sources = ["pretrain", "downstream"]
185
+ else:
186
+ sources = [data_source]
187
+
188
+ patient_dfs = []
189
+ epoch_dfs = []
190
+ csv_prefix = "epoch_regression" if regression_targets else "epoch"
191
+
192
+ for source in sources:
193
+ patient_csv = f"{csv_dir}/patient_{source}_{split_suffix}.csv"
194
+ epoch_csv = f"{csv_dir}/{csv_prefix}_{source}_{split_suffix}.csv"
195
+
196
+ if Path(patient_csv).is_file() and Path(epoch_csv).is_file():
197
+ patient_dfs.append(pd.read_csv(patient_csv))
198
+ epoch_dfs.append(pd.read_csv(epoch_csv))
199
+
200
+ patient_df = pd.concat(patient_dfs, ignore_index=True).drop_duplicates(subset=['nsrrid'])
201
+ epoch_df = pd.concat(epoch_dfs, ignore_index=True)
202
+
203
+ base_cols = ['nsrrid', 'seg_id', 'dataset_name', 'epoch_id', 'path_head']
204
+ if event_cols:
205
+ if return_all_event_cols:
206
+ for col in event_cols:
207
+ if col and col not in base_cols:
208
+ base_cols.append(col)
209
+ elif event_cols[0]:
210
+ base_cols.append(event_cols[0])
211
+
212
+ if regression_targets:
213
+ for t in regression_targets:
214
+ col_name = f"{t}_mean"
215
+ if col_name in epoch_df.columns:
216
+ base_cols.append(col_name)
217
+
218
+ keep_cols = [c for c in base_cols if c in epoch_df.columns]
219
+ epoch_df = epoch_df[keep_cols].copy()
220
+
221
+ if regression_targets:
222
+ label_cols = [f"{t}_mean" for t in regression_targets]
223
+ existing = [c for c in label_cols if c in epoch_df.columns]
224
+ if existing:
225
+ epoch_df = epoch_df.dropna(subset=existing).reset_index(drop=True)
226
+
227
+ if regression_filter_config:
228
+ for col_name, filter_rules in regression_filter_config.items():
229
+ if col_name in epoch_df.columns:
230
+ mask = pd.Series([True] * len(epoch_df))
231
+ if "min" in filter_rules:
232
+ mask = mask & (epoch_df[col_name] >= filter_rules["min"])
233
+ if "max" in filter_rules:
234
+ mask = mask & (epoch_df[col_name] <= filter_rules["max"])
235
+ epoch_df = epoch_df[mask].reset_index(drop=True)
236
+
237
+ if include_datasets is not None and 'dataset_name' in epoch_df.columns:
238
+ include_lower = [d.lower() for d in include_datasets]
239
+ mask = epoch_df['dataset_name'].astype(str).str.lower().isin(include_lower)
240
+ epoch_df = epoch_df[mask].copy()
241
+ patient_df = patient_df[patient_df['nsrrid'].isin(epoch_df['nsrrid'].unique())].copy()
242
+
243
+ return patient_df, epoch_df
244
+
245
+ def __len__(self) -> int:
246
+ return len(self.epoch_df)
247
+
248
+ def get_class_counts(self) -> Optional[np.ndarray]:
249
+ return self._class_counts
250
+
251
+ def _resample_df(self, df: pd.DataFrame, target_hz: int) -> pd.DataFrame:
252
+ if not np.issubdtype(df.index.dtype, np.number):
253
+ t = np.arange(len(df)) / float(target_hz)
254
+ df = df.copy()
255
+ df.index = t
256
+
257
+ t0 = float(df.index.min())
258
+ t1 = float(df.index.max())
259
+ t_target = np.arange(t0, t0 + self.window_size, 1.0 / target_hz)
260
+ if t_target[-1] > t1:
261
+ t_target = t_target[t_target <= t1 + 1e-9]
262
+ out = df.reindex(t_target).interpolate(method="linear", limit_direction="both")
263
+ return out.fillna(0.0)
264
+
265
+ def __getitem__(self, idx: int):
266
+ row = self.epoch_df.iloc[idx]
267
+ nsrrid = row["nsrrid"]
268
+ seg_id = int(row["seg_id"])
269
+ cols = list(self.train_edf_cols) if self.train_edf_cols is not None else None
270
+
271
+ if self.split == "pretrain":
272
+ df_epoch = self._load_epoch_all_df(row["path_head"], seg_id, columns=cols)
273
+ df_epoch = self._resample_df(df_epoch, self.sample_rate)
274
+
275
+ if cols is not None:
276
+ for ch in cols:
277
+ if ch not in df_epoch.columns:
278
+ df_epoch[ch] = 0.0
279
+ elif ch in NEED_NORM_COL:
280
+ df_epoch[ch] = to_pm1(df_epoch[ch])
281
+ df_epoch = df_epoch[cols]
282
+
283
+ samples_per_epoch = int(self.window_size * self.sample_rate)
284
+ if len(df_epoch) < samples_per_epoch:
285
+ pad = samples_per_epoch - len(df_epoch)
286
+ tail = pd.DataFrame({c: 0.0 for c in df_epoch.columns},
287
+ index=df_epoch.index[-1] + (np.arange(1, pad + 1) / self.sample_rate))
288
+ df_epoch = pd.concat([df_epoch, tail], axis=0)
289
+ elif len(df_epoch) > samples_per_epoch:
290
+ df_epoch = df_epoch.iloc[:samples_per_epoch]
291
+
292
+ x = torch.tensor(df_epoch.to_numpy(copy=False), dtype=torch.float32).t().contiguous()
293
+ x = torch.clamp(x, min=-6, max=6)
294
+
295
+ output = {"psg": x}
296
+ if self.return_nsrrid:
297
+ output["nsrrid"] = nsrrid
298
+ output["seg_id"] = seg_id
299
+
300
+ if self.patient_cols:
301
+ y = torch.tensor(self.patient_df.loc[nsrrid, self.patient_cols].values.astype(float), dtype=torch.float32)
302
+ output["label"] = y.long() if not self.return_nsrrid else y
303
+ elif self.event_cols:
304
+ if self.return_all_event_cols:
305
+ available_cols = [c for c in self.event_cols if c in row.index]
306
+ y = torch.tensor([row[c] for c in available_cols], dtype=torch.float32)
307
+ else:
308
+ y = torch.tensor([row[self.event_cols[0]]], dtype=torch.float32)
309
+ output["label"] = y
310
+
311
+ return output
312
+ else:
313
+ # Downstream split
314
+ if self._seg_indices is None:
315
+ seg_df = self.all_epoch_df[
316
+ (self.all_epoch_df['nsrrid'] == nsrrid) & (self.all_epoch_df['seg_id'] == seg_id)
317
+ ].sort_values('epoch_id')
318
+ else:
319
+ idx_arr = self._seg_indices.get((nsrrid, seg_id))
320
+ seg_df = self.all_epoch_df.iloc[idx_arr] if idx_arr is not None else \
321
+ self.all_epoch_df[(self.all_epoch_df['nsrrid'] == nsrrid) & (self.all_epoch_df['seg_id'] == seg_id)].sort_values('epoch_id')
322
+
323
+ df_epoch = self._load_epoch_all_df(row["path_head"], seg_id, columns=cols)
324
+ df_epoch = self._resample_df(df_epoch, self.sample_rate)
325
+
326
+ if cols is not None:
327
+ for ch in cols:
328
+ if ch not in df_epoch.columns:
329
+ df_epoch[ch] = 0.0
330
+ elif ch in NEED_NORM_COL:
331
+ df_epoch[ch] = to_pm1(df_epoch[ch])
332
+ df_epoch = df_epoch[cols]
333
+
334
+ samples_per_epoch = int(self.window_size * self.sample_rate)
335
+ if len(df_epoch) < samples_per_epoch:
336
+ pad = samples_per_epoch - len(df_epoch)
337
+ tail = pd.DataFrame({c: 0.0 for c in df_epoch.columns},
338
+ index=df_epoch.index[-1] + (np.arange(1, pad + 1) / self.sample_rate))
339
+ df_epoch = pd.concat([df_epoch, tail], axis=0)
340
+ elif len(df_epoch) > samples_per_epoch:
341
+ df_epoch = df_epoch.iloc[:samples_per_epoch]
342
+
343
+ x = torch.tensor(df_epoch.to_numpy(copy=False), dtype=torch.float32).t().contiguous()
344
+ x = torch.clamp(x, min=-6, max=6)
345
+
346
+ output = {"psg": x}
347
+ if self.return_nsrrid:
348
+ output["nsrrid"] = nsrrid
349
+ output["seg_id"] = seg_id
350
+
351
+ if self.patient_cols:
352
+ y = torch.tensor(self.patient_df.loc[nsrrid, self.patient_cols].values.astype(float), dtype=torch.float32)
353
+ y = y.repeat(self.window_size // self.epoch_length)
354
+ output["label"] = y
355
+ elif self.event_cols:
356
+ if self.return_all_event_cols:
357
+ available_cols = [c for c in self.event_cols if c in seg_df.columns]
358
+ y = torch.tensor(seg_df[available_cols].values.astype(float), dtype=torch.float32).squeeze(0)
359
+ else:
360
+ y = torch.tensor(seg_df[self.event_cols].values.astype(float), dtype=torch.float32).squeeze(1)
361
+ output["label"] = y
362
+ elif self.regression_targets:
363
+ label_cols = [f"{t}_mean" for t in self.regression_targets]
364
+ y = torch.tensor([row[c] for c in label_cols], dtype=torch.float32)
365
+ output["label"] = y
366
+
367
+ return output
368
+
369
+ def _build_epoch_all_path(self, path_head: str, epoch_id: int) -> Path:
370
+ return Path(f"{path_head}/epoch-{epoch_id:05d}_all.parquet")
371
+
372
+ def _load_epoch_all_df(self, path_head: str, epoch_id: int, columns=None) -> pd.DataFrame:
373
+ fp = self._build_epoch_all_path(path_head, epoch_id)
374
+ if not fp.is_file():
375
+ raise FileNotFoundError(f"Parquet missing: {fp}")
376
+ df = pd.read_parquet(fp)
377
+ for c in df.columns:
378
+ if not np.issubdtype(df[c].dtype, np.floating):
379
+ with suppress(Exception):
380
+ df[c] = df[c].astype(np.float32)
381
+ return df
osf/datasets/simclr_aug_registry.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Two-view augmentation registry for SSL pretraining (SimCLR, DINO).
3
+ Provides multi-view generation pipelines for contrastive and self-distillation methods.
4
+ """
5
+ from __future__ import annotations
6
+ from typing import Callable, Dict
7
+ import torch
8
+ from osf.datasets import augmentations as A
9
+
10
+
11
+ def _two_view(pipe1: Callable, pipe2: Callable | None = None) -> Callable:
12
+ """Wrap one/two single-view pipelines into a two-view augmentation maker."""
13
+ if pipe2 is None:
14
+ pipe2 = pipe1
15
+ def make(x: torch.Tensor):
16
+ return pipe1(x), pipe2(x)
17
+ return make
18
+
19
+
20
+ SIMCLR_AUG_REGISTRY: Dict[str, Callable] = {
21
+ "none": _two_view(lambda x: x),
22
+
23
+ "channel_dropout": _two_view(lambda x: A.channel_dropout(x, drop_prob=0.2, min_keep=1)),
24
+ "channel_dropout_light": _two_view(lambda x: A.channel_dropout(x, drop_prob=0.25, min_keep=1)),
25
+ "channel_dropout_aligned": _two_view(lambda x: A.channel_dropout(x, drop_prob=0.5, min_keep=1)),
26
+ }
27
+
28
+ SIMCLR_AUG_FACTORIES: Dict[str, Callable[..., Callable]] = {}
29
+
30
+
31
+ def build_simclr_augmentor(name: str, **kwargs) -> Callable:
32
+ key = (name or "none").lower()
33
+ if key in SIMCLR_AUG_REGISTRY:
34
+ return SIMCLR_AUG_REGISTRY[key]
35
+ if key in SIMCLR_AUG_FACTORIES:
36
+ return SIMCLR_AUG_FACTORIES[key](**kwargs)
37
+ raise ValueError(
38
+ f"Unknown simclr_augmentation '{name}'. "
39
+ f"Available presets: {list(SIMCLR_AUG_REGISTRY.keys())} | "
40
+ f"factories: {list(SIMCLR_AUG_FACTORIES.keys())}"
41
+ )
42
+
43
+
44
+ def _per_channel_span_mask_factory(
45
+ ratio: tuple[float, float] = (0.10, 0.30),
46
+ n_spans: int = 1,
47
+ fill: str | torch.Tensor = "zero",
48
+ noise_scale: float = 0.05,
49
+ same_mask_for_batch: bool = False,
50
+ ):
51
+ assert 0.0 <= ratio[0] <= ratio[1] <= 1.0
52
+
53
+ def _single_view(x: torch.Tensor) -> torch.Tensor:
54
+ B, C, T = x.shape
55
+ device, dtype = x.device, x.dtype
56
+
57
+ min_len = max(1, int(round(ratio[0] * T)))
58
+ max_len = max(min_len, int(round(ratio[1] * T)))
59
+ arange_T = torch.arange(T, device=device)
60
+ mask = torch.zeros((B, C, T), device=device, dtype=torch.bool)
61
+ shape_bc = (1, C) if same_mask_for_batch else (B, C)
62
+
63
+ for _ in range(max(1, int(n_spans))):
64
+ if max_len == min_len:
65
+ lengths = torch.full(shape_bc, max_len, device=device, dtype=torch.long)
66
+ else:
67
+ lengths = torch.randint(min_len, max_len + 1, shape_bc, device=device)
68
+ max_start = (T - lengths).clamp_min(0)
69
+ if (max_start > 0).any():
70
+ rnd = torch.rand_like(max_start, dtype=torch.float32)
71
+ starts = torch.floor(rnd * (max_start.to(torch.float32) + 1)).to(torch.long)
72
+ else:
73
+ starts = torch.zeros_like(max_start)
74
+ if same_mask_for_batch and B > 1:
75
+ starts = starts.expand(B, C)
76
+ lengths = lengths.expand(B, C)
77
+ span_mask = (arange_T.view(1, 1, T) >= starts.unsqueeze(-1)) & \
78
+ (arange_T.view(1, 1, T) < (starts + lengths).unsqueeze(-1))
79
+ mask |= span_mask
80
+
81
+ y = x.clone()
82
+ if isinstance(fill, torch.Tensor):
83
+ fill_t = fill.to(device=device, dtype=dtype)
84
+ if fill_t.dim() == 0:
85
+ fill_t = fill_t.view(1, 1, 1)
86
+ if fill_t.shape[-1] == 1 and fill_t.dim() == 3 and fill_t.shape[0] in (1, B):
87
+ fill_t = fill_t if fill_t.shape[0] == B else fill_t.expand(B, -1, -1)
88
+ elif fill_t.dim() == 3 and fill_t.shape == (B, C, T):
89
+ pass
90
+ elif fill_t.dim() == 3 and fill_t.shape == (1, C, 1):
91
+ fill_t = fill_t.expand(B, -1, T)
92
+ y[mask] = fill_t[mask.expand_as(fill_t)]
93
+ elif fill == "zero":
94
+ y[mask] = 0.0
95
+ elif fill == "mean":
96
+ m = x.mean(dim=-1, keepdim=True)
97
+ y = torch.where(mask, m.expand_as(x), y)
98
+ elif fill == "noise":
99
+ m = x.mean(dim=-1, keepdim=True)
100
+ s = x.std(dim=-1, keepdim=True, unbiased=False).clamp_min(1e-8)
101
+ noise = torch.randn_like(x) * (s * noise_scale) + m
102
+ y = torch.where(mask, noise, y)
103
+ else:
104
+ raise ValueError(f"Unknown fill mode: {fill!r}")
105
+ return y
106
+
107
+ return _two_view(_single_view)
108
+
109
+
110
+ SIMCLR_AUG_FACTORIES["pc_span_mask"] = _per_channel_span_mask_factory
111
+
112
+ SIMCLR_AUG_REGISTRY.update({
113
+ "pc_span_mask_light": _per_channel_span_mask_factory(
114
+ ratio=(0.1, 0.3), n_spans=1, fill="zero", noise_scale=0.05, same_mask_for_batch=False
115
+ ),
116
+ "pc_span_mask_heavy": _per_channel_span_mask_factory(
117
+ ratio=(0.20, 0.6), n_spans=2, fill="zero", noise_scale=0.05, same_mask_for_batch=False
118
+ ),
119
+ "pc_span_mask_aligned": _per_channel_span_mask_factory(
120
+ ratio=(0.3, 0.6), n_spans=1, fill="zero", noise_scale=0, same_mask_for_batch=False
121
+ ),
122
+ })
123
+
124
+
125
+ def _channel_then_pcspan_factory(
126
+ drop_prob: float = 0.3,
127
+ min_keep: int = 1,
128
+ ratio: tuple[float, float] = (0.10, 0.30),
129
+ n_spans: int = 1,
130
+ fill: str = "zero",
131
+ noise_scale: float = 0.05,
132
+ same_mask_for_batch: bool = False,
133
+ ):
134
+ def single_view(x: torch.Tensor) -> torch.Tensor:
135
+ y = A.channel_dropout(x, drop_prob=drop_prob, min_keep=min_keep)
136
+ B, C, T = y.shape
137
+ device = y.device
138
+
139
+ min_len = max(1, int(round(ratio[0] * T)))
140
+ max_len = max(min_len, int(round(ratio[1] * T)))
141
+ arange_T = torch.arange(T, device=device)
142
+ mask = torch.zeros((B, C, T), device=device, dtype=torch.bool)
143
+ shape_bc = (1, C) if same_mask_for_batch else (B, C)
144
+
145
+ for _ in range(max(1, int(n_spans))):
146
+ lengths = torch.full(shape_bc, max_len, device=device, dtype=torch.long) \
147
+ if max_len == min_len else torch.randint(min_len, max_len + 1, shape_bc, device=device)
148
+ max_start = (T - lengths).clamp_min(0)
149
+ if (max_start > 0).any():
150
+ rnd = torch.rand_like(max_start, dtype=torch.float32)
151
+ starts = torch.floor(rnd * (max_start.to(torch.float32) + 1)).to(torch.long)
152
+ else:
153
+ starts = torch.zeros_like(max_start)
154
+ if same_mask_for_batch and B > 1:
155
+ starts = starts.expand(B, C)
156
+ lengths = lengths.expand(B, C)
157
+ span_mask = (arange_T.view(1, 1, T) >= starts.unsqueeze(-1)) & \
158
+ (arange_T.view(1, 1, T) < (starts + lengths).unsqueeze(-1))
159
+ mask |= span_mask
160
+
161
+ out = y.clone()
162
+ if fill == "zero":
163
+ out[mask] = 0.0
164
+ elif fill == "mean":
165
+ m = y.mean(dim=-1, keepdim=True)
166
+ out = torch.where(mask, m.expand_as(y), out)
167
+ elif fill == "noise":
168
+ m = y.mean(dim=-1, keepdim=True)
169
+ s = y.std(dim=-1, keepdim=True, unbiased=False).clamp_min(1e-8)
170
+ noise = torch.randn_like(y) * (s * noise_scale) + m
171
+ out = torch.where(mask, noise, out)
172
+ else:
173
+ raise ValueError(f"Unknown fill: {fill!r}")
174
+ return out
175
+
176
+ return _two_view(single_view)
177
+
178
+
179
+ SIMCLR_AUG_FACTORIES["chan_then_pcspan"] = _channel_then_pcspan_factory
180
+
181
+
182
+ def _crop_then_chan_pcspan_factory(
183
+ crop_ratio: tuple[float, float] = (0.25, 0.75),
184
+ align_to: int = 40,
185
+ drop_prob: float = 0.5,
186
+ min_keep: int = 1,
187
+ span_ratio: tuple[float, float] = (0.3, 0.6),
188
+ n_spans: int = 1,
189
+ fill: str = "zero",
190
+ noise_scale: float = 0.0,
191
+ same_mask_for_batch: bool = False,
192
+ ):
193
+ def single_view(x: torch.Tensor) -> torch.Tensor:
194
+ y = A.random_time_crop(x, ratio=crop_ratio, resize_back=True, align_to=align_to)
195
+ y = A.channel_dropout(y, drop_prob=drop_prob, min_keep=min_keep)
196
+
197
+ B, C, T = y.shape
198
+ device = y.device
199
+ min_len = max(1, int(round(span_ratio[0] * T)))
200
+ max_len = max(min_len, int(round(span_ratio[1] * T)))
201
+ arange_T = torch.arange(T, device=device)
202
+ mask = torch.zeros((B, C, T), device=device, dtype=torch.bool)
203
+ shape_bc = (1, C) if same_mask_for_batch else (B, C)
204
+
205
+ for _ in range(max(1, int(n_spans))):
206
+ lengths = torch.full(shape_bc, max_len, device=device, dtype=torch.long) \
207
+ if max_len == min_len else torch.randint(min_len, max_len + 1, shape_bc, device=device)
208
+ max_start = (T - lengths).clamp_min(0)
209
+ if (max_start > 0).any():
210
+ rnd = torch.rand_like(max_start, dtype=torch.float32)
211
+ starts = torch.floor(rnd * (max_start.to(torch.float32) + 1)).to(torch.long)
212
+ else:
213
+ starts = torch.zeros_like(max_start)
214
+ if same_mask_for_batch and B > 1:
215
+ starts = starts.expand(B, C)
216
+ lengths = lengths.expand(B, C)
217
+ span_mask = (arange_T.view(1, 1, T) >= starts.unsqueeze(-1)) & \
218
+ (arange_T.view(1, 1, T) < (starts + lengths).unsqueeze(-1))
219
+ mask |= span_mask
220
+
221
+ out = y.clone()
222
+ if fill == "zero":
223
+ out[mask] = 0.0
224
+ elif fill == "mean":
225
+ m = y.mean(dim=-1, keepdim=True)
226
+ out = torch.where(mask, m.expand_as(y), out)
227
+ elif fill == "noise":
228
+ m = y.mean(dim=-1, keepdim=True)
229
+ s = y.std(dim=-1, keepdim=True, unbiased=False).clamp_min(1e-8)
230
+ noise = torch.randn_like(y) * (s * noise_scale) + m
231
+ out = torch.where(mask, noise, out)
232
+ else:
233
+ raise ValueError(f"Unknown fill: {fill!r}")
234
+ return out
235
+
236
+ return _two_view(single_view)
237
+
238
+
239
+ SIMCLR_AUG_FACTORIES["crop_then_chan_pcspan"] = _crop_then_chan_pcspan_factory
240
+
241
+ SIMCLR_AUG_REGISTRY.update({
242
+ "chan_then_pcspan": _channel_then_pcspan_factory(
243
+ drop_prob=0.5, min_keep=1, ratio=(0.3, 0.6), n_spans=1, fill="zero",
244
+ noise_scale=0, same_mask_for_batch=False
245
+ ),
246
+ "chan_then_pcspan_light": _channel_then_pcspan_factory(
247
+ drop_prob=0.25, min_keep=1, ratio=(0.3, 0.6), n_spans=1, fill="zero",
248
+ noise_scale=0, same_mask_for_batch=False
249
+ ),
250
+ "crop_then_chan_pcspan": _crop_then_chan_pcspan_factory(
251
+ crop_ratio=(0.25, 0.75), align_to=40, drop_prob=0.5, min_keep=1,
252
+ span_ratio=(0.3, 0.6), n_spans=1, fill="zero", noise_scale=0, same_mask_for_batch=False
253
+ ),
254
+ "crop_then_chan_pcspan_light": _crop_then_chan_pcspan_factory(
255
+ crop_ratio=(0.25, 0.75), align_to=40, drop_prob=0.25, min_keep=1,
256
+ span_ratio=(0.3, 0.6), n_spans=1, fill="zero", noise_scale=0, same_mask_for_batch=False
257
+ ),
258
+ })
osf/models/__init__.py ADDED
File without changes
osf/models/balanced_losses.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Balanced/imbalanced learning losses.
3
+ Reference: https://github.com/YyzHarry/SubpopBench
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from typing import Optional
9
+
10
+
11
+ class FocalLoss(nn.Module):
12
+ """
13
+ Focal Loss: FL(p_t) = -α_t * (1 - p_t)^γ * log(p_t)
14
+ Paper: https://arxiv.org/abs/1708.02002
15
+
16
+ Args:
17
+ alpha: Weighting factor (float or [num_classes] tensor)
18
+ gamma: Focusing parameter (higher = more focus on hard examples)
19
+ reduction: 'mean' or 'none'
20
+ """
21
+ def __init__(self, alpha: Optional[float | torch.Tensor] = None, gamma: float = 2.0, reduction: str = "mean"):
22
+ super().__init__()
23
+ self.gamma = gamma
24
+ self.reduction = reduction
25
+
26
+ if isinstance(alpha, (float, int)):
27
+ self.register_buffer("alpha", torch.tensor([alpha], dtype=torch.float32))
28
+ elif isinstance(alpha, torch.Tensor):
29
+ self.register_buffer("alpha", alpha.float())
30
+ elif alpha is None:
31
+ self.alpha = None
32
+ else:
33
+ raise ValueError(f"alpha must be float, Tensor, or None, got {type(alpha)}")
34
+
35
+ def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
36
+ """
37
+ Args:
38
+ logits: [B, C] unnormalized logits
39
+ targets: [B] class indices
40
+ """
41
+ ce_loss = F.cross_entropy(logits, targets, reduction="none")
42
+ pt = torch.exp(-ce_loss) # p_t
43
+ focal_loss = ((1 - pt) ** self.gamma) * ce_loss
44
+
45
+ if self.alpha is not None:
46
+ if self.alpha.dim() == 0 or len(self.alpha) == 1:
47
+ alpha_t = self.alpha.squeeze()
48
+ else:
49
+ alpha_t = self.alpha[targets] # [B]
50
+ focal_loss = alpha_t * focal_loss
51
+
52
+ if self.reduction == "mean":
53
+ return focal_loss.mean()
54
+ elif self.reduction == "none":
55
+ return focal_loss
56
+ else:
57
+ raise ValueError(f"reduction must be 'mean' or 'none', got {self.reduction}")
58
+
59
+
60
+ class BalancedSoftmax(nn.Module):
61
+ """
62
+ Balanced Softmax: adjusted_logits = logits + log(class_counts)
63
+ Paper: https://arxiv.org/abs/2007.10740
64
+
65
+ Args:
66
+ class_counts: [C] tensor of sample counts per class
67
+ reduction: 'mean' or 'none'
68
+ """
69
+ def __init__(self, class_counts: torch.Tensor, reduction: str = "mean"):
70
+ super().__init__()
71
+ if not isinstance(class_counts, torch.Tensor):
72
+ class_counts = torch.tensor(class_counts, dtype=torch.float32)
73
+
74
+ class_counts = class_counts.float()
75
+ if (class_counts == 0).any():
76
+ zero_classes = (class_counts == 0).nonzero(as_tuple=True)[0].tolist()
77
+ raise ValueError(f"BalancedSoftmax requires non-zero class counts. Zero counts: {zero_classes}")
78
+
79
+ self.register_buffer("log_class_counts", torch.log(class_counts))
80
+ self.reduction = reduction
81
+
82
+ def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
83
+ """
84
+ Args:
85
+ logits: [B, C] unnormalized logits
86
+ targets: [B] class indices
87
+ """
88
+ adjusted_logits = logits + self.log_class_counts.unsqueeze(0)
89
+ return F.cross_entropy(adjusted_logits, targets, reduction=self.reduction)
osf/models/base_pretrain_model.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from pytorch_lightning import LightningModule
5
+
6
+ from osf.backbone.vit1d import vit_nano, vit_tiny, vit_small, vit_middle, vit_base
7
+
8
+ VIT_FACTORIES = {
9
+ "vit_nano": vit_nano,
10
+ "vit_tiny": vit_tiny,
11
+ "vit_small": vit_small,
12
+ "vit_middle": vit_middle,
13
+ "vit_base": vit_base,
14
+ }
15
+
16
+
17
+ class PSGModalityEncoder(nn.Module):
18
+ """ViT encoder for PSG signals: backbone -> optional projection -> L2-norm"""
19
+
20
+ def __init__(self, *,
21
+ encoder_name: str,
22
+ proj_out: int = 256,
23
+ proj_hidden: int = 512,
24
+ freq: int = 64,
25
+ win_sec: int = 30,
26
+ channel: int = 11,
27
+ lead_wise=0,
28
+ patch_size=40,
29
+ patch_size_ch=4,
30
+ use_lead_embedding: bool = True,
31
+ is_proj_head=1):
32
+ super().__init__()
33
+ token_len = freq * win_sec
34
+ self.token_len = token_len
35
+ self.patch_size = patch_size
36
+
37
+ if encoder_name not in VIT_FACTORIES:
38
+ raise ValueError(f"Unknown encoder_name: {encoder_name}. Choose from {list(VIT_FACTORIES.keys())}")
39
+
40
+ self.backbone = VIT_FACTORIES[encoder_name](
41
+ num_leads=channel, seq_len=token_len, patch_size=patch_size,
42
+ lead_wise=lead_wise, patch_size_ch=patch_size_ch,
43
+ use_lead_embedding=use_lead_embedding,
44
+ )
45
+
46
+ d_model = self.backbone.width
47
+ if is_proj_head == 1:
48
+ self.proj_head = nn.Sequential(
49
+ nn.Linear(d_model, proj_hidden),
50
+ nn.LayerNorm(proj_hidden),
51
+ nn.ReLU(inplace=True),
52
+ nn.Linear(proj_hidden, proj_out),
53
+ nn.LayerNorm(proj_out),
54
+ )
55
+ else:
56
+ self.proj_head = None
57
+
58
+ def forward(self, x, normalize=True):
59
+ # x: [B, C, T]
60
+ h = self.backbone(x) # [B, D]
61
+ if self.proj_head is not None:
62
+ h = self.proj_head(h) # [B, proj_out]
63
+ if normalize:
64
+ return F.normalize(h, dim=-1)
65
+ return h
66
+
67
+
68
+ class BasePretrainModel(LightningModule):
69
+ def __init__(self,
70
+ psg_encoder_name: str = "vit_base",
71
+ text_encoder_name: str = "google/flan-t5-base",
72
+ fusion_decoder_name: str = 'cross-attn',
73
+ shared_emb_dim: int = 256,
74
+ lr: float = 2e-4,
75
+ weight_decay: float = 0.2,
76
+ training_steps_per_epoch: int = 7000,
77
+ max_epochs: int = 100,
78
+ *args, **kwargs):
79
+ super().__init__()
80
+ self.save_hyperparameters()
81
+ self.psg_encoder_name = psg_encoder_name
82
+ self.text_encoder_name = text_encoder_name
83
+ self.fusion_decoder_name = fusion_decoder_name
84
+ self.shared_emb_dim = shared_emb_dim
85
+ self.lr = lr
86
+ self.weight_decay = weight_decay
87
+ self.training_steps_per_epoch = training_steps_per_epoch
88
+ self.max_epochs = max_epochs
89
+ self.warmup_epochs = 0.1 * self.max_epochs
90
+ self.proj_out = shared_emb_dim
91
+ self.proj_hidden = 256
92
+
93
+ assert self.training_steps_per_epoch > 1
94
+
95
+ def configure_optimizers(self):
96
+ optimizer = torch.optim.AdamW(
97
+ self.parameters(),
98
+ lr=self.lr,
99
+ weight_decay=self.weight_decay,
100
+ betas=(0.9, 0.95),
101
+ )
102
+
103
+ total_steps = int(self.training_steps_per_epoch * self.max_epochs)
104
+ warmup_steps = int(round(self.training_steps_per_epoch * self.warmup_epochs))
105
+ warmup_steps = max(0, warmup_steps)
106
+ decay_steps = max(1, total_steps - warmup_steps)
107
+
108
+ if warmup_steps > 0:
109
+ warmup = torch.optim.lr_scheduler.LinearLR(
110
+ optimizer, start_factor=0.01, end_factor=1.0, total_iters=warmup_steps)
111
+ cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
112
+ optimizer, T_max=decay_steps, eta_min=1e-8)
113
+ sched = torch.optim.lr_scheduler.SequentialLR(
114
+ optimizer, schedulers=[warmup, cosine], milestones=[warmup_steps])
115
+ else:
116
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(
117
+ optimizer, T_max=decay_steps, eta_min=1e-8)
118
+
119
+ return [optimizer], [{"scheduler": sched, "interval": "step", "frequency": 1}]
120
+
121
+ def training_step(self, batch, batch_idx):
122
+ loss_dict, metrics_dict = self.shared_step(batch, batch_idx)
123
+ for k, v in loss_dict.items():
124
+ self.log(f"train/{k}", v, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
125
+ for k, v in metrics_dict.items():
126
+ self.log(f"train/{k}", v, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
127
+ return loss_dict['loss']
128
+
129
+ def validation_step(self, batch, batch_idx):
130
+ with torch.no_grad():
131
+ loss_dict, metrics_dict = self.shared_step(batch, batch_idx)
132
+ for k, v in loss_dict.items():
133
+ self.log(f"val/{k}", v, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
134
+ for k, v in metrics_dict.items():
135
+ self.log(f"val/{k}", v, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
136
+ return loss_dict
137
+
138
+ def test_step(self, batch, batch_idx):
139
+ loss_dict, metrics_dict = self.shared_step(batch, batch_idx)
140
+ for k, v in loss_dict.items():
141
+ self.log(f"test/{k}", v, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
142
+ for k, v in metrics_dict.items():
143
+ self.log(f"test/{k}", v, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
144
+ return loss_dict
osf/models/base_pretrain_model_cls.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from osf.backbone.vit1d_cls import vit_nano, vit_tiny, vit_small, vit_middle, vit_base, vit_large, vit_xl
3
+
4
+
5
+ class PSGModalityEncoderCLS(nn.Module):
6
+ """
7
+ Init helper for ViT with CLS token. No forward() - access .backbone directly.
8
+
9
+ Used by DINO to initialize encoder, then DINO accesses self.encoders["all"].backbone.
10
+ """
11
+ def __init__(self, *,
12
+ encoder_name: str,
13
+ proj_out: int = 256,
14
+ proj_hidden: int = 512,
15
+ freq: int = 64,
16
+ win_sec: int = 30,
17
+ channel: int = 12,
18
+ lead_wise = 0,
19
+ patch_size = 40,
20
+ patch_size_ch = 4,
21
+ is_proj_head = 1,
22
+ ):
23
+ super().__init__()
24
+ token_len = freq * win_sec
25
+
26
+ self.token_len = token_len
27
+ self.patch_size = patch_size
28
+
29
+ if encoder_name == "vit_nano":
30
+ self.backbone = vit_nano(num_leads=channel, seq_len=token_len, patch_size=patch_size, lead_wise=lead_wise, patch_size_ch=patch_size_ch)
31
+ elif encoder_name == "vit_tiny":
32
+ self.backbone = vit_tiny(num_leads=channel, seq_len=token_len, patch_size=patch_size, lead_wise=lead_wise, patch_size_ch=patch_size_ch)
33
+ elif encoder_name == "vit_small":
34
+ self.backbone = vit_small(num_leads=channel, seq_len=token_len, patch_size=patch_size, lead_wise=lead_wise, patch_size_ch=patch_size_ch)
35
+ elif encoder_name == "vit_middle":
36
+ self.backbone = vit_middle(num_leads=channel, seq_len=token_len, patch_size=patch_size, lead_wise=lead_wise, patch_size_ch=patch_size_ch)
37
+ elif encoder_name == "vit_base":
38
+ self.backbone = vit_base(num_leads=channel, seq_len=token_len, patch_size=patch_size, lead_wise=lead_wise, patch_size_ch=patch_size_ch)
39
+ elif encoder_name == "vit_large":
40
+ self.backbone = vit_large(num_leads=channel, seq_len=token_len, patch_size=patch_size, lead_wise=lead_wise, patch_size_ch=patch_size_ch)
41
+ elif encoder_name == "vit_xl":
42
+ self.backbone = vit_xl(num_leads=channel, seq_len=token_len, patch_size=patch_size, lead_wise=lead_wise, patch_size_ch=patch_size_ch)
43
+ else:
44
+ raise ValueError(f"Unknown encoder_name for CLS variant: {encoder_name}")
45
+
46
+ d_model = self.backbone.width
47
+ if is_proj_head == 1:
48
+ self.proj_head = nn.Sequential(
49
+ nn.Linear(d_model, proj_hidden),
50
+ nn.LayerNorm(proj_hidden),
51
+ nn.ReLU(inplace=True),
52
+ nn.Linear(proj_hidden, proj_out),
53
+ nn.LayerNorm(proj_out),
54
+ )
55
+ else:
56
+ self.proj_head = None
osf/models/dino_model_cls.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from typing import Optional
7
+
8
+ from osf.models.dino_utils.dino_clstoken_loss import DINOLoss
9
+ from osf.models.dino_utils.ibot_patch_loss import iBOTPatchLoss
10
+ from osf.models.dino_utils.koleo_loss import KoLeoLoss
11
+ from osf.models.base_pretrain_model import BasePretrainModel
12
+ from osf.models.base_pretrain_model_cls import PSGModalityEncoderCLS
13
+ from osf.datasets.simclr_aug_registry import build_simclr_augmentor
14
+
15
+
16
+ class DINOHead(nn.Module):
17
+ def __init__(self, in_dim, out_dim, hidden_dim=2048, bottleneck_dim=256, nlayers=3):
18
+ super().__init__()
19
+ num_layers = max(nlayers, 1)
20
+ if num_layers == 1:
21
+ self.mlp = nn.Sequential(nn.Linear(in_dim, bottleneck_dim))
22
+ else:
23
+ layers = [nn.Linear(in_dim, hidden_dim), nn.GELU()]
24
+ for _ in range(num_layers - 2):
25
+ layers += [nn.Linear(hidden_dim, hidden_dim), nn.GELU()]
26
+ layers += [nn.Linear(hidden_dim, bottleneck_dim)]
27
+ self.mlp = nn.Sequential(*layers)
28
+
29
+ self.apply(self._init_weights)
30
+ self.prototypes = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
31
+ self.prototypes.weight_g.data.fill_(1.0)
32
+
33
+ @staticmethod
34
+ def _init_weights(m):
35
+ if isinstance(m, nn.Linear):
36
+ nn.init.trunc_normal_(m.weight, std=0.02)
37
+ if m.bias is not None:
38
+ nn.init.zeros_(m.bias)
39
+
40
+ def forward(self, x):
41
+ x = self.mlp(x)
42
+ x = F.normalize(x, dim=-1)
43
+ return self.prototypes(x)
44
+
45
+
46
+ class DINOCLSModel(BasePretrainModel):
47
+ def __init__(
48
+ self,
49
+ psg_encoder_name: str = "vit_base",
50
+ text_encoder_name: Optional[str] = None,
51
+ shared_emb_dim: int = 768,
52
+ out_dim: int = 2048,
53
+ patch_out_dim: int = 2048,
54
+ dino_out_dim: int = None,
55
+ dino_patch_out_dim: int = None,
56
+ dino_hidden_dim: int = 2048,
57
+ dino_bottleneck_dim: int = 256,
58
+ student_temp: float = 0.1,
59
+ teacher_temp_warmup: float = 0.04,
60
+ teacher_temp_final: float = 0.07,
61
+ teacher_temp_warmup_iters: int = 10000,
62
+ base_momentum: float = 0.996,
63
+ use_koleo: bool = True,
64
+ koleo_lambda: float = 0.0,
65
+ ibot_lambda: float = 0.0,
66
+ lr: float = 2e-4,
67
+ weight_decay: float = 0.2,
68
+ num_freeze_layers: int = 6,
69
+ simclr_augmentation: dict | None = None,
70
+ n_local_crops: int = 2,
71
+ *args, **kwargs
72
+ ):
73
+ super().__init__(
74
+ psg_encoder_name=psg_encoder_name,
75
+ text_encoder_name=None,
76
+ shared_emb_dim=shared_emb_dim,
77
+ lr=lr,
78
+ weight_decay=weight_decay,
79
+ *args, **kwargs
80
+ )
81
+ self.save_hyperparameters()
82
+
83
+ self.proj_out = shared_emb_dim
84
+ self.proj_hidden = 256
85
+ self.num_freeze_layers = num_freeze_layers
86
+
87
+ num_leads = kwargs.get('num_leads', 12)
88
+ self.num_leads = num_leads
89
+
90
+ self.cfg = [dict(name="all", freq=64, win_sec=30, in_ch=num_leads)]
91
+ self.encoders = nn.ModuleDict()
92
+ for mod in self.cfg:
93
+ self.encoders[mod["name"]] = PSGModalityEncoderCLS(
94
+ encoder_name=psg_encoder_name,
95
+ proj_out=shared_emb_dim,
96
+ proj_hidden=256,
97
+ freq=mod["freq"],
98
+ win_sec=mod["win_sec"],
99
+ channel=mod["in_ch"],
100
+ patch_size=kwargs['patch_size_time'],
101
+ lead_wise=kwargs['lead_wise'],
102
+ patch_size_ch=(num_leads if kwargs['lead_wise'] == 0 else kwargs['patch_size_ch']),
103
+ is_proj_head=0,
104
+ )
105
+ self.lead_wise = kwargs['lead_wise']
106
+ self.patch_size_time = kwargs['patch_size_time']
107
+ self.patch_size_ch = (num_leads if self.lead_wise == 0 else kwargs['patch_size_ch'])
108
+ trunk_dim = self.encoders['all'].backbone.width
109
+ out_dim = dino_out_dim if dino_out_dim is not None else out_dim
110
+ patch_out_dim = dino_patch_out_dim if dino_patch_out_dim is not None else patch_out_dim
111
+ self.out_dim = out_dim
112
+ self.patch_out_dim = patch_out_dim
113
+
114
+ self.student_global_head = DINOHead(trunk_dim, out_dim, dino_hidden_dim, dino_bottleneck_dim, 3)
115
+ self.student_patch_head = DINOHead(trunk_dim, patch_out_dim, dino_hidden_dim, dino_bottleneck_dim, 3)
116
+ self.teacher_encoder = copy.deepcopy(self.encoders["all"])
117
+ for p in self.teacher_encoder.parameters():
118
+ p.requires_grad = False
119
+ self.teacher_encoder.eval()
120
+
121
+ self.teacher_global_head = DINOHead(trunk_dim, out_dim, dino_hidden_dim, dino_bottleneck_dim, 3)
122
+ self.teacher_patch_head = DINOHead(trunk_dim, patch_out_dim, dino_hidden_dim, dino_bottleneck_dim, 3)
123
+ self.teacher_global_head.load_state_dict(self.student_global_head.state_dict(), strict=True)
124
+ self.teacher_patch_head.load_state_dict(self.student_patch_head.state_dict(), strict=True)
125
+ for p in self.teacher_global_head.parameters():
126
+ p.requires_grad = False
127
+ for p in self.teacher_patch_head.parameters():
128
+ p.requires_grad = False
129
+ self.teacher_global_head.eval()
130
+ self.teacher_patch_head.eval()
131
+ self.dino_loss = DINOLoss(out_dim=out_dim, student_temp=student_temp, center_momentum=0.9)
132
+ self.ibot_loss = iBOTPatchLoss(patch_out_dim=patch_out_dim, student_temp=student_temp, center_momentum=0.9)
133
+ self.koleo = KoLeoLoss() if use_koleo else None
134
+ self.koleo_lambda = float(koleo_lambda)
135
+ self.ibot_lambda = float(ibot_lambda)
136
+ self.teacher_temp_warmup = float(teacher_temp_warmup)
137
+ self.teacher_temp_final = float(teacher_temp_final)
138
+ self.teacher_temp_warmup_iters = int(teacher_temp_warmup_iters)
139
+ self.base_momentum = float(base_momentum)
140
+
141
+ self.register_buffer("seen_steps", torch.tensor(0, dtype=torch.long))
142
+
143
+ if simclr_augmentation is None:
144
+ simclr_augmentation = {}
145
+ self.simclr_augmentation = simclr_augmentation
146
+ self.augmentor = build_simclr_augmentor(self.simclr_augmentation)
147
+ self.n_local_crops = int(n_local_crops)
148
+
149
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, trunk_dim))
150
+ nn.init.trunc_normal_(self.mask_token, std=0.02)
151
+
152
+ def _teacher_temp(self, step: int) -> float:
153
+ if step < self.teacher_temp_warmup_iters:
154
+ alpha = step / float(max(1, self.teacher_temp_warmup_iters))
155
+ return self.teacher_temp_warmup * (1 - alpha) + self.teacher_temp_final * alpha
156
+ return self.teacher_temp_final
157
+
158
+ def _momentum(self, step: int, max_steps: int) -> float:
159
+ return 1.0 - (1.0 - self.base_momentum) * (math.cos(math.pi * step / max_steps) + 1) / 2
160
+
161
+ @torch.no_grad()
162
+ def _ema_update(self, m: float):
163
+ for param_q, param_k in zip(self.encoders['all'].parameters(), self.teacher_encoder.parameters()):
164
+ param_k.data.mul_(m).add_(param_q.data, alpha=1.0 - m)
165
+ for param_q, param_k in zip(self.student_global_head.parameters(), self.teacher_global_head.parameters()):
166
+ param_k.data.mul_(m).add_(param_q.data, alpha=1.0 - m)
167
+ for param_q, param_k in zip(self.student_patch_head.parameters(), self.teacher_patch_head.parameters()):
168
+ param_k.data.mul_(m).add_(param_q.data, alpha=1.0 - m)
169
+ self.teacher_encoder.eval()
170
+ self.teacher_global_head.eval()
171
+ self.teacher_patch_head.eval()
172
+
173
+ def _forward_encoder(self, encoder, x, return_tokens=True):
174
+ # x: [B, C, T]
175
+ if return_tokens:
176
+ cls, patches = encoder.backbone.forward_encoding(x, return_sequence=False)
177
+ return cls, patches # [B, D], [B, N, D]
178
+ else:
179
+ cls = encoder.backbone(x)
180
+ return cls, None # [B, D], None
181
+
182
+ def _make_views_aug(self, x: torch.Tensor):
183
+ v1, v2 = self.augmentor(x)
184
+ globals_x = [v1, v2]
185
+ locals_x = []
186
+ for _ in range(self.n_local_crops):
187
+ lv1, _ = self.augmentor(x)
188
+ locals_x.append(lv1)
189
+ return globals_x, locals_x
190
+
191
+ def shared_step(self, batch, batch_idx):
192
+ x = batch["psg"]
193
+ globals_x, locals_x = self._make_views_aug(x)
194
+ tt = self._teacher_temp(int(self.global_step))
195
+
196
+ with torch.no_grad():
197
+ teacher_out_soft_list = []
198
+ teacher_global_logits_cache = []
199
+ teacher_patch_logits_cache = []
200
+
201
+ if len(globals_x) > 0:
202
+ g_sizes = [gx.size(0) for gx in globals_x]
203
+ g_cat = torch.cat(globals_x, dim=0)
204
+ cls_t_cat, _ = self._forward_encoder(self.teacher_encoder, g_cat, return_tokens=True)
205
+ g_logits_cat = self.teacher_global_head(cls_t_cat)
206
+ g_logits_split = list(torch.split(g_logits_cat, g_sizes, dim=0))
207
+ teacher_out_soft_list = [self.dino_loss.softmax_center_teacher(gl, tt) for gl in g_logits_split]
208
+ teacher_global_logits_cache = g_logits_split
209
+
210
+ student_global_logits = []
211
+ student_cls_tokens = []
212
+ all_student_views = globals_x + locals_x
213
+ if len(all_student_views) > 0:
214
+ s_sizes = [sx.size(0) for sx in all_student_views]
215
+ s_cat = torch.cat(all_student_views, dim=0)
216
+ cls_s_cat, _ = self._forward_encoder(self.encoders["all"], s_cat, return_tokens=False)
217
+ sg_logits_cat = self.student_global_head(cls_s_cat)
218
+ student_global_logits = list(torch.split(sg_logits_cat, s_sizes, dim=0))
219
+ student_cls_tokens = list(torch.split(cls_s_cat, s_sizes, dim=0))
220
+
221
+ ibot_loss_val = torch.tensor(0.0, device=x.device)
222
+ if len(globals_x) > 0:
223
+ with torch.no_grad():
224
+ t_tokens, _ = self.teacher_encoder.backbone.to_tokens_2d(
225
+ globals_x[0], patch_size_ch=self.patch_size_ch, patch_size_time=self.patch_size_time)
226
+ B2 = t_tokens.size(0)
227
+ cls_tok = self.teacher_encoder.backbone.cls_token.expand(B2, -1, -1)
228
+ t_full = torch.cat([cls_tok, t_tokens], dim=1)
229
+ pe_full = self.teacher_encoder.backbone.pos_embedding[:, :t_full.size(1), :].to(t_full.device)
230
+ t_full = t_full + pe_full
231
+ t_full = self.teacher_encoder.backbone._run_blocks(t_full)
232
+ _, t_patches = t_full[:, 0], t_full[:, 1:]
233
+ t_logits_all = self.teacher_patch_head(t_patches)
234
+ t_soft = self.ibot_loss.softmax_center_teacher(t_logits_all, tt)
235
+
236
+ s_tokens, _ = self.encoders["all"].backbone.to_tokens_2d(
237
+ globals_x[0], patch_size_ch=self.patch_size_ch, patch_size_time=self.patch_size_time)
238
+ B2, N, Dtok = s_tokens.shape
239
+
240
+ mask_ratio = float(getattr(self, "ibot_mask_ratio", 0.3))
241
+ n_mask = max(1, int(round(N * mask_ratio)))
242
+ rand = torch.rand(B2, N, device=x.device)
243
+ topk_idx = rand.topk(k=n_mask, dim=1, largest=True).indices
244
+ masks = torch.zeros(B2, N, dtype=torch.bool, device=x.device)
245
+ masks.scatter_(1, topk_idx, True)
246
+
247
+ s_tokens_masked = torch.where(
248
+ masks.unsqueeze(-1),
249
+ self.mask_token.expand_as(s_tokens),
250
+ s_tokens
251
+ )
252
+
253
+ cls_tok_s = self.encoders["all"].backbone.cls_token.expand(B2, -1, -1)
254
+ s_full = torch.cat([cls_tok_s, s_tokens_masked], dim=1)
255
+ pe_full_s = self.encoders["all"].backbone.pos_embedding[:, :s_full.size(1), :].to(s_full.device)
256
+ s_full = s_full + pe_full_s
257
+ s_full = self.encoders["all"].backbone._run_blocks(s_full)
258
+ _, s_patches = s_full[:, 0], s_full[:, 1:]
259
+ s_logits_all = self.student_patch_head(s_patches)
260
+
261
+ ibot_loss_val = self.ibot_loss.forward_masked(
262
+ student_patch_tokens_masked=s_logits_all[masks],
263
+ teacher_patch_tokens_masked=t_soft[masks],
264
+ student_masks_flat=masks,
265
+ )
266
+
267
+ with torch.no_grad():
268
+ teacher_patch_logits_cache.append(t_logits_all)
269
+
270
+ dino_loss_val = self.dino_loss(student_global_logits, teacher_out_soft_list)
271
+ pair_norm = max(1, len(student_global_logits) * len(teacher_out_soft_list))
272
+ dino_loss_val = dino_loss_val / pair_norm
273
+ koleo_val = torch.tensor(0.0, device=x.device)
274
+ if self.koleo is not None and len(student_cls_tokens) > 0:
275
+ koleo_val = self.koleo(F.normalize(student_cls_tokens[0], dim=-1))
276
+
277
+ total_loss = dino_loss_val + self.ibot_lambda * ibot_loss_val + self.koleo_lambda * koleo_val
278
+
279
+ with torch.no_grad():
280
+ if self.training:
281
+ if len(teacher_global_logits_cache) > 0:
282
+ self.dino_loss.update_center(torch.cat(teacher_global_logits_cache, dim=0))
283
+ if len(teacher_patch_logits_cache) > 0:
284
+ self.ibot_loss.update_center(torch.cat(teacher_patch_logits_cache, dim=0))
285
+
286
+ metrics = {
287
+ "loss": total_loss,
288
+ "loss/dino": dino_loss_val,
289
+ "loss/ibot": ibot_loss_val,
290
+ "loss/koleo": koleo_val,
291
+ "sched/teacher_temp": torch.tensor(tt, device=x.device),
292
+ }
293
+ return {"loss": total_loss}, metrics
294
+
295
+ def training_step(self, batch, batch_idx):
296
+ loss_dict, metrics = self.shared_step(batch, batch_idx)
297
+ for k, v in metrics.items():
298
+ self.log(f"train/{k}", v, on_step=True, on_epoch=True, prog_bar=(k == "loss"), sync_dist=True)
299
+ return loss_dict["loss"]
300
+
301
+ def on_train_batch_end(self, outputs, batch, batch_idx):
302
+ max_steps = max(1, getattr(self.trainer, "max_steps", getattr(self.trainer, "estimated_stepping_batches", 100000)))
303
+ m = self._momentum(int(self.global_step), max_steps)
304
+ self._ema_update(m)
305
+ self.log("sched/momentum", torch.tensor(m, device=self.device), on_step=True, prog_bar=False)
306
+
307
+ def validation_step(self, batch, batch_idx):
308
+ loss_dict, metrics = self.shared_step(batch, batch_idx)
309
+ for k, v in metrics.items():
310
+ self.log(f"val/{k}", v, on_step=True, on_epoch=True, prog_bar=(k == "loss"), sync_dist=True)
311
+ return loss_dict["loss"]
osf/models/dino_utils/dino_clstoken_loss.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.distributed as dist
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
+
11
+
12
+ class DINOLoss(nn.Module):
13
+ def __init__(
14
+ self,
15
+ out_dim,
16
+ student_temp=0.1,
17
+ center_momentum=0.9,
18
+ ):
19
+ super().__init__()
20
+ self.student_temp = student_temp
21
+ self.center_momentum = center_momentum
22
+ self.register_buffer("center", torch.zeros(1, out_dim))
23
+ self.updated = True
24
+ self.reduce_handle = None
25
+ self.len_teacher_output = None
26
+ self.async_batch_center = None
27
+
28
+ @torch.no_grad()
29
+ def softmax_center_teacher(self, teacher_output, teacher_temp):
30
+ self.apply_center_update()
31
+ # teacher centering and sharpening
32
+ return F.softmax((teacher_output - self.center) / teacher_temp, dim=-1)
33
+
34
+ @torch.no_grad()
35
+ def sinkhorn_knopp_teacher(self, teacher_output, teacher_temp, n_iterations=3):
36
+ teacher_output = teacher_output.float()
37
+ world_size = dist.get_world_size() if dist.is_initialized() else 1
38
+ Q = torch.exp(teacher_output / teacher_temp).t()
39
+ B = Q.shape[1] * world_size
40
+ K = Q.shape[0]
41
+ sum_Q = torch.sum(Q)
42
+ if dist.is_initialized():
43
+ dist.all_reduce(sum_Q)
44
+ Q /= sum_Q
45
+
46
+ for it in range(n_iterations):
47
+ sum_of_rows = torch.sum(Q, dim=1, keepdim=True)
48
+ if dist.is_initialized():
49
+ dist.all_reduce(sum_of_rows)
50
+ Q /= sum_of_rows
51
+ Q /= K
52
+ Q /= torch.sum(Q, dim=0, keepdim=True)
53
+ Q /= B
54
+
55
+ Q *= B
56
+ return Q.t()
57
+
58
+ def forward(self, student_output_list, teacher_out_softmaxed_centered_list):
59
+ """
60
+ Cross-entropy between softmax outputs of the teacher and student networks.
61
+ """
62
+ # TODO: Use cross_entropy_distribution here
63
+ total_loss = 0
64
+ for s in student_output_list:
65
+ lsm = F.log_softmax(s / self.student_temp, dim=-1)
66
+ for t in teacher_out_softmaxed_centered_list:
67
+ loss = torch.sum(t * lsm, dim=-1)
68
+ total_loss -= loss.mean()
69
+ return total_loss
70
+
71
+ @torch.no_grad()
72
+ def update_center(self, teacher_output):
73
+ self.reduce_center_update(teacher_output)
74
+
75
+ @torch.no_grad()
76
+ def reduce_center_update(self, teacher_output):
77
+ self.updated = False
78
+ self.len_teacher_output = len(teacher_output)
79
+ self.async_batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
80
+ if dist.is_initialized():
81
+ self.reduce_handle = dist.all_reduce(self.async_batch_center, async_op=True)
82
+
83
+ @torch.no_grad()
84
+ def apply_center_update(self):
85
+ if self.updated is False:
86
+ world_size = dist.get_world_size() if dist.is_initialized() else 1
87
+
88
+ if self.reduce_handle is not None:
89
+ self.reduce_handle.wait()
90
+ _t = self.async_batch_center / (self.len_teacher_output * world_size)
91
+
92
+ self.center = self.center * self.center_momentum + _t * (1 - self.center_momentum)
93
+
94
+ self.updated = True
95
+
96
+
osf/models/dino_utils/ibot_patch_loss.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.distributed as dist
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
+
11
+ import logging
12
+
13
+
14
+ logger = logging.getLogger("dinov2")
15
+
16
+
17
+ try:
18
+ from xformers.ops import cross_entropy
19
+
20
+ def lossfunc(t, s, temp):
21
+ s = s.float()
22
+ t = t.float()
23
+ if s.ndim == 2:
24
+ return -cross_entropy(s.unsqueeze(0), t.unsqueeze(0), temp, bw_inplace=True).squeeze(0)
25
+ elif s.ndim == 3:
26
+ return -cross_entropy(s, t, temp, bw_inplace=True)
27
+
28
+ except ImportError:
29
+
30
+ def lossfunc(t, s, temp):
31
+ return torch.sum(t * F.log_softmax(s / temp, dim=-1), dim=-1)
32
+
33
+
34
+ class iBOTPatchLoss(nn.Module):
35
+ def __init__(self, patch_out_dim, student_temp=0.1, center_momentum=0.9):
36
+ super().__init__()
37
+ self.student_temp = student_temp
38
+ self.center_momentum = center_momentum
39
+ self.register_buffer("center", torch.zeros(1, 1, patch_out_dim))
40
+ self.updated = True
41
+ self.reduce_handle = None
42
+ self.len_teacher_patch_tokens = None
43
+ self.async_batch_center = None
44
+
45
+ @torch.no_grad()
46
+ def softmax_center_teacher(self, teacher_patch_tokens, teacher_temp):
47
+ self.apply_center_update()
48
+ return F.softmax((teacher_patch_tokens - self.center) / teacher_temp, dim=-1)
49
+
50
+ @torch.no_grad()
51
+ def sinkhorn_knopp_teacher(self, teacher_output, teacher_temp, n_masked_patches_tensor, n_iterations=3):
52
+ teacher_output = teacher_output.float()
53
+ # world_size = dist.get_world_size() if dist.is_initialized() else 1
54
+ Q = torch.exp(teacher_output / teacher_temp).t()
55
+ B = n_masked_patches_tensor
56
+ dist.all_reduce(B)
57
+ K = Q.shape[0]
58
+ sum_Q = torch.sum(Q)
59
+ if dist.is_initialized():
60
+ dist.all_reduce(sum_Q)
61
+ Q /= sum_Q
62
+
63
+ for it in range(n_iterations):
64
+ sum_of_rows = torch.sum(Q, dim=1, keepdim=True)
65
+ if dist.is_initialized():
66
+ dist.all_reduce(sum_of_rows)
67
+ Q /= sum_of_rows
68
+ Q /= K
69
+ Q /= torch.sum(Q, dim=0, keepdim=True)
70
+ Q /= B
71
+
72
+ Q *= B
73
+ return Q.t()
74
+
75
+ def forward(self, student_patch_tokens, teacher_patch_tokens, student_masks_flat):
76
+ """
77
+ Cross-entropy between softmax outputs of the teacher and student networks.
78
+ student_patch_tokens: (B, N, D) tensor
79
+ teacher_patch_tokens: (B, N, D) tensor
80
+ student_masks_flat: (B, N) tensor
81
+ """
82
+ t = teacher_patch_tokens
83
+ s = student_patch_tokens
84
+ loss = torch.sum(t * F.log_softmax(s / self.student_temp, dim=-1), dim=-1)
85
+ loss = torch.sum(loss * student_masks_flat.float(), dim=-1) / student_masks_flat.sum(dim=-1).clamp(min=1.0)
86
+ return -loss.mean()
87
+
88
+ def forward_masked(
89
+ self,
90
+ student_patch_tokens_masked,
91
+ teacher_patch_tokens_masked,
92
+ student_masks_flat,
93
+ n_masked_patches=None,
94
+ masks_weight=None,
95
+ ):
96
+ t = teacher_patch_tokens_masked
97
+ s = student_patch_tokens_masked
98
+ # loss = torch.sum(t * F.log_softmax(s / self.student_temp, dim=-1), dim=-1)
99
+ loss = lossfunc(t, s, self.student_temp)
100
+ if masks_weight is None:
101
+ masks_weight = (
102
+ (1 / student_masks_flat.sum(-1).clamp(min=1.0))
103
+ .unsqueeze(-1)
104
+ .expand_as(student_masks_flat)[student_masks_flat]
105
+ )
106
+ if n_masked_patches is not None:
107
+ loss = loss[:n_masked_patches]
108
+ loss = loss * masks_weight
109
+ return -loss.sum() / student_masks_flat.shape[0]
110
+
111
+ @torch.no_grad()
112
+ def update_center(self, teacher_patch_tokens):
113
+ self.reduce_center_update(teacher_patch_tokens)
114
+
115
+ @torch.no_grad()
116
+ def reduce_center_update(self, teacher_patch_tokens):
117
+ self.updated = False
118
+ self.len_teacher_patch_tokens = len(teacher_patch_tokens)
119
+ self.async_batch_center = torch.sum(teacher_patch_tokens.mean(1), dim=0, keepdim=True)
120
+ if dist.is_initialized():
121
+ self.reduce_handle = dist.all_reduce(self.async_batch_center, async_op=True)
122
+
123
+ @torch.no_grad()
124
+ def apply_center_update(self):
125
+ if self.updated is False:
126
+ world_size = dist.get_world_size() if dist.is_initialized() else 1
127
+
128
+ if self.reduce_handle is not None:
129
+ self.reduce_handle.wait()
130
+ _t = self.async_batch_center / (self.len_teacher_patch_tokens * world_size)
131
+
132
+ self.center = self.center * self.center_momentum + _t * (1 - self.center_momentum)
133
+
134
+ self.updated = True
osf/models/dino_utils/koleo_loss.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ # import torch.distributed as dist
13
+
14
+
15
+ logger = logging.getLogger("dinov2")
16
+
17
+
18
+ class KoLeoLoss(nn.Module):
19
+ """Kozachenko-Leonenko entropic loss regularizer from Sablayrolles et al. - 2018 - Spreading vectors for similarity search"""
20
+
21
+ def __init__(self):
22
+ super().__init__()
23
+ self.pdist = nn.PairwiseDistance(2, eps=1e-8)
24
+
25
+ def pairwise_NNs_inner(self, x):
26
+ """
27
+ Pairwise nearest neighbors for L2-normalized vectors.
28
+ Uses Torch rather than Faiss to remain on GPU.
29
+ """
30
+ dots = torch.mm(x, x.t())
31
+ n = x.shape[0]
32
+ dots.view(-1)[:: (n + 1)].fill_(-1)
33
+ _, I = torch.max(dots, dim=1) # noqa: E741
34
+ return I
35
+
36
+ def forward(self, student_output, eps=1e-8):
37
+ """
38
+ Args:
39
+ student_output (BxD): backbone output of student
40
+ """
41
+ with torch.cuda.amp.autocast(enabled=False):
42
+ student_output = F.normalize(student_output, eps=eps, p=2, dim=-1)
43
+ I = self.pairwise_NNs_inner(student_output) # noqa: E741
44
+ distances = self.pdist(student_output, student_output[I]) # BxD, BxD -> B
45
+ loss = -torch.log(distances + eps).mean()
46
+ return loss
osf/models/ssl_finetuner.py ADDED
@@ -0,0 +1,568 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from pytorch_lightning import LightningModule
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from itertools import chain
9
+ from torchmetrics import Accuracy, Precision, Recall, F1Score, AUROC, ConfusionMatrix, CohenKappa, AveragePrecision, MetricCollection
10
+ from osf.models.balanced_losses import FocalLoss, BalancedSoftmax
11
+
12
+
13
+ def _create_pred_metrics(num_classes: int) -> MetricCollection:
14
+ """Create metrics that take preds (class indices) as input."""
15
+ metrics = {
16
+ "acc": Accuracy(task="multiclass", num_classes=num_classes, average="micro"),
17
+ "f1": F1Score(task="multiclass", num_classes=num_classes, average="macro"),
18
+ "f1_w": F1Score(task="multiclass", num_classes=num_classes, average="weighted"),
19
+ "rec_m": Recall(task="multiclass", num_classes=num_classes, average="macro"),
20
+ "kappa": CohenKappa(task="multiclass", num_classes=num_classes, weights="quadratic"),
21
+ }
22
+ return MetricCollection(metrics)
23
+
24
+
25
+ def _create_prob_metrics(num_classes: int) -> MetricCollection:
26
+ """Create metrics that take probs (probabilities) as input."""
27
+ metrics = {
28
+ "auc": AUROC(task="multiclass", num_classes=num_classes, average="macro"),
29
+ "auprc": AveragePrecision(task="multiclass", num_classes=num_classes, average="macro"),
30
+ }
31
+ return MetricCollection(metrics)
32
+
33
+
34
+ def _create_perclass_pred_metrics(num_classes: int) -> MetricCollection:
35
+ """Create per-class metrics that take preds as input."""
36
+ metrics = {
37
+ "acc_c": Accuracy(task="multiclass", num_classes=num_classes, average=None),
38
+ "prec_c": Precision(task="multiclass", num_classes=num_classes, average=None),
39
+ "rec_c": Recall(task="multiclass", num_classes=num_classes, average=None),
40
+ "f1_c": F1Score(task="multiclass", num_classes=num_classes, average=None),
41
+ "cm": ConfusionMatrix(task="multiclass", num_classes=num_classes, normalize=None),
42
+ }
43
+ return MetricCollection(metrics)
44
+
45
+
46
+ def _create_perclass_prob_metrics(num_classes: int) -> MetricCollection:
47
+ """Create per-class metrics that take probs as input."""
48
+ metrics = {
49
+ "auc_c": AUROC(task="multiclass", num_classes=num_classes, average=None),
50
+ "auprc_c": AveragePrecision(task="multiclass", num_classes=num_classes, average=None),
51
+ }
52
+ return MetricCollection(metrics)
53
+
54
+
55
+
56
+ class SSLFineTuner(LightningModule):
57
+ def __init__(self,
58
+ backbones,
59
+ use_which_backbone,
60
+ config = None,
61
+ in_features: int = 256,
62
+ num_classes: int = 2,
63
+ epochs: int = 10,
64
+ dropout: float = 0.0,
65
+ lr: float = 1e-3,
66
+ weight_decay: float = 1e-4,
67
+ final_lr: float = 1e-5,
68
+ use_channel_bank: bool = True,
69
+ loss_type: str = "ce",
70
+ class_distribution: Optional[torch.Tensor] = None,
71
+ focal_gamma: float = 2.0,
72
+ focal_alpha: Optional[float | torch.Tensor] = None,
73
+ use_mean_pool: bool = False,
74
+ total_training_steps: int = None,
75
+ finetune_backbone: bool = False,
76
+ *args, **kwargs
77
+ ) -> None:
78
+ super().__init__()
79
+ self.save_hyperparameters()
80
+ self.lr = lr
81
+ self.weight_decay = weight_decay
82
+ self.epochs = epochs
83
+ self.final_lr = final_lr
84
+ self.use_channel_bank = use_channel_bank
85
+ self.loss_type = loss_type
86
+ self.focal_gamma = focal_gamma
87
+ self.focal_alpha = focal_alpha
88
+ self.use_mean_pool = use_mean_pool
89
+ self.total_training_steps = total_training_steps
90
+ self.finetune_backbone = finetune_backbone
91
+
92
+ if loss_type == "ce":
93
+ self.criterion = None
94
+ elif loss_type == "focal":
95
+ alpha = focal_alpha
96
+ if alpha is None and class_distribution is not None:
97
+ class_dist = class_distribution.float()
98
+ total_samples = class_dist.sum()
99
+ alpha = total_samples / (num_classes * class_dist)
100
+ alpha = alpha / alpha.mean()
101
+ self.criterion = FocalLoss(alpha=alpha, gamma=focal_gamma, reduction="mean")
102
+ elif loss_type == "balanced_softmax":
103
+ self.criterion = BalancedSoftmax(class_distribution, reduction="mean")
104
+ else:
105
+ raise ValueError(f"Unknown loss_type: {loss_type}. Must be one of ['ce', 'focal', 'balanced_softmax']")
106
+
107
+ if isinstance(backbones, nn.ModuleDict):
108
+ self.backbones = backbones
109
+ else:
110
+ self.backbones = nn.ModuleDict(backbones)
111
+ self.config = config
112
+ self.use_which_backbone = use_which_backbone
113
+ self.backbone = self.backbones[self.use_which_backbone] if self.use_which_backbone != "fusion" else None
114
+
115
+
116
+ if self.use_which_backbone == "fusion":
117
+ for k in ("ecg", "resp", "elect"):
118
+ if k in self.backbones:
119
+ for p in self.backbones[k].parameters():
120
+ p.requires_grad = self.finetune_backbone
121
+ if not self.finetune_backbone:
122
+ self.backbones[k].eval()
123
+ else:
124
+ for p in self.backbone.parameters():
125
+ p.requires_grad = self.finetune_backbone
126
+ if not self.finetune_backbone:
127
+ self.backbone.eval()
128
+
129
+ if self.finetune_backbone:
130
+ print(f"[INFO] Full finetuning mode: backbone parameters are TRAINABLE")
131
+
132
+ if self.use_which_backbone == "fusion":
133
+ dims = [getattr(self.backbones[k], "out_dim", in_features)
134
+ for k in ("ecg", "resp", "elect") if k in self.backbones]
135
+ if len(dims) == 0:
136
+ raise ValueError("fusion requires at least one of {'ecg','resp','elect'} in backbones.")
137
+ if len(set(dims)) != 1:
138
+ raise ValueError(f"Mean fusion requires equal output dims, got {dims}")
139
+ final_in_features = dims[0]
140
+ else:
141
+ final_in_features = getattr(self.backbone, "out_dim", in_features)
142
+
143
+ self.linear_layer = nn.Sequential(
144
+ nn.Dropout(dropout),
145
+ nn.Linear(final_in_features, num_classes)
146
+ )
147
+
148
+ self.train_pred_metrics = _create_pred_metrics(num_classes)
149
+ self.val_pred_metrics = _create_pred_metrics(num_classes)
150
+ self.test_pred_metrics = _create_pred_metrics(num_classes)
151
+
152
+ self.train_prob_metrics = _create_prob_metrics(num_classes)
153
+ self.val_prob_metrics = _create_prob_metrics(num_classes)
154
+ self.test_prob_metrics = _create_prob_metrics(num_classes)
155
+
156
+ self.train_pred_metrics_c = _create_perclass_pred_metrics(num_classes)
157
+ self.val_pred_metrics_c = _create_perclass_pred_metrics(num_classes)
158
+ self.test_pred_metrics_c = _create_perclass_pred_metrics(num_classes)
159
+
160
+ self.train_prob_metrics_c = _create_perclass_prob_metrics(num_classes)
161
+ self.val_prob_metrics_c = _create_perclass_prob_metrics(num_classes)
162
+ self.test_prob_metrics_c = _create_perclass_prob_metrics(num_classes)
163
+
164
+
165
+ self.class_names = getattr(self.config, "class_names", [str(i) for i in range(num_classes)])
166
+
167
+ def on_train_epoch_start(self) -> None:
168
+ if not self.finetune_backbone:
169
+ if self.use_which_backbone == "fusion":
170
+ for k in ("ecg", "resp", "elect"):
171
+ if k in self.backbones:
172
+ self.backbones[k].eval()
173
+ else:
174
+ self.backbone.eval()
175
+
176
+ def training_step(self, batch, batch_idx):
177
+ loss, logits, y = self.shared_step(batch)
178
+ probs = logits.softmax(-1)
179
+ preds = logits.argmax(-1)
180
+
181
+ self.train_pred_metrics.update(preds, y)
182
+ self.train_prob_metrics.update(probs, y)
183
+ self.train_pred_metrics_c.update(preds, y)
184
+ self.train_prob_metrics_c.update(probs, y)
185
+
186
+ self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=False, sync_dist=True)
187
+ return loss
188
+
189
+ def on_train_epoch_end(self):
190
+ pred_agg = self.train_pred_metrics.compute()
191
+ prob_agg = self.train_prob_metrics.compute()
192
+
193
+ self.log("train_acc", pred_agg["acc"], prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
194
+ self.log("train_f1", pred_agg["f1"], prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
195
+ self.log("train_auc", prob_agg["auc"], prog_bar=False, on_step=False, on_epoch=True, sync_dist=True)
196
+ self.log("train_auprc", prob_agg["auprc"], prog_bar=False, on_step=False, on_epoch=True, sync_dist=True)
197
+
198
+ pred_c = self.train_pred_metrics_c.compute()
199
+ prob_c = self.train_prob_metrics_c.compute()
200
+ cm = pred_c["cm"]
201
+ support = cm.sum(dim=1) if cm is not None else None
202
+
203
+ for i in range(len(pred_c["acc_c"])):
204
+ name = self.class_names[i] if i < len(self.class_names) else str(i)
205
+ self.log(f"train/acc_{name}", pred_c["acc_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
206
+ self.log(f"train/prec_{name}", pred_c["prec_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
207
+ self.log(f"train/rec_{name}", pred_c["rec_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
208
+ self.log(f"train/f1_{name}", pred_c["f1_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
209
+ self.log(f"train/auc_{name}", prob_c["auc_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
210
+ self.log(f"train/auprc_{name}", prob_c["auprc_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
211
+ if support is not None:
212
+ self.log(f"train/support_{name}", support[i].to(pred_c["acc_c"][i].dtype), on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
213
+
214
+ self.train_pred_metrics.reset()
215
+ self.train_prob_metrics.reset()
216
+ self.train_pred_metrics_c.reset()
217
+ self.train_prob_metrics_c.reset()
218
+
219
+ def validation_step(self, batch, batch_idx):
220
+ loss, logits, y = self.shared_step(batch)
221
+ probs = logits.softmax(-1)
222
+ preds = logits.argmax(-1)
223
+
224
+ self.val_pred_metrics.update(preds, y)
225
+ self.val_prob_metrics.update(probs, y)
226
+ self.val_pred_metrics_c.update(preds, y)
227
+ self.val_prob_metrics_c.update(probs, y)
228
+
229
+ self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
230
+ return loss
231
+
232
+ def on_validation_epoch_end(self):
233
+ pred_agg = self.val_pred_metrics.compute()
234
+ prob_agg = self.val_prob_metrics.compute()
235
+
236
+ self.log("val_acc", pred_agg["acc"], prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
237
+ self.log("val_f1", pred_agg["f1"], prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
238
+ self.log("val_f1_w", pred_agg["f1_w"], prog_bar=False, on_step=False, on_epoch=True, sync_dist=True)
239
+ self.log("val_rec_m", pred_agg["rec_m"], prog_bar=False, on_step=False, on_epoch=True, sync_dist=True)
240
+ self.log("val_auc", prob_agg["auc"], prog_bar=False, on_step=False, on_epoch=True, sync_dist=True)
241
+ self.log("val_auprc", prob_agg["auprc"], prog_bar=False, on_step=False, on_epoch=True, sync_dist=True)
242
+ self.log("val_kappa", pred_agg["kappa"], prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
243
+
244
+ pred_c = self.val_pred_metrics_c.compute()
245
+ prob_c = self.val_prob_metrics_c.compute()
246
+ cm = pred_c["cm"]
247
+ support = cm.sum(dim=1)
248
+
249
+ for i in range(len(pred_c["acc_c"])):
250
+ name = self.class_names[i] if i < len(self.class_names) else str(i)
251
+ self.log(f"val/acc_{name}", pred_c["acc_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
252
+ self.log(f"val/prec_{name}", pred_c["prec_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
253
+ self.log(f"val/rec_{name}", pred_c["rec_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
254
+ self.log(f"val/f1_{name}", pred_c["f1_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
255
+ self.log(f"val/auc_{name}", prob_c["auc_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
256
+ self.log(f"val/auprc_{name}", prob_c["auprc_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
257
+ self.log(f"val/support_{name}", support[i].to(pred_c["acc_c"][i].dtype), on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
258
+
259
+ self.val_pred_metrics.reset()
260
+ self.val_prob_metrics.reset()
261
+ self.val_pred_metrics_c.reset()
262
+ self.val_prob_metrics_c.reset()
263
+
264
+ def test_step(self, batch, batch_idx):
265
+ loss, logits, y = self.shared_step(batch)
266
+ probs = logits.softmax(-1)
267
+ preds = logits.argmax(-1)
268
+
269
+ self.test_pred_metrics.update(preds, y)
270
+ self.test_prob_metrics.update(probs, y)
271
+ self.test_pred_metrics_c.update(preds, y)
272
+ self.test_prob_metrics_c.update(probs, y)
273
+
274
+ self.log("test_loss", loss, on_step=False, on_epoch=True, sync_dist=True)
275
+ return loss
276
+
277
+ def on_test_epoch_end(self):
278
+ pred_agg = self.test_pred_metrics.compute()
279
+ prob_agg = self.test_prob_metrics.compute()
280
+
281
+ self.log("test_acc", pred_agg["acc"], on_step=False, on_epoch=True, sync_dist=True)
282
+ self.log("test_f1", pred_agg["f1"], on_step=False, on_epoch=True, sync_dist=True)
283
+ self.log("test_f1_w", pred_agg["f1_w"], on_step=False, on_epoch=True, sync_dist=True)
284
+ self.log("test_rec_m", pred_agg["rec_m"], on_step=False, on_epoch=True, sync_dist=True)
285
+ self.log("test_auc", prob_agg["auc"], on_step=False, on_epoch=True, sync_dist=True)
286
+ self.log("test_auprc", prob_agg["auprc"], on_step=False, on_epoch=True, sync_dist=True)
287
+ self.log("test_kappa", pred_agg["kappa"], on_step=False, on_epoch=True, sync_dist=True)
288
+
289
+ pred_c = self.test_pred_metrics_c.compute()
290
+ prob_c = self.test_prob_metrics_c.compute()
291
+ cm = pred_c["cm"]
292
+ support = cm.sum(dim=1) if cm is not None else None
293
+
294
+ for i in range(len(pred_c["acc_c"])):
295
+ name = self.class_names[i] if i < len(self.class_names) else str(i)
296
+ self.log(f"test/acc_{name}", pred_c["acc_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
297
+ self.log(f"test/prec_{name}", pred_c["prec_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
298
+ self.log(f"test/rec_{name}", pred_c["rec_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
299
+ self.log(f"test/f1_{name}", pred_c["f1_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
300
+ self.log(f"test/auc_{name}", prob_c["auc_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
301
+ self.log(f"test/auprc_{name}", prob_c["auprc_c"][i], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
302
+ if support is not None:
303
+ self.log(f"test/support_{name}", support[i].to(pred_c["acc_c"][i].dtype),
304
+ on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
305
+
306
+ self.test_pred_metrics.reset()
307
+ self.test_prob_metrics.reset()
308
+ self.test_pred_metrics_c.reset()
309
+ self.test_prob_metrics_c.reset()
310
+ def shared_step(self, batch):
311
+ context = torch.no_grad() if not self.finetune_backbone else torch.enable_grad()
312
+
313
+ with context:
314
+ psg = batch['psg']
315
+ if self.use_which_backbone == 'ecg':
316
+ x = psg[:, 0:1, :]
317
+ feats = self._get_features(self.backbone, x)
318
+
319
+ elif self.use_which_backbone == 'resp':
320
+ x = psg[:, 1:5, :]
321
+ feats = self._get_features(self.backbone, x)
322
+ elif self.use_which_backbone == 'elect':
323
+ x = psg[:, 5:, :]
324
+ feats = self._get_features(self.backbone, x)
325
+ elif self.use_which_backbone == 'all':
326
+
327
+ x = psg
328
+ feats = self._get_features(self.backbone, x)
329
+
330
+ elif self.use_which_backbone == 'fusion':
331
+ feats_list = []
332
+ if 'ecg' in self.backbones:
333
+ x_ecg = psg[:, 0:1, :]
334
+ f_ecg = self._get_features(self.backbones['ecg'], x_ecg)
335
+ feats_list.append(f_ecg)
336
+ if 'resp' in self.backbones:
337
+ x_resp = psg[:, 1:5, :]
338
+ f_resp = self._get_features(self.backbones['resp'], x_resp)
339
+ feats_list.append(f_resp)
340
+ if 'elect' in self.backbones:
341
+ x_elect = psg[:, 5:, :]
342
+ f_elect = self._get_features(self.backbones['elect'], x_elect)
343
+ feats_list.append(f_elect)
344
+
345
+
346
+ feats = torch.stack(feats_list, dim=0).mean(dim=0)
347
+ else:
348
+ raise ValueError(f"Unknown use_which_backbone: {self.use_which_backbone}")
349
+
350
+ y = batch["label"]
351
+ feats = feats.view(feats.size(0), -1)
352
+ logits = self.linear_layer(feats)
353
+ y = y.squeeze(1).long()
354
+
355
+ if self.criterion is None:
356
+ loss = F.cross_entropy(logits, y)
357
+ else:
358
+ loss = self.criterion(logits, y)
359
+
360
+ return loss, logits, y
361
+
362
+ def _get_features(self, backbone, x):
363
+ """Get features from backbone. Uses mean pooling if use_mean_pool=True."""
364
+ if self.use_mean_pool:
365
+ if hasattr(backbone, 'forward_encoding_mean_pool'):
366
+ return backbone.forward_encoding_mean_pool(x)
367
+ elif hasattr(backbone, 'forward_avg_pool'):
368
+ return backbone.forward_avg_pool(x)
369
+ return backbone(x)
370
+
371
+
372
+ def configure_optimizers(self):
373
+ if self.finetune_backbone:
374
+ if self.use_which_backbone == "fusion":
375
+ backbone_params = chain(*[self.backbones[k].parameters()
376
+ for k in ("ecg", "resp", "elect") if k in self.backbones])
377
+ else:
378
+ backbone_params = self.backbone.parameters()
379
+ params = chain(backbone_params, self.linear_layer.parameters())
380
+ else:
381
+ params = self.linear_layer.parameters()
382
+
383
+ optimizer = torch.optim.AdamW(
384
+ params,
385
+ lr=self.lr,
386
+ weight_decay=self.weight_decay,
387
+ )
388
+
389
+ if self.total_training_steps is not None and self.total_training_steps > 0:
390
+ warmup_steps = int(0.1 * self.total_training_steps)
391
+ cosine_steps = self.total_training_steps - warmup_steps
392
+
393
+ warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
394
+ optimizer,
395
+ start_factor=0.1,
396
+ end_factor=1.0,
397
+ total_iters=warmup_steps
398
+ )
399
+ cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
400
+ optimizer,
401
+ T_max=cosine_steps,
402
+ eta_min=self.final_lr
403
+ )
404
+ scheduler = torch.optim.lr_scheduler.SequentialLR(
405
+ optimizer,
406
+ schedulers=[warmup_scheduler, cosine_scheduler],
407
+ milestones=[warmup_steps]
408
+ )
409
+ return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
410
+ else:
411
+ return [optimizer]
412
+
413
+
414
+
415
+ class SSLVitalSignsRegressor(SSLFineTuner):
416
+ """SSL Finetuner for vital signs regression (HR, SPO2). Uses MSE loss."""
417
+ def __init__(self,
418
+ backbones,
419
+ use_which_backbone,
420
+ config = None,
421
+ in_features: int = 256,
422
+ num_classes: int = 1,
423
+ target_names: list = None,
424
+ dropout: float = 0.0,
425
+ **kwargs
426
+ ) -> None:
427
+ kwargs['loss_type'] = 'ce'
428
+
429
+ super().__init__(
430
+ backbones=backbones,
431
+ use_which_backbone=use_which_backbone,
432
+ config=config,
433
+ in_features=in_features,
434
+ num_classes=2,
435
+ dropout=dropout,
436
+ **kwargs
437
+ )
438
+
439
+ self.num_targets = num_classes
440
+ self.target_names = target_names or [f"target_{i}" for i in range(num_classes)]
441
+ self.criterion = nn.MSELoss()
442
+
443
+ in_feat = self.linear_layer[1].in_features
444
+ self.linear_layer = nn.Sequential(
445
+ nn.Dropout(dropout),
446
+ nn.Linear(in_feat, num_classes)
447
+ )
448
+
449
+ del self.train_pred_metrics, self.val_pred_metrics, self.test_pred_metrics
450
+ del self.train_prob_metrics, self.val_prob_metrics, self.test_prob_metrics
451
+ del self.train_pred_metrics_c, self.val_pred_metrics_c, self.test_pred_metrics_c
452
+ del self.train_prob_metrics_c, self.val_prob_metrics_c, self.test_prob_metrics_c
453
+
454
+ def shared_step(self, batch):
455
+ """Override: regression loss instead of classification."""
456
+ context = torch.no_grad() if not self.finetune_backbone else torch.enable_grad()
457
+
458
+ with context:
459
+ psg = batch['psg']
460
+ if self.use_which_backbone == 'ecg':
461
+ x = psg[:, 0:1, :]
462
+ feats = self._get_features(self.backbone, x)
463
+ elif self.use_which_backbone == 'resp':
464
+ x = psg[:, 1:5, :]
465
+ feats = self._get_features(self.backbone, x)
466
+ elif self.use_which_backbone == 'elect':
467
+ x = psg[:, 5:, :]
468
+ feats = self._get_features(self.backbone, x)
469
+ elif self.use_which_backbone == 'all':
470
+ x = psg
471
+ feats = self._get_features(self.backbone, x)
472
+ elif self.use_which_backbone == 'fusion':
473
+ feats_list = []
474
+ if 'ecg' in self.backbones:
475
+ f_ecg = self._get_features(self.backbones['ecg'], psg[:, 0:1, :])
476
+ feats_list.append(f_ecg)
477
+ if 'resp' in self.backbones:
478
+ f_resp = self._get_features(self.backbones['resp'], psg[:, 1:5, :])
479
+ feats_list.append(f_resp)
480
+ if 'elect' in self.backbones:
481
+ f_elect = self._get_features(self.backbones['elect'], psg[:, 5:, :])
482
+ feats_list.append(f_elect)
483
+
484
+ feats = torch.stack(feats_list, dim=0).mean(dim=0)
485
+ else:
486
+ raise ValueError(f"Unknown use_which_backbone: {self.use_which_backbone}")
487
+
488
+ y = batch["label"].float() # [B, num_targets]
489
+ feats = feats.view(feats.size(0), -1)
490
+ preds = self.linear_layer(feats) # [B, num_targets]
491
+
492
+ loss = self.criterion(preds, y)
493
+ return loss, preds, y
494
+
495
+ def training_step(self, batch, batch_idx):
496
+ """Override: regression metrics."""
497
+ loss, preds, y = self.shared_step(batch)
498
+
499
+ with torch.no_grad():
500
+ for i, name in enumerate(self.target_names):
501
+ mae = F.l1_loss(preds[:, i], y[:, i])
502
+ self.log(f"train_{name}_mae", mae, on_step=False, on_epoch=True, sync_dist=True)
503
+
504
+ self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=False, sync_dist=True)
505
+ return loss
506
+
507
+ def on_train_epoch_end(self):
508
+ """Override: no classification metrics to compute."""
509
+ pass
510
+
511
+ def validation_step(self, batch, batch_idx):
512
+ """Override: regression metrics."""
513
+ loss, preds, y = self.shared_step(batch)
514
+
515
+ for i, name in enumerate(self.target_names):
516
+ mae = F.l1_loss(preds[:, i], y[:, i])
517
+ self.log(f"val_{name}_mae", mae, on_step=False, on_epoch=True, sync_dist=True)
518
+
519
+ overall_mae = F.l1_loss(preds, y)
520
+ self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
521
+ self.log("val_mae", overall_mae, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
522
+ return loss
523
+
524
+ def on_validation_epoch_end(self):
525
+ """Override: no classification metrics to compute."""
526
+ pass
527
+
528
+ def test_step(self, batch, batch_idx):
529
+ """Override: regression metrics."""
530
+ loss, preds, y = self.shared_step(batch)
531
+
532
+ for i, name in enumerate(self.target_names):
533
+ p, t = preds[:, i], y[:, i]
534
+ mae = F.l1_loss(p, t)
535
+ mse = F.mse_loss(p, t)
536
+ rmse = torch.sqrt(mse)
537
+
538
+ self.log(f"test_{name}_mae", mae, on_step=False, on_epoch=True, sync_dist=True)
539
+ self.log(f"test_{name}_mse", mse, on_step=False, on_epoch=True, sync_dist=True)
540
+ self.log(f"test_{name}_rmse", rmse, on_step=False, on_epoch=True, sync_dist=True)
541
+
542
+ overall_mae = F.l1_loss(preds, y)
543
+ overall_mse = F.mse_loss(preds, y)
544
+ self.log("test_loss", loss, on_step=False, on_epoch=True, sync_dist=True)
545
+ self.log("test_mae", overall_mae, on_step=False, on_epoch=True, sync_dist=True)
546
+ self.log("test_mse", overall_mse, on_step=False, on_epoch=True, sync_dist=True)
547
+ return loss
548
+
549
+ def on_test_epoch_end(self):
550
+ """Override: no classification metrics to compute."""
551
+ pass
552
+
553
+
554
+ class SupervisedVitalSignsRegressor(SSLVitalSignsRegressor):
555
+ """Supervised from-scratch regression. Equivalent to SSLVitalSignsRegressor with finetune_backbone=True."""
556
+ def __init__(self,
557
+ backbones,
558
+ use_which_backbone,
559
+ epochs: int = 100,
560
+ **kwargs
561
+ ):
562
+ kwargs['finetune_backbone'] = True
563
+ super().__init__(
564
+ backbones=backbones,
565
+ use_which_backbone=use_which_backbone,
566
+ epochs=epochs,
567
+ **kwargs
568
+ )
osf/utils/openclip_loss.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.nn import functional as F
7
+
8
+ try:
9
+ import torch.distributed.nn
10
+ from torch import distributed as dist
11
+
12
+ has_distributed = True
13
+ except ImportError:
14
+ has_distributed = False
15
+
16
+ try:
17
+ import horovod.torch as hvd
18
+ except ImportError:
19
+ hvd = None
20
+
21
+
22
+ def get_clip_metrics(image_features, text_features, logit_scale):
23
+ metrics = {}
24
+ logits_per_image = (logit_scale * image_features @ text_features.t()).detach().cpu()
25
+ logits_per_text = logits_per_image.t().detach().cpu()
26
+
27
+ logits = {"image_to_text": logits_per_image, "text_to_image": logits_per_text}
28
+ ground_truth = torch.arange(len(text_features)).view(-1, 1)
29
+
30
+ for name, logit in logits.items():
31
+ ranking = torch.argsort(logit, descending=True)
32
+ preds = torch.where(ranking == ground_truth)[1]
33
+ preds = preds.detach().cpu().numpy()
34
+ metrics[f"{name}_mean_rank"] = preds.mean() + 1
35
+ metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1
36
+ for k in [1, 5, 10]:
37
+ metrics[f"{name}_R@{k}"] = np.mean(preds < k)
38
+
39
+ return metrics
40
+
41
+
42
+ def gather_features(
43
+ image_features,
44
+ text_features,
45
+ local_loss=False,
46
+ gather_with_grad=False,
47
+ rank=0,
48
+ world_size=1,
49
+ use_horovod=False
50
+ ):
51
+ assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
52
+ if use_horovod:
53
+ assert hvd is not None, 'Please install horovod'
54
+ if gather_with_grad:
55
+ all_image_features = hvd.allgather(image_features)
56
+ all_text_features = hvd.allgather(text_features)
57
+ else:
58
+ with torch.no_grad():
59
+ all_image_features = hvd.allgather(image_features)
60
+ all_text_features = hvd.allgather(text_features)
61
+ if not local_loss:
62
+ # ensure grads for local rank when all_* features don't have a gradient
63
+ gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
64
+ gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
65
+ gathered_image_features[rank] = image_features
66
+ gathered_text_features[rank] = text_features
67
+ all_image_features = torch.cat(gathered_image_features, dim=0)
68
+ all_text_features = torch.cat(gathered_text_features, dim=0)
69
+ else:
70
+ if gather_with_grad:
71
+ all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
72
+ all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
73
+ else:
74
+ gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
75
+ gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
76
+ dist.all_gather(gathered_image_features, image_features)
77
+ dist.all_gather(gathered_text_features, text_features)
78
+ if not local_loss:
79
+ gathered_image_features[rank] = image_features
80
+ gathered_text_features[rank] = text_features
81
+ all_image_features = torch.cat(gathered_image_features, dim=0)
82
+ all_text_features = torch.cat(gathered_text_features, dim=0)
83
+
84
+ return all_image_features, all_text_features
85
+
86
+
87
+ class ClipLoss(nn.Module):
88
+
89
+ def __init__(
90
+ self,
91
+ local_loss=True,
92
+ gather_with_grad=True,
93
+ cache_labels=True,
94
+ rank=0,
95
+ world_size=1,
96
+ use_horovod=False,
97
+ ):
98
+ super().__init__()
99
+ self.local_loss = local_loss
100
+ self.gather_with_grad = gather_with_grad
101
+ self.cache_labels = cache_labels
102
+ self.rank = rank
103
+ self.world_size = world_size
104
+ self.use_horovod = use_horovod
105
+
106
+ # cache state
107
+ self.prev_num_logits = 0
108
+ self.labels = {}
109
+
110
+ def get_ground_truth(self, device, num_logits) -> torch.Tensor:
111
+ if self.prev_num_logits != num_logits or device not in self.labels:
112
+ labels = torch.arange(num_logits, device=device, dtype=torch.long)
113
+ if self.world_size > 1 and self.local_loss:
114
+ labels = labels + num_logits * self.rank
115
+ if self.cache_labels:
116
+ self.labels[device] = labels
117
+ self.prev_num_logits = num_logits
118
+ else:
119
+ labels = self.labels[device]
120
+ return labels
121
+
122
+ def get_logits(self, image_features, text_features, logit_scale, return_gather_features=False):
123
+ if self.world_size > 1:
124
+ all_image_features, all_text_features = gather_features(
125
+ image_features,
126
+ text_features,
127
+ local_loss=self.local_loss,
128
+ gather_with_grad=self.gather_with_grad,
129
+ rank=self.rank,
130
+ world_size=self.world_size,
131
+ use_horovod=self.use_horovod,
132
+ )
133
+
134
+ if self.local_loss:
135
+ logits_per_image = logit_scale * image_features @ all_text_features.T
136
+ logits_per_text = logit_scale * text_features @ all_image_features.T
137
+ else:
138
+ logits_per_image = logit_scale * all_image_features @ all_text_features.T
139
+ logits_per_text = logits_per_image.T
140
+
141
+ if return_gather_features:
142
+ return logits_per_image, logits_per_text, all_image_features, all_text_features
143
+ else:
144
+ return logits_per_image, logits_per_text
145
+
146
+ else:
147
+ logits_per_image = logit_scale * image_features @ text_features.T
148
+ logits_per_text = logit_scale * text_features @ image_features.T
149
+
150
+ return logits_per_image, logits_per_text
151
+
152
+ def forward(self, image_features, text_features, logit_scale, output_dict=False):
153
+ device = image_features.device
154
+ logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)
155
+
156
+ labels = self.get_ground_truth(device, logits_per_image.shape[0])
157
+
158
+ total_loss = (
159
+ F.cross_entropy(logits_per_image, labels) +
160
+ F.cross_entropy(logits_per_text, labels)
161
+ ) / 2
162
+
163
+ return {"contrastive_loss": total_loss} if output_dict else total_loss
164
+
165
+
166
+ class CoCaLoss(ClipLoss):
167
+ def __init__(
168
+ self,
169
+ caption_loss_weight,
170
+ clip_loss_weight,
171
+ pad_id=0,
172
+ local_loss=False,
173
+ gather_with_grad=False,
174
+ cache_labels=False,
175
+ rank=0,
176
+ world_size=1,
177
+ use_horovod=False,
178
+ ):
179
+ super().__init__(
180
+ local_loss=local_loss,
181
+ gather_with_grad=gather_with_grad,
182
+ cache_labels=cache_labels,
183
+ rank=rank,
184
+ world_size=world_size,
185
+ use_horovod=use_horovod
186
+ )
187
+
188
+ self.clip_loss_weight = clip_loss_weight
189
+ self.caption_loss_weight = caption_loss_weight
190
+ self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id)
191
+
192
+ def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False):
193
+ if self.clip_loss_weight:
194
+ clip_loss = super().forward(image_features, text_features, logit_scale)
195
+ clip_loss = self.clip_loss_weight * clip_loss
196
+ else:
197
+ clip_loss = torch.tensor(0, device=logits.device)
198
+
199
+ caption_loss = self.caption_loss(
200
+ logits.permute(0, 2, 1),
201
+ labels,
202
+ )
203
+ caption_loss = caption_loss * self.caption_loss_weight
204
+
205
+ if output_dict:
206
+ return {"contrastive_loss": clip_loss, "caption_loss": caption_loss}
207
+
208
+ return clip_loss, caption_loss
209
+
210
+
211
+ class DistillClipLoss(ClipLoss):
212
+
213
+ def dist_loss(self, teacher_logits, student_logits):
214
+ return -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0)
215
+
216
+ def forward(
217
+ self,
218
+ image_features,
219
+ text_features,
220
+ logit_scale,
221
+ dist_image_features,
222
+ dist_text_features,
223
+ dist_logit_scale,
224
+ output_dict=False,
225
+ ):
226
+ logits_per_image, logits_per_text = \
227
+ self.get_logits(image_features, text_features, logit_scale)
228
+
229
+ dist_logits_per_image, dist_logits_per_text = \
230
+ self.get_logits(dist_image_features, dist_text_features, dist_logit_scale)
231
+
232
+ labels = self.get_ground_truth(image_features.device, logits_per_image.shape[0])
233
+
234
+ contrastive_loss = (
235
+ F.cross_entropy(logits_per_image, labels) +
236
+ F.cross_entropy(logits_per_text, labels)
237
+ ) / 2
238
+
239
+ distill_loss = (
240
+ self.dist_loss(dist_logits_per_image, logits_per_image) +
241
+ self.dist_loss(dist_logits_per_text, logits_per_text)
242
+ ) / 2
243
+
244
+ if output_dict:
245
+ return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss}
246
+
247
+ return contrastive_loss, distill_loss
248
+
249
+
250
+ def neighbour_exchange(from_rank, to_rank, tensor, group=None):
251
+ tensor_recv = torch.zeros_like(tensor)
252
+ send_op = torch.distributed.P2POp(
253
+ torch.distributed.isend,
254
+ tensor,
255
+ to_rank,
256
+ group=group,
257
+ )
258
+ recv_op = torch.distributed.P2POp(
259
+ torch.distributed.irecv,
260
+ tensor_recv,
261
+ from_rank,
262
+ group=group,
263
+ )
264
+ reqs = torch.distributed.batch_isend_irecv([send_op, recv_op])
265
+ for req in reqs:
266
+ req.wait()
267
+ return tensor_recv
268
+
269
+
270
+ def neighbour_exchange_bidir(left_rank, right_rank, tensor_to_left, tensor_to_right, group=None):
271
+ tensor_from_left = torch.zeros_like(tensor_to_right)
272
+ tensor_from_right = torch.zeros_like(tensor_to_left)
273
+ send_op_left = torch.distributed.P2POp(
274
+ torch.distributed.isend,
275
+ tensor_to_left,
276
+ left_rank,
277
+ group=group,
278
+ )
279
+ send_op_right = torch.distributed.P2POp(
280
+ torch.distributed.isend,
281
+ tensor_to_right,
282
+ right_rank,
283
+ group=group,
284
+ )
285
+ recv_op_left = torch.distributed.P2POp(
286
+ torch.distributed.irecv,
287
+ tensor_from_left,
288
+ left_rank,
289
+ group=group,
290
+ )
291
+ recv_op_right = torch.distributed.P2POp(
292
+ torch.distributed.irecv,
293
+ tensor_from_right,
294
+ right_rank,
295
+ group=group,
296
+ )
297
+ reqs = torch.distributed.batch_isend_irecv([send_op_right, send_op_left, recv_op_right, recv_op_left])
298
+ for req in reqs:
299
+ req.wait()
300
+ return tensor_from_right, tensor_from_left
301
+
302
+
303
+ class NeighbourExchange(torch.autograd.Function):
304
+ @staticmethod
305
+ def forward(ctx, from_rank, to_rank, group, tensor):
306
+ ctx.group = group
307
+ ctx.from_rank = from_rank
308
+ ctx.to_rank = to_rank
309
+ return neighbour_exchange(from_rank, to_rank, tensor, group=group)
310
+
311
+ @staticmethod
312
+ def backward(ctx, grad_output):
313
+ return (None, None, None) + (NeighbourExchange.apply(ctx.to_rank, ctx.from_rank, ctx.group, grad_output),)
314
+
315
+
316
+ def neighbour_exchange_with_grad(from_rank, to_rank, tensor, group=None):
317
+ return NeighbourExchange.apply(from_rank, to_rank, group, tensor)
318
+
319
+
320
+ class NeighbourExchangeBidir(torch.autograd.Function):
321
+ @staticmethod
322
+ def forward(ctx, left_rank, right_rank, group, tensor_to_left, tensor_to_right):
323
+ ctx.group = group
324
+ ctx.left_rank = left_rank
325
+ ctx.right_rank = right_rank
326
+ return neighbour_exchange_bidir(left_rank, right_rank, tensor_to_left, tensor_to_right, group=group)
327
+
328
+ @staticmethod
329
+ def backward(ctx, *grad_outputs):
330
+ return (None, None, None) + \
331
+ NeighbourExchangeBidir.apply(ctx.right_rank, ctx.left_rank, ctx.group, *grad_outputs)
332
+
333
+
334
+ def neighbour_exchange_bidir_with_grad(left_rank, right_rank, tensor_to_left, tensor_to_right, group=None):
335
+ return NeighbourExchangeBidir.apply(left_rank, right_rank, group, tensor_to_left, tensor_to_right)
336
+
337
+
338
+ class SigLipLoss(nn.Module):
339
+ """ Sigmoid Loss for Language Image Pre-Training (SigLIP) - https://arxiv.org/abs/2303.15343
340
+
341
+ @article{zhai2023sigmoid,
342
+ title={Sigmoid loss for language image pre-training},
343
+ author={Zhai, Xiaohua and Mustafa, Basil and Kolesnikov, Alexander and Beyer, Lucas},
344
+ journal={arXiv preprint arXiv:2303.15343},
345
+ year={2023}
346
+ }
347
+ """
348
+ def __init__(
349
+ self,
350
+ cache_labels: bool = False,
351
+ rank: int = 0,
352
+ world_size: int = 1,
353
+ dist_impl: Optional[str] = None,
354
+ ):
355
+ super().__init__()
356
+ self.cache_labels = cache_labels
357
+ self.rank = rank
358
+ self.world_size = world_size
359
+ self.dist_impl = dist_impl or 'bidir' # default to bidir exchange for now, this will likely change
360
+ assert self.dist_impl in ('bidir', 'shift', 'reduce', 'gather')
361
+
362
+ # FIXME: cache not currently used
363
+ self.prev_num_logits = 0
364
+ self.labels = {}
365
+
366
+ def get_ground_truth(self, device, dtype, num_logits, negative_only=False) -> torch.Tensor:
367
+ labels = -torch.ones((num_logits, num_logits), device=device, dtype=dtype)
368
+ if not negative_only:
369
+ labels = 2 * torch.eye(num_logits, device=device, dtype=dtype) + labels
370
+ return labels
371
+
372
+ def get_logits(self, image_features, text_features, logit_scale, logit_bias=None):
373
+ logits = logit_scale * image_features @ text_features.T
374
+ if logit_bias is not None:
375
+ logits += logit_bias
376
+ return logits
377
+
378
+ def _loss(self, image_features, text_features, logit_scale, logit_bias=None, negative_only=False):
379
+ logits = self.get_logits(image_features, text_features, logit_scale, logit_bias)
380
+ labels = self.get_ground_truth(
381
+ image_features.device,
382
+ image_features.dtype,
383
+ image_features.shape[0],
384
+ negative_only=negative_only,
385
+ )
386
+ loss = -F.logsigmoid(labels * logits).sum() / image_features.shape[0]
387
+ return loss
388
+
389
+ def forward(self, image_features, text_features, logit_scale, logit_bias, output_dict=False):
390
+ loss = self._loss(image_features, text_features, logit_scale, logit_bias)
391
+
392
+ if self.world_size > 1:
393
+ if self.dist_impl == 'bidir':
394
+ right_rank = (self.rank + 1) % self.world_size
395
+ left_rank = (self.rank - 1 + self.world_size) % self.world_size
396
+ text_features_to_right = text_features_to_left = text_features
397
+ num_bidir, remainder = divmod(self.world_size - 1, 2)
398
+ for i in range(num_bidir):
399
+ text_features_recv = neighbour_exchange_bidir_with_grad(
400
+ left_rank,
401
+ right_rank,
402
+ text_features_to_left,
403
+ text_features_to_right,
404
+ )
405
+ for f in text_features_recv:
406
+ loss += self._loss(
407
+ image_features,
408
+ f,
409
+ logit_scale,
410
+ logit_bias,
411
+ negative_only=True,
412
+ )
413
+ text_features_to_left, text_features_to_right = text_features_recv
414
+
415
+ if remainder:
416
+ text_features_recv = neighbour_exchange_with_grad(
417
+ left_rank,
418
+ right_rank,
419
+ text_features_to_right
420
+ )
421
+ loss += self._loss(
422
+ image_features,
423
+ text_features_recv,
424
+ logit_scale,
425
+ logit_bias,
426
+ negative_only=True,
427
+ )
428
+ elif self.dist_impl == "shift":
429
+ right_rank = (self.rank + 1) % self.world_size
430
+ left_rank = (self.rank - 1 + self.world_size) % self.world_size
431
+ text_features_to_right = text_features
432
+ for i in range(self.world_size - 1):
433
+ text_features_from_left = neighbour_exchange_with_grad(
434
+ left_rank,
435
+ right_rank,
436
+ text_features_to_right,
437
+ )
438
+ loss += self._loss(
439
+ image_features,
440
+ text_features_from_left,
441
+ logit_scale,
442
+ logit_bias,
443
+ negative_only=True,
444
+ )
445
+ text_features_to_right = text_features_from_left
446
+ elif self.dist_impl == "reduce":
447
+ for i in range(self.world_size):
448
+ text_from_other = torch.distributed.nn.all_reduce(
449
+ text_features * (self.rank == i),
450
+ torch.distributed.ReduceOp.SUM,
451
+ )
452
+ loss += float(i != self.rank) * self._loss(
453
+ image_features,
454
+ text_from_other,
455
+ logit_scale,
456
+ logit_bias,
457
+ negative_only=True,
458
+ )
459
+ elif self.dist_impl == "gather":
460
+ all_text = torch.distributed.nn.all_gather(text_features)
461
+ for i in range(self.world_size):
462
+ loss += float(i != self.rank) * self._loss(
463
+ image_features,
464
+ all_text[i],
465
+ logit_scale,
466
+ logit_bias,
467
+ negative_only=True,
468
+ )
469
+ else:
470
+ assert False
471
+
472
+ return {"contrastive_loss": loss} if output_dict else loss
osf/utils/results_utils.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for saving experiment results to JSON/CSV.
3
+ """
4
+
5
+ import os
6
+ import json
7
+ import glob
8
+ import numpy as np
9
+ import pandas as pd
10
+ from typing import Dict, Any, Optional, List
11
+
12
+
13
+ def convert_to_serializable(value):
14
+ """Convert tensor/numpy values to Python native types for JSON serialization."""
15
+ if hasattr(value, 'item'): # torch.Tensor
16
+ return float(value.item())
17
+ elif isinstance(value, (np.ndarray, np.generic)):
18
+ return float(value)
19
+ return value
20
+
21
+
22
+ def extract_embedding_type(root_dir: str) -> str:
23
+ """
24
+ Extract embedding type identifier from root_dir path.
25
+
26
+ Examples:
27
+ ".../dino_stage1_emb_no_norm" -> "dino_no_norm"
28
+ ".../dino_stage1_emb" -> "dino"
29
+ ".../mae_emb_normalized" -> "mae_normalized"
30
+
31
+ Args:
32
+ root_dir: Path to embedding directory
33
+
34
+ Returns:
35
+ Short embedding type identifier
36
+ """
37
+ if not root_dir:
38
+ return "unknown"
39
+
40
+ basename = os.path.basename(root_dir.rstrip('/'))
41
+
42
+ # Remove common suffixes/patterns
43
+ emb_type = basename
44
+ emb_type = emb_type.replace("_stage1_emb", "")
45
+ emb_type = emb_type.replace("_stage1", "")
46
+ emb_type = emb_type.replace("_emb", "")
47
+ emb_type = emb_type.replace("final_", "")
48
+
49
+ # Keep it concise
50
+ if len(emb_type) > 30:
51
+ emb_type = emb_type[:30]
52
+
53
+ return emb_type if emb_type else "emb"
54
+
55
+
56
+ def format_lr(lr: float) -> str:
57
+ """Format learning rate for filenames (e.g., 0.001 -> 1e-3)."""
58
+ if lr >= 1:
59
+ return f"{lr:.0f}"
60
+ elif lr >= 0.1:
61
+ return f"{lr:.1f}"
62
+ else:
63
+ # Convert to scientific notation
64
+ exp = 0
65
+ val = lr
66
+ while val < 1:
67
+ val *= 10
68
+ exp += 1
69
+ return f"{val:.0f}e-{exp}"
70
+
71
+
72
+ def save_results_to_json(
73
+ test_metrics: Dict[str, Any],
74
+ hparams: Any,
75
+ extension: str,
76
+ ckpt_dir: str,
77
+ timestamp: str,
78
+ results_dir: str = "./results",
79
+ extra_fields: Optional[Dict[str, Any]] = None,
80
+ filename_prefix: str = "",
81
+ ) -> str:
82
+ """
83
+ Save test results to a JSON file.
84
+
85
+ Args:
86
+ test_metrics: Dictionary of test metrics from trainer.test()
87
+ hparams: Hyperparameters namespace/object
88
+ extension: Experiment extension string (now used as run_name)
89
+ ckpt_dir: Checkpoint directory path
90
+ timestamp: Timestamp string
91
+ results_dir: Directory to save results (default: ./results)
92
+ extra_fields: Additional fields to include in the result record
93
+ - Should include: exp_type, task, dataset, model, etc.
94
+ filename_prefix: Prefix for the filename
95
+
96
+ Returns:
97
+ Path to the saved JSON file
98
+ """
99
+ os.makedirs(results_dir, exist_ok=True)
100
+
101
+ # Base result record
102
+ result_record = {
103
+ "run_name": extension, # Renamed from "extension" for clarity
104
+ "ckpt_dir": ckpt_dir,
105
+ "timestamp": timestamp,
106
+ }
107
+
108
+ common_fields = [
109
+ "model_name", "downstream_dataset_name", "ckpt_path", "stage2_ckpt_path",
110
+ "eval_label", "patient_cols", "use_which_backbone", "variant",
111
+ "in_features", "train_data_pct", "lr", "batch_size",
112
+ "max_epochs", "max_steps", "loss_type", "use_mean_pool",
113
+ "root_dir", "is_pretrain", "pooling", "use_transformer", "use_mil",
114
+ "encoder_name", "encoder", "mask_channels", "encoder_size",
115
+ "num_classes", "seed",
116
+ ]
117
+ for field in common_fields:
118
+ if hasattr(hparams, field):
119
+ result_record[field] = getattr(hparams, field)
120
+
121
+ standard_metrics = [
122
+ "test_acc", "test_f1", "test_f1_w", "test_auc", "test_auprc",
123
+ "test_kappa", "test_rec_m", "test_loss",
124
+ "test/acc", "test/f1_macro", "test/auc_macro", "test/auprc_macro",
125
+ ]
126
+ for metric in standard_metrics:
127
+ if metric in test_metrics:
128
+ key = metric.replace("/", "_")
129
+ result_record[key] = test_metrics[metric]
130
+
131
+ for key, value in test_metrics.items():
132
+ if key.startswith("test/") or key.startswith("test_"):
133
+ normalized_key = key.replace("/", "_")
134
+ if normalized_key not in result_record:
135
+ result_record[normalized_key] = value
136
+
137
+ if extra_fields:
138
+ result_record.update(extra_fields)
139
+
140
+ for key, value in result_record.items():
141
+ result_record[key] = convert_to_serializable(value)
142
+
143
+ if filename_prefix:
144
+ result_filename = f"{filename_prefix}_{timestamp}.json"
145
+ else:
146
+ model_name = getattr(hparams, 'model_name', 'model')
147
+ dataset_name = getattr(hparams, 'downstream_dataset_name', 'dataset')
148
+ label = getattr(hparams, 'eval_label', None) or getattr(hparams, 'patient_cols', 'task')
149
+ result_filename = f"{model_name}_{dataset_name}_{label}_{timestamp}.json"
150
+
151
+ result_path = os.path.join(results_dir, result_filename)
152
+
153
+ # Save to JSON
154
+ with open(result_path, 'w') as f:
155
+ json.dump(result_record, f, indent=2)
156
+
157
+ print(f"\n{'='*80}")
158
+ print(f"Results saved to: {result_path}")
159
+ print(f"{'='*80}\n")
160
+
161
+ return result_path
162
+
163
+
164
+ def aggregate_results_to_csv(
165
+ results_dirs: List[str],
166
+ output_path: str = "./results/aggregated_results.csv",
167
+ key_columns: Optional[List[str]] = None,
168
+ metric_columns: Optional[List[str]] = None,
169
+ ) -> pd.DataFrame:
170
+ """
171
+ Aggregate all JSON result files from multiple directories into a single CSV.
172
+
173
+ Args:
174
+ results_dirs: List of directories containing JSON result files
175
+ output_path: Path to save the aggregated CSV
176
+ key_columns: Columns to use as identifiers (default: common experiment params)
177
+ metric_columns: Metric columns to include (default: all test metrics)
178
+
179
+ Returns:
180
+ DataFrame with aggregated results
181
+ """
182
+ if key_columns is None:
183
+ key_columns = [
184
+ "exp_type", "task", "dataset", "model", "encoder",
185
+ "train_data_pct", "lr", "embedding_type",
186
+ "pretrain_ckpt_path", "finetuned_ckpt_dir", "trained_ckpt_dir",
187
+ "stage2_pretrain_ckpt", "embedding_root_dir",
188
+ "model_name", "downstream_dataset_name", "eval_label", "patient_cols",
189
+ "use_which_backbone", "variant", "loss_type",
190
+ "use_mean_pool", "pooling", "use_transformer", "use_mil",
191
+ "mask_channels", "mask_channels_str",
192
+ "ckpt_path", "stage2_ckpt_path", "root_dir",
193
+ ]
194
+
195
+ if metric_columns is None:
196
+ metric_columns = [
197
+ "test_acc", "test_f1", "test_f1_w", "test_auc", "test_auprc",
198
+ "test_kappa", "test_rec_m", "test_loss",
199
+ ]
200
+
201
+ all_records = []
202
+
203
+ for results_dir in results_dirs:
204
+ if not os.path.exists(results_dir):
205
+ print(f"[WARN] Directory not found: {results_dir}")
206
+ continue
207
+
208
+ # Find all JSON files
209
+ json_files = glob.glob(os.path.join(results_dir, "*.json"))
210
+ print(f"[INFO] Found {len(json_files)} JSON files in {results_dir}")
211
+
212
+ for json_file in json_files:
213
+ try:
214
+ with open(json_file, 'r') as f:
215
+ record = json.load(f)
216
+ record['_source_file'] = os.path.basename(json_file)
217
+ record['_source_dir'] = results_dir
218
+ all_records.append(record)
219
+ except Exception as e:
220
+ print(f"[WARN] Failed to load {json_file}: {e}")
221
+
222
+ if not all_records:
223
+ print("[WARN] No records found!")
224
+ return pd.DataFrame()
225
+
226
+ # Convert to DataFrame
227
+ df = pd.DataFrame(all_records)
228
+
229
+ existing_key_cols = [c for c in key_columns if c in df.columns]
230
+ existing_metric_cols = [c for c in metric_columns if c in df.columns]
231
+
232
+ per_class_cols = [c for c in df.columns if c.startswith("test_") and c not in existing_metric_cols]
233
+ per_class_cols = sorted(per_class_cols)
234
+
235
+ other_cols = [c for c in df.columns if c not in existing_key_cols + existing_metric_cols + per_class_cols]
236
+
237
+ ordered_cols = existing_key_cols + existing_metric_cols + per_class_cols + other_cols
238
+ df = df[[c for c in ordered_cols if c in df.columns]]
239
+
240
+ os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else ".", exist_ok=True)
241
+ df.to_csv(output_path, index=False)
242
+
243
+ print(f"\n{'='*80}")
244
+ print(f"Aggregated {len(all_records)} results to: {output_path}")
245
+ print(f"Columns: {list(df.columns[:10])}... ({len(df.columns)} total)")
246
+ print(f"{'='*80}\n")
247
+
248
+ return df
249
+
250
+
251
+ def load_results_from_json(json_path: str) -> Dict[str, Any]:
252
+ """Load a single JSON result file."""
253
+ with open(json_path, 'r') as f:
254
+ return json.load(f)
255
+
256
+
257
+ def filter_results(
258
+ df: pd.DataFrame,
259
+ model_name: Optional[str] = None,
260
+ dataset_name: Optional[str] = None,
261
+ eval_label: Optional[str] = None,
262
+ patient_cols: Optional[str] = None,
263
+ ) -> pd.DataFrame:
264
+ """
265
+ Filter aggregated results DataFrame by common fields.
266
+
267
+ Args:
268
+ df: DataFrame from aggregate_results_to_csv()
269
+ model_name: Filter by model name
270
+ dataset_name: Filter by downstream dataset name
271
+ eval_label: Filter by eval label (stage 1)
272
+ patient_cols: Filter by patient columns (stage 2)
273
+
274
+ Returns:
275
+ Filtered DataFrame
276
+ """
277
+ filtered = df.copy()
278
+
279
+ if model_name is not None and 'model_name' in filtered.columns:
280
+ filtered = filtered[filtered['model_name'] == model_name]
281
+ if dataset_name is not None and 'downstream_dataset_name' in filtered.columns:
282
+ filtered = filtered[filtered['downstream_dataset_name'] == dataset_name]
283
+ if eval_label is not None and 'eval_label' in filtered.columns:
284
+ filtered = filtered[filtered['eval_label'] == eval_label]
285
+ if patient_cols is not None and 'patient_cols' in filtered.columns:
286
+ filtered = filtered[filtered['patient_cols'] == patient_cols]
287
+
288
+ return filtered
289
+
osf_backbone.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c51190b1942556969af3c3d63c2e59430ddb1ea0377c50ea87df83712fc31857
3
+ size 341360652
pretrained_weights/readme.md ADDED
@@ -0,0 +1 @@
 
 
1
+ Please download the checkpoint throught the link in the readme of the directory root.
requirements.txt ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.3.1
2
+ accelerate==1.2.1
3
+ aiohappyeyeballs==2.4.4
4
+ aiohttp==3.11.10
5
+ aiosignal==1.3.1
6
+ albucore==0.0.24
7
+ albumentations==2.0.8
8
+ altair==5.5.0
9
+ annotated-types==0.7.0
10
+ antlr4-python3-runtime==4.9.3
11
+ asttokens==3.0.0
12
+ async-timeout==5.0.1
13
+ attrs==25.3.0
14
+ beartype==0.22.2
15
+ bitarray==3.0.0
16
+ blinker==1.9.0
17
+ braceexpand==0.1.7
18
+ certifi==2025.10.5
19
+ cffi==1.17.1
20
+ charset-normalizer==3.4.3
21
+ click==8.1.7
22
+ coloredlogs==15.0.1
23
+ comm==0.2.3
24
+ contourpy==1.3.1
25
+ cosine_annealing_warmup @ git+https://github.com/katsura-jp/pytorch-cosine-annealing-with-warmup@12d03c07553aedd3d9e9155e2b3e31ce8c64081a
26
+ cycler==0.12.1
27
+ Cython==3.0.11
28
+ datasets==3.2.0
29
+ debugpy==1.8.17
30
+ decorator==5.2.1
31
+ diffusers==0.32.1
32
+ dill==0.3.8
33
+ docker-pycreds==0.4.0
34
+ easydict==1.13
35
+ efficientnet_pytorch==0.7.1
36
+ einops==0.8.0
37
+ ema-pytorch==0.7.7
38
+ et_xmlfile==2.0.0
39
+ exceptiongroup==1.3.0
40
+ executing==2.2.1
41
+ fairseq_signals_backbone @ git+https://github.com/fuying-wang/fairseq-signals@27d94bab8a1040879c011609df1488aac21a586a
42
+ filelock==3.20.0
43
+ flatbuffers==25.9.23
44
+ fonttools==4.55.1
45
+ frozenlist==1.5.0
46
+ fsspec==2024.6.1
47
+ gitdb==4.0.11
48
+ GitPython==3.1.43
49
+ grpcio==1.75.1
50
+ h5py==3.14.0
51
+ hf-xet==1.1.10
52
+ huggingface-hub==0.34.4
53
+ humanfriendly==10.0
54
+ hydra-core==1.3.2
55
+ idna==3.10
56
+ imageio==2.37.0
57
+ importlib_metadata==8.7.0
58
+ insightface==0.7.3
59
+ ipdb==0.13.13
60
+ ipykernel==6.30.1
61
+ ipython==8.37.0
62
+ jedi==0.19.2
63
+ Jinja2==3.1.6
64
+ joblib==1.4.2
65
+ jupyter_client==8.6.3
66
+ jupyter_core==5.8.1
67
+ kiwisolver==1.4.7
68
+ kornia==0.8.1
69
+ kornia_rs==0.1.9
70
+ lazy_loader==0.4
71
+ lightning==2.4.0
72
+ lightning-utilities==0.11.9
73
+ llvmlite==0.46.0
74
+ loguru==0.7.3
75
+ lxml==5.3.0
76
+ Markdown==3.9
77
+ MarkupSafe==3.0.3
78
+ matplotlib==3.9.3
79
+ matplotlib-inline==0.1.7
80
+ ml_dtypes==0.5.3
81
+ mne==1.10.1
82
+ mpmath==1.3.0
83
+ multidict==6.1.0
84
+ multiprocess==0.70.16
85
+ munch==4.0.0
86
+ narwhals==2.6.0
87
+ nest-asyncio==1.6.0
88
+ networkx==3.4.2
89
+ neurokit2==0.2.12
90
+ ninja==1.13.0
91
+ nltk==3.9.1
92
+ numba==0.63.1
93
+ numpy==2.1.2
94
+ omegaconf==2.3.0
95
+ onnx==1.19.1
96
+ onnx2torch==1.5.15
97
+ onnxruntime==1.23.1
98
+ opencv-python==4.12.0.88
99
+ opencv-python-headless==4.12.0.88
100
+ openpyxl==3.1.5
101
+ packaging==24.2
102
+ pandas==2.2.3
103
+ parso==0.8.5
104
+ peft==0.14.0
105
+ pexpect==4.9.0
106
+ pillow==11.0.0
107
+ platformdirs==4.4.0
108
+ pooch==1.8.2
109
+ portalocker==3.0.0
110
+ POT==0.9.5
111
+ pretrainedmodels==0.7.4
112
+ prettytable==3.16.0
113
+ prompt_toolkit==3.0.52
114
+ propcache==0.2.1
115
+ protobuf==5.29.1
116
+ psutil==7.1.0
117
+ ptyprocess==0.7.0
118
+ pure_eval==0.2.3
119
+ pyarrow==18.1.0
120
+ pycparser==2.23
121
+ pydantic==2.10.3
122
+ pydantic_core==2.27.1
123
+ pydeck==0.9.1
124
+ Pygments==2.19.2
125
+ pynndescent==0.5.13
126
+ pyparsing==3.2.0
127
+ pysam==0.23.3
128
+ python-dateutil==2.9.0.post0
129
+ pytorch-lightning==2.4.0
130
+ pytorch-warmup==0.2.0
131
+ pytz==2024.2
132
+ PyWavelets==1.8.0
133
+ PyYAML==6.0.3
134
+ pyzmq==27.1.0
135
+ regex==2024.11.6
136
+ requests==2.32.5
137
+ sacrebleu==2.4.3
138
+ safetensors==0.6.2
139
+ scikit-image==0.25.2
140
+ scikit-learn==1.7.2
141
+ scipy==1.14.1
142
+ seaborn==0.13.2
143
+ segmentation_models_pytorch==0.4.0
144
+ sentencepiece==0.2.1
145
+ sentry-sdk==2.19.2
146
+ setproctitle==1.3.4
147
+ simsimd==6.5.3
148
+ six==1.17.0
149
+ smmap==5.0.1
150
+ soundfile==0.12.1
151
+ stack-data==0.6.3
152
+ streamlit==1.50.0
153
+ stringzilla==4.2.1
154
+ sympy==1.13.1
155
+ tabulate==0.9.0
156
+ tenacity==9.1.2
157
+ tensorboard==2.20.0
158
+ tensorboard-data-server==0.7.2
159
+ tensorboardX==2.6.4
160
+ threadpoolctl==3.5.0
161
+ tifffile==2025.5.10
162
+ timm==1.0.12
163
+ tokenizers==0.21.0
164
+ toml==0.10.2
165
+ tomli==2.2.1
166
+ torch==2.5.1
167
+ torchaudio==2.5.1
168
+ torchdiffeq==0.2.5
169
+ torchmetrics==1.6.0
170
+ torchtools @ git+https://github.com/pabloppp/pytorch-tools@610158d5016d6418aee27f956e7afd17ff35ba04
171
+ torchvision==0.20.1
172
+ tornado==6.5.2
173
+ tqdm==4.67.1
174
+ traitlets==5.14.3
175
+ transformers==4.47.0
176
+ typing-inspection==0.4.1
177
+ typing_extensions==4.15.0
178
+ tzdata==2024.2
179
+ umap-learn==0.5.9.post2
180
+ unet==0.8.1
181
+ urllib3==2.5.0
182
+ vitaldb==1.5.8
183
+ wandb==0.21.4
184
+ warmup_scheduler==0.3
185
+ watchdog==6.0.0
186
+ wcwidth==0.2.13
187
+ webdataset==1.0.2
188
+ Werkzeug==3.1.3
189
+ wfdb==4.1.2
190
+ xgboost==2.1.3
191
+ xxhash==3.5.0
192
+ yarl==1.18.3
193
+ zipp==3.23.0
train_config.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from config import *
2
+
3
+
4
+ # Uni-encoder models (simclr, dino, mae, vqvae, ar, etc.)
5
+ TRAIN_EDF_COLS_UNI_ENC = [ECG, EMG_Chin, EMG_LLeg, EMG_RLeg,
6
+ ABD, THX, NP, SN,
7
+ EOG_E1_A2, EOG_E2_A1,EEG_C3_A2, EEG_C4_A1,
8
+ ]
9
+ TRAIN_EDF_COLS_MULTI_ENC = [ECG,
10
+ ABD, THX, NP, SN,
11
+ EMG_Chin, EMG_LLeg, EMG_RLeg,
12
+ EOG_E1_A2, EOG_E2_A1,EEG_C3_A2, EEG_C4_A1,
13
+ ]
14
+ TRAIN_EDF_COLS_TYPE3 = [ECG, ABD, THX, NP, SN]
15
+ TRAIN_EDF_COLS_TYPE4 = [ECG, ABD, THX]
16
+
17
+
18
+ MONITOR_TYPE_MAP = {
19
+ "main": TRAIN_EDF_COLS_UNI_ENC,
20
+ "type3": TRAIN_EDF_COLS_TYPE3,
21
+ "type4": TRAIN_EDF_COLS_TYPE4,
22
+ }
23
+ STAGE2_LABEL_PATH_WITH_PATHHEAD = "/path/to/your/label/splits"
24
+ CKPT_PATH = "/path/to/your/checkpoints"
25
+ MODEL_LIST = ["dino_ours"]
26
+
27
+ AUGMENTATION_MAP = {
28
+ "dino_ours": "chan_then_pcspan",
29
+ }
30
+ SPLIT_DATA_FOLDER = "/path/to/your/postprocessed/data"
31
+ PRETRAIN_VAL_DATASET_LIST = ['shhs']
32
+ NEED_NORM_COL = [HR, SPO2, OX]