Mamba-based Deep Learning Approach for Sleep Staging on a Wireless Multimodal Wearable System without Electroencephalography
Paper โข 2412.15947 โข Published
YAML Metadata Warning:empty or missing yaml metadata in repo card
Check out the documentation for more information.
SleepStageNet ๆฏไธไธชๅบไบๆทฑๅบฆๅญฆไน ็็ก็ ๅๆๆจกๅ๏ผไฝฟ็จ 4ไธช้EEG็นๅพ ่ฟ่ก่ชๅจ็ก็ ๅๆ๏ผ
| ็นๅพ | ่ฏดๆ | ็็ๆไน |
|---|---|---|
| HRV (RMSSD) | ๅฟ็ๅๅผๆง | ๅๆ ่ชไธป็ฅ็ป็ถๆ๏ผๅฏไบคๆๆดปๆง๏ผ |
| ๅฟ็ (HR) | ๆฏๅ้ๅฟ่ทณๆฌกๆฐ | ๅๆ ๆดไฝๅฟ่ก็ฎกๆฐดๅนณ |
| ๅผๅธ้ข็ (RR) | ๆฏๅ้ๅผๅธๆฌกๆฐ | ๅๆ ๅผๅธ่ฐ่็ถๆ |
| ไฝๅจ (Movement) | ่บซไฝๆดปๅจ้ | ๅๆ ่ขไฝ่ฟๅจ/่ง้ |
่พๅบ๏ผๆฏ30็งepoch็็ก็ ๅๆ โ Wake / N1 / N2 / N3 / REM
็ปผๅไบไปฅไธSOTA่ฎบๆ็ๆไฝณ่ฎพ่ฎก๏ผ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ SleepStageNet โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค
โ โ
โ ่พๅ
ฅ: (batch, T, 4) โ Tไธช30็งepoch, ๆฏไธชๆ4ไธช็นๅพ โ
โ โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ 1. Feature Projection (ๅ่SleepPPG-Net) โ โ
โ โ MLP: 4 โ d_model*2 โ d_model โ โ
โ โโโโโโโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ โ
โ โโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ 2. Cross-Feature Attention (ๅ่wav2sleep) โ โ
โ โ ๆฏไธช็นๅพ็ฌ็ซๆๅฝฑ + CLS Token + Transformer โ โ
โ โ ๅญฆไน HRVโHRโRRโMovement ็ไบคไบๅ
ณ็ณป โ โ
โ โโโโโโโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ (้จๆง่ๅ) โ
โ โโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ 3. Positional Encoding โ โ
โ โ ๆญฃๅผฆไฝ็ฝฎ็ผ็ (ๆถ้ดไฝ็ฝฎๅฏน็ก็ ็ปๆๅพ้่ฆ) โ โ
โ โโโโโโโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ โ
โ โโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ 4. Dilated Temporal CNN (ๅ่wav2sleep) โ โ
โ โ 2 blocks ร [d=1,2,4,8,16,32], k=7 โ โ
โ โ ๆๅ้ โ 6ๅฐๆถ โ ๆ่ทๅฎๆด็ก็ ๅจๆ โ โ
โ โโโโโโโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ โ
โ โโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ 5. Classification Head โ โ
โ โ Linear(d_model โ d_model/2 โ n_classes) โ โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ
โ ่พๅบ: (batch, T, n_classes) โ ๆฏไธชepoch็ๅ็ฑปlogits โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
| ้ ็ฝฎ | d_model | ๅๆฐ้ | Epoch Mixer | Sequence Mixer | ้็จๅบๆฏ |
|---|---|---|---|---|---|
| small | 64 | ~195K | 1ๅฑ Transformer | 1 block, d=[1,2,4,8,16] | ๅฟซ้ๅฎ้ช |
| base | 128 | ~2M | 2ๅฑ Transformer | 2 blocks, d=[1,2,4,8,16,32] | ๆจ่ |
| large | 256 | ~13.7M | 3ๅฑ Transformer | 3 blocks, d=[1,2,4,8,16,32,64] | ๆไฝณๆง่ฝ |
| ่ฎบๆ | ่ดก็ฎ | ๅนดไปฝ |
|---|---|---|
| wav2sleep | Epoch Mixer + Sequence Mixer + ้ๆบๆจกๆ้ฎ่ฝ | 2024 |
| Cross-Modal Transformer | ่ทจๆจกๆๆณจๆๅ + ๅ ๆไบคๅ็ต [1,2,1,2,2] | 2022 |
| SleepPPG-Net | Per-patient Z-score + BiLSTM FEๅบ็บฟ | 2022 |
| Mamba-sleep | ่ฝป้Mambaๅบๅๅปบๆจก + ้้ข็ๅ ๆ | 2024 |
pip install torch numpy
import numpy as np
import torch
from sleep_staging_model import create_model
# ๅๅปบๆจกๅ
model = create_model('base', n_features=4, n_classes=4)
# ๅๅค่พๅ
ฅๆฐๆฎ: Tไธช30็งepoch็็นๅพ
T = 1200 # 10ๅฐๆถ = 1200ไธชepoch
features = np.stack([
hrv_rmssd, # HRV (RMSSD) ๅบๅ
heart_rate, # ๅฟ็ๅบๅ
respiratory_rate, # ๅผๅธ้ข็ๅบๅ
body_movement, # ไฝๅจๅบๅ
], axis=-1) # shape: (T, 4)
# Z-scoreๆ ๅๅ (ๅ
ณ้ฎๆญฅ้ชค!)
features = (features - features.mean(axis=0)) / (features.std(axis=0) + 1e-8)
features = np.clip(features, -5, 5)
# ๆจ็
x = torch.tensor(features, dtype=torch.float32).unsqueeze(0) # (1, T, 4)
model.eval()
with torch.no_grad():
logits = model(x) # (1, T, n_classes)
predictions = torch.argmax(logits, dim=-1) # (1, T)
# ๆ ็ญพ: 0=Wake, 1=N1, 2=N2, 3=N3, 4=REM
stage_names = {0: 'Wake', 1: 'N1', 2: 'N2', 3: 'N3', 4: 'REM'}
python train_sleep_staging.py --model base --batch_size 16 --lr 1e-3 --max_epochs 100
| ่ถ ๅๆฐ | ๅผ | ๆฅๆบ |
|---|---|---|
| Optimizer | AdamW | wav2sleep |
| Learning Rate | 1e-3 | wav2sleep |
| Weight Decay | 1e-2 | wav2sleep |
| Batch Size | 16 (ๆดๅค) | wav2sleep |
| LR Schedule | OneCycleLR (10% warmup + cosine) | ๆน่ฟ่ชwav2sleep |
| Loss | Weighted Focal Loss (ฮณ=2) | Cross-Modal Transformer + Focal |
| Class Weights | ้้ข็ๅ ๆ | Mamba-sleep |
| Early Stopping | patience=10 | wav2sleep (5) |
| Gradient Clip | max_norm=1.0 | ๆ ๅๅฎ่ทต |
| Augmentation | ้ๆบ็ฟป่ฝฌ(p=0.5) + ๅชๅฃฐ(p=0.3) | wav2sleep |
| Feature Mask | p=0.3 | wav2sleep |
่ฎญ็ปๆฐๆฎ: abmallick/heart-breath-sleep-stage-dataset
| ๆๆ | 4็ฑป (ๆ REM) | 5็ฑป (ๅซREM) |
|---|---|---|
| Cohen's ฮบ | 0.55-0.65 | 0.50-0.60 |
| Accuracy | 70-80% | 65-75% |
| F1 (macro) | 0.50-0.65 | 0.45-0.60 |
ๆณจ: ้EEG็นๅพ็็ก็ ๅๆๆง่ฝ้ๅธธไฝไบEEG-basedๆนๆณ(ฮบโ0.75+)๏ผ่ฟๆฏ่ฏฅ้ขๅ็ๅบๆ้ๅถใ
MIT