import torch from train.losses import LossWeights, compute_total_loss def test_planner_gradient_flow(): planner_scores = torch.tensor([[0.2, -0.1, 0.0]], requires_grad=True) success_logits = torch.tensor([[0.0, 0.0, 0.0]], requires_grad=True) risk_values = torch.tensor([[0.2, 0.2, 0.2]], requires_grad=True) model_output = { "action_mean": torch.zeros(1, 1, 14), "planner_scores": planner_scores, "planner_success_logits": success_logits, "planner_risk_values": risk_values, } batch = { "action_chunk": torch.zeros(1, 1, 14), "candidate_retrieval_success": torch.tensor([[1.0, 0.0, 0.0]]), "candidate_final_disturbance_cost": torch.tensor([[0.0, 0.2, 0.4]]), "candidate_reocclusion_rate": torch.tensor([[0.0, 0.1, 0.2]]), "candidate_utility": torch.tensor([[1.0, 0.1, -0.6]]), } weights = LossWeights(action=0.0, planner_success=0.2, planner_risk=0.1, planner_ranking=1.0) losses = compute_total_loss(model_output, batch, weights=weights) losses["total"].backward() assert planner_scores.grad is not None assert success_logits.grad is not None assert risk_values.grad is not None assert float(planner_scores.grad.abs().sum()) > 0.0