from __future__ import annotations import torch from sepsis_mcp.grud_model import GRUDModel def test_grud_forward_returns_batch_probabilities() -> None: model = GRUDModel( input_size=2, static_size=5, hidden_size=8, ) values = torch.tensor( [ [[80.0, 95.0], [82.0, 0.0], [84.0, 93.0]], [[70.0, 97.0], [71.0, 96.0], [72.0, 95.0]], ], dtype=torch.float32, ) masks = torch.tensor( [ [[1.0, 1.0], [1.0, 0.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]], ], dtype=torch.float32, ) deltas = torch.tensor( [ [[0.0, 0.0], [0.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], ], dtype=torch.float32, ) static = torch.ones((2, 5), dtype=torch.float32) probabilities = model(values, masks, deltas, static) assert probabilities.shape == (2,) assert torch.all(probabilities >= 0.0) assert torch.all(probabilities <= 1.0) def test_grud_forward_logits_matches_probability_output() -> None: model = GRUDModel( input_size=2, static_size=5, hidden_size=8, ) values = torch.rand((2, 3, 2), dtype=torch.float32) masks = torch.ones((2, 3, 2), dtype=torch.float32) deltas = torch.zeros((2, 3, 2), dtype=torch.float32) static = torch.rand((2, 5), dtype=torch.float32) logits = model.forward_logits(values, masks, deltas, static) probabilities = model(values, masks, deltas, static) assert logits.shape == (2,) assert torch.allclose(torch.sigmoid(logits), probabilities) def test_grud_supports_single_optimizer_step() -> None: torch.manual_seed(0) model = GRUDModel( input_size=2, static_size=5, hidden_size=8, ) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) criterion = torch.nn.BCELoss() values = torch.rand((4, 3, 2), dtype=torch.float32) masks = torch.randint(0, 2, (4, 3, 2), dtype=torch.float32) deltas = torch.rand((4, 3, 2), dtype=torch.float32) static = torch.rand((4, 5), dtype=torch.float32) labels = torch.tensor([0.0, 1.0, 0.0, 1.0], dtype=torch.float32) optimizer.zero_grad() probabilities = model(values, masks, deltas, static) loss = criterion(probabilities, labels) loss.backward() optimizer.step() assert torch.isfinite(loss)