misscp / tests /test_grud_data.py
Anonymous
Initial anonymous MissCP release
32f5a65
from __future__ import annotations
import numpy as np
import pandas as pd
from sepsis_mcp.grud_data import (
build_patient_grud_samples,
fit_grud_scaler,
stack_grud_samples,
transform_grud_stacked,
)
def _patient_frame() -> pd.DataFrame:
return pd.DataFrame(
{
"HR": [80.0, 82.0, 84.0, 86.0],
"O2Sat": [95.0, np.nan, 93.0, 92.0],
"Age": [65.0] * 4,
"Gender": [1.0] * 4,
"Unit1": [1.0] * 4,
"Unit2": [0.0] * 4,
"HospAdmTime": [-5.0] * 4,
"ICULOS": [1.0, 2.0, 3.0, 4.0],
"SepsisLabel": [0, 0, 1, 1],
}
)
def test_build_patient_grud_samples_pads_left_and_aligns_future_targets() -> None:
samples = build_patient_grud_samples(
_patient_frame(),
patient_id="p000001",
dynamic_columns=["HR", "O2Sat"],
lookback_hours=3,
horizon_hours=1,
)
assert len(samples) == 3
first = samples[0]
assert first["patient_id"] == "p000001"
assert first["sample_index"] == 0
assert first["label"] == 0
assert first["values"].shape == (3, 2)
assert first["masks"].shape == (3, 2)
assert first["deltas"].shape == (3, 2)
assert first["static"].shape == (5,)
assert first["values"].tolist() == [
[0.0, 0.0],
[0.0, 0.0],
[80.0, 95.0],
]
assert first["masks"].tolist() == [
[0.0, 0.0],
[0.0, 0.0],
[1.0, 1.0],
]
def test_build_patient_grud_samples_computes_featurewise_deltas_for_missing_values() -> None:
samples = build_patient_grud_samples(
_patient_frame(),
patient_id="p000001",
dynamic_columns=["HR", "O2Sat"],
lookback_hours=3,
horizon_hours=1,
)
second = samples[1]
assert second["values"].tolist() == [
[0.0, 0.0],
[80.0, 95.0],
[82.0, 0.0],
]
assert second["masks"].tolist() == [
[0.0, 0.0],
[1.0, 1.0],
[1.0, 0.0],
]
assert second["deltas"].tolist() == [
[0.0, 0.0],
[0.0, 0.0],
[0.0, 1.0],
]
assert second["label"] == 1
assert np.isclose(second["global_missing_rate"], 0.5)
def test_build_patient_grud_samples_replaces_missing_static_values_with_zero() -> None:
frame = _patient_frame()
frame.loc[0, "Unit1"] = np.nan
frame.loc[0, "Unit2"] = np.nan
samples = build_patient_grud_samples(
frame,
patient_id="p000001",
dynamic_columns=["HR", "O2Sat"],
lookback_hours=3,
horizon_hours=1,
)
assert samples[0]["static"].tolist() == [65.0, 1.0, 0.0, 0.0, -5.0]
def test_transform_grud_stacked_standardizes_observed_values_and_preserves_missing_zero() -> None:
samples = build_patient_grud_samples(
_patient_frame(),
patient_id="p000001",
dynamic_columns=["HR", "O2Sat"],
lookback_hours=3,
horizon_hours=1,
)
stacked = stack_grud_samples(samples)
scaler = fit_grud_scaler(stacked)
transformed = transform_grud_stacked(stacked, scaler)
observed_hr = transformed["values"][stacked["masks"][:, :, 0] == 1, 0]
assert np.isclose(observed_hr.mean(), 0.0, atol=1e-5)
assert np.isclose(observed_hr.std(), 1.0, atol=1e-5)
assert transformed["values"][1, 2, 1] == 0.0
def test_transform_grud_stacked_standardizes_static_features_after_imputation() -> None:
samples = build_patient_grud_samples(
_patient_frame(),
patient_id="p000001",
dynamic_columns=["HR", "O2Sat"],
lookback_hours=3,
horizon_hours=1,
)
stacked = stack_grud_samples(samples)
scaler = fit_grud_scaler(stacked)
transformed = transform_grud_stacked(stacked, scaler)
assert np.all(np.isfinite(transformed["static"]))
assert np.isclose(transformed["static"][:, 0].mean(), 0.0, atol=1e-6)
def test_stack_grud_samples_preserves_sample_missing_rates() -> None:
samples = build_patient_grud_samples(
_patient_frame(),
patient_id="p000001",
dynamic_columns=["HR", "O2Sat"],
lookback_hours=3,
horizon_hours=1,
)
stacked = stack_grud_samples(samples)
assert "global_missing_rates" in stacked
assert stacked["global_missing_rates"].shape == (3,)
assert np.isclose(stacked["global_missing_rates"][0], 2.0 / 3.0)