PhysioJEPA / scripts /fetch_ptbxl_v2.py
guychuk's picture
Upload folder using huggingface_hub
31e2456 verified
"""PTB-XL fetch v2 — multiprocessing, full zip download via wget.
Downloads PTB-XL v1.0.3 from PhysioNet, extracts, parses with wfdb in parallel
using a process pool (8-16 workers), caches to /workspace/cache/ptbxl_af.npz.
Usage:
python scripts/fetch_ptbxl_v2.py --root /workspace/cache/ptbxl --out /workspace/cache/ptbxl_af.npz [--workers 12]
"""
from __future__ import annotations
import argparse
import json
import multiprocessing as mp
import os
import subprocess
import zipfile
from pathlib import Path
import numpy as np
import pandas as pd
from scipy.signal import resample_poly
from tqdm import tqdm
import wfdb
PTBXL_VERSION = "1.0.3"
PTBXL_URL = (
f"https://physionet.org/static/published-projects/ptb-xl/"
f"ptb-xl-a-large-publicly-available-electrocardiography-dataset-{PTBXL_VERSION}.zip"
)
def _resample_500_to_250(x):
return resample_poly(x, up=1, down=2, axis=-1).astype(np.float32)
def _parse_scp(val):
if isinstance(val, dict):
return val
if not isinstance(val, str):
return {}
try:
return json.loads(val.replace("'", '"'))
except Exception:
out = {}
for tok in val.strip("{} ").split(","):
if ":" in tok:
k, v = tok.split(":", 1)
out[k.strip().strip("'\"")] = float(v.strip())
return out
def _process_one(arg):
"""Read one PTB-XL record's lead II and return (x, y)."""
db_root, fname_hr, afib = arg
try:
rec = wfdb.rdrecord(str(Path(db_root) / fname_hr))
signals = rec.p_signal
lead_names = rec.sig_name
if "II" not in lead_names:
return None
lead_ii = signals[:, lead_names.index("II")]
x = _resample_500_to_250(lead_ii)
if x.shape[0] < 2500:
x = np.pad(x, (0, 2500 - x.shape[0]))
else:
x = x[:2500]
x = (x - x.mean()) / (x.std() + 1e-6)
return (x.astype(np.float32), int(afib))
except Exception as e:
return None
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--root", default="/workspace/cache/ptbxl")
ap.add_argument("--out", default="/workspace/cache/ptbxl_af.npz")
ap.add_argument("--workers", type=int, default=16)
ap.add_argument("--limit", type=int, default=None)
args = ap.parse_args()
root = Path(args.root)
root.mkdir(parents=True, exist_ok=True)
zip_path = root / "ptbxl.zip"
# Download via wget (resumable, faster than requests for 3 GB)
if not zip_path.exists() or zip_path.stat().st_size < 1_000_000_000: # < 1 GB = incomplete
print(f"[fetch] downloading PTB-XL via wget", flush=True)
zip_path.unlink(missing_ok=True)
subprocess.run([
"wget", "-c", "-O", str(zip_path), PTBXL_URL
], check=True)
print(f"[fetch] zip size: {zip_path.stat().st_size / 1e9:.2f} GB", flush=True)
extract_dir = root / "extracted"
if not extract_dir.exists() or not list(extract_dir.rglob("ptbxl_database.csv")):
print(f"[fetch] extracting to {extract_dir}", flush=True)
extract_dir.mkdir(parents=True, exist_ok=True)
with zipfile.ZipFile(zip_path) as z:
z.extractall(extract_dir)
csvs = list(extract_dir.rglob("ptbxl_database.csv"))
assert csvs, "ptbxl_database.csv not found after extract"
db_csv = csvs[0]
db_root = db_csv.parent
print(f"[fetch] db_root = {db_root}", flush=True)
meta = pd.read_csv(db_csv, index_col="ecg_id")
meta["scp_parsed"] = meta["scp_codes"].apply(_parse_scp)
meta["afib"] = meta["scp_parsed"].apply(
lambda d: int(any(k in ("AFIB", "AFLT") for k in d.keys()))
)
if args.limit:
meta = meta.sample(n=args.limit, random_state=0)
print(f"[fetch] {len(meta)} records, AF positive = {int(meta['afib'].sum())}", flush=True)
work = [(str(db_root), row["filename_hr"], row["afib"])
for _, row in meta.iterrows()]
print(f"[fetch] parsing with {args.workers} workers", flush=True)
xs, ys = [], []
with mp.Pool(args.workers) as pool:
for r in tqdm(pool.imap_unordered(_process_one, work, chunksize=8),
total=len(work), desc="ptb-xl"):
if r is None:
continue
xs.append(r[0])
ys.append(r[1])
X = np.stack(xs).astype(np.float32)[:, None, :]
y = np.array(ys, dtype=np.int64)
out = Path(args.out)
out.parent.mkdir(parents=True, exist_ok=True)
np.savez_compressed(out, X=X, y=y)
print(f"[fetch] wrote {out}: X={X.shape} y_pos={int(y.sum())} y_neg={int((1-y).sum())}",
flush=True)
if __name__ == "__main__":
main()