File size: 3,897 Bytes
31e2456
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""Fetch PTB-XL from PhysioNet (open access, no credentialing) and cache lead II
@ 250 Hz with binary AFIB labels into a single .npz file for fast eval reload.

Resulting cache layout:
    /workspace/cache/ptbxl_af.npz   (X: [N,1,2500] float32, y: [N] int64)
"""
from __future__ import annotations

import argparse
import io
import os
import re
import tarfile
import zipfile
from pathlib import Path

import numpy as np
import requests
from tqdm import tqdm

PTBXL_VERSION = "1.0.3"
PTBXL_URL = (
    f"https://physionet.org/static/published-projects/ptb-xl/"
    f"ptb-xl-a-large-publicly-available-electrocardiography-dataset-{PTBXL_VERSION}.zip"
)


def _resample_500_to_250(x):
    from scipy.signal import resample_poly
    return resample_poly(x, up=1, down=2, axis=-1).astype(np.float32)


def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument("--root", default="/workspace/cache/ptbxl")
    ap.add_argument("--out", default="/workspace/cache/ptbxl_af.npz")
    ap.add_argument("--limit", type=int, default=None)
    args = ap.parse_args()

    root = Path(args.root)
    root.mkdir(parents=True, exist_ok=True)
    zip_path = root / "ptbxl.zip"
    if not zip_path.exists():
        print(f"[fetch] downloading PTB-XL ({PTBXL_URL})")
        r = requests.get(PTBXL_URL, stream=True, timeout=600)
        r.raise_for_status()
        total = int(r.headers.get("content-length", 0))
        with open(zip_path, "wb") as f:
            for chunk in tqdm(r.iter_content(chunk_size=1024 * 1024),
                              total=total // (1024 * 1024)):
                if chunk:
                    f.write(chunk)
    extract_dir = root / "extracted"
    if not extract_dir.exists():
        print(f"[fetch] extracting to {extract_dir}")
        with zipfile.ZipFile(zip_path) as z:
            z.extractall(extract_dir)
    # find ptbxl_database.csv
    csvs = list(extract_dir.rglob("ptbxl_database.csv"))
    assert csvs, "ptbxl_database.csv not found in extracted zip"
    db_csv = csvs[0]
    db_root = db_csv.parent
    print(f"[fetch] db_root = {db_root}")

    import pandas as pd
    import wfdb

    meta = pd.read_csv(db_csv, index_col="ecg_id")
    # parse scp_codes safely
    def _parse(val):
        try:
            import json
            return json.loads(val.replace("'", '"'))
        except Exception:
            out = {}
            for tok in val.strip("{} ").split(","):
                if ":" in tok:
                    k, v = tok.split(":", 1)
                    out[k.strip().strip("'\"")] = float(v.strip())
            return out

    meta["scp_parsed"] = meta["scp_codes"].apply(_parse)
    meta["afib"] = meta["scp_parsed"].apply(
        lambda d: int(any(k in ("AFIB", "AFLT") for k in d.keys()))
    )
    if args.limit:
        meta = meta.sample(n=args.limit, random_state=0)
    print(f"[fetch] {len(meta)} records, AF positive = {int(meta['afib'].sum())}")

    xs, ys = [], []
    for _, row in tqdm(meta.iterrows(), total=len(meta), desc="ptb-xl"):
        rec = wfdb.rdrecord(str(db_root / row["filename_hr"]))
        signals = rec.p_signal  # [T, 12] @ 500 Hz
        lead_names = rec.sig_name
        if "II" not in lead_names:
            continue
        lead_ii = signals[:, lead_names.index("II")]
        x = _resample_500_to_250(lead_ii)
        if x.shape[0] < 2500:
            x = np.pad(x, (0, 2500 - x.shape[0]))
        else:
            x = x[:2500]
        x = (x - x.mean()) / (x.std() + 1e-6)
        xs.append(x.astype(np.float32))
        ys.append(int(row["afib"]))

    X = np.stack(xs).astype(np.float32)[:, None, :]
    y = np.array(ys, dtype=np.int64)
    out = Path(args.out)
    out.parent.mkdir(parents=True, exist_ok=True)
    np.savez_compressed(out, X=X, y=y)
    print(f"[fetch] wrote {out}: X={X.shape} y_pos={int(y.sum())} y_neg={int((1-y).sum())}")


if __name__ == "__main__":
    main()