PhysioJEPA / src /physiojepa /data_fast.py
guychuk's picture
Upload folder using huggingface_hub
31e2456 verified
"""Fast mmap-backed dataset for precomputed ECG/PPG windows.
__getitem__ is a single mmap slice (~0.1 ms) — no per-window I/O, no
bandpass, no zscore. All preprocessing happened in precompute_windows.py.
"""
from __future__ import annotations
import json
import mmap
from pathlib import Path
import numpy as np
import torch
from torch.utils.data import Dataset
class MIMICFastDataset(Dataset):
def __init__(
self,
cache_dir: Path,
subjects_allow: set[str] | None = None,
):
meta_path = Path(cache_dir) / "windows_meta.json"
meta = json.loads(meta_path.read_text())
self.n_total = meta["n_windows"]
self.ecg_win = meta["ecg_win"]
self.ppg_win = meta["ppg_win"]
self.subjects = meta["subjects"]
self.ecg_bytes = self.ecg_win * 4 # float32
self.ppg_bytes = self.ppg_win * 4
# Build index of allowed windows
if subjects_allow is not None:
self.indices = [i for i, s in enumerate(self.subjects) if s in subjects_allow]
else:
self.indices = list(range(self.n_total))
# mmap the binary files (read-only)
ecg_path = Path(cache_dir) / "windows_ecg.bin"
ppg_path = Path(cache_dir) / "windows_ppg.bin"
self._ecg_fh = open(ecg_path, "rb")
self._ppg_fh = open(ppg_path, "rb")
self._ecg_mm = mmap.mmap(self._ecg_fh.fileno(), 0, access=mmap.ACCESS_READ)
self._ppg_mm = mmap.mmap(self._ppg_fh.fileno(), 0, access=mmap.ACCESS_READ)
def __len__(self) -> int:
return len(self.indices)
def __getitem__(self, idx: int) -> dict:
real_idx = self.indices[idx]
ecg_off = real_idx * self.ecg_bytes
ppg_off = real_idx * self.ppg_bytes
ecg = np.frombuffer(self._ecg_mm, dtype=np.float32,
count=self.ecg_win, offset=ecg_off).copy()
ppg = np.frombuffer(self._ppg_mm, dtype=np.float32,
count=self.ppg_win, offset=ppg_off).copy()
return {
"ecg": torch.from_numpy(ecg).unsqueeze(0), # [1, 2500]
"ppg": torch.from_numpy(ppg).unsqueeze(0), # [1, 1250]
"subject_id": self.subjects[real_idx],
"ptt_ms": float("nan"),
}
def __del__(self):
try:
self._ecg_mm.close()
self._ppg_mm.close()
self._ecg_fh.close()
self._ppg_fh.close()
except Exception:
pass