"""E0 audit v2 — fixes: 1. Download cheap metadata file from EVERY shard to get true patient count. 2. Better PTT pairing: require clean QRS-to-PPG pairs (exactly one PPG peak in [50, 500] ms after R) and report within-segment std only for tight beats. 3. Estimate alignment error as the within-segment std of PTT from clean beats. """ from __future__ import annotations import json import os import random import re from pathlib import Path import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np from dotenv import load_dotenv from scipy.signal import butter, filtfilt, find_peaks from tqdm import tqdm load_dotenv() os.environ.setdefault("HF_TOKEN", os.environ.get("HUGGINGFACE_API_KEY", "")) from datasets import load_from_disk from huggingface_hub import snapshot_download REPO = "lucky9-cyou/mimic-iv-aligned-ppg-ecg" N_SHARDS = 412 OUT = Path(__file__).resolve().parent.parent / "docs" FIG_DIR = OUT / "figures" FIG_DIR.mkdir(parents=True, exist_ok=True) RNG = random.Random(42) def parse_subject_id(record_name: str) -> str: m = re.match(r"p\d+/(p\d+)/", record_name) return m.group(1) if m else record_name.split("/")[0] def bandpass(x: np.ndarray, fs: float, lo: float, hi: float, order: int = 3) -> np.ndarray: ny = 0.5 * fs b, a = butter(order, [lo / ny, min(hi / ny, 0.99)], btype="band") return filtfilt(b, a, x, method="gust") def r_peaks(ecg: np.ndarray, fs: float) -> np.ndarray: x = bandpass(ecg, fs, 5.0, 15.0) d = np.diff(x, prepend=x[:1]) s = d * d w = max(int(0.12 * fs), 1) mwa = np.convolve(s, np.ones(w) / w, mode="same") thr = np.mean(mwa) + 0.5 * np.std(mwa) peaks, _ = find_peaks(mwa, height=thr, distance=int(0.3 * fs)) snap = max(int(0.06 * fs), 1) out = [] for p in peaks: lo, hi = max(0, p - snap), min(len(x), p + snap) if hi > lo: out.append(lo + int(np.argmax(x[lo:hi]))) return np.asarray(out, dtype=int) def ppg_peaks(ppg: np.ndarray, fs: float) -> np.ndarray: x = bandpass(ppg, fs, 0.5, 8.0) peaks, _ = find_peaks( x, distance=int(0.3 * fs), height=np.mean(x) + 0.3 * np.std(x), prominence=0.1 * np.std(x), ) return peaks def clean_ptts_ms(ecg_lead, ecg_fs, ppg, ppg_fs, t0_e, t0_p): """Return list of clean PTTs: for each R, require exactly one PPG peak in [50,500]ms.""" r = r_peaks(ecg_lead, ecg_fs) p = ppg_peaks(ppg, ppg_fs) if len(r) < 3 or len(p) < 3: return [] r_t = t0_e + r / ecg_fs p_t = t0_p + p / ppg_fs out = [] for rt in r_t: cand = p_t[(p_t >= rt + 0.050) & (p_t <= rt + 0.500)] if len(cand) == 1: out.append((cand[0] - rt) * 1000.0) return out def main() -> None: # -------- Pass 1: download dataset_info.json (cheap) from ALL shards not feasible -- # Instead: sample 120 shards uniformly for metadata. That is >25% coverage. meta_shards = sorted(RNG.sample(range(N_SHARDS), 120)) print(f"[pass 1] downloading metadata from {len(meta_shards)} shards") patterns = ["metadata.json"] + [f"shard_{i:05d}/*" for i in meta_shards] root = Path( snapshot_download(REPO, repo_type="dataset", allow_patterns=patterns, max_workers=12) ) patients: set[str] = set() total_duration_s = 0.0 ecg_fs_list = [] ppg_fs_list = [] ecg_siglen = [] ppg_siglen = [] ecg_leads_counter: dict[str, int] = {} has_lead_II = 0 n_segments = 0 shard_to_rows: dict[int, int] = {} reservoir: list[tuple[int, int]] = [] for sidx in tqdm(meta_shards, desc="shards(meta)"): ds = load_from_disk(str(root / f"shard_{sidx:05d}")) shard_to_rows[sidx] = len(ds) cheap = ds.remove_columns( [c for c in ds.column_names if c in ("ecg", "ppg", "ecg_time_s", "ppg_time_s")] ) for i, row in enumerate(cheap): patients.add(parse_subject_id(row["record_name"])) total_duration_s += float(row["segment_duration_sec"]) ecg_fs_list.append(float(row["ecg_fs"])) ppg_fs_list.append(float(row["ppg_fs"])) ecg_siglen.append(int(row["ecg_siglen"])) ppg_siglen.append(int(row["ppg_siglen"])) names = tuple(row["ecg_names"]) for n in names: ecg_leads_counter[n] = ecg_leads_counter.get(n, 0) + 1 if "II" in names: has_lead_II += 1 n_segments += 1 reservoir.append((sidx, i)) # -------- Pass 2: PTT on 200 segments (stop at 150 with >=3 clean beats) -------- RNG.shuffle(reservoir) all_ptts = [] clean_segment_stds = [] sanity_samples = [] want_sanity = 5 processed = 0 good_segments = 0 by_shard: dict[int, list[int]] = {} for s, i in reservoir[:400]: by_shard.setdefault(s, []).append(i) print(f"[pass 2] PTT on up to 400 segments") for sidx, idxs in tqdm(list(by_shard.items()), desc="shards(ptt)"): if good_segments >= 150: break ds = load_from_disk(str(root / f"shard_{sidx:05d}")) for i in idxs: if good_segments >= 150: break row = ds[i] ecg = np.asarray(row["ecg"], dtype=np.float32) ppg = np.asarray(row["ppg"], dtype=np.float32) if ecg.size == 0 or ppg.size == 0: continue names = list(row["ecg_names"]) if "II" in names: lead_idx = names.index("II") else: lead_idx = 0 ecg_lead = ecg[lead_idx] ppg_ch = ppg[0] ptts = clean_ptts_ms( ecg_lead, float(row["ecg_fs"]), ppg_ch, float(row["ppg_fs"]), float(row["ecg_time_s"][0]), float(row["ppg_time_s"][0]), ) processed += 1 if len(ptts) >= 3: all_ptts.extend(ptts) clean_segment_stds.append(float(np.std(ptts))) good_segments += 1 if len(sanity_samples) < want_sanity and len(ptts) >= 3: sanity_samples.append( ( ecg_lead.copy(), ppg_ch.copy(), float(row["ecg_fs"]), float(row["ppg_fs"]), row["record_name"], ptts, ) ) # -------- Aggregate -------- total_hours_sampled = total_duration_s / 3600.0 total_hours_estimated = total_hours_sampled * (N_SHARDS / len(meta_shards)) # Patient count estimate: if sampled 120 shards and found K patients, and each shard seems # to be mostly one patient (a recording per patient), then true patients ≈ K * (412/120). # But de-duplicate: we also observed patient IDs; if #patients saturates well below 412, # the dataset has fewer than one-per-shard. patients_extrap = int(len(patients) * N_SHARDS / len(meta_shards)) median = lambda v: float(np.median(v)) if len(v) else float("nan") report = { "dataset": REPO, "shards_total": N_SHARDS, "shards_sampled_meta": len(meta_shards), "segments_meta_scanned": n_segments, "unique_patients_in_sample": len(patients), "unique_patients_extrapolated": patients_extrap, "total_duration_hours_sampled": round(total_hours_sampled, 2), "total_duration_hours_estimated": round(total_hours_estimated, 2), "ecg_fs_median_hz": median(ecg_fs_list), "ppg_fs_median_hz": median(ppg_fs_list), "ecg_siglen_median_samples": int(median(ecg_siglen)) if ecg_siglen else 0, "ppg_siglen_median_samples": int(median(ppg_siglen)) if ppg_siglen else 0, "ecg_lead_counts_top10": dict( sorted(ecg_leads_counter.items(), key=lambda kv: -kv[1])[:10] ), "lead_II_available_frac": has_lead_II / max(n_segments, 1), "ptt_beats_measured": len(all_ptts), "ptt_good_segments": good_segments, "ptt_segments_attempted": processed, "ptt_median_ms": median(all_ptts), "ptt_p5_ms": float(np.percentile(all_ptts, 5)) if all_ptts else float("nan"), "ptt_p95_ms": float(np.percentile(all_ptts, 95)) if all_ptts else float("nan"), "ptt_within_segment_std_median_ms": median(clean_segment_stds), "ptt_within_segment_std_p90_ms": ( float(np.percentile(clean_segment_stds, 90)) if clean_segment_stds else float("nan") ), } # Plots if all_ptts: plt.figure(figsize=(7, 4)) plt.hist(all_ptts, bins=60, color="#3a7", edgecolor="black") plt.axvline(100, color="red", linestyle="--", alpha=0.5, label="100 ms") plt.axvline(400, color="red", linestyle="--", alpha=0.5, label="400 ms") plt.xlabel("PTT (ms)") plt.ylabel("count") plt.title( f"PTT distribution — {len(all_ptts)} clean beats, " f"{good_segments} segments, {len(by_shard)} shards" ) plt.legend() plt.tight_layout() plt.savefig(FIG_DIR / "ptt_histogram.png", dpi=120) plt.close() if sanity_samples: fig, axes = plt.subplots(len(sanity_samples), 1, figsize=(10, 2.4 * len(sanity_samples))) if len(sanity_samples) == 1: axes = [axes] for ax, (ecg, ppg, efs, pfs, name, ptts) in zip(axes, sanity_samples): t_e = np.arange(len(ecg)) / efs t_p = np.arange(len(ppg)) / pfs ax2 = ax.twinx() ax.plot(t_e, ecg, color="#266", lw=0.6, label="ECG II") ax2.plot(t_p, ppg, color="#b30", lw=0.6, label="PPG") ax.set_title( f"{name} PTT median={np.median(ptts):.0f} ms N={len(ptts)}", fontsize=8, ) ax.set_xlabel("time (s)") ax.set_ylabel("ECG", color="#266") ax2.set_ylabel("PPG", color="#b30") plt.tight_layout() plt.savefig(FIG_DIR / "sanity_check.png", dpi=120) plt.close() (OUT / "e0_report.json").write_text(json.dumps(report, indent=2)) print(json.dumps(report, indent=2)) if __name__ == "__main__": main()