3d_model / tests /test_model_and_losses_torch_optional.py
Azan
Clean deployment build (Squashed)
7a87926
import pytest
torch = pytest.importorskip("torch")
def test_student_model_forward_shapes():
from ylff.models.metric_depth_with_uncertainty import MetricDepthWithUncertainty
model = MetricDepthWithUncertainty(temporal_window=5)
x = torch.randn(2, 5, 3, 32, 32)
out = model(x)
assert out.depth.shape == (2, 32, 32)
assert out.log_sigma.shape == (2, 32, 32)
def test_losses_finite():
from ylff.services.training.losses import compute_losses
depth_pred = torch.ones(1, 8, 8) * 2.0
log_sigma = torch.zeros(1, 8, 8) - 2.0
depth_gt = torch.ones(1, 8, 8) * 2.1
losses = compute_losses(depth_pred, log_sigma, depth_gt)
assert torch.isfinite(losses.total)