"""PTB-XL loader that pulls from HuggingFace (`PULSE-ECG/PTB-XL`) or PhysioNet. We only need: lead II waveforms @ 500 Hz, resampled to 250 Hz, plus binary AFIB label per record. The HF mirror is the default path because it needs no credentialing. """ from __future__ import annotations import json from pathlib import Path import numpy as np import pandas as pd from scipy.signal import resample_poly def _resample_500_to_250(x: np.ndarray) -> np.ndarray: return resample_poly(x, up=1, down=2, axis=-1).astype(np.float32) def _parse_scp_dict(val): """scp_codes column may be a dict or a stringified dict. Parse safely.""" if isinstance(val, dict): return val if not isinstance(val, str): return {} # scp_codes in PTB-XL look like "{'NORM': 100.0, 'SR': 0.0}" — try JSON after # swapping single to double quotes; fall back to a key scan. try: return json.loads(val.replace("'", '"')) except Exception: pass out = {} tokens = val.strip("{} ").split(",") for tok in tokens: if ":" not in tok: continue k, v = tok.split(":", 1) k = k.strip().strip("'").strip('"') try: out[k] = float(v.strip()) except ValueError: pass return out def load_ptbxl_af_from_physionet_local(root: Path, limit: int | None = None): """Load PTB-XL from a local PhysioNet download directory.""" import wfdb root = Path(root) meta = pd.read_csv(root / "ptbxl_database.csv", index_col="ecg_id") meta["scp_parsed"] = meta["scp_codes"].apply(_parse_scp_dict) meta["afib"] = meta["scp_parsed"].apply( lambda d: int(any(k in ("AFIB", "AFLT") for k in d.keys())) ) if limit is not None: meta = meta.sample(n=limit, random_state=0) xs, ys = [], [] for ecg_id, row in meta.iterrows(): fn = root / row["filename_hr"] # 500 Hz rec = wfdb.rdrecord(str(fn)) signals = rec.p_signal lead_names = rec.sig_name lead_ii = signals[:, lead_names.index("II")] x = _resample_500_to_250(lead_ii) if x.shape[0] < 2500: x = np.pad(x, (0, 2500 - x.shape[0])) else: x = x[:2500] x = (x - x.mean()) / (x.std() + 1e-6) xs.append(x) ys.append(int(row["afib"])) X = np.stack(xs).astype(np.float32)[:, None, :] y = np.array(ys, dtype=np.int64) return X, y def load_ptbxl_af_from_hf(limit: int | None = None): """Load PTB-XL via HuggingFace — open access, no credentials.""" from datasets import load_dataset ds = load_dataset("PULSE-ECG/PTB-XL", split="train", streaming=False) xs, ys = [], [] for i, row in enumerate(ds): if limit is not None and i >= limit: break scp = _parse_scp_dict(row.get("scp_codes", {})) afib = int(any(k in ("AFIB", "AFLT") for k in scp)) sig_raw = row.get("signal") or row.get("ecg") sig = np.asarray(sig_raw, dtype=np.float32) if sig.ndim != 2: continue lead_names = row.get("lead_names") or ["I", "II", "III", "aVR", "aVL", "aVF", "V1", "V2", "V3", "V4", "V5", "V6"] if "II" in lead_names: lead_ii = sig[lead_names.index("II")] else: lead_ii = sig[1] x = _resample_500_to_250(lead_ii) if x.shape[0] < 2500: x = np.pad(x, (0, 2500 - x.shape[0])) else: x = x[:2500] x = (x - x.mean()) / (x.std() + 1e-6) xs.append(x) ys.append(afib) X = np.stack(xs).astype(np.float32)[:, None, :] y = np.array(ys, dtype=np.int64) return X, y