File size: 4,691 Bytes
31e2456
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
"""PTB-XL fetch v2 — multiprocessing, full zip download via wget.

Downloads PTB-XL v1.0.3 from PhysioNet, extracts, parses with wfdb in parallel
using a process pool (8-16 workers), caches to /workspace/cache/ptbxl_af.npz.

Usage:
    python scripts/fetch_ptbxl_v2.py --root /workspace/cache/ptbxl --out /workspace/cache/ptbxl_af.npz [--workers 12]
"""
from __future__ import annotations

import argparse
import json
import multiprocessing as mp
import os
import subprocess
import zipfile
from pathlib import Path

import numpy as np
import pandas as pd
from scipy.signal import resample_poly
from tqdm import tqdm
import wfdb

PTBXL_VERSION = "1.0.3"
PTBXL_URL = (
    f"https://physionet.org/static/published-projects/ptb-xl/"
    f"ptb-xl-a-large-publicly-available-electrocardiography-dataset-{PTBXL_VERSION}.zip"
)


def _resample_500_to_250(x):
    return resample_poly(x, up=1, down=2, axis=-1).astype(np.float32)


def _parse_scp(val):
    if isinstance(val, dict):
        return val
    if not isinstance(val, str):
        return {}
    try:
        return json.loads(val.replace("'", '"'))
    except Exception:
        out = {}
        for tok in val.strip("{} ").split(","):
            if ":" in tok:
                k, v = tok.split(":", 1)
                out[k.strip().strip("'\"")] = float(v.strip())
        return out


def _process_one(arg):
    """Read one PTB-XL record's lead II and return (x, y)."""
    db_root, fname_hr, afib = arg
    try:
        rec = wfdb.rdrecord(str(Path(db_root) / fname_hr))
        signals = rec.p_signal
        lead_names = rec.sig_name
        if "II" not in lead_names:
            return None
        lead_ii = signals[:, lead_names.index("II")]
        x = _resample_500_to_250(lead_ii)
        if x.shape[0] < 2500:
            x = np.pad(x, (0, 2500 - x.shape[0]))
        else:
            x = x[:2500]
        x = (x - x.mean()) / (x.std() + 1e-6)
        return (x.astype(np.float32), int(afib))
    except Exception as e:
        return None


def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument("--root", default="/workspace/cache/ptbxl")
    ap.add_argument("--out", default="/workspace/cache/ptbxl_af.npz")
    ap.add_argument("--workers", type=int, default=16)
    ap.add_argument("--limit", type=int, default=None)
    args = ap.parse_args()

    root = Path(args.root)
    root.mkdir(parents=True, exist_ok=True)
    zip_path = root / "ptbxl.zip"

    # Download via wget (resumable, faster than requests for 3 GB)
    if not zip_path.exists() or zip_path.stat().st_size < 1_000_000_000:  # < 1 GB = incomplete
        print(f"[fetch] downloading PTB-XL via wget", flush=True)
        zip_path.unlink(missing_ok=True)
        subprocess.run([
            "wget", "-c", "-O", str(zip_path), PTBXL_URL
        ], check=True)

    print(f"[fetch] zip size: {zip_path.stat().st_size / 1e9:.2f} GB", flush=True)

    extract_dir = root / "extracted"
    if not extract_dir.exists() or not list(extract_dir.rglob("ptbxl_database.csv")):
        print(f"[fetch] extracting to {extract_dir}", flush=True)
        extract_dir.mkdir(parents=True, exist_ok=True)
        with zipfile.ZipFile(zip_path) as z:
            z.extractall(extract_dir)

    csvs = list(extract_dir.rglob("ptbxl_database.csv"))
    assert csvs, "ptbxl_database.csv not found after extract"
    db_csv = csvs[0]
    db_root = db_csv.parent
    print(f"[fetch] db_root = {db_root}", flush=True)

    meta = pd.read_csv(db_csv, index_col="ecg_id")
    meta["scp_parsed"] = meta["scp_codes"].apply(_parse_scp)
    meta["afib"] = meta["scp_parsed"].apply(
        lambda d: int(any(k in ("AFIB", "AFLT") for k in d.keys()))
    )
    if args.limit:
        meta = meta.sample(n=args.limit, random_state=0)
    print(f"[fetch] {len(meta)} records, AF positive = {int(meta['afib'].sum())}", flush=True)

    work = [(str(db_root), row["filename_hr"], row["afib"])
            for _, row in meta.iterrows()]

    print(f"[fetch] parsing with {args.workers} workers", flush=True)
    xs, ys = [], []
    with mp.Pool(args.workers) as pool:
        for r in tqdm(pool.imap_unordered(_process_one, work, chunksize=8),
                      total=len(work), desc="ptb-xl"):
            if r is None:
                continue
            xs.append(r[0])
            ys.append(r[1])

    X = np.stack(xs).astype(np.float32)[:, None, :]
    y = np.array(ys, dtype=np.int64)
    out = Path(args.out)
    out.parent.mkdir(parents=True, exist_ok=True)
    np.savez_compressed(out, X=X, y=y)
    print(f"[fetch] wrote {out}: X={X.shape} y_pos={int(y.sum())} y_neg={int((1-y).sum())}",
          flush=True)


if __name__ == "__main__":
    main()