| import pytest | |
| torch = pytest.importorskip("torch") | |
| def test_student_model_forward_shapes(): | |
| from ylff.models.metric_depth_with_uncertainty import MetricDepthWithUncertainty | |
| model = MetricDepthWithUncertainty(temporal_window=5) | |
| x = torch.randn(2, 5, 3, 32, 32) | |
| out = model(x) | |
| assert out.depth.shape == (2, 32, 32) | |
| assert out.log_sigma.shape == (2, 32, 32) | |
| def test_losses_finite(): | |
| from ylff.services.training.losses import compute_losses | |
| depth_pred = torch.ones(1, 8, 8) * 2.0 | |
| log_sigma = torch.zeros(1, 8, 8) - 2.0 | |
| depth_gt = torch.ones(1, 8, 8) * 2.1 | |
| losses = compute_losses(depth_pred, log_sigma, depth_gt) | |
| assert torch.isfinite(losses.total) | |