language:
- en
library_name: pytorch
license: mit
pipeline_tag: other
tags:
- sleep
- eeg
- polysomnography
- foundation-model
- self-supervised
- vit
- biosignals
OSF: On Pre-training and Scaling of Sleep Foundation Models
This repository contains the weights for OSF, a family of sleep foundation models introduced in the paper OSF: On Pre-training and Scaling of Sleep Foundation Models.
π₯ News
- [2026-2-24] Our codebase and checkpoint is released. Full codebase for benchmarking will be public available after acceptance.
- [2026-2-22] Our paper is out.
π Introduction
Polysomnography (PSG) provides the gold standard for sleep assessment but suffers from substantial heterogeneity across recording devices and cohorts. OSF is a family of sleep foundation models (FMs) pre-trained on a massive corpus of 166,500 hours of sleep recordings from nine public sources. Leveraging the SleepBench benchmark, the authors establish an enhanced pre-training and scaling recipe that achieves state-of-the-art performance across diverse sleep and disease prediction tasks.
πΏ Installation
git clone https://huggingface.co/yang-ai-lab/OSF-Base
cd OSF-Base
conda env create -f environment.yml
conda activate myenv
Dependencies
- Python >= 3.10
- PyTorch >= 2.9.0
- PyTorch Lightning >= 2.5.5
π Quick Start
We provide a demo notebook (demo.ipynb) demonstrating how to extract embeddings from PSG signals using the pretrained model.
import torch
from osf.backbone.vit1d_cls import vit_base
# Load pretrained weights (included in this repo)
payload = torch.load("osf_backbone.pth", map_location="cpu", weights_only=False)
meta = payload["metadata"]
# Initialize model
backbone = vit_base(
num_leads=meta["num_leads"], # 12 channels
seq_len=meta["seq_len"], # 1920 (64 Hz Γ 30 s)
patch_size=meta["patch_size_time"],
lead_wise=meta["lead_wise"],
patch_size_ch=meta["patch_size_ch"],
)
backbone.load_state_dict(payload["state_dict"])
backbone.eval()
# Extract embeddings
# x: [B, 12, 1920] - 12-channel PSG, 64 Hz Γ 30 seconds
with torch.no_grad():
cls_embs, patch_embs = backbone.forward_encoding(x, return_sequence=False)
# cls_embs: [B, 768] - Global epoch-level representation
# patch_embs: [B, 90, 768] - Local patch representations
π©βπ» Usage
Input Format
Expected input format:
- 12 PSG Channels: ECG, EMG_Chin, EMG_LLeg, EMG_RLeg, ABD, THX, NP, SN, EOG_E1_A2, EOG_E2_A1, EEG_C3_A2, EEG_C4_A1
- Sample Rate: 64 Hz
- Epoch Length: 30 seconds
- Input Shape:
[B, 12, 1920]
Pretraining and Fine-tuning
For detailed instructions on pretraining and fine-tuning using the OSF framework, please refer to the scripts in the official GitHub repository.
π Benchmark Evaluations
OSF has been evaluated on the SleepBench benchmark across tasks such as Sleep Stage classification, Arousal detection, Hypopnea event detection, and Oxygen Desaturation detection, outperforming existing SSL methods like SleepFM, SimCLR, and DINO.
π Citation
If you use this code or models in your research, please cite the paper:
@article{shuai2026osf,
title={OSF: On Pre-training and Scaling of Sleep Foundation Models},
author={Shuai, Zitao and Xu, Zongzhe and Yang, David and Wang, Wei and Yang, Yuzhe},
journal={arXiv preprint arXiv:2603.00190},
year={2026}
}