| """E0 data audit for lucky9-cyou/mimic-iv-aligned-ppg-ecg. |
| |
| Computes: patient count, total hours, sample rates, alignment tolerance, |
| PTT distribution, missing-value rate, and sanity plots. |
| |
| Strategy: stream across ALL shards for cheap metadata (record_name, fs, siglen, |
| nan rates). Subsample shards for the expensive per-beat PTT computation. |
| """ |
| from __future__ import annotations |
|
|
| import json |
| import os |
| import random |
| import re |
| from pathlib import Path |
|
|
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| import numpy as np |
| from dotenv import load_dotenv |
| from scipy.signal import butter, filtfilt, find_peaks |
| from tqdm import tqdm |
|
|
| load_dotenv() |
| os.environ.setdefault("HF_TOKEN", os.environ.get("HUGGINGFACE_API_KEY", "")) |
|
|
| from datasets import load_from_disk |
| from huggingface_hub import snapshot_download |
|
|
| REPO = "lucky9-cyou/mimic-iv-aligned-ppg-ecg" |
| N_SHARDS = 412 |
| OUT = Path(__file__).resolve().parent.parent / "docs" |
| FIG_DIR = OUT / "figures" |
| FIG_DIR.mkdir(parents=True, exist_ok=True) |
|
|
| RNG = random.Random(42) |
|
|
|
|
| def parse_subject_id(record_name: str) -> str: |
| m = re.match(r"p\d+/(p\d+)/", record_name) |
| return m.group(1) if m else record_name.split("/")[0] |
|
|
|
|
| def bandpass(x: np.ndarray, fs: float, lo: float, hi: float, order: int = 3) -> np.ndarray: |
| ny = 0.5 * fs |
| lo_n = max(lo / ny, 1e-4) |
| hi_n = min(hi / ny, 0.99) |
| b, a = butter(order, [lo_n, hi_n], btype="band") |
| return filtfilt(b, a, x, method="gust") |
|
|
|
|
| def pan_tompkins_lite(ecg: np.ndarray, fs: float) -> np.ndarray: |
| """Simple QRS detector. Returns R-peak sample indices.""" |
| x = bandpass(ecg, fs, 5.0, 15.0) |
| d = np.diff(x, prepend=x[:1]) |
| s = d * d |
| w = max(int(0.12 * fs), 1) |
| mwa = np.convolve(s, np.ones(w) / w, mode="same") |
| thr = np.mean(mwa) + 0.5 * np.std(mwa) |
| min_dist = int(0.3 * fs) |
| peaks, _ = find_peaks(mwa, height=thr, distance=min_dist) |
| |
| snap = max(int(0.06 * fs), 1) |
| refined = [] |
| for p in peaks: |
| lo = max(0, p - snap) |
| hi = min(len(x), p + snap) |
| if hi > lo: |
| refined.append(lo + int(np.argmax(x[lo:hi]))) |
| return np.asarray(refined, dtype=int) |
|
|
|
|
| def ppg_systolic_peaks(ppg: np.ndarray, fs: float) -> np.ndarray: |
| x = bandpass(ppg, fs, 0.5, 8.0) |
| min_dist = int(0.3 * fs) |
| thr = np.mean(x) + 0.3 * np.std(x) |
| peaks, _ = find_peaks(x, distance=min_dist, height=thr, prominence=0.1 * np.std(x)) |
| return peaks |
|
|
|
|
| def compute_ptt_ms( |
| ecg_lead: np.ndarray, |
| ecg_fs: float, |
| ppg: np.ndarray, |
| ppg_fs: float, |
| t0_ecg: float, |
| t0_ppg: float, |
| ) -> list[float]: |
| """For each R-peak, find the next PPG systolic peak within [50, 500] ms.""" |
| r_idx = pan_tompkins_lite(ecg_lead, ecg_fs) |
| p_idx = ppg_systolic_peaks(ppg, ppg_fs) |
| if len(r_idx) < 3 or len(p_idx) < 3: |
| return [] |
| r_t = t0_ecg + r_idx / ecg_fs |
| p_t = t0_ppg + p_idx / ppg_fs |
| ptts = [] |
| j = 0 |
| for rt in r_t: |
| while j < len(p_t) and p_t[j] < rt + 0.050: |
| j += 1 |
| if j >= len(p_t): |
| break |
| dt = p_t[j] - rt |
| if 0.050 <= dt <= 0.500: |
| ptts.append(dt * 1000.0) |
| return ptts |
|
|
|
|
| def quick_snapshot(allow_shards: list[int]) -> str: |
| patterns = ["metadata.json"] + [f"shard_{i:05d}/*" for i in allow_shards] |
| return snapshot_download( |
| REPO, repo_type="dataset", allow_patterns=patterns, max_workers=8 |
| ) |
|
|
|
|
| def main() -> None: |
| |
| |
| |
| meta_shards = sorted(RNG.sample(range(N_SHARDS), 40)) |
| print(f"[pass 1] downloading metadata from {len(meta_shards)} shards") |
| root = quick_snapshot(meta_shards) |
| root_p = Path(root) |
|
|
| patients: set[str] = set() |
| total_duration_s = 0.0 |
| ecg_fs_list: list[float] = [] |
| ppg_fs_list: list[float] = [] |
| ecg_siglen: list[int] = [] |
| ppg_siglen: list[int] = [] |
| ecg_names_seen: set[tuple[str, ...]] = set() |
| ppg_names_seen: set[tuple[str, ...]] = set() |
| n_segments = 0 |
| missing_ecg = 0 |
| missing_ppg = 0 |
| nan_ecg_frac = [] |
| nan_ppg_frac = [] |
|
|
| |
| reservoir: list[tuple[int, int]] = [] |
|
|
| for sidx in tqdm(meta_shards, desc="shards(meta)"): |
| ds = load_from_disk(str(root_p / f"shard_{sidx:05d}")) |
| cols_cheap = ds.remove_columns( |
| [c for c in ds.column_names if c in ("ecg", "ppg", "ecg_time_s", "ppg_time_s")] |
| ) |
| for i, row in enumerate(cols_cheap): |
| patients.add(parse_subject_id(row["record_name"])) |
| total_duration_s += float(row["segment_duration_sec"]) |
| ecg_fs_list.append(float(row["ecg_fs"])) |
| ppg_fs_list.append(float(row["ppg_fs"])) |
| ecg_siglen.append(int(row["ecg_siglen"])) |
| ppg_siglen.append(int(row["ppg_siglen"])) |
| ecg_names_seen.add(tuple(row["ecg_names"])) |
| ppg_names_seen.add(tuple(row["ppg_names"])) |
| n_segments += 1 |
| reservoir.append((sidx, i)) |
|
|
| |
| RNG.shuffle(reservoir) |
| ptt_targets = reservoir[:250] |
| print(f"[pass 2] computing PTT on up to {len(ptt_targets)} segments") |
|
|
| all_ptts: list[float] = [] |
| per_segment_ptt_std: list[float] = [] |
| per_patient_ptt_median: dict[str, list[float]] = {} |
|
|
| sanity_samples = [] |
| want_sanity = 5 |
|
|
| |
| by_shard: dict[int, list[int]] = {} |
| for s, i in ptt_targets: |
| by_shard.setdefault(s, []).append(i) |
|
|
| processed = 0 |
| for sidx, idxs in tqdm(by_shard.items(), desc="shards(ptt)"): |
| ds = load_from_disk(str(root_p / f"shard_{sidx:05d}")) |
| for i in idxs: |
| if processed >= 100: |
| break |
| row = ds[i] |
| ecg = np.asarray(row["ecg"], dtype=np.float32) |
| ppg = np.asarray(row["ppg"], dtype=np.float32) |
| if ecg.size == 0 or ppg.size == 0: |
| missing_ecg += ecg.size == 0 |
| missing_ppg += ppg.size == 0 |
| continue |
| nan_ecg_frac.append(float(np.isnan(ecg).mean())) |
| nan_ppg_frac.append(float(np.isnan(ppg).mean())) |
| if np.isnan(ecg).any() or np.isnan(ppg).any(): |
| ecg = np.nan_to_num(ecg, nan=0.0) |
| ppg = np.nan_to_num(ppg, nan=0.0) |
| ecg_lead = ecg[0] |
| ppg_ch = ppg[0] |
| ecg_fs = float(row["ecg_fs"]) |
| ppg_fs = float(row["ppg_fs"]) |
| t0_e = float(row["ecg_time_s"][0]) |
| t0_p = float(row["ppg_time_s"][0]) |
| ptts = compute_ptt_ms(ecg_lead, ecg_fs, ppg_ch, ppg_fs, t0_e, t0_p) |
| if len(ptts) >= 3: |
| all_ptts.extend(ptts) |
| per_segment_ptt_std.append(float(np.std(ptts))) |
| pid = parse_subject_id(row["record_name"]) |
| per_patient_ptt_median.setdefault(pid, []).append(float(np.median(ptts))) |
| if len(sanity_samples) < want_sanity: |
| sanity_samples.append( |
| (ecg_lead.copy(), ppg_ch.copy(), ecg_fs, ppg_fs, row["record_name"]) |
| ) |
| processed += 1 |
| if processed >= 100: |
| break |
|
|
| |
| ecg_fs_med = float(np.median(ecg_fs_list)) if ecg_fs_list else 0.0 |
| ppg_fs_med = float(np.median(ppg_fs_list)) if ppg_fs_list else 0.0 |
| total_hours_sampled = total_duration_s / 3600.0 |
| |
| total_hours_estimated = total_hours_sampled * (N_SHARDS / len(meta_shards)) |
| patients_sampled = len(patients) |
| |
| |
|
|
| ptt_median = float(np.median(all_ptts)) if all_ptts else float("nan") |
| ptt_p5 = float(np.percentile(all_ptts, 5)) if all_ptts else float("nan") |
| ptt_p95 = float(np.percentile(all_ptts, 95)) if all_ptts else float("nan") |
| within_seg_std_median = ( |
| float(np.median(per_segment_ptt_std)) if per_segment_ptt_std else float("nan") |
| ) |
|
|
| within_patient_std = [] |
| for pid, meds in per_patient_ptt_median.items(): |
| if len(meds) >= 2: |
| within_patient_std.append(float(np.std(meds))) |
| within_patient_std_median = ( |
| float(np.median(within_patient_std)) if within_patient_std else float("nan") |
| ) |
|
|
| nan_ecg_frac_mean = float(np.mean(nan_ecg_frac)) if nan_ecg_frac else 0.0 |
| nan_ppg_frac_mean = float(np.mean(nan_ppg_frac)) if nan_ppg_frac else 0.0 |
| ptt_plausible_frac = ( |
| float(np.mean([(50 <= p <= 500) for p in all_ptts])) if all_ptts else 0.0 |
| ) |
|
|
| |
| if all_ptts: |
| plt.figure(figsize=(7, 4)) |
| plt.hist(all_ptts, bins=50, color="#3a7", edgecolor="black") |
| plt.axvline(100, color="red", linestyle="--", alpha=0.5, label="100 ms (lower normal)") |
| plt.axvline(400, color="red", linestyle="--", alpha=0.5, label="400 ms (upper normal)") |
| plt.xlabel("PTT (ms)") |
| plt.ylabel("count") |
| plt.title(f"PTT distribution, N={len(all_ptts)} beats across {len(by_shard)} shards") |
| plt.legend() |
| plt.tight_layout() |
| plt.savefig(FIG_DIR / "ptt_histogram.png", dpi=120) |
| plt.close() |
|
|
| if sanity_samples: |
| fig, axes = plt.subplots(len(sanity_samples), 1, figsize=(10, 2.2 * len(sanity_samples))) |
| if len(sanity_samples) == 1: |
| axes = [axes] |
| for ax, (ecg, ppg, efs, pfs, name) in zip(axes, sanity_samples): |
| t_e = np.arange(len(ecg)) / efs |
| t_p = np.arange(len(ppg)) / pfs |
| ax2 = ax.twinx() |
| ax.plot(t_e, ecg, color="#266", lw=0.6, label="ECG[0]") |
| ax2.plot(t_p, ppg, color="#b30", lw=0.6, label="PPG") |
| ax.set_title(name, fontsize=8) |
| ax.set_xlabel("time (s)") |
| ax.set_ylabel("ECG", color="#266") |
| ax2.set_ylabel("PPG", color="#b30") |
| plt.tight_layout() |
| plt.savefig(FIG_DIR / "sanity_check.png", dpi=120) |
| plt.close() |
|
|
| |
| report = { |
| "dataset": REPO, |
| "shards_total": N_SHARDS, |
| "shards_sampled_meta": len(meta_shards), |
| "segments_meta_scanned": n_segments, |
| "unique_patients_in_sample": patients_sampled, |
| "total_duration_hours_sampled": round(total_hours_sampled, 2), |
| "total_duration_hours_estimated": round(total_hours_estimated, 2), |
| "ecg_fs_median_hz": ecg_fs_med, |
| "ppg_fs_median_hz": ppg_fs_med, |
| "ecg_siglen_median_samples": int(np.median(ecg_siglen)) if ecg_siglen else 0, |
| "ppg_siglen_median_samples": int(np.median(ppg_siglen)) if ppg_siglen else 0, |
| "ecg_leads_seen": [list(t) for t in list(ecg_names_seen)[:10]], |
| "ppg_channels_seen": [list(t) for t in list(ppg_names_seen)[:10]], |
| "n_ecg_lead_combinations": len(ecg_names_seen), |
| "n_ppg_channel_combinations": len(ppg_names_seen), |
| "missing_ecg_segments": missing_ecg, |
| "missing_ppg_segments": missing_ppg, |
| "nan_ecg_frac_mean": nan_ecg_frac_mean, |
| "nan_ppg_frac_mean": nan_ppg_frac_mean, |
| "ptt_beats_measured": len(all_ptts), |
| "ptt_median_ms": ptt_median, |
| "ptt_p5_ms": ptt_p5, |
| "ptt_p95_ms": ptt_p95, |
| "ptt_within_segment_std_median_ms": within_seg_std_median, |
| "ptt_within_patient_std_median_ms": within_patient_std_median, |
| "ptt_physio_plausible_frac": ptt_plausible_frac, |
| } |
| (OUT / "e0_report.json").write_text(json.dumps(report, indent=2)) |
| print(json.dumps(report, indent=2)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|