"""Fast mmap-backed dataset for precomputed ECG/PPG windows. __getitem__ is a single mmap slice (~0.1 ms) — no per-window I/O, no bandpass, no zscore. All preprocessing happened in precompute_windows.py. """ from __future__ import annotations import json import mmap from pathlib import Path import numpy as np import torch from torch.utils.data import Dataset class MIMICFastDataset(Dataset): def __init__( self, cache_dir: Path, subjects_allow: set[str] | None = None, ): meta_path = Path(cache_dir) / "windows_meta.json" meta = json.loads(meta_path.read_text()) self.n_total = meta["n_windows"] self.ecg_win = meta["ecg_win"] self.ppg_win = meta["ppg_win"] self.subjects = meta["subjects"] self.ecg_bytes = self.ecg_win * 4 # float32 self.ppg_bytes = self.ppg_win * 4 # Build index of allowed windows if subjects_allow is not None: self.indices = [i for i, s in enumerate(self.subjects) if s in subjects_allow] else: self.indices = list(range(self.n_total)) # mmap the binary files (read-only) ecg_path = Path(cache_dir) / "windows_ecg.bin" ppg_path = Path(cache_dir) / "windows_ppg.bin" self._ecg_fh = open(ecg_path, "rb") self._ppg_fh = open(ppg_path, "rb") self._ecg_mm = mmap.mmap(self._ecg_fh.fileno(), 0, access=mmap.ACCESS_READ) self._ppg_mm = mmap.mmap(self._ppg_fh.fileno(), 0, access=mmap.ACCESS_READ) def __len__(self) -> int: return len(self.indices) def __getitem__(self, idx: int) -> dict: real_idx = self.indices[idx] ecg_off = real_idx * self.ecg_bytes ppg_off = real_idx * self.ppg_bytes ecg = np.frombuffer(self._ecg_mm, dtype=np.float32, count=self.ecg_win, offset=ecg_off).copy() ppg = np.frombuffer(self._ppg_mm, dtype=np.float32, count=self.ppg_win, offset=ppg_off).copy() return { "ecg": torch.from_numpy(ecg).unsqueeze(0), # [1, 2500] "ppg": torch.from_numpy(ppg).unsqueeze(0), # [1, 1250] "subject_id": self.subjects[real_idx], "ptt_ms": float("nan"), } def __del__(self): try: self._ecg_mm.close() self._ppg_mm.close() self._ecg_fh.close() self._ppg_fh.close() except Exception: pass