PhysioJEPA / scripts /e0_audit.py
guychuk's picture
Upload folder using huggingface_hub
31e2456 verified
"""E0 data audit for lucky9-cyou/mimic-iv-aligned-ppg-ecg.
Computes: patient count, total hours, sample rates, alignment tolerance,
PTT distribution, missing-value rate, and sanity plots.
Strategy: stream across ALL shards for cheap metadata (record_name, fs, siglen,
nan rates). Subsample shards for the expensive per-beat PTT computation.
"""
from __future__ import annotations
import json
import os
import random
import re
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
from dotenv import load_dotenv
from scipy.signal import butter, filtfilt, find_peaks
from tqdm import tqdm
load_dotenv()
os.environ.setdefault("HF_TOKEN", os.environ.get("HUGGINGFACE_API_KEY", ""))
from datasets import load_from_disk
from huggingface_hub import snapshot_download
REPO = "lucky9-cyou/mimic-iv-aligned-ppg-ecg"
N_SHARDS = 412
OUT = Path(__file__).resolve().parent.parent / "docs"
FIG_DIR = OUT / "figures"
FIG_DIR.mkdir(parents=True, exist_ok=True)
RNG = random.Random(42)
def parse_subject_id(record_name: str) -> str:
m = re.match(r"p\d+/(p\d+)/", record_name)
return m.group(1) if m else record_name.split("/")[0]
def bandpass(x: np.ndarray, fs: float, lo: float, hi: float, order: int = 3) -> np.ndarray:
ny = 0.5 * fs
lo_n = max(lo / ny, 1e-4)
hi_n = min(hi / ny, 0.99)
b, a = butter(order, [lo_n, hi_n], btype="band")
return filtfilt(b, a, x, method="gust")
def pan_tompkins_lite(ecg: np.ndarray, fs: float) -> np.ndarray:
"""Simple QRS detector. Returns R-peak sample indices."""
x = bandpass(ecg, fs, 5.0, 15.0)
d = np.diff(x, prepend=x[:1])
s = d * d
w = max(int(0.12 * fs), 1)
mwa = np.convolve(s, np.ones(w) / w, mode="same")
thr = np.mean(mwa) + 0.5 * np.std(mwa)
min_dist = int(0.3 * fs) # refractory 300 ms -> max 200 bpm
peaks, _ = find_peaks(mwa, height=thr, distance=min_dist)
# Snap to local max in the filtered ECG within ±60 ms
snap = max(int(0.06 * fs), 1)
refined = []
for p in peaks:
lo = max(0, p - snap)
hi = min(len(x), p + snap)
if hi > lo:
refined.append(lo + int(np.argmax(x[lo:hi])))
return np.asarray(refined, dtype=int)
def ppg_systolic_peaks(ppg: np.ndarray, fs: float) -> np.ndarray:
x = bandpass(ppg, fs, 0.5, 8.0)
min_dist = int(0.3 * fs)
thr = np.mean(x) + 0.3 * np.std(x)
peaks, _ = find_peaks(x, distance=min_dist, height=thr, prominence=0.1 * np.std(x))
return peaks
def compute_ptt_ms(
ecg_lead: np.ndarray,
ecg_fs: float,
ppg: np.ndarray,
ppg_fs: float,
t0_ecg: float,
t0_ppg: float,
) -> list[float]:
"""For each R-peak, find the next PPG systolic peak within [50, 500] ms."""
r_idx = pan_tompkins_lite(ecg_lead, ecg_fs)
p_idx = ppg_systolic_peaks(ppg, ppg_fs)
if len(r_idx) < 3 or len(p_idx) < 3:
return []
r_t = t0_ecg + r_idx / ecg_fs
p_t = t0_ppg + p_idx / ppg_fs
ptts = []
j = 0
for rt in r_t:
while j < len(p_t) and p_t[j] < rt + 0.050:
j += 1
if j >= len(p_t):
break
dt = p_t[j] - rt
if 0.050 <= dt <= 0.500:
ptts.append(dt * 1000.0)
return ptts
def quick_snapshot(allow_shards: list[int]) -> str:
patterns = ["metadata.json"] + [f"shard_{i:05d}/*" for i in allow_shards]
return snapshot_download(
REPO, repo_type="dataset", allow_patterns=patterns, max_workers=8
)
def main() -> None:
# -------- Pass 1: metadata over a wide shard sample (cheap columns only) --------
# We want ≥500 patients confirmed and overall fs/siglen stats.
# Sample 40 shards uniformly → ~4000 segments; should hit plenty of patients.
meta_shards = sorted(RNG.sample(range(N_SHARDS), 40))
print(f"[pass 1] downloading metadata from {len(meta_shards)} shards")
root = quick_snapshot(meta_shards)
root_p = Path(root)
patients: set[str] = set()
total_duration_s = 0.0
ecg_fs_list: list[float] = []
ppg_fs_list: list[float] = []
ecg_siglen: list[int] = []
ppg_siglen: list[int] = []
ecg_names_seen: set[tuple[str, ...]] = set()
ppg_names_seen: set[tuple[str, ...]] = set()
n_segments = 0
missing_ecg = 0
missing_ppg = 0
nan_ecg_frac = []
nan_ppg_frac = []
# keep a reservoir of (shard_idx, within_shard_idx) candidates for PTT sampling
reservoir: list[tuple[int, int]] = []
for sidx in tqdm(meta_shards, desc="shards(meta)"):
ds = load_from_disk(str(root_p / f"shard_{sidx:05d}"))
cols_cheap = ds.remove_columns(
[c for c in ds.column_names if c in ("ecg", "ppg", "ecg_time_s", "ppg_time_s")]
)
for i, row in enumerate(cols_cheap):
patients.add(parse_subject_id(row["record_name"]))
total_duration_s += float(row["segment_duration_sec"])
ecg_fs_list.append(float(row["ecg_fs"]))
ppg_fs_list.append(float(row["ppg_fs"]))
ecg_siglen.append(int(row["ecg_siglen"]))
ppg_siglen.append(int(row["ppg_siglen"]))
ecg_names_seen.add(tuple(row["ecg_names"]))
ppg_names_seen.add(tuple(row["ppg_names"]))
n_segments += 1
reservoir.append((sidx, i))
# -------- Pass 2: PTT + waveform stats on 100 random segments --------
RNG.shuffle(reservoir)
ptt_targets = reservoir[:250] # oversample; some will fail QRS detection
print(f"[pass 2] computing PTT on up to {len(ptt_targets)} segments")
all_ptts: list[float] = []
per_segment_ptt_std: list[float] = []
per_patient_ptt_median: dict[str, list[float]] = {}
sanity_samples = [] # (ecg_lead, ppg, ecg_fs, ppg_fs, record_name)
want_sanity = 5
# group by shard to avoid reloading
by_shard: dict[int, list[int]] = {}
for s, i in ptt_targets:
by_shard.setdefault(s, []).append(i)
processed = 0
for sidx, idxs in tqdm(by_shard.items(), desc="shards(ptt)"):
ds = load_from_disk(str(root_p / f"shard_{sidx:05d}"))
for i in idxs:
if processed >= 100:
break
row = ds[i]
ecg = np.asarray(row["ecg"], dtype=np.float32)
ppg = np.asarray(row["ppg"], dtype=np.float32)
if ecg.size == 0 or ppg.size == 0:
missing_ecg += ecg.size == 0
missing_ppg += ppg.size == 0
continue
nan_ecg_frac.append(float(np.isnan(ecg).mean()))
nan_ppg_frac.append(float(np.isnan(ppg).mean()))
if np.isnan(ecg).any() or np.isnan(ppg).any():
ecg = np.nan_to_num(ecg, nan=0.0)
ppg = np.nan_to_num(ppg, nan=0.0)
ecg_lead = ecg[0]
ppg_ch = ppg[0]
ecg_fs = float(row["ecg_fs"])
ppg_fs = float(row["ppg_fs"])
t0_e = float(row["ecg_time_s"][0])
t0_p = float(row["ppg_time_s"][0])
ptts = compute_ptt_ms(ecg_lead, ecg_fs, ppg_ch, ppg_fs, t0_e, t0_p)
if len(ptts) >= 3:
all_ptts.extend(ptts)
per_segment_ptt_std.append(float(np.std(ptts)))
pid = parse_subject_id(row["record_name"])
per_patient_ptt_median.setdefault(pid, []).append(float(np.median(ptts)))
if len(sanity_samples) < want_sanity:
sanity_samples.append(
(ecg_lead.copy(), ppg_ch.copy(), ecg_fs, ppg_fs, row["record_name"])
)
processed += 1
if processed >= 100:
break
# -------- Aggregate --------
ecg_fs_med = float(np.median(ecg_fs_list)) if ecg_fs_list else 0.0
ppg_fs_med = float(np.median(ppg_fs_list)) if ppg_fs_list else 0.0
total_hours_sampled = total_duration_s / 3600.0
# Extrapolate to full dataset (we sampled 40/412 shards)
total_hours_estimated = total_hours_sampled * (N_SHARDS / len(meta_shards))
patients_sampled = len(patients)
# Extrapolate patient count (patients typically distribute roughly uniformly across shards)
# but with a coupon-collector cap; report both figures.
ptt_median = float(np.median(all_ptts)) if all_ptts else float("nan")
ptt_p5 = float(np.percentile(all_ptts, 5)) if all_ptts else float("nan")
ptt_p95 = float(np.percentile(all_ptts, 95)) if all_ptts else float("nan")
within_seg_std_median = (
float(np.median(per_segment_ptt_std)) if per_segment_ptt_std else float("nan")
)
within_patient_std = []
for pid, meds in per_patient_ptt_median.items():
if len(meds) >= 2:
within_patient_std.append(float(np.std(meds)))
within_patient_std_median = (
float(np.median(within_patient_std)) if within_patient_std else float("nan")
)
nan_ecg_frac_mean = float(np.mean(nan_ecg_frac)) if nan_ecg_frac else 0.0
nan_ppg_frac_mean = float(np.mean(nan_ppg_frac)) if nan_ppg_frac else 0.0
ptt_plausible_frac = (
float(np.mean([(50 <= p <= 500) for p in all_ptts])) if all_ptts else 0.0
)
# -------- Plots --------
if all_ptts:
plt.figure(figsize=(7, 4))
plt.hist(all_ptts, bins=50, color="#3a7", edgecolor="black")
plt.axvline(100, color="red", linestyle="--", alpha=0.5, label="100 ms (lower normal)")
plt.axvline(400, color="red", linestyle="--", alpha=0.5, label="400 ms (upper normal)")
plt.xlabel("PTT (ms)")
plt.ylabel("count")
plt.title(f"PTT distribution, N={len(all_ptts)} beats across {len(by_shard)} shards")
plt.legend()
plt.tight_layout()
plt.savefig(FIG_DIR / "ptt_histogram.png", dpi=120)
plt.close()
if sanity_samples:
fig, axes = plt.subplots(len(sanity_samples), 1, figsize=(10, 2.2 * len(sanity_samples)))
if len(sanity_samples) == 1:
axes = [axes]
for ax, (ecg, ppg, efs, pfs, name) in zip(axes, sanity_samples):
t_e = np.arange(len(ecg)) / efs
t_p = np.arange(len(ppg)) / pfs
ax2 = ax.twinx()
ax.plot(t_e, ecg, color="#266", lw=0.6, label="ECG[0]")
ax2.plot(t_p, ppg, color="#b30", lw=0.6, label="PPG")
ax.set_title(name, fontsize=8)
ax.set_xlabel("time (s)")
ax.set_ylabel("ECG", color="#266")
ax2.set_ylabel("PPG", color="#b30")
plt.tight_layout()
plt.savefig(FIG_DIR / "sanity_check.png", dpi=120)
plt.close()
# -------- Write JSON output --------
report = {
"dataset": REPO,
"shards_total": N_SHARDS,
"shards_sampled_meta": len(meta_shards),
"segments_meta_scanned": n_segments,
"unique_patients_in_sample": patients_sampled,
"total_duration_hours_sampled": round(total_hours_sampled, 2),
"total_duration_hours_estimated": round(total_hours_estimated, 2),
"ecg_fs_median_hz": ecg_fs_med,
"ppg_fs_median_hz": ppg_fs_med,
"ecg_siglen_median_samples": int(np.median(ecg_siglen)) if ecg_siglen else 0,
"ppg_siglen_median_samples": int(np.median(ppg_siglen)) if ppg_siglen else 0,
"ecg_leads_seen": [list(t) for t in list(ecg_names_seen)[:10]],
"ppg_channels_seen": [list(t) for t in list(ppg_names_seen)[:10]],
"n_ecg_lead_combinations": len(ecg_names_seen),
"n_ppg_channel_combinations": len(ppg_names_seen),
"missing_ecg_segments": missing_ecg,
"missing_ppg_segments": missing_ppg,
"nan_ecg_frac_mean": nan_ecg_frac_mean,
"nan_ppg_frac_mean": nan_ppg_frac_mean,
"ptt_beats_measured": len(all_ptts),
"ptt_median_ms": ptt_median,
"ptt_p5_ms": ptt_p5,
"ptt_p95_ms": ptt_p95,
"ptt_within_segment_std_median_ms": within_seg_std_median,
"ptt_within_patient_std_median_ms": within_patient_std_median,
"ptt_physio_plausible_frac": ptt_plausible_frac,
}
(OUT / "e0_report.json").write_text(json.dumps(report, indent=2))
print(json.dumps(report, indent=2))
if __name__ == "__main__":
main()