File size: 1,264 Bytes
e7d8e79 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 | import torch
from train.losses import LossWeights, compute_total_loss
def test_planner_gradient_flow():
planner_scores = torch.tensor([[0.2, -0.1, 0.0]], requires_grad=True)
success_logits = torch.tensor([[0.0, 0.0, 0.0]], requires_grad=True)
risk_values = torch.tensor([[0.2, 0.2, 0.2]], requires_grad=True)
model_output = {
"action_mean": torch.zeros(1, 1, 14),
"planner_scores": planner_scores,
"planner_success_logits": success_logits,
"planner_risk_values": risk_values,
}
batch = {
"action_chunk": torch.zeros(1, 1, 14),
"candidate_retrieval_success": torch.tensor([[1.0, 0.0, 0.0]]),
"candidate_final_disturbance_cost": torch.tensor([[0.0, 0.2, 0.4]]),
"candidate_reocclusion_rate": torch.tensor([[0.0, 0.1, 0.2]]),
"candidate_utility": torch.tensor([[1.0, 0.1, -0.6]]),
}
weights = LossWeights(action=0.0, planner_success=0.2, planner_risk=0.1, planner_ranking=1.0)
losses = compute_total_loss(model_output, batch, weights=weights)
losses["total"].backward()
assert planner_scores.grad is not None
assert success_logits.grad is not None
assert risk_values.grad is not None
assert float(planner_scores.grad.abs().sum()) > 0.0
|