| import torch | |
| from train.losses import LossWeights, compute_total_loss | |
| def test_planner_gradient_flow(): | |
| planner_scores = torch.tensor([[0.2, -0.1, 0.0]], requires_grad=True) | |
| success_logits = torch.tensor([[0.0, 0.0, 0.0]], requires_grad=True) | |
| risk_values = torch.tensor([[0.2, 0.2, 0.2]], requires_grad=True) | |
| model_output = { | |
| "action_mean": torch.zeros(1, 1, 14), | |
| "planner_scores": planner_scores, | |
| "planner_success_logits": success_logits, | |
| "planner_risk_values": risk_values, | |
| } | |
| batch = { | |
| "action_chunk": torch.zeros(1, 1, 14), | |
| "candidate_retrieval_success": torch.tensor([[1.0, 0.0, 0.0]]), | |
| "candidate_final_disturbance_cost": torch.tensor([[0.0, 0.2, 0.4]]), | |
| "candidate_reocclusion_rate": torch.tensor([[0.0, 0.1, 0.2]]), | |
| "candidate_utility": torch.tensor([[1.0, 0.1, -0.6]]), | |
| } | |
| weights = LossWeights(action=0.0, planner_success=0.2, planner_risk=0.1, planner_ranking=1.0) | |
| losses = compute_total_loss(model_output, batch, weights=weights) | |
| losses["total"].backward() | |
| assert planner_scores.grad is not None | |
| assert success_logits.grad is not None | |
| assert risk_values.grad is not None | |
| assert float(planner_scores.grad.abs().sum()) > 0.0 | |