| from __future__ import annotations |
|
|
| import numpy as np |
| import pandas as pd |
|
|
| from sepsis_mcp.grud_data import ( |
| build_patient_grud_samples, |
| fit_grud_scaler, |
| stack_grud_samples, |
| transform_grud_stacked, |
| ) |
|
|
|
|
| def _patient_frame() -> pd.DataFrame: |
| return pd.DataFrame( |
| { |
| "HR": [80.0, 82.0, 84.0, 86.0], |
| "O2Sat": [95.0, np.nan, 93.0, 92.0], |
| "Age": [65.0] * 4, |
| "Gender": [1.0] * 4, |
| "Unit1": [1.0] * 4, |
| "Unit2": [0.0] * 4, |
| "HospAdmTime": [-5.0] * 4, |
| "ICULOS": [1.0, 2.0, 3.0, 4.0], |
| "SepsisLabel": [0, 0, 1, 1], |
| } |
| ) |
|
|
|
|
| def test_build_patient_grud_samples_pads_left_and_aligns_future_targets() -> None: |
| samples = build_patient_grud_samples( |
| _patient_frame(), |
| patient_id="p000001", |
| dynamic_columns=["HR", "O2Sat"], |
| lookback_hours=3, |
| horizon_hours=1, |
| ) |
|
|
| assert len(samples) == 3 |
| first = samples[0] |
| assert first["patient_id"] == "p000001" |
| assert first["sample_index"] == 0 |
| assert first["label"] == 0 |
| assert first["values"].shape == (3, 2) |
| assert first["masks"].shape == (3, 2) |
| assert first["deltas"].shape == (3, 2) |
| assert first["static"].shape == (5,) |
| assert first["values"].tolist() == [ |
| [0.0, 0.0], |
| [0.0, 0.0], |
| [80.0, 95.0], |
| ] |
| assert first["masks"].tolist() == [ |
| [0.0, 0.0], |
| [0.0, 0.0], |
| [1.0, 1.0], |
| ] |
|
|
|
|
| def test_build_patient_grud_samples_computes_featurewise_deltas_for_missing_values() -> None: |
| samples = build_patient_grud_samples( |
| _patient_frame(), |
| patient_id="p000001", |
| dynamic_columns=["HR", "O2Sat"], |
| lookback_hours=3, |
| horizon_hours=1, |
| ) |
|
|
| second = samples[1] |
|
|
| assert second["values"].tolist() == [ |
| [0.0, 0.0], |
| [80.0, 95.0], |
| [82.0, 0.0], |
| ] |
| assert second["masks"].tolist() == [ |
| [0.0, 0.0], |
| [1.0, 1.0], |
| [1.0, 0.0], |
| ] |
| assert second["deltas"].tolist() == [ |
| [0.0, 0.0], |
| [0.0, 0.0], |
| [0.0, 1.0], |
| ] |
| assert second["label"] == 1 |
| assert np.isclose(second["global_missing_rate"], 0.5) |
|
|
|
|
| def test_build_patient_grud_samples_replaces_missing_static_values_with_zero() -> None: |
| frame = _patient_frame() |
| frame.loc[0, "Unit1"] = np.nan |
| frame.loc[0, "Unit2"] = np.nan |
|
|
| samples = build_patient_grud_samples( |
| frame, |
| patient_id="p000001", |
| dynamic_columns=["HR", "O2Sat"], |
| lookback_hours=3, |
| horizon_hours=1, |
| ) |
|
|
| assert samples[0]["static"].tolist() == [65.0, 1.0, 0.0, 0.0, -5.0] |
|
|
|
|
| def test_transform_grud_stacked_standardizes_observed_values_and_preserves_missing_zero() -> None: |
| samples = build_patient_grud_samples( |
| _patient_frame(), |
| patient_id="p000001", |
| dynamic_columns=["HR", "O2Sat"], |
| lookback_hours=3, |
| horizon_hours=1, |
| ) |
| stacked = stack_grud_samples(samples) |
| scaler = fit_grud_scaler(stacked) |
|
|
| transformed = transform_grud_stacked(stacked, scaler) |
|
|
| observed_hr = transformed["values"][stacked["masks"][:, :, 0] == 1, 0] |
| assert np.isclose(observed_hr.mean(), 0.0, atol=1e-5) |
| assert np.isclose(observed_hr.std(), 1.0, atol=1e-5) |
| assert transformed["values"][1, 2, 1] == 0.0 |
|
|
|
|
| def test_transform_grud_stacked_standardizes_static_features_after_imputation() -> None: |
| samples = build_patient_grud_samples( |
| _patient_frame(), |
| patient_id="p000001", |
| dynamic_columns=["HR", "O2Sat"], |
| lookback_hours=3, |
| horizon_hours=1, |
| ) |
| stacked = stack_grud_samples(samples) |
| scaler = fit_grud_scaler(stacked) |
|
|
| transformed = transform_grud_stacked(stacked, scaler) |
|
|
| assert np.all(np.isfinite(transformed["static"])) |
| assert np.isclose(transformed["static"][:, 0].mean(), 0.0, atol=1e-6) |
|
|
|
|
| def test_stack_grud_samples_preserves_sample_missing_rates() -> None: |
| samples = build_patient_grud_samples( |
| _patient_frame(), |
| patient_id="p000001", |
| dynamic_columns=["HR", "O2Sat"], |
| lookback_hours=3, |
| horizon_hours=1, |
| ) |
|
|
| stacked = stack_grud_samples(samples) |
|
|
| assert "global_missing_rates" in stacked |
| assert stacked["global_missing_rates"].shape == (3,) |
| assert np.isclose(stacked["global_missing_rates"][0], 2.0 / 3.0) |
|
|