OSF-Base / README.md
nielsr's picture
nielsr HF Staff
Improve model card metadata and add research links
bf623a0 verified
|
raw
history blame
3.99 kB
metadata
language:
  - en
library_name: pytorch
license: mit
pipeline_tag: other
tags:
  - sleep
  - eeg
  - polysomnography
  - foundation-model
  - self-supervised
  - vit
  - biosignals

OSF: On Pre-training and Scaling of Sleep Foundation Models

This repository contains the weights for OSF, a family of sleep foundation models introduced in the paper OSF: On Pre-training and Scaling of Sleep Foundation Models.

Paper Webpage GitHub License Python

πŸ”₯ News

  • [2026-2-24] Our codebase and checkpoint is released. Full codebase for benchmarking will be public available after acceptance.
  • [2026-2-22] Our paper is out.

πŸ“– Introduction

Polysomnography (PSG) provides the gold standard for sleep assessment but suffers from substantial heterogeneity across recording devices and cohorts. OSF is a family of sleep foundation models (FMs) pre-trained on a massive corpus of 166,500 hours of sleep recordings from nine public sources. Leveraging the SleepBench benchmark, the authors establish an enhanced pre-training and scaling recipe that achieves state-of-the-art performance across diverse sleep and disease prediction tasks.

πŸ’Ώ Installation

git clone https://huggingface.co/yang-ai-lab/OSF-Base
cd OSF-Base
conda env create -f environment.yml
conda activate myenv

Dependencies

  • Python >= 3.10
  • PyTorch >= 2.9.0
  • PyTorch Lightning >= 2.5.5

πŸš€ Quick Start

We provide a demo notebook (demo.ipynb) demonstrating how to extract embeddings from PSG signals using the pretrained model.

import torch
from osf.backbone.vit1d_cls import vit_base

# Load pretrained weights (included in this repo)
payload = torch.load("osf_backbone.pth", map_location="cpu", weights_only=False)
meta = payload["metadata"]

# Initialize model
backbone = vit_base(
    num_leads=meta["num_leads"],        # 12 channels
    seq_len=meta["seq_len"],            # 1920 (64 Hz Γ— 30 s)
    patch_size=meta["patch_size_time"],
    lead_wise=meta["lead_wise"],
    patch_size_ch=meta["patch_size_ch"],
)
backbone.load_state_dict(payload["state_dict"])
backbone.eval()

# Extract embeddings
# x: [B, 12, 1920] - 12-channel PSG, 64 Hz Γ— 30 seconds
with torch.no_grad():
    cls_embs, patch_embs = backbone.forward_encoding(x, return_sequence=False)
# cls_embs: [B, 768] - Global epoch-level representation
# patch_embs: [B, 90, 768] - Local patch representations

πŸ‘©β€πŸ’» Usage

Input Format

Expected input format:

  • 12 PSG Channels: ECG, EMG_Chin, EMG_LLeg, EMG_RLeg, ABD, THX, NP, SN, EOG_E1_A2, EOG_E2_A1, EEG_C3_A2, EEG_C4_A1
  • Sample Rate: 64 Hz
  • Epoch Length: 30 seconds
  • Input Shape: [B, 12, 1920]

Pretraining and Fine-tuning

For detailed instructions on pretraining and fine-tuning using the OSF framework, please refer to the scripts in the official GitHub repository.

πŸ“Š Benchmark Evaluations

OSF has been evaluated on the SleepBench benchmark across tasks such as Sleep Stage classification, Arousal detection, Hypopnea event detection, and Oxygen Desaturation detection, outperforming existing SSL methods like SleepFM, SimCLR, and DINO.

πŸ“ Citation

If you use this code or models in your research, please cite the paper:

@article{shuai2026osf,
  title={OSF: On Pre-training and Scaling of Sleep Foundation Models},
  author={Shuai, Zitao and Xu, Zongzhe and Yang, David and Wang, Wei and Yang, Yuzhe},
  journal={arXiv preprint arXiv:2603.00190},
  year={2026}
}