| from __future__ import annotations | |
| import torch | |
| from sepsis_mcp.grud_model import GRUDModel | |
| def test_grud_forward_returns_batch_probabilities() -> None: | |
| model = GRUDModel( | |
| input_size=2, | |
| static_size=5, | |
| hidden_size=8, | |
| ) | |
| values = torch.tensor( | |
| [ | |
| [[80.0, 95.0], [82.0, 0.0], [84.0, 93.0]], | |
| [[70.0, 97.0], [71.0, 96.0], [72.0, 95.0]], | |
| ], | |
| dtype=torch.float32, | |
| ) | |
| masks = torch.tensor( | |
| [ | |
| [[1.0, 1.0], [1.0, 0.0], [1.0, 1.0]], | |
| [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]], | |
| ], | |
| dtype=torch.float32, | |
| ) | |
| deltas = torch.tensor( | |
| [ | |
| [[0.0, 0.0], [0.0, 1.0], [0.0, 0.0]], | |
| [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], | |
| ], | |
| dtype=torch.float32, | |
| ) | |
| static = torch.ones((2, 5), dtype=torch.float32) | |
| probabilities = model(values, masks, deltas, static) | |
| assert probabilities.shape == (2,) | |
| assert torch.all(probabilities >= 0.0) | |
| assert torch.all(probabilities <= 1.0) | |
| def test_grud_forward_logits_matches_probability_output() -> None: | |
| model = GRUDModel( | |
| input_size=2, | |
| static_size=5, | |
| hidden_size=8, | |
| ) | |
| values = torch.rand((2, 3, 2), dtype=torch.float32) | |
| masks = torch.ones((2, 3, 2), dtype=torch.float32) | |
| deltas = torch.zeros((2, 3, 2), dtype=torch.float32) | |
| static = torch.rand((2, 5), dtype=torch.float32) | |
| logits = model.forward_logits(values, masks, deltas, static) | |
| probabilities = model(values, masks, deltas, static) | |
| assert logits.shape == (2,) | |
| assert torch.allclose(torch.sigmoid(logits), probabilities) | |
| def test_grud_supports_single_optimizer_step() -> None: | |
| torch.manual_seed(0) | |
| model = GRUDModel( | |
| input_size=2, | |
| static_size=5, | |
| hidden_size=8, | |
| ) | |
| optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) | |
| criterion = torch.nn.BCELoss() | |
| values = torch.rand((4, 3, 2), dtype=torch.float32) | |
| masks = torch.randint(0, 2, (4, 3, 2), dtype=torch.float32) | |
| deltas = torch.rand((4, 3, 2), dtype=torch.float32) | |
| static = torch.rand((4, 5), dtype=torch.float32) | |
| labels = torch.tensor([0.0, 1.0, 0.0, 1.0], dtype=torch.float32) | |
| optimizer.zero_grad() | |
| probabilities = model(values, masks, deltas, static) | |
| loss = criterion(probabilities, labels) | |
| loss.backward() | |
| optimizer.step() | |
| assert torch.isfinite(loss) | |