File size: 6,012 Bytes
31e2456
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
"""PTB-XL fetch v3 — concurrent per-file HTTP downloads (no 3 GB monolithic zip).

PhysioNet exposes individual files at:
  https://physionet.org/files/ptb-xl/1.0.3/<filename>

Strategy:
  1. Download just `ptbxl_database.csv` (~4 MB) to know which records exist
  2. Concurrent download of the .hea/.dat pairs we need (lead II only — but
     we need to download all 12 leads since each .dat is one multilead file)
  3. Parse with wfdb in a process pool

Total bytes: 21k records × ~400 KB each ≈ 8 GB. Even at 200 KB/s that's
slow, but with 32 concurrent connections we should saturate the pod's
~1 Gbit network (~125 MB/s). 8 GB / 125 MB/s = 64 sec ideal, ~10 min
realistic given physionet bandwidth caps.

Actually shortcut — use the LR (low-res, 100 Hz) variant: ~75 KB per file
×21k = 1.5 GB total. We resample 100→250 Hz with scipy. Quality is fine
for AF detection (PTB-XL paper uses both 100 and 500 Hz freely).
"""
from __future__ import annotations

import argparse
import concurrent.futures as cf
import json
import multiprocessing as mp
import os
import urllib.request
from pathlib import Path

import numpy as np
import pandas as pd
from scipy.signal import resample_poly
from tqdm import tqdm
import wfdb

BASE = "https://physionet.org/files/ptb-xl/1.0.3"


def _parse_scp(val):
    if isinstance(val, dict):
        return val
    if not isinstance(val, str):
        return {}
    try:
        return json.loads(val.replace("'", '"'))
    except Exception:
        out = {}
        for tok in val.strip("{} ").split(","):
            if ":" in tok:
                k, v = tok.split(":", 1)
                out[k.strip().strip("'\"")] = float(v.strip())
        return out


def _download(args):
    url, dst = args
    if dst.exists() and dst.stat().st_size > 0:
        return True
    dst.parent.mkdir(parents=True, exist_ok=True)
    try:
        with urllib.request.urlopen(url, timeout=60) as r, open(dst, "wb") as f:
            f.write(r.read())
        return True
    except Exception as e:
        return False


def _resample(x, src_hz, dst_hz):
    from math import gcd
    g = gcd(int(src_hz), int(dst_hz))
    return resample_poly(x, up=int(dst_hz)//g, down=int(src_hz)//g, axis=-1).astype(np.float32)


def _process_one(arg):
    db_root, fname, afib, src_hz = arg
    try:
        rec = wfdb.rdrecord(str(Path(db_root) / fname))
        signals = rec.p_signal
        lead_names = rec.sig_name
        if "II" not in lead_names:
            return None
        lead_ii = signals[:, lead_names.index("II")]
        x = _resample(lead_ii, src_hz, 250)
        if x.shape[0] < 2500:
            x = np.pad(x, (0, 2500 - x.shape[0]))
        else:
            x = x[:2500]
        x = (x - x.mean()) / (x.std() + 1e-6)
        return (x.astype(np.float32), int(afib))
    except Exception:
        return None


def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument("--root", default="/workspace/cache/ptbxl")
    ap.add_argument("--out", default="/workspace/cache/ptbxl_af.npz")
    ap.add_argument("--use_lr", action="store_true", help="100 Hz variant (smaller, faster)")
    ap.add_argument("--limit", type=int, default=None)
    ap.add_argument("--dl_workers", type=int, default=32)
    ap.add_argument("--parse_workers", type=int, default=16)
    args = ap.parse_args()

    root = Path(args.root)
    root.mkdir(parents=True, exist_ok=True)

    csv_path = root / "ptbxl_database.csv"
    if not csv_path.exists():
        print(f"[fetch] downloading ptbxl_database.csv", flush=True)
        urllib.request.urlretrieve(f"{BASE}/ptbxl_database.csv", str(csv_path))
    print(f"[fetch] csv size: {csv_path.stat().st_size/1e6:.1f} MB", flush=True)

    meta = pd.read_csv(csv_path, index_col="ecg_id")
    meta["scp_parsed"] = meta["scp_codes"].apply(_parse_scp)
    meta["afib"] = meta["scp_parsed"].apply(
        lambda d: int(any(k in ("AFIB", "AFLT") for k in d.keys()))
    )
    if args.limit:
        meta = meta.sample(n=args.limit, random_state=0)
    print(f"[fetch] {len(meta)} records, AF positive = {int(meta['afib'].sum())}", flush=True)

    # Decide LR vs HR
    fname_col = "filename_lr" if args.use_lr else "filename_hr"
    src_hz = 100 if args.use_lr else 500

    # Build download list (.hea + .dat per record)
    dl_list = []
    for _, row in meta.iterrows():
        rel = row[fname_col]  # e.g. records100/00000/00001_lr
        for ext in (".hea", ".dat"):
            url = f"{BASE}/{rel}{ext}"
            dst = root / f"{rel}{ext}"
            dl_list.append((url, dst))

    # Filter out already-present
    todo = [(u, d) for u, d in dl_list if not (d.exists() and d.stat().st_size > 0)]
    print(f"[fetch] {len(todo)} files to download (skipping {len(dl_list)-len(todo)} cached)",
          flush=True)

    if todo:
        with cf.ThreadPoolExecutor(max_workers=args.dl_workers) as ex:
            ok_count = 0
            for ok in tqdm(ex.map(_download, todo), total=len(todo), desc="dl"):
                if ok:
                    ok_count += 1
        print(f"[fetch] downloaded ok={ok_count}/{len(todo)}", flush=True)

    # Parse
    work = [(str(root), row[fname_col], row["afib"], src_hz)
            for _, row in meta.iterrows()]
    print(f"[fetch] parsing {len(work)} records with {args.parse_workers} workers",
          flush=True)
    xs, ys = [], []
    with mp.Pool(args.parse_workers) as pool:
        for r in tqdm(pool.imap_unordered(_process_one, work, chunksize=16),
                      total=len(work), desc="parse"):
            if r is None:
                continue
            xs.append(r[0])
            ys.append(r[1])

    X = np.stack(xs).astype(np.float32)[:, None, :]
    y = np.array(ys, dtype=np.int64)
    out = Path(args.out)
    out.parent.mkdir(parents=True, exist_ok=True)
    np.savez_compressed(out, X=X, y=y)
    print(f"[fetch] wrote {out}: X={X.shape} y_pos={int(y.sum())} y_neg={int((1-y).sum())}",
          flush=True)


if __name__ == "__main__":
    main()