| """E0 audit v2 — fixes: |
| |
| 1. Download cheap metadata file from EVERY shard to get true patient count. |
| 2. Better PTT pairing: require clean QRS-to-PPG pairs (exactly one PPG peak |
| in [50, 500] ms after R) and report within-segment std only for tight beats. |
| 3. Estimate alignment error as the within-segment std of PTT from clean beats. |
| """ |
| from __future__ import annotations |
|
|
| import json |
| import os |
| import random |
| import re |
| from pathlib import Path |
|
|
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| import numpy as np |
| from dotenv import load_dotenv |
| from scipy.signal import butter, filtfilt, find_peaks |
| from tqdm import tqdm |
|
|
| load_dotenv() |
| os.environ.setdefault("HF_TOKEN", os.environ.get("HUGGINGFACE_API_KEY", "")) |
|
|
| from datasets import load_from_disk |
| from huggingface_hub import snapshot_download |
|
|
| REPO = "lucky9-cyou/mimic-iv-aligned-ppg-ecg" |
| N_SHARDS = 412 |
| OUT = Path(__file__).resolve().parent.parent / "docs" |
| FIG_DIR = OUT / "figures" |
| FIG_DIR.mkdir(parents=True, exist_ok=True) |
|
|
| RNG = random.Random(42) |
|
|
|
|
| def parse_subject_id(record_name: str) -> str: |
| m = re.match(r"p\d+/(p\d+)/", record_name) |
| return m.group(1) if m else record_name.split("/")[0] |
|
|
|
|
| def bandpass(x: np.ndarray, fs: float, lo: float, hi: float, order: int = 3) -> np.ndarray: |
| ny = 0.5 * fs |
| b, a = butter(order, [lo / ny, min(hi / ny, 0.99)], btype="band") |
| return filtfilt(b, a, x, method="gust") |
|
|
|
|
| def r_peaks(ecg: np.ndarray, fs: float) -> np.ndarray: |
| x = bandpass(ecg, fs, 5.0, 15.0) |
| d = np.diff(x, prepend=x[:1]) |
| s = d * d |
| w = max(int(0.12 * fs), 1) |
| mwa = np.convolve(s, np.ones(w) / w, mode="same") |
| thr = np.mean(mwa) + 0.5 * np.std(mwa) |
| peaks, _ = find_peaks(mwa, height=thr, distance=int(0.3 * fs)) |
| snap = max(int(0.06 * fs), 1) |
| out = [] |
| for p in peaks: |
| lo, hi = max(0, p - snap), min(len(x), p + snap) |
| if hi > lo: |
| out.append(lo + int(np.argmax(x[lo:hi]))) |
| return np.asarray(out, dtype=int) |
|
|
|
|
| def ppg_peaks(ppg: np.ndarray, fs: float) -> np.ndarray: |
| x = bandpass(ppg, fs, 0.5, 8.0) |
| peaks, _ = find_peaks( |
| x, |
| distance=int(0.3 * fs), |
| height=np.mean(x) + 0.3 * np.std(x), |
| prominence=0.1 * np.std(x), |
| ) |
| return peaks |
|
|
|
|
| def clean_ptts_ms(ecg_lead, ecg_fs, ppg, ppg_fs, t0_e, t0_p): |
| """Return list of clean PTTs: for each R, require exactly one PPG peak in [50,500]ms.""" |
| r = r_peaks(ecg_lead, ecg_fs) |
| p = ppg_peaks(ppg, ppg_fs) |
| if len(r) < 3 or len(p) < 3: |
| return [] |
| r_t = t0_e + r / ecg_fs |
| p_t = t0_p + p / ppg_fs |
| out = [] |
| for rt in r_t: |
| cand = p_t[(p_t >= rt + 0.050) & (p_t <= rt + 0.500)] |
| if len(cand) == 1: |
| out.append((cand[0] - rt) * 1000.0) |
| return out |
|
|
|
|
| def main() -> None: |
| |
| |
| meta_shards = sorted(RNG.sample(range(N_SHARDS), 120)) |
| print(f"[pass 1] downloading metadata from {len(meta_shards)} shards") |
| patterns = ["metadata.json"] + [f"shard_{i:05d}/*" for i in meta_shards] |
| root = Path( |
| snapshot_download(REPO, repo_type="dataset", allow_patterns=patterns, max_workers=12) |
| ) |
|
|
| patients: set[str] = set() |
| total_duration_s = 0.0 |
| ecg_fs_list = [] |
| ppg_fs_list = [] |
| ecg_siglen = [] |
| ppg_siglen = [] |
| ecg_leads_counter: dict[str, int] = {} |
| has_lead_II = 0 |
| n_segments = 0 |
| shard_to_rows: dict[int, int] = {} |
|
|
| reservoir: list[tuple[int, int]] = [] |
|
|
| for sidx in tqdm(meta_shards, desc="shards(meta)"): |
| ds = load_from_disk(str(root / f"shard_{sidx:05d}")) |
| shard_to_rows[sidx] = len(ds) |
| cheap = ds.remove_columns( |
| [c for c in ds.column_names if c in ("ecg", "ppg", "ecg_time_s", "ppg_time_s")] |
| ) |
| for i, row in enumerate(cheap): |
| patients.add(parse_subject_id(row["record_name"])) |
| total_duration_s += float(row["segment_duration_sec"]) |
| ecg_fs_list.append(float(row["ecg_fs"])) |
| ppg_fs_list.append(float(row["ppg_fs"])) |
| ecg_siglen.append(int(row["ecg_siglen"])) |
| ppg_siglen.append(int(row["ppg_siglen"])) |
| names = tuple(row["ecg_names"]) |
| for n in names: |
| ecg_leads_counter[n] = ecg_leads_counter.get(n, 0) + 1 |
| if "II" in names: |
| has_lead_II += 1 |
| n_segments += 1 |
| reservoir.append((sidx, i)) |
|
|
| |
| RNG.shuffle(reservoir) |
| all_ptts = [] |
| clean_segment_stds = [] |
| sanity_samples = [] |
| want_sanity = 5 |
| processed = 0 |
| good_segments = 0 |
| by_shard: dict[int, list[int]] = {} |
| for s, i in reservoir[:400]: |
| by_shard.setdefault(s, []).append(i) |
|
|
| print(f"[pass 2] PTT on up to 400 segments") |
| for sidx, idxs in tqdm(list(by_shard.items()), desc="shards(ptt)"): |
| if good_segments >= 150: |
| break |
| ds = load_from_disk(str(root / f"shard_{sidx:05d}")) |
| for i in idxs: |
| if good_segments >= 150: |
| break |
| row = ds[i] |
| ecg = np.asarray(row["ecg"], dtype=np.float32) |
| ppg = np.asarray(row["ppg"], dtype=np.float32) |
| if ecg.size == 0 or ppg.size == 0: |
| continue |
| names = list(row["ecg_names"]) |
| if "II" in names: |
| lead_idx = names.index("II") |
| else: |
| lead_idx = 0 |
| ecg_lead = ecg[lead_idx] |
| ppg_ch = ppg[0] |
| ptts = clean_ptts_ms( |
| ecg_lead, |
| float(row["ecg_fs"]), |
| ppg_ch, |
| float(row["ppg_fs"]), |
| float(row["ecg_time_s"][0]), |
| float(row["ppg_time_s"][0]), |
| ) |
| processed += 1 |
| if len(ptts) >= 3: |
| all_ptts.extend(ptts) |
| clean_segment_stds.append(float(np.std(ptts))) |
| good_segments += 1 |
| if len(sanity_samples) < want_sanity and len(ptts) >= 3: |
| sanity_samples.append( |
| ( |
| ecg_lead.copy(), |
| ppg_ch.copy(), |
| float(row["ecg_fs"]), |
| float(row["ppg_fs"]), |
| row["record_name"], |
| ptts, |
| ) |
| ) |
|
|
| |
| total_hours_sampled = total_duration_s / 3600.0 |
| total_hours_estimated = total_hours_sampled * (N_SHARDS / len(meta_shards)) |
| |
| |
| |
| |
| patients_extrap = int(len(patients) * N_SHARDS / len(meta_shards)) |
|
|
| median = lambda v: float(np.median(v)) if len(v) else float("nan") |
| report = { |
| "dataset": REPO, |
| "shards_total": N_SHARDS, |
| "shards_sampled_meta": len(meta_shards), |
| "segments_meta_scanned": n_segments, |
| "unique_patients_in_sample": len(patients), |
| "unique_patients_extrapolated": patients_extrap, |
| "total_duration_hours_sampled": round(total_hours_sampled, 2), |
| "total_duration_hours_estimated": round(total_hours_estimated, 2), |
| "ecg_fs_median_hz": median(ecg_fs_list), |
| "ppg_fs_median_hz": median(ppg_fs_list), |
| "ecg_siglen_median_samples": int(median(ecg_siglen)) if ecg_siglen else 0, |
| "ppg_siglen_median_samples": int(median(ppg_siglen)) if ppg_siglen else 0, |
| "ecg_lead_counts_top10": dict( |
| sorted(ecg_leads_counter.items(), key=lambda kv: -kv[1])[:10] |
| ), |
| "lead_II_available_frac": has_lead_II / max(n_segments, 1), |
| "ptt_beats_measured": len(all_ptts), |
| "ptt_good_segments": good_segments, |
| "ptt_segments_attempted": processed, |
| "ptt_median_ms": median(all_ptts), |
| "ptt_p5_ms": float(np.percentile(all_ptts, 5)) if all_ptts else float("nan"), |
| "ptt_p95_ms": float(np.percentile(all_ptts, 95)) if all_ptts else float("nan"), |
| "ptt_within_segment_std_median_ms": median(clean_segment_stds), |
| "ptt_within_segment_std_p90_ms": ( |
| float(np.percentile(clean_segment_stds, 90)) if clean_segment_stds else float("nan") |
| ), |
| } |
| |
| if all_ptts: |
| plt.figure(figsize=(7, 4)) |
| plt.hist(all_ptts, bins=60, color="#3a7", edgecolor="black") |
| plt.axvline(100, color="red", linestyle="--", alpha=0.5, label="100 ms") |
| plt.axvline(400, color="red", linestyle="--", alpha=0.5, label="400 ms") |
| plt.xlabel("PTT (ms)") |
| plt.ylabel("count") |
| plt.title( |
| f"PTT distribution — {len(all_ptts)} clean beats, " |
| f"{good_segments} segments, {len(by_shard)} shards" |
| ) |
| plt.legend() |
| plt.tight_layout() |
| plt.savefig(FIG_DIR / "ptt_histogram.png", dpi=120) |
| plt.close() |
|
|
| if sanity_samples: |
| fig, axes = plt.subplots(len(sanity_samples), 1, figsize=(10, 2.4 * len(sanity_samples))) |
| if len(sanity_samples) == 1: |
| axes = [axes] |
| for ax, (ecg, ppg, efs, pfs, name, ptts) in zip(axes, sanity_samples): |
| t_e = np.arange(len(ecg)) / efs |
| t_p = np.arange(len(ppg)) / pfs |
| ax2 = ax.twinx() |
| ax.plot(t_e, ecg, color="#266", lw=0.6, label="ECG II") |
| ax2.plot(t_p, ppg, color="#b30", lw=0.6, label="PPG") |
| ax.set_title( |
| f"{name} PTT median={np.median(ptts):.0f} ms N={len(ptts)}", |
| fontsize=8, |
| ) |
| ax.set_xlabel("time (s)") |
| ax.set_ylabel("ECG", color="#266") |
| ax2.set_ylabel("PPG", color="#b30") |
| plt.tight_layout() |
| plt.savefig(FIG_DIR / "sanity_check.png", dpi=120) |
| plt.close() |
|
|
| (OUT / "e0_report.json").write_text(json.dumps(report, indent=2)) |
| print(json.dumps(report, indent=2)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|