"""E0 data audit for lucky9-cyou/mimic-iv-aligned-ppg-ecg. Computes: patient count, total hours, sample rates, alignment tolerance, PTT distribution, missing-value rate, and sanity plots. Strategy: stream across ALL shards for cheap metadata (record_name, fs, siglen, nan rates). Subsample shards for the expensive per-beat PTT computation. """ from __future__ import annotations import json import os import random import re from pathlib import Path import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np from dotenv import load_dotenv from scipy.signal import butter, filtfilt, find_peaks from tqdm import tqdm load_dotenv() os.environ.setdefault("HF_TOKEN", os.environ.get("HUGGINGFACE_API_KEY", "")) from datasets import load_from_disk from huggingface_hub import snapshot_download REPO = "lucky9-cyou/mimic-iv-aligned-ppg-ecg" N_SHARDS = 412 OUT = Path(__file__).resolve().parent.parent / "docs" FIG_DIR = OUT / "figures" FIG_DIR.mkdir(parents=True, exist_ok=True) RNG = random.Random(42) def parse_subject_id(record_name: str) -> str: m = re.match(r"p\d+/(p\d+)/", record_name) return m.group(1) if m else record_name.split("/")[0] def bandpass(x: np.ndarray, fs: float, lo: float, hi: float, order: int = 3) -> np.ndarray: ny = 0.5 * fs lo_n = max(lo / ny, 1e-4) hi_n = min(hi / ny, 0.99) b, a = butter(order, [lo_n, hi_n], btype="band") return filtfilt(b, a, x, method="gust") def pan_tompkins_lite(ecg: np.ndarray, fs: float) -> np.ndarray: """Simple QRS detector. Returns R-peak sample indices.""" x = bandpass(ecg, fs, 5.0, 15.0) d = np.diff(x, prepend=x[:1]) s = d * d w = max(int(0.12 * fs), 1) mwa = np.convolve(s, np.ones(w) / w, mode="same") thr = np.mean(mwa) + 0.5 * np.std(mwa) min_dist = int(0.3 * fs) # refractory 300 ms -> max 200 bpm peaks, _ = find_peaks(mwa, height=thr, distance=min_dist) # Snap to local max in the filtered ECG within ±60 ms snap = max(int(0.06 * fs), 1) refined = [] for p in peaks: lo = max(0, p - snap) hi = min(len(x), p + snap) if hi > lo: refined.append(lo + int(np.argmax(x[lo:hi]))) return np.asarray(refined, dtype=int) def ppg_systolic_peaks(ppg: np.ndarray, fs: float) -> np.ndarray: x = bandpass(ppg, fs, 0.5, 8.0) min_dist = int(0.3 * fs) thr = np.mean(x) + 0.3 * np.std(x) peaks, _ = find_peaks(x, distance=min_dist, height=thr, prominence=0.1 * np.std(x)) return peaks def compute_ptt_ms( ecg_lead: np.ndarray, ecg_fs: float, ppg: np.ndarray, ppg_fs: float, t0_ecg: float, t0_ppg: float, ) -> list[float]: """For each R-peak, find the next PPG systolic peak within [50, 500] ms.""" r_idx = pan_tompkins_lite(ecg_lead, ecg_fs) p_idx = ppg_systolic_peaks(ppg, ppg_fs) if len(r_idx) < 3 or len(p_idx) < 3: return [] r_t = t0_ecg + r_idx / ecg_fs p_t = t0_ppg + p_idx / ppg_fs ptts = [] j = 0 for rt in r_t: while j < len(p_t) and p_t[j] < rt + 0.050: j += 1 if j >= len(p_t): break dt = p_t[j] - rt if 0.050 <= dt <= 0.500: ptts.append(dt * 1000.0) return ptts def quick_snapshot(allow_shards: list[int]) -> str: patterns = ["metadata.json"] + [f"shard_{i:05d}/*" for i in allow_shards] return snapshot_download( REPO, repo_type="dataset", allow_patterns=patterns, max_workers=8 ) def main() -> None: # -------- Pass 1: metadata over a wide shard sample (cheap columns only) -------- # We want ≥500 patients confirmed and overall fs/siglen stats. # Sample 40 shards uniformly → ~4000 segments; should hit plenty of patients. meta_shards = sorted(RNG.sample(range(N_SHARDS), 40)) print(f"[pass 1] downloading metadata from {len(meta_shards)} shards") root = quick_snapshot(meta_shards) root_p = Path(root) patients: set[str] = set() total_duration_s = 0.0 ecg_fs_list: list[float] = [] ppg_fs_list: list[float] = [] ecg_siglen: list[int] = [] ppg_siglen: list[int] = [] ecg_names_seen: set[tuple[str, ...]] = set() ppg_names_seen: set[tuple[str, ...]] = set() n_segments = 0 missing_ecg = 0 missing_ppg = 0 nan_ecg_frac = [] nan_ppg_frac = [] # keep a reservoir of (shard_idx, within_shard_idx) candidates for PTT sampling reservoir: list[tuple[int, int]] = [] for sidx in tqdm(meta_shards, desc="shards(meta)"): ds = load_from_disk(str(root_p / f"shard_{sidx:05d}")) cols_cheap = ds.remove_columns( [c for c in ds.column_names if c in ("ecg", "ppg", "ecg_time_s", "ppg_time_s")] ) for i, row in enumerate(cols_cheap): patients.add(parse_subject_id(row["record_name"])) total_duration_s += float(row["segment_duration_sec"]) ecg_fs_list.append(float(row["ecg_fs"])) ppg_fs_list.append(float(row["ppg_fs"])) ecg_siglen.append(int(row["ecg_siglen"])) ppg_siglen.append(int(row["ppg_siglen"])) ecg_names_seen.add(tuple(row["ecg_names"])) ppg_names_seen.add(tuple(row["ppg_names"])) n_segments += 1 reservoir.append((sidx, i)) # -------- Pass 2: PTT + waveform stats on 100 random segments -------- RNG.shuffle(reservoir) ptt_targets = reservoir[:250] # oversample; some will fail QRS detection print(f"[pass 2] computing PTT on up to {len(ptt_targets)} segments") all_ptts: list[float] = [] per_segment_ptt_std: list[float] = [] per_patient_ptt_median: dict[str, list[float]] = {} sanity_samples = [] # (ecg_lead, ppg, ecg_fs, ppg_fs, record_name) want_sanity = 5 # group by shard to avoid reloading by_shard: dict[int, list[int]] = {} for s, i in ptt_targets: by_shard.setdefault(s, []).append(i) processed = 0 for sidx, idxs in tqdm(by_shard.items(), desc="shards(ptt)"): ds = load_from_disk(str(root_p / f"shard_{sidx:05d}")) for i in idxs: if processed >= 100: break row = ds[i] ecg = np.asarray(row["ecg"], dtype=np.float32) ppg = np.asarray(row["ppg"], dtype=np.float32) if ecg.size == 0 or ppg.size == 0: missing_ecg += ecg.size == 0 missing_ppg += ppg.size == 0 continue nan_ecg_frac.append(float(np.isnan(ecg).mean())) nan_ppg_frac.append(float(np.isnan(ppg).mean())) if np.isnan(ecg).any() or np.isnan(ppg).any(): ecg = np.nan_to_num(ecg, nan=0.0) ppg = np.nan_to_num(ppg, nan=0.0) ecg_lead = ecg[0] ppg_ch = ppg[0] ecg_fs = float(row["ecg_fs"]) ppg_fs = float(row["ppg_fs"]) t0_e = float(row["ecg_time_s"][0]) t0_p = float(row["ppg_time_s"][0]) ptts = compute_ptt_ms(ecg_lead, ecg_fs, ppg_ch, ppg_fs, t0_e, t0_p) if len(ptts) >= 3: all_ptts.extend(ptts) per_segment_ptt_std.append(float(np.std(ptts))) pid = parse_subject_id(row["record_name"]) per_patient_ptt_median.setdefault(pid, []).append(float(np.median(ptts))) if len(sanity_samples) < want_sanity: sanity_samples.append( (ecg_lead.copy(), ppg_ch.copy(), ecg_fs, ppg_fs, row["record_name"]) ) processed += 1 if processed >= 100: break # -------- Aggregate -------- ecg_fs_med = float(np.median(ecg_fs_list)) if ecg_fs_list else 0.0 ppg_fs_med = float(np.median(ppg_fs_list)) if ppg_fs_list else 0.0 total_hours_sampled = total_duration_s / 3600.0 # Extrapolate to full dataset (we sampled 40/412 shards) total_hours_estimated = total_hours_sampled * (N_SHARDS / len(meta_shards)) patients_sampled = len(patients) # Extrapolate patient count (patients typically distribute roughly uniformly across shards) # but with a coupon-collector cap; report both figures. ptt_median = float(np.median(all_ptts)) if all_ptts else float("nan") ptt_p5 = float(np.percentile(all_ptts, 5)) if all_ptts else float("nan") ptt_p95 = float(np.percentile(all_ptts, 95)) if all_ptts else float("nan") within_seg_std_median = ( float(np.median(per_segment_ptt_std)) if per_segment_ptt_std else float("nan") ) within_patient_std = [] for pid, meds in per_patient_ptt_median.items(): if len(meds) >= 2: within_patient_std.append(float(np.std(meds))) within_patient_std_median = ( float(np.median(within_patient_std)) if within_patient_std else float("nan") ) nan_ecg_frac_mean = float(np.mean(nan_ecg_frac)) if nan_ecg_frac else 0.0 nan_ppg_frac_mean = float(np.mean(nan_ppg_frac)) if nan_ppg_frac else 0.0 ptt_plausible_frac = ( float(np.mean([(50 <= p <= 500) for p in all_ptts])) if all_ptts else 0.0 ) # -------- Plots -------- if all_ptts: plt.figure(figsize=(7, 4)) plt.hist(all_ptts, bins=50, color="#3a7", edgecolor="black") plt.axvline(100, color="red", linestyle="--", alpha=0.5, label="100 ms (lower normal)") plt.axvline(400, color="red", linestyle="--", alpha=0.5, label="400 ms (upper normal)") plt.xlabel("PTT (ms)") plt.ylabel("count") plt.title(f"PTT distribution, N={len(all_ptts)} beats across {len(by_shard)} shards") plt.legend() plt.tight_layout() plt.savefig(FIG_DIR / "ptt_histogram.png", dpi=120) plt.close() if sanity_samples: fig, axes = plt.subplots(len(sanity_samples), 1, figsize=(10, 2.2 * len(sanity_samples))) if len(sanity_samples) == 1: axes = [axes] for ax, (ecg, ppg, efs, pfs, name) in zip(axes, sanity_samples): t_e = np.arange(len(ecg)) / efs t_p = np.arange(len(ppg)) / pfs ax2 = ax.twinx() ax.plot(t_e, ecg, color="#266", lw=0.6, label="ECG[0]") ax2.plot(t_p, ppg, color="#b30", lw=0.6, label="PPG") ax.set_title(name, fontsize=8) ax.set_xlabel("time (s)") ax.set_ylabel("ECG", color="#266") ax2.set_ylabel("PPG", color="#b30") plt.tight_layout() plt.savefig(FIG_DIR / "sanity_check.png", dpi=120) plt.close() # -------- Write JSON output -------- report = { "dataset": REPO, "shards_total": N_SHARDS, "shards_sampled_meta": len(meta_shards), "segments_meta_scanned": n_segments, "unique_patients_in_sample": patients_sampled, "total_duration_hours_sampled": round(total_hours_sampled, 2), "total_duration_hours_estimated": round(total_hours_estimated, 2), "ecg_fs_median_hz": ecg_fs_med, "ppg_fs_median_hz": ppg_fs_med, "ecg_siglen_median_samples": int(np.median(ecg_siglen)) if ecg_siglen else 0, "ppg_siglen_median_samples": int(np.median(ppg_siglen)) if ppg_siglen else 0, "ecg_leads_seen": [list(t) for t in list(ecg_names_seen)[:10]], "ppg_channels_seen": [list(t) for t in list(ppg_names_seen)[:10]], "n_ecg_lead_combinations": len(ecg_names_seen), "n_ppg_channel_combinations": len(ppg_names_seen), "missing_ecg_segments": missing_ecg, "missing_ppg_segments": missing_ppg, "nan_ecg_frac_mean": nan_ecg_frac_mean, "nan_ppg_frac_mean": nan_ppg_frac_mean, "ptt_beats_measured": len(all_ptts), "ptt_median_ms": ptt_median, "ptt_p5_ms": ptt_p5, "ptt_p95_ms": ptt_p95, "ptt_within_segment_std_median_ms": within_seg_std_median, "ptt_within_patient_std_median_ms": within_patient_std_median, "ptt_physio_plausible_frac": ptt_plausible_frac, } (OUT / "e0_report.json").write_text(json.dumps(report, indent=2)) print(json.dumps(report, indent=2)) if __name__ == "__main__": main()