PhysioJEPA / src /physiojepa /ptbxl.py
guychuk's picture
Upload folder using huggingface_hub
31e2456 verified
"""PTB-XL loader that pulls from HuggingFace (`PULSE-ECG/PTB-XL`) or PhysioNet.
We only need: lead II waveforms @ 500 Hz, resampled to 250 Hz, plus binary
AFIB label per record. The HF mirror is the default path because it needs no
credentialing.
"""
from __future__ import annotations
import json
from pathlib import Path
import numpy as np
import pandas as pd
from scipy.signal import resample_poly
def _resample_500_to_250(x: np.ndarray) -> np.ndarray:
return resample_poly(x, up=1, down=2, axis=-1).astype(np.float32)
def _parse_scp_dict(val):
"""scp_codes column may be a dict or a stringified dict. Parse safely."""
if isinstance(val, dict):
return val
if not isinstance(val, str):
return {}
# scp_codes in PTB-XL look like "{'NORM': 100.0, 'SR': 0.0}" — try JSON after
# swapping single to double quotes; fall back to a key scan.
try:
return json.loads(val.replace("'", '"'))
except Exception:
pass
out = {}
tokens = val.strip("{} ").split(",")
for tok in tokens:
if ":" not in tok:
continue
k, v = tok.split(":", 1)
k = k.strip().strip("'").strip('"')
try:
out[k] = float(v.strip())
except ValueError:
pass
return out
def load_ptbxl_af_from_physionet_local(root: Path, limit: int | None = None):
"""Load PTB-XL from a local PhysioNet download directory."""
import wfdb
root = Path(root)
meta = pd.read_csv(root / "ptbxl_database.csv", index_col="ecg_id")
meta["scp_parsed"] = meta["scp_codes"].apply(_parse_scp_dict)
meta["afib"] = meta["scp_parsed"].apply(
lambda d: int(any(k in ("AFIB", "AFLT") for k in d.keys()))
)
if limit is not None:
meta = meta.sample(n=limit, random_state=0)
xs, ys = [], []
for ecg_id, row in meta.iterrows():
fn = root / row["filename_hr"] # 500 Hz
rec = wfdb.rdrecord(str(fn))
signals = rec.p_signal
lead_names = rec.sig_name
lead_ii = signals[:, lead_names.index("II")]
x = _resample_500_to_250(lead_ii)
if x.shape[0] < 2500:
x = np.pad(x, (0, 2500 - x.shape[0]))
else:
x = x[:2500]
x = (x - x.mean()) / (x.std() + 1e-6)
xs.append(x)
ys.append(int(row["afib"]))
X = np.stack(xs).astype(np.float32)[:, None, :]
y = np.array(ys, dtype=np.int64)
return X, y
def load_ptbxl_af_from_hf(limit: int | None = None):
"""Load PTB-XL via HuggingFace — open access, no credentials."""
from datasets import load_dataset
ds = load_dataset("PULSE-ECG/PTB-XL", split="train", streaming=False)
xs, ys = [], []
for i, row in enumerate(ds):
if limit is not None and i >= limit:
break
scp = _parse_scp_dict(row.get("scp_codes", {}))
afib = int(any(k in ("AFIB", "AFLT") for k in scp))
sig_raw = row.get("signal") or row.get("ecg")
sig = np.asarray(sig_raw, dtype=np.float32)
if sig.ndim != 2:
continue
lead_names = row.get("lead_names") or ["I", "II", "III", "aVR", "aVL", "aVF",
"V1", "V2", "V3", "V4", "V5", "V6"]
if "II" in lead_names:
lead_ii = sig[lead_names.index("II")]
else:
lead_ii = sig[1]
x = _resample_500_to_250(lead_ii)
if x.shape[0] < 2500:
x = np.pad(x, (0, 2500 - x.shape[0]))
else:
x = x[:2500]
x = (x - x.mean()) / (x.std() + 1e-6)
xs.append(x)
ys.append(afib)
X = np.stack(xs).astype(np.float32)[:, None, :]
y = np.array(ys, dtype=np.int64)
return X, y