PhysioJEPA / scripts /e1_ppg_encoding.py
guychuk's picture
Upload folder using huggingface_hub
31e2456 verified
"""E1 — PPG encoding decision: morphological vs raw patch.
Per the E1 decision rule in EXPERIMENT_TRACKING.md:
if morphology_extraction_rate < 0.70: -> raw patches
elif E1b_linear_probe_AUROC > E1a + 0.02: -> morphological
else: -> raw patches
This script implements Stage 1 (extraction rate) directly. If extraction rate
passes, we'd move to Stage 2 (linear probe comparison on AF) — but that
requires AF labels, which are pending. For now we decide Stage 1 and defer
Stage 2 until AF labels land.
Features extracted (Bishop & Ercole / neurokit2):
PPG_Rate, PPG_Width, PPG_UpstrokeSlope, PPG_Amplitude, PPG_DicroticNotch.
"""
from __future__ import annotations
import json
import os
import random
import re
import warnings
from pathlib import Path
import numpy as np
from dotenv import load_dotenv
from tqdm import tqdm
warnings.filterwarnings("ignore")
load_dotenv()
os.environ.setdefault("HF_TOKEN", os.environ.get("HUGGINGFACE_API_KEY", ""))
from datasets import load_from_disk
from huggingface_hub import snapshot_download
import neurokit2 as nk
REPO = "lucky9-cyou/mimic-iv-aligned-ppg-ecg"
OUT = Path(__file__).resolve().parent.parent / "docs"
RNG = random.Random(11)
def try_morphology(ppg: np.ndarray, fs: float) -> tuple[bool, int, int]:
"""Returns (ok, n_detected_beats, n_expected_beats).
`ok` is True if neurokit2 detects ≥5 valid beats AND the fraction
detected/expected > 0.70. Expected beats is duration * typical_hr (60-100).
"""
try:
signals, info = nk.ppg_process(ppg, sampling_rate=int(round(fs)))
peaks = np.asarray(info.get("PPG_Peaks", []))
if len(peaks) < 5:
return False, len(peaks), 0
duration_s = len(ppg) / fs
# Expected beats: use the detected rate itself for a robust estimate
detected_rate = signals["PPG_Rate"].dropna().median()
if not np.isfinite(detected_rate) or detected_rate < 30 or detected_rate > 200:
return False, len(peaks), 0
expected = int(duration_s * detected_rate / 60.0)
if expected < 3:
return False, len(peaks), expected
extracted_frac = len(peaks) / expected
return 0.70 <= extracted_frac <= 1.30, len(peaks), expected
except Exception:
return False, 0, 0
def main() -> None:
# Use shards we already have in cache (from E0 audits)
want = sorted(RNG.sample(range(412), 40))
root = Path(
snapshot_download(
REPO,
repo_type="dataset",
allow_patterns=[f"shard_{i:05d}/*" for i in want],
max_workers=12,
)
)
shards = [s for s in want if (root / f"shard_{s:05d}" / "dataset_info.json").exists()]
n_attempted = 0
n_ok = 0
n_nonempty = 0
beat_counts = []
target = 500
results = []
for sidx in tqdm(shards, desc="shards"):
if n_attempted >= target:
break
ds = load_from_disk(str(root / f"shard_{sidx:05d}"))
for i in range(len(ds)):
if n_attempted >= target:
break
row = ds[i]
ppg = np.asarray(row["ppg"], dtype=np.float32)[0]
fs = float(row["ppg_fs"])
n_attempted += 1
if ppg.size == 0:
continue
n_nonempty += 1
ok, got, exp = try_morphology(ppg, fs)
beat_counts.append(got)
if ok:
n_ok += 1
results.append(
{"record": row["record_name"], "ok": ok, "detected": got, "expected": exp}
)
extraction_rate = n_ok / max(n_nonempty, 1)
decision = "raw_patches" if extraction_rate < 0.70 else "needs_stage2_probe"
report = {
"n_segments_attempted": n_attempted,
"n_segments_nonempty": n_nonempty,
"n_segments_ok": n_ok,
"extraction_rate": extraction_rate,
"median_detected_beats_per_segment": (
float(np.median(beat_counts)) if beat_counts else 0.0
),
"mean_detected_beats_per_segment": (
float(np.mean(beat_counts)) if beat_counts else 0.0
),
"stage1_decision": decision,
"rule": (
"extraction_rate < 0.70 -> raw_patches (stop). "
"else -> run stage-2 linear-probe comparison after AF labels arrive."
),
}
(OUT / "e1_stage1_report.json").write_text(json.dumps(report, indent=2))
print(json.dumps(report, indent=2))
if __name__ == "__main__":
main()