File size: 1,923 Bytes
31e2456
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
"""Download all MIMIC shards on the training host, then build the window index.

Run on the RunPod pod right after boot. Saves to /workspace/cache/.
"""
from __future__ import annotations

import argparse
import json
import os
from pathlib import Path

from dotenv import load_dotenv
from huggingface_hub import snapshot_download

load_dotenv()
os.environ.setdefault("HF_TOKEN", os.environ.get("HUGGINGFACE_API_KEY", ""))

from physiojepa.data import MIMICAlignedDataset

REPO = "lucky9-cyou/mimic-iv-aligned-ppg-ecg"


def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument("--root", type=str, default="/workspace/cache/mimic")
    ap.add_argument("--index", type=str, default="/workspace/cache/mimic_index.json")
    ap.add_argument("--n_shards", type=int, default=412)
    args = ap.parse_args()

    root = Path(args.root)
    root.mkdir(parents=True, exist_ok=True)
    patterns = [f"shard_{i:05d}/*" for i in range(args.n_shards)]
    print(f"[prepare] downloading {len(patterns)} shard patterns to {root}")
    local = snapshot_download(REPO, repo_type="dataset", allow_patterns=patterns,
                              local_dir=str(root), max_workers=16)
    shard_roots = sorted([p for p in Path(local).glob("shard_*")
                          if (p / "dataset_info.json").exists()])
    print(f"[prepare] {len(shard_roots)} shards ready; building window index")
    ds = MIMICAlignedDataset(shard_roots=shard_roots, index_path=Path(args.index),
                             build_index=True)
    info = {
        "n_shards": len(shard_roots),
        "n_windows": len(ds),
        "n_subjects": len(set(r["subject_id"] for r in ds.index)),
        "shard_roots": [str(p) for p in shard_roots],
    }
    Path(args.index).with_suffix(".meta.json").write_text(json.dumps(info, indent=2))
    print(f"[prepare] index built: {json.dumps(info, indent=2)}")


if __name__ == "__main__":
    main()