PhysioJEPA / scripts /fetch_ptbxl.py
guychuk's picture
Upload folder using huggingface_hub
31e2456 verified
"""Fetch PTB-XL from PhysioNet (open access, no credentialing) and cache lead II
@ 250 Hz with binary AFIB labels into a single .npz file for fast eval reload.
Resulting cache layout:
/workspace/cache/ptbxl_af.npz (X: [N,1,2500] float32, y: [N] int64)
"""
from __future__ import annotations
import argparse
import io
import os
import re
import tarfile
import zipfile
from pathlib import Path
import numpy as np
import requests
from tqdm import tqdm
PTBXL_VERSION = "1.0.3"
PTBXL_URL = (
f"https://physionet.org/static/published-projects/ptb-xl/"
f"ptb-xl-a-large-publicly-available-electrocardiography-dataset-{PTBXL_VERSION}.zip"
)
def _resample_500_to_250(x):
from scipy.signal import resample_poly
return resample_poly(x, up=1, down=2, axis=-1).astype(np.float32)
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--root", default="/workspace/cache/ptbxl")
ap.add_argument("--out", default="/workspace/cache/ptbxl_af.npz")
ap.add_argument("--limit", type=int, default=None)
args = ap.parse_args()
root = Path(args.root)
root.mkdir(parents=True, exist_ok=True)
zip_path = root / "ptbxl.zip"
if not zip_path.exists():
print(f"[fetch] downloading PTB-XL ({PTBXL_URL})")
r = requests.get(PTBXL_URL, stream=True, timeout=600)
r.raise_for_status()
total = int(r.headers.get("content-length", 0))
with open(zip_path, "wb") as f:
for chunk in tqdm(r.iter_content(chunk_size=1024 * 1024),
total=total // (1024 * 1024)):
if chunk:
f.write(chunk)
extract_dir = root / "extracted"
if not extract_dir.exists():
print(f"[fetch] extracting to {extract_dir}")
with zipfile.ZipFile(zip_path) as z:
z.extractall(extract_dir)
# find ptbxl_database.csv
csvs = list(extract_dir.rglob("ptbxl_database.csv"))
assert csvs, "ptbxl_database.csv not found in extracted zip"
db_csv = csvs[0]
db_root = db_csv.parent
print(f"[fetch] db_root = {db_root}")
import pandas as pd
import wfdb
meta = pd.read_csv(db_csv, index_col="ecg_id")
# parse scp_codes safely
def _parse(val):
try:
import json
return json.loads(val.replace("'", '"'))
except Exception:
out = {}
for tok in val.strip("{} ").split(","):
if ":" in tok:
k, v = tok.split(":", 1)
out[k.strip().strip("'\"")] = float(v.strip())
return out
meta["scp_parsed"] = meta["scp_codes"].apply(_parse)
meta["afib"] = meta["scp_parsed"].apply(
lambda d: int(any(k in ("AFIB", "AFLT") for k in d.keys()))
)
if args.limit:
meta = meta.sample(n=args.limit, random_state=0)
print(f"[fetch] {len(meta)} records, AF positive = {int(meta['afib'].sum())}")
xs, ys = [], []
for _, row in tqdm(meta.iterrows(), total=len(meta), desc="ptb-xl"):
rec = wfdb.rdrecord(str(db_root / row["filename_hr"]))
signals = rec.p_signal # [T, 12] @ 500 Hz
lead_names = rec.sig_name
if "II" not in lead_names:
continue
lead_ii = signals[:, lead_names.index("II")]
x = _resample_500_to_250(lead_ii)
if x.shape[0] < 2500:
x = np.pad(x, (0, 2500 - x.shape[0]))
else:
x = x[:2500]
x = (x - x.mean()) / (x.std() + 1e-6)
xs.append(x.astype(np.float32))
ys.append(int(row["afib"]))
X = np.stack(xs).astype(np.float32)[:, None, :]
y = np.array(ys, dtype=np.int64)
out = Path(args.out)
out.parent.mkdir(parents=True, exist_ok=True)
np.savez_compressed(out, X=X, y=y)
print(f"[fetch] wrote {out}: X={X.shape} y_pos={int(y.sum())} y_neg={int((1-y).sum())}")
if __name__ == "__main__":
main()