File size: 4,504 Bytes
31e2456 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | """E1 — PPG encoding decision: morphological vs raw patch.
Per the E1 decision rule in EXPERIMENT_TRACKING.md:
if morphology_extraction_rate < 0.70: -> raw patches
elif E1b_linear_probe_AUROC > E1a + 0.02: -> morphological
else: -> raw patches
This script implements Stage 1 (extraction rate) directly. If extraction rate
passes, we'd move to Stage 2 (linear probe comparison on AF) — but that
requires AF labels, which are pending. For now we decide Stage 1 and defer
Stage 2 until AF labels land.
Features extracted (Bishop & Ercole / neurokit2):
PPG_Rate, PPG_Width, PPG_UpstrokeSlope, PPG_Amplitude, PPG_DicroticNotch.
"""
from __future__ import annotations
import json
import os
import random
import re
import warnings
from pathlib import Path
import numpy as np
from dotenv import load_dotenv
from tqdm import tqdm
warnings.filterwarnings("ignore")
load_dotenv()
os.environ.setdefault("HF_TOKEN", os.environ.get("HUGGINGFACE_API_KEY", ""))
from datasets import load_from_disk
from huggingface_hub import snapshot_download
import neurokit2 as nk
REPO = "lucky9-cyou/mimic-iv-aligned-ppg-ecg"
OUT = Path(__file__).resolve().parent.parent / "docs"
RNG = random.Random(11)
def try_morphology(ppg: np.ndarray, fs: float) -> tuple[bool, int, int]:
"""Returns (ok, n_detected_beats, n_expected_beats).
`ok` is True if neurokit2 detects ≥5 valid beats AND the fraction
detected/expected > 0.70. Expected beats is duration * typical_hr (60-100).
"""
try:
signals, info = nk.ppg_process(ppg, sampling_rate=int(round(fs)))
peaks = np.asarray(info.get("PPG_Peaks", []))
if len(peaks) < 5:
return False, len(peaks), 0
duration_s = len(ppg) / fs
# Expected beats: use the detected rate itself for a robust estimate
detected_rate = signals["PPG_Rate"].dropna().median()
if not np.isfinite(detected_rate) or detected_rate < 30 or detected_rate > 200:
return False, len(peaks), 0
expected = int(duration_s * detected_rate / 60.0)
if expected < 3:
return False, len(peaks), expected
extracted_frac = len(peaks) / expected
return 0.70 <= extracted_frac <= 1.30, len(peaks), expected
except Exception:
return False, 0, 0
def main() -> None:
# Use shards we already have in cache (from E0 audits)
want = sorted(RNG.sample(range(412), 40))
root = Path(
snapshot_download(
REPO,
repo_type="dataset",
allow_patterns=[f"shard_{i:05d}/*" for i in want],
max_workers=12,
)
)
shards = [s for s in want if (root / f"shard_{s:05d}" / "dataset_info.json").exists()]
n_attempted = 0
n_ok = 0
n_nonempty = 0
beat_counts = []
target = 500
results = []
for sidx in tqdm(shards, desc="shards"):
if n_attempted >= target:
break
ds = load_from_disk(str(root / f"shard_{sidx:05d}"))
for i in range(len(ds)):
if n_attempted >= target:
break
row = ds[i]
ppg = np.asarray(row["ppg"], dtype=np.float32)[0]
fs = float(row["ppg_fs"])
n_attempted += 1
if ppg.size == 0:
continue
n_nonempty += 1
ok, got, exp = try_morphology(ppg, fs)
beat_counts.append(got)
if ok:
n_ok += 1
results.append(
{"record": row["record_name"], "ok": ok, "detected": got, "expected": exp}
)
extraction_rate = n_ok / max(n_nonempty, 1)
decision = "raw_patches" if extraction_rate < 0.70 else "needs_stage2_probe"
report = {
"n_segments_attempted": n_attempted,
"n_segments_nonempty": n_nonempty,
"n_segments_ok": n_ok,
"extraction_rate": extraction_rate,
"median_detected_beats_per_segment": (
float(np.median(beat_counts)) if beat_counts else 0.0
),
"mean_detected_beats_per_segment": (
float(np.mean(beat_counts)) if beat_counts else 0.0
),
"stage1_decision": decision,
"rule": (
"extraction_rate < 0.70 -> raw_patches (stop). "
"else -> run stage-2 linear-probe comparison after AF labels arrive."
),
}
(OUT / "e1_stage1_report.json").write_text(json.dumps(report, indent=2))
print(json.dumps(report, indent=2))
if __name__ == "__main__":
main()
|