File size: 1,264 Bytes
e7d8e79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch

from train.losses import LossWeights, compute_total_loss


def test_planner_gradient_flow():
    planner_scores = torch.tensor([[0.2, -0.1, 0.0]], requires_grad=True)
    success_logits = torch.tensor([[0.0, 0.0, 0.0]], requires_grad=True)
    risk_values = torch.tensor([[0.2, 0.2, 0.2]], requires_grad=True)
    model_output = {
        "action_mean": torch.zeros(1, 1, 14),
        "planner_scores": planner_scores,
        "planner_success_logits": success_logits,
        "planner_risk_values": risk_values,
    }
    batch = {
        "action_chunk": torch.zeros(1, 1, 14),
        "candidate_retrieval_success": torch.tensor([[1.0, 0.0, 0.0]]),
        "candidate_final_disturbance_cost": torch.tensor([[0.0, 0.2, 0.4]]),
        "candidate_reocclusion_rate": torch.tensor([[0.0, 0.1, 0.2]]),
        "candidate_utility": torch.tensor([[1.0, 0.1, -0.6]]),
    }
    weights = LossWeights(action=0.0, planner_success=0.2, planner_risk=0.1, planner_ranking=1.0)
    losses = compute_total_loss(model_output, batch, weights=weights)
    losses["total"].backward()
    assert planner_scores.grad is not None
    assert success_logits.grad is not None
    assert risk_values.grad is not None
    assert float(planner_scores.grad.abs().sum()) > 0.0