"""PTB-XL fetch v2 — multiprocessing, full zip download via wget. Downloads PTB-XL v1.0.3 from PhysioNet, extracts, parses with wfdb in parallel using a process pool (8-16 workers), caches to /workspace/cache/ptbxl_af.npz. Usage: python scripts/fetch_ptbxl_v2.py --root /workspace/cache/ptbxl --out /workspace/cache/ptbxl_af.npz [--workers 12] """ from __future__ import annotations import argparse import json import multiprocessing as mp import os import subprocess import zipfile from pathlib import Path import numpy as np import pandas as pd from scipy.signal import resample_poly from tqdm import tqdm import wfdb PTBXL_VERSION = "1.0.3" PTBXL_URL = ( f"https://physionet.org/static/published-projects/ptb-xl/" f"ptb-xl-a-large-publicly-available-electrocardiography-dataset-{PTBXL_VERSION}.zip" ) def _resample_500_to_250(x): return resample_poly(x, up=1, down=2, axis=-1).astype(np.float32) def _parse_scp(val): if isinstance(val, dict): return val if not isinstance(val, str): return {} try: return json.loads(val.replace("'", '"')) except Exception: out = {} for tok in val.strip("{} ").split(","): if ":" in tok: k, v = tok.split(":", 1) out[k.strip().strip("'\"")] = float(v.strip()) return out def _process_one(arg): """Read one PTB-XL record's lead II and return (x, y).""" db_root, fname_hr, afib = arg try: rec = wfdb.rdrecord(str(Path(db_root) / fname_hr)) signals = rec.p_signal lead_names = rec.sig_name if "II" not in lead_names: return None lead_ii = signals[:, lead_names.index("II")] x = _resample_500_to_250(lead_ii) if x.shape[0] < 2500: x = np.pad(x, (0, 2500 - x.shape[0])) else: x = x[:2500] x = (x - x.mean()) / (x.std() + 1e-6) return (x.astype(np.float32), int(afib)) except Exception as e: return None def main() -> None: ap = argparse.ArgumentParser() ap.add_argument("--root", default="/workspace/cache/ptbxl") ap.add_argument("--out", default="/workspace/cache/ptbxl_af.npz") ap.add_argument("--workers", type=int, default=16) ap.add_argument("--limit", type=int, default=None) args = ap.parse_args() root = Path(args.root) root.mkdir(parents=True, exist_ok=True) zip_path = root / "ptbxl.zip" # Download via wget (resumable, faster than requests for 3 GB) if not zip_path.exists() or zip_path.stat().st_size < 1_000_000_000: # < 1 GB = incomplete print(f"[fetch] downloading PTB-XL via wget", flush=True) zip_path.unlink(missing_ok=True) subprocess.run([ "wget", "-c", "-O", str(zip_path), PTBXL_URL ], check=True) print(f"[fetch] zip size: {zip_path.stat().st_size / 1e9:.2f} GB", flush=True) extract_dir = root / "extracted" if not extract_dir.exists() or not list(extract_dir.rglob("ptbxl_database.csv")): print(f"[fetch] extracting to {extract_dir}", flush=True) extract_dir.mkdir(parents=True, exist_ok=True) with zipfile.ZipFile(zip_path) as z: z.extractall(extract_dir) csvs = list(extract_dir.rglob("ptbxl_database.csv")) assert csvs, "ptbxl_database.csv not found after extract" db_csv = csvs[0] db_root = db_csv.parent print(f"[fetch] db_root = {db_root}", flush=True) meta = pd.read_csv(db_csv, index_col="ecg_id") meta["scp_parsed"] = meta["scp_codes"].apply(_parse_scp) meta["afib"] = meta["scp_parsed"].apply( lambda d: int(any(k in ("AFIB", "AFLT") for k in d.keys())) ) if args.limit: meta = meta.sample(n=args.limit, random_state=0) print(f"[fetch] {len(meta)} records, AF positive = {int(meta['afib'].sum())}", flush=True) work = [(str(db_root), row["filename_hr"], row["afib"]) for _, row in meta.iterrows()] print(f"[fetch] parsing with {args.workers} workers", flush=True) xs, ys = [], [] with mp.Pool(args.workers) as pool: for r in tqdm(pool.imap_unordered(_process_one, work, chunksize=8), total=len(work), desc="ptb-xl"): if r is None: continue xs.append(r[0]) ys.append(r[1]) X = np.stack(xs).astype(np.float32)[:, None, :] y = np.array(ys, dtype=np.int64) out = Path(args.out) out.parent.mkdir(parents=True, exist_ok=True) np.savez_compressed(out, X=X, y=y) print(f"[fetch] wrote {out}: X={X.shape} y_pos={int(y.sum())} y_neg={int((1-y).sum())}", flush=True) if __name__ == "__main__": main()