PhysioJEPA / scripts /fetch_ptbxl_v3.py
guychuk's picture
Upload folder using huggingface_hub
31e2456 verified
"""PTB-XL fetch v3 — concurrent per-file HTTP downloads (no 3 GB monolithic zip).
PhysioNet exposes individual files at:
https://physionet.org/files/ptb-xl/1.0.3/<filename>
Strategy:
1. Download just `ptbxl_database.csv` (~4 MB) to know which records exist
2. Concurrent download of the .hea/.dat pairs we need (lead II only — but
we need to download all 12 leads since each .dat is one multilead file)
3. Parse with wfdb in a process pool
Total bytes: 21k records × ~400 KB each ≈ 8 GB. Even at 200 KB/s that's
slow, but with 32 concurrent connections we should saturate the pod's
~1 Gbit network (~125 MB/s). 8 GB / 125 MB/s = 64 sec ideal, ~10 min
realistic given physionet bandwidth caps.
Actually shortcut — use the LR (low-res, 100 Hz) variant: ~75 KB per file
×21k = 1.5 GB total. We resample 100→250 Hz with scipy. Quality is fine
for AF detection (PTB-XL paper uses both 100 and 500 Hz freely).
"""
from __future__ import annotations
import argparse
import concurrent.futures as cf
import json
import multiprocessing as mp
import os
import urllib.request
from pathlib import Path
import numpy as np
import pandas as pd
from scipy.signal import resample_poly
from tqdm import tqdm
import wfdb
BASE = "https://physionet.org/files/ptb-xl/1.0.3"
def _parse_scp(val):
if isinstance(val, dict):
return val
if not isinstance(val, str):
return {}
try:
return json.loads(val.replace("'", '"'))
except Exception:
out = {}
for tok in val.strip("{} ").split(","):
if ":" in tok:
k, v = tok.split(":", 1)
out[k.strip().strip("'\"")] = float(v.strip())
return out
def _download(args):
url, dst = args
if dst.exists() and dst.stat().st_size > 0:
return True
dst.parent.mkdir(parents=True, exist_ok=True)
try:
with urllib.request.urlopen(url, timeout=60) as r, open(dst, "wb") as f:
f.write(r.read())
return True
except Exception as e:
return False
def _resample(x, src_hz, dst_hz):
from math import gcd
g = gcd(int(src_hz), int(dst_hz))
return resample_poly(x, up=int(dst_hz)//g, down=int(src_hz)//g, axis=-1).astype(np.float32)
def _process_one(arg):
db_root, fname, afib, src_hz = arg
try:
rec = wfdb.rdrecord(str(Path(db_root) / fname))
signals = rec.p_signal
lead_names = rec.sig_name
if "II" not in lead_names:
return None
lead_ii = signals[:, lead_names.index("II")]
x = _resample(lead_ii, src_hz, 250)
if x.shape[0] < 2500:
x = np.pad(x, (0, 2500 - x.shape[0]))
else:
x = x[:2500]
x = (x - x.mean()) / (x.std() + 1e-6)
return (x.astype(np.float32), int(afib))
except Exception:
return None
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--root", default="/workspace/cache/ptbxl")
ap.add_argument("--out", default="/workspace/cache/ptbxl_af.npz")
ap.add_argument("--use_lr", action="store_true", help="100 Hz variant (smaller, faster)")
ap.add_argument("--limit", type=int, default=None)
ap.add_argument("--dl_workers", type=int, default=32)
ap.add_argument("--parse_workers", type=int, default=16)
args = ap.parse_args()
root = Path(args.root)
root.mkdir(parents=True, exist_ok=True)
csv_path = root / "ptbxl_database.csv"
if not csv_path.exists():
print(f"[fetch] downloading ptbxl_database.csv", flush=True)
urllib.request.urlretrieve(f"{BASE}/ptbxl_database.csv", str(csv_path))
print(f"[fetch] csv size: {csv_path.stat().st_size/1e6:.1f} MB", flush=True)
meta = pd.read_csv(csv_path, index_col="ecg_id")
meta["scp_parsed"] = meta["scp_codes"].apply(_parse_scp)
meta["afib"] = meta["scp_parsed"].apply(
lambda d: int(any(k in ("AFIB", "AFLT") for k in d.keys()))
)
if args.limit:
meta = meta.sample(n=args.limit, random_state=0)
print(f"[fetch] {len(meta)} records, AF positive = {int(meta['afib'].sum())}", flush=True)
# Decide LR vs HR
fname_col = "filename_lr" if args.use_lr else "filename_hr"
src_hz = 100 if args.use_lr else 500
# Build download list (.hea + .dat per record)
dl_list = []
for _, row in meta.iterrows():
rel = row[fname_col] # e.g. records100/00000/00001_lr
for ext in (".hea", ".dat"):
url = f"{BASE}/{rel}{ext}"
dst = root / f"{rel}{ext}"
dl_list.append((url, dst))
# Filter out already-present
todo = [(u, d) for u, d in dl_list if not (d.exists() and d.stat().st_size > 0)]
print(f"[fetch] {len(todo)} files to download (skipping {len(dl_list)-len(todo)} cached)",
flush=True)
if todo:
with cf.ThreadPoolExecutor(max_workers=args.dl_workers) as ex:
ok_count = 0
for ok in tqdm(ex.map(_download, todo), total=len(todo), desc="dl"):
if ok:
ok_count += 1
print(f"[fetch] downloaded ok={ok_count}/{len(todo)}", flush=True)
# Parse
work = [(str(root), row[fname_col], row["afib"], src_hz)
for _, row in meta.iterrows()]
print(f"[fetch] parsing {len(work)} records with {args.parse_workers} workers",
flush=True)
xs, ys = [], []
with mp.Pool(args.parse_workers) as pool:
for r in tqdm(pool.imap_unordered(_process_one, work, chunksize=16),
total=len(work), desc="parse"):
if r is None:
continue
xs.append(r[0])
ys.append(r[1])
X = np.stack(xs).astype(np.float32)[:, None, :]
y = np.array(ys, dtype=np.int64)
out = Path(args.out)
out.parent.mkdir(parents=True, exist_ok=True)
np.savez_compressed(out, X=X, y=y)
print(f"[fetch] wrote {out}: X={X.shape} y_pos={int(y.sum())} y_neg={int((1-y).sum())}",
flush=True)
if __name__ == "__main__":
main()