misscp / tests /test_grud_experiment.py
Anonymous
Initial anonymous MissCP release
32f5a65
from __future__ import annotations
import json
from pathlib import Path
import numpy as np
import pandas as pd
from sepsis_mcp.grud_experiment import (
GRUDRunConfig,
_split_paths,
_positive_class_weight,
_select_validation_score,
build_parser,
main,
)
def _write_patient(path: Path, *, positive: bool, shifted: bool) -> None:
base_hr = 78 if positive else 62
if shifted:
base_hr += 6
rows = ["HR|O2Sat|Age|Gender|Unit1|Unit2|HospAdmTime|ICULOS|SepsisLabel"]
for hour in range(1, 9):
hr = base_hr + hour * (3 if positive else 1)
o2sat = 96 - hour if positive else 98 - (hour % 2)
rows.append(
f"{hr}|{o2sat}|65|1|1|0|-5|{hour}|{1 if positive and hour >= 5 else 0}"
)
path.write_text("\n".join(rows), encoding="utf-8")
def test_grud_experiment_writes_metrics_and_predictions(tmp_path: Path) -> None:
data_root = tmp_path / "training"
training_a = data_root / "training_setA"
training_b = data_root / "training_setB"
training_a.mkdir(parents=True)
training_b.mkdir(parents=True)
for index in range(8):
_write_patient(
training_a / f"pA{index:03d}.psv",
positive=index % 2 == 0,
shifted=False,
)
for index in range(4):
_write_patient(
training_b / f"pB{index:03d}.psv",
positive=index % 2 == 0,
shifted=True,
)
output_dir = tmp_path / "grud-output"
main(
[
"--data-root",
str(data_root),
"--train-hospital",
"A",
"--test-mode",
"both",
"--train-patients",
"4",
"--validation-patients",
"2",
"--test-patients",
"2",
"--lookback-hours",
"3",
"--horizon-hours",
"2",
"--epochs",
"1",
"--batch-size",
"8",
"--learning-rate",
"0.001",
"--output-dir",
str(output_dir),
]
)
metrics = json.loads((output_dir / "metrics.json").read_text(encoding="utf-8"))
predictions = pd.read_csv(output_dir / "predictions.csv")
history = pd.read_csv(output_dir / "training_history.csv")
assert set(metrics["experiments"]) == {"A_to_A", "A_to_B"}
assert "auroc" in metrics["experiments"]["A_to_A"]
assert "auprc" in metrics["experiments"]["A_to_B"]
assert "train_positive_rate" in metrics["experiments"]["A_to_A"]
assert "best_validation_metric" in metrics["experiments"]["A_to_A"]
assert "best_epoch" in metrics["experiments"]["A_to_A"]
assert "last_epoch" in metrics["experiments"]["A_to_A"]
assert "best_train_loss" in metrics["experiments"]["A_to_A"]
assert "best_validation_loss" in metrics["experiments"]["A_to_A"]
assert "last_train_auroc" in metrics["experiments"]["A_to_A"]
assert "last_validation_auprc" in metrics["experiments"]["A_to_A"]
assert {"experiment", "patient_id", "sample_index", "probability", "label"} <= set(predictions.columns)
assert {
"experiment",
"epoch",
"train_loss",
"train_auroc",
"train_auprc",
"validation_loss",
"validation_auroc",
"validation_auprc",
} <= set(history.columns)
assert set(history["experiment"]) == {"A_to_A", "A_to_B"}
assert history["epoch"].min() == 1
def test_positive_class_weight_is_greater_than_one_for_imbalanced_labels() -> None:
labels = np.array([0, 0, 0, 0, 1], dtype=np.float32)
weight = _positive_class_weight(labels)
assert weight > 1.0
def test_positive_class_weight_respects_max_cap() -> None:
labels = np.array([0, 0, 0, 0, 1], dtype=np.float32)
weight = _positive_class_weight(labels, max_weight=3.0)
assert weight == 3.0
def test_select_validation_score_prefers_loss_when_configured() -> None:
summary = _select_validation_score(
metrics={"auroc": 0.7, "auprc": 0.2},
loss=0.4,
selection_metric="loss",
)
assert summary["metric_name"] == "neg_loss"
assert summary["metric_value"] == -0.4
def test_build_parser_defaults_match_run_config() -> None:
parser = build_parser()
args = parser.parse_args(
[
"--data-root",
"/tmp/data",
"--output-dir",
"/tmp/output",
]
)
assert args.epochs == GRUDRunConfig(data_root=Path("/tmp/data"), output_dir=Path("/tmp/output")).epochs
assert args.hidden_size == 24
assert args.patience == 3
assert args.weight_decay == 3e-4
assert args.rebalance_strategy == "none"
assert args.max_pos_weight == 10.0
assert args.selection_metric == "loss"
def test_split_paths_shuffles_and_stratifies_patients_for_grud(tmp_path: Path) -> None:
data_root = tmp_path / "training"
training_a = data_root / "training_setA"
training_b = data_root / "training_setB"
training_a.mkdir(parents=True)
training_b.mkdir(parents=True)
for index in range(12):
_write_patient(
training_a / f"pA{index:03d}.psv",
positive=index % 2 == 0,
shifted=False,
)
for index in range(6):
_write_patient(
training_b / f"pB{index:03d}.psv",
positive=index % 2 == 0,
shifted=True,
)
split_zero = _split_paths(
GRUDRunConfig(
data_root=data_root,
train_hospital="A",
test_mode="both",
train_patients=4,
validation_patients=4,
test_patients=2,
random_state=0,
output_dir=tmp_path / "out-zero",
)
)
split_one = _split_paths(
GRUDRunConfig(
data_root=data_root,
train_hospital="A",
test_mode="both",
train_patients=4,
validation_patients=4,
test_patients=2,
random_state=1,
output_dir=tmp_path / "out-one",
)
)
aa_zero = split_zero["A_to_A"]
aa_one = split_one["A_to_A"]
ab_zero = split_zero["A_to_B"]
assert len(aa_zero[0]) == 4
assert len(aa_zero[1]) == 4
assert len(aa_zero[2]) == 2
assert len(ab_zero[2]) == 2
assert [path.name for path in aa_zero[0]] != [path.name for path in aa_one[0]]