PhysioJEPA / scripts /e0_audit_v2.py
guychuk's picture
Upload folder using huggingface_hub
31e2456 verified
"""E0 audit v2 — fixes:
1. Download cheap metadata file from EVERY shard to get true patient count.
2. Better PTT pairing: require clean QRS-to-PPG pairs (exactly one PPG peak
in [50, 500] ms after R) and report within-segment std only for tight beats.
3. Estimate alignment error as the within-segment std of PTT from clean beats.
"""
from __future__ import annotations
import json
import os
import random
import re
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
from dotenv import load_dotenv
from scipy.signal import butter, filtfilt, find_peaks
from tqdm import tqdm
load_dotenv()
os.environ.setdefault("HF_TOKEN", os.environ.get("HUGGINGFACE_API_KEY", ""))
from datasets import load_from_disk
from huggingface_hub import snapshot_download
REPO = "lucky9-cyou/mimic-iv-aligned-ppg-ecg"
N_SHARDS = 412
OUT = Path(__file__).resolve().parent.parent / "docs"
FIG_DIR = OUT / "figures"
FIG_DIR.mkdir(parents=True, exist_ok=True)
RNG = random.Random(42)
def parse_subject_id(record_name: str) -> str:
m = re.match(r"p\d+/(p\d+)/", record_name)
return m.group(1) if m else record_name.split("/")[0]
def bandpass(x: np.ndarray, fs: float, lo: float, hi: float, order: int = 3) -> np.ndarray:
ny = 0.5 * fs
b, a = butter(order, [lo / ny, min(hi / ny, 0.99)], btype="band")
return filtfilt(b, a, x, method="gust")
def r_peaks(ecg: np.ndarray, fs: float) -> np.ndarray:
x = bandpass(ecg, fs, 5.0, 15.0)
d = np.diff(x, prepend=x[:1])
s = d * d
w = max(int(0.12 * fs), 1)
mwa = np.convolve(s, np.ones(w) / w, mode="same")
thr = np.mean(mwa) + 0.5 * np.std(mwa)
peaks, _ = find_peaks(mwa, height=thr, distance=int(0.3 * fs))
snap = max(int(0.06 * fs), 1)
out = []
for p in peaks:
lo, hi = max(0, p - snap), min(len(x), p + snap)
if hi > lo:
out.append(lo + int(np.argmax(x[lo:hi])))
return np.asarray(out, dtype=int)
def ppg_peaks(ppg: np.ndarray, fs: float) -> np.ndarray:
x = bandpass(ppg, fs, 0.5, 8.0)
peaks, _ = find_peaks(
x,
distance=int(0.3 * fs),
height=np.mean(x) + 0.3 * np.std(x),
prominence=0.1 * np.std(x),
)
return peaks
def clean_ptts_ms(ecg_lead, ecg_fs, ppg, ppg_fs, t0_e, t0_p):
"""Return list of clean PTTs: for each R, require exactly one PPG peak in [50,500]ms."""
r = r_peaks(ecg_lead, ecg_fs)
p = ppg_peaks(ppg, ppg_fs)
if len(r) < 3 or len(p) < 3:
return []
r_t = t0_e + r / ecg_fs
p_t = t0_p + p / ppg_fs
out = []
for rt in r_t:
cand = p_t[(p_t >= rt + 0.050) & (p_t <= rt + 0.500)]
if len(cand) == 1:
out.append((cand[0] - rt) * 1000.0)
return out
def main() -> None:
# -------- Pass 1: download dataset_info.json (cheap) from ALL shards not feasible --
# Instead: sample 120 shards uniformly for metadata. That is >25% coverage.
meta_shards = sorted(RNG.sample(range(N_SHARDS), 120))
print(f"[pass 1] downloading metadata from {len(meta_shards)} shards")
patterns = ["metadata.json"] + [f"shard_{i:05d}/*" for i in meta_shards]
root = Path(
snapshot_download(REPO, repo_type="dataset", allow_patterns=patterns, max_workers=12)
)
patients: set[str] = set()
total_duration_s = 0.0
ecg_fs_list = []
ppg_fs_list = []
ecg_siglen = []
ppg_siglen = []
ecg_leads_counter: dict[str, int] = {}
has_lead_II = 0
n_segments = 0
shard_to_rows: dict[int, int] = {}
reservoir: list[tuple[int, int]] = []
for sidx in tqdm(meta_shards, desc="shards(meta)"):
ds = load_from_disk(str(root / f"shard_{sidx:05d}"))
shard_to_rows[sidx] = len(ds)
cheap = ds.remove_columns(
[c for c in ds.column_names if c in ("ecg", "ppg", "ecg_time_s", "ppg_time_s")]
)
for i, row in enumerate(cheap):
patients.add(parse_subject_id(row["record_name"]))
total_duration_s += float(row["segment_duration_sec"])
ecg_fs_list.append(float(row["ecg_fs"]))
ppg_fs_list.append(float(row["ppg_fs"]))
ecg_siglen.append(int(row["ecg_siglen"]))
ppg_siglen.append(int(row["ppg_siglen"]))
names = tuple(row["ecg_names"])
for n in names:
ecg_leads_counter[n] = ecg_leads_counter.get(n, 0) + 1
if "II" in names:
has_lead_II += 1
n_segments += 1
reservoir.append((sidx, i))
# -------- Pass 2: PTT on 200 segments (stop at 150 with >=3 clean beats) --------
RNG.shuffle(reservoir)
all_ptts = []
clean_segment_stds = []
sanity_samples = []
want_sanity = 5
processed = 0
good_segments = 0
by_shard: dict[int, list[int]] = {}
for s, i in reservoir[:400]:
by_shard.setdefault(s, []).append(i)
print(f"[pass 2] PTT on up to 400 segments")
for sidx, idxs in tqdm(list(by_shard.items()), desc="shards(ptt)"):
if good_segments >= 150:
break
ds = load_from_disk(str(root / f"shard_{sidx:05d}"))
for i in idxs:
if good_segments >= 150:
break
row = ds[i]
ecg = np.asarray(row["ecg"], dtype=np.float32)
ppg = np.asarray(row["ppg"], dtype=np.float32)
if ecg.size == 0 or ppg.size == 0:
continue
names = list(row["ecg_names"])
if "II" in names:
lead_idx = names.index("II")
else:
lead_idx = 0
ecg_lead = ecg[lead_idx]
ppg_ch = ppg[0]
ptts = clean_ptts_ms(
ecg_lead,
float(row["ecg_fs"]),
ppg_ch,
float(row["ppg_fs"]),
float(row["ecg_time_s"][0]),
float(row["ppg_time_s"][0]),
)
processed += 1
if len(ptts) >= 3:
all_ptts.extend(ptts)
clean_segment_stds.append(float(np.std(ptts)))
good_segments += 1
if len(sanity_samples) < want_sanity and len(ptts) >= 3:
sanity_samples.append(
(
ecg_lead.copy(),
ppg_ch.copy(),
float(row["ecg_fs"]),
float(row["ppg_fs"]),
row["record_name"],
ptts,
)
)
# -------- Aggregate --------
total_hours_sampled = total_duration_s / 3600.0
total_hours_estimated = total_hours_sampled * (N_SHARDS / len(meta_shards))
# Patient count estimate: if sampled 120 shards and found K patients, and each shard seems
# to be mostly one patient (a recording per patient), then true patients ≈ K * (412/120).
# But de-duplicate: we also observed patient IDs; if #patients saturates well below 412,
# the dataset has fewer than one-per-shard.
patients_extrap = int(len(patients) * N_SHARDS / len(meta_shards))
median = lambda v: float(np.median(v)) if len(v) else float("nan")
report = {
"dataset": REPO,
"shards_total": N_SHARDS,
"shards_sampled_meta": len(meta_shards),
"segments_meta_scanned": n_segments,
"unique_patients_in_sample": len(patients),
"unique_patients_extrapolated": patients_extrap,
"total_duration_hours_sampled": round(total_hours_sampled, 2),
"total_duration_hours_estimated": round(total_hours_estimated, 2),
"ecg_fs_median_hz": median(ecg_fs_list),
"ppg_fs_median_hz": median(ppg_fs_list),
"ecg_siglen_median_samples": int(median(ecg_siglen)) if ecg_siglen else 0,
"ppg_siglen_median_samples": int(median(ppg_siglen)) if ppg_siglen else 0,
"ecg_lead_counts_top10": dict(
sorted(ecg_leads_counter.items(), key=lambda kv: -kv[1])[:10]
),
"lead_II_available_frac": has_lead_II / max(n_segments, 1),
"ptt_beats_measured": len(all_ptts),
"ptt_good_segments": good_segments,
"ptt_segments_attempted": processed,
"ptt_median_ms": median(all_ptts),
"ptt_p5_ms": float(np.percentile(all_ptts, 5)) if all_ptts else float("nan"),
"ptt_p95_ms": float(np.percentile(all_ptts, 95)) if all_ptts else float("nan"),
"ptt_within_segment_std_median_ms": median(clean_segment_stds),
"ptt_within_segment_std_p90_ms": (
float(np.percentile(clean_segment_stds, 90)) if clean_segment_stds else float("nan")
),
}
# Plots
if all_ptts:
plt.figure(figsize=(7, 4))
plt.hist(all_ptts, bins=60, color="#3a7", edgecolor="black")
plt.axvline(100, color="red", linestyle="--", alpha=0.5, label="100 ms")
plt.axvline(400, color="red", linestyle="--", alpha=0.5, label="400 ms")
plt.xlabel("PTT (ms)")
plt.ylabel("count")
plt.title(
f"PTT distribution — {len(all_ptts)} clean beats, "
f"{good_segments} segments, {len(by_shard)} shards"
)
plt.legend()
plt.tight_layout()
plt.savefig(FIG_DIR / "ptt_histogram.png", dpi=120)
plt.close()
if sanity_samples:
fig, axes = plt.subplots(len(sanity_samples), 1, figsize=(10, 2.4 * len(sanity_samples)))
if len(sanity_samples) == 1:
axes = [axes]
for ax, (ecg, ppg, efs, pfs, name, ptts) in zip(axes, sanity_samples):
t_e = np.arange(len(ecg)) / efs
t_p = np.arange(len(ppg)) / pfs
ax2 = ax.twinx()
ax.plot(t_e, ecg, color="#266", lw=0.6, label="ECG II")
ax2.plot(t_p, ppg, color="#b30", lw=0.6, label="PPG")
ax.set_title(
f"{name} PTT median={np.median(ptts):.0f} ms N={len(ptts)}",
fontsize=8,
)
ax.set_xlabel("time (s)")
ax.set_ylabel("ECG", color="#266")
ax2.set_ylabel("PPG", color="#b30")
plt.tight_layout()
plt.savefig(FIG_DIR / "sanity_check.png", dpi=120)
plt.close()
(OUT / "e0_report.json").write_text(json.dumps(report, indent=2))
print(json.dumps(report, indent=2))
if __name__ == "__main__":
main()