File size: 4,560 Bytes
31e2456
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""Precompute all ECG/PPG windows into a single memory-mapped tensor file.

Reads the MIMIC shard index, applies bandpass + zscore per window, and writes
a flat binary file with a companion metadata JSON. At runtime, __getitem__
is a single mmap read (~0.1 ms) instead of load_from_disk + filter (~20 ms).

Output:
  /workspace/cache/windows_ecg.bin   (float32, [N, 2500])
  /workspace/cache/windows_ppg.bin   (float32, [N, 1250])
  /workspace/cache/windows_meta.json (subject_id per window, N total)
"""
from __future__ import annotations

import argparse
import json
import os
import struct
from pathlib import Path

import numpy as np
from scipy.signal import butter, filtfilt
from tqdm import tqdm

from dotenv import load_dotenv

load_dotenv()
os.environ.setdefault("HF_TOKEN", os.environ.get("HUGGINGFACE_API_KEY", ""))

from datasets import load_from_disk

ECG_FS = 250.0
PPG_FS = 125.0
ECG_WIN = 2500
PPG_WIN = 1250


def _bandpass(x, fs, lo, hi, order=3):
    ny = 0.5 * fs
    b, a = butter(order, [lo / ny, min(hi / ny, 0.99)], btype="band")
    return filtfilt(b, a, x, method="gust").astype(np.float32)


def _zscore(x, eps=1e-6):
    return ((x - x.mean()) / (x.std() + eps)).astype(np.float32)


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--index", required=True)
    ap.add_argument("--out_dir", default="/workspace/cache")
    ap.add_argument("--workers", type=int, default=1)
    args = ap.parse_args()

    index = json.loads(Path(args.index).read_text())
    out = Path(args.out_dir)
    out.mkdir(parents=True, exist_ok=True)

    ecg_path = out / "windows_ecg.bin"
    ppg_path = out / "windows_ppg.bin"
    meta_path = out / "windows_meta.json"

    if ecg_path.exists() and ppg_path.exists() and meta_path.exists():
        existing = json.loads(meta_path.read_text())
        if existing.get("n_windows") == len(index):
            print(f"[precompute] already done: {existing['n_windows']} windows")
            return

    print(f"[precompute] {len(index)} windows to process")

    shard_cache = {}

    def load_shard(sidx):
        if sidx not in shard_cache:
            for p in Path(args.out_dir).parent.glob("mimic/shard_*"):
                if int(p.name.split("_")[1]) == sidx:
                    shard_cache[sidx] = load_from_disk(str(p))
                    break
        return shard_cache.get(sidx)

    # Find shard root
    mimic_root = None
    for candidate in [Path(args.out_dir) / "mimic", Path(args.out_dir).parent / "mimic",
                      Path("/workspace/cache/mimic")]:
        if candidate.exists():
            mimic_root = candidate
            break
    assert mimic_root, "mimic shard root not found"

    def load_shard_v2(sidx):
        if sidx not in shard_cache:
            p = mimic_root / f"shard_{sidx:05d}"
            if (p / "dataset_info.json").exists():
                shard_cache[sidx] = load_from_disk(str(p))
        return shard_cache.get(sidx)

    subjects = []
    n_written = 0

    with open(ecg_path, "wb") as f_ecg, open(ppg_path, "wb") as f_ppg:
        for rec in tqdm(index, desc="precompute"):
            sidx = rec["shard_idx"]
            ds = load_shard_v2(sidx)
            if ds is None:
                continue
            row = ds[rec["row_idx"]]
            ecg_full = np.asarray(row["ecg"], dtype=np.float32)
            ppg_full = np.asarray(row["ppg"], dtype=np.float32)[0]
            names = list(row["ecg_names"])
            if "II" not in names:
                continue
            ecg_lead = ecg_full[names.index("II")]
            se = rec["win_start_ecg"]
            sp = rec["win_start_ppg"]
            ecg_win = ecg_lead[se : se + ECG_WIN]
            ppg_win = ppg_full[sp : sp + PPG_WIN]
            if ecg_win.shape[0] != ECG_WIN or ppg_win.shape[0] != PPG_WIN:
                continue

            ecg_win = _zscore(_bandpass(ecg_win, ECG_FS, 0.5, 40.0))
            ppg_win = _zscore(_bandpass(ppg_win, PPG_FS, 0.5, 8.0))

            f_ecg.write(ecg_win.tobytes())
            f_ppg.write(ppg_win.tobytes())
            subjects.append(rec["subject_id"])
            n_written += 1

    meta = {
        "n_windows": n_written,
        "ecg_win": ECG_WIN,
        "ppg_win": PPG_WIN,
        "dtype": "float32",
        "subjects": subjects,
    }
    meta_path.write_text(json.dumps(meta))
    ecg_gb = ecg_path.stat().st_size / 1e9
    ppg_gb = ppg_path.stat().st_size / 1e9
    print(f"[precompute] wrote {n_written} windows: ecg={ecg_gb:.2f}GB ppg={ppg_gb:.2f}GB")


if __name__ == "__main__":
    main()