File size: 2,484 Bytes
31e2456
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
"""Fast mmap-backed dataset for precomputed ECG/PPG windows.

__getitem__ is a single mmap slice (~0.1 ms) — no per-window I/O, no
bandpass, no zscore. All preprocessing happened in precompute_windows.py.
"""
from __future__ import annotations

import json
import mmap
from pathlib import Path

import numpy as np
import torch
from torch.utils.data import Dataset


class MIMICFastDataset(Dataset):
    def __init__(
        self,
        cache_dir: Path,
        subjects_allow: set[str] | None = None,
    ):
        meta_path = Path(cache_dir) / "windows_meta.json"
        meta = json.loads(meta_path.read_text())
        self.n_total = meta["n_windows"]
        self.ecg_win = meta["ecg_win"]
        self.ppg_win = meta["ppg_win"]
        self.subjects = meta["subjects"]
        self.ecg_bytes = self.ecg_win * 4  # float32
        self.ppg_bytes = self.ppg_win * 4

        # Build index of allowed windows
        if subjects_allow is not None:
            self.indices = [i for i, s in enumerate(self.subjects) if s in subjects_allow]
        else:
            self.indices = list(range(self.n_total))

        # mmap the binary files (read-only)
        ecg_path = Path(cache_dir) / "windows_ecg.bin"
        ppg_path = Path(cache_dir) / "windows_ppg.bin"
        self._ecg_fh = open(ecg_path, "rb")
        self._ppg_fh = open(ppg_path, "rb")
        self._ecg_mm = mmap.mmap(self._ecg_fh.fileno(), 0, access=mmap.ACCESS_READ)
        self._ppg_mm = mmap.mmap(self._ppg_fh.fileno(), 0, access=mmap.ACCESS_READ)

    def __len__(self) -> int:
        return len(self.indices)

    def __getitem__(self, idx: int) -> dict:
        real_idx = self.indices[idx]
        ecg_off = real_idx * self.ecg_bytes
        ppg_off = real_idx * self.ppg_bytes
        ecg = np.frombuffer(self._ecg_mm, dtype=np.float32,
                            count=self.ecg_win, offset=ecg_off).copy()
        ppg = np.frombuffer(self._ppg_mm, dtype=np.float32,
                            count=self.ppg_win, offset=ppg_off).copy()
        return {
            "ecg": torch.from_numpy(ecg).unsqueeze(0),  # [1, 2500]
            "ppg": torch.from_numpy(ppg).unsqueeze(0),  # [1, 1250]
            "subject_id": self.subjects[real_idx],
            "ptt_ms": float("nan"),
        }

    def __del__(self):
        try:
            self._ecg_mm.close()
            self._ppg_mm.close()
            self._ecg_fh.close()
            self._ppg_fh.close()
        except Exception:
            pass