import torch from train.losses import LossWeights, compute_total_loss def _model_output(planner_scores: torch.Tensor) -> dict[str, torch.Tensor]: shape = planner_scores.shape zeros = torch.zeros(shape, dtype=planner_scores.dtype, device=planner_scores.device) return { "action_mean": torch.zeros(shape[0], 1, 14, dtype=planner_scores.dtype, device=planner_scores.device), "planner_scores": planner_scores, "planner_success_logits": zeros, "planner_risk_values": zeros, } def test_candidate_ranking_loss_prefers_oracle_order(): batch = { "action_chunk": torch.zeros(1, 1, 14), "candidate_retrieval_success": torch.tensor([[1.0, 0.5, 0.0]]), "candidate_final_disturbance_cost": torch.tensor([[0.0, 0.1, 0.2]]), "candidate_reocclusion_rate": torch.tensor([[0.0, 0.1, 0.3]]), "candidate_utility": torch.tensor([[1.0, 0.4, -0.5]]), } weights = LossWeights(action=0.0, planner_success=0.0, planner_risk=0.0, planner_ranking=1.0) aligned = compute_total_loss(_model_output(torch.tensor([[2.0, 0.5, -1.0]])), batch, weights=weights) reversed_order = compute_total_loss(_model_output(torch.tensor([[-1.0, 0.5, 2.0]])), batch, weights=weights) assert float(aligned["planner_ranking"]) < float(reversed_order["planner_ranking"]) def _proposal_model_output(proposal_logits: torch.Tensor) -> dict[str, torch.Tensor]: return { "action_mean": torch.zeros(proposal_logits.shape[0], 1, 14, dtype=proposal_logits.dtype, device=proposal_logits.device), "proposal_logits": proposal_logits, "proposal_mode_logits": torch.zeros(proposal_logits.shape[0], 2, dtype=proposal_logits.dtype, device=proposal_logits.device), "proposal_mode_assignments": torch.tensor([0, 1] * ((proposal_logits.shape[1] + 1) // 2), dtype=torch.long, device=proposal_logits.device)[: proposal_logits.shape[1]], "proposal_candidates": torch.zeros( proposal_logits.shape[0], proposal_logits.shape[1], 1, 14, dtype=proposal_logits.dtype, device=proposal_logits.device, ), } def test_proposal_ranking_uses_aligned_targets_when_present(): batch = { "action_chunk": torch.zeros(1, 1, 14), "candidate_action_chunks": torch.zeros(1, 3, 1, 14), "candidate_retrieval_success": torch.tensor([[1.0, 0.0, 0.0]]), "candidate_utility": torch.tensor([[1.0, -0.2, -0.4]]), "proposal_target_action_chunks": torch.zeros(1, 3, 1, 14), "proposal_target_retrieval_success": torch.tensor([[0.0, 1.0, 0.0]]), "proposal_target_utility": torch.tensor([[-0.3, 1.0, -0.1]]), "proposal_target_risk": torch.tensor([[0.8, 0.0, 0.2]]), } weights = LossWeights(action=0.0, proposal_reconstruction=0.0, proposal_success=0.0, proposal_ranking=1.0) aligned = compute_total_loss(_proposal_model_output(torch.tensor([[0.0, 2.0, -1.0]])), batch, weights=weights) reversed_order = compute_total_loss(_proposal_model_output(torch.tensor([[2.0, -1.0, 0.0]])), batch, weights=weights) assert float(aligned["proposal_ranking"]) < float(reversed_order["proposal_ranking"]) def test_proposal_reconstruction_uses_order_invariant_teacher_family(): proposal_candidates = torch.zeros(1, 2, 1, 14) proposal_candidates[0, 0, 0, 0] = 1.0 proposal_candidates[0, 1, 0, 1] = 1.0 batch = { "action_chunk": torch.zeros(1, 1, 14), "candidate_action_chunks": proposal_candidates.flip(1).clone(), "proposal_target_action_chunks": torch.full_like(proposal_candidates, 5.0), } model_output = _proposal_model_output(torch.zeros(1, 2)) model_output["proposal_candidates"] = proposal_candidates weights = LossWeights(action=0.0, proposal_reconstruction=1.0, proposal_success=0.0, proposal_ranking=0.0, proposal_diversity=0.0) losses = compute_total_loss(model_output, batch, weights=weights) assert float(losses["proposal_reconstruction"]) < 1e-6 def test_proposal_reconstruction_prefers_high_utility_teacher_subset(): proposal_candidates = torch.zeros(1, 4, 1, 14) proposal_candidates[0, 0, 0, 0] = 1.0 proposal_candidates[0, 1, 0, 1] = 1.0 proposal_candidates[0, 2, 0, 0] = 1.0 proposal_candidates[0, 3, 0, 1] = 1.0 teacher_candidates = torch.zeros(1, 4, 1, 14) teacher_candidates[0, 0, 0, 1] = 1.0 teacher_candidates[0, 1, 0, 0] = 1.0 teacher_candidates[0, 2, 0, 2] = 5.0 teacher_candidates[0, 3, 0, 3] = 5.0 batch = { "action_chunk": torch.zeros(1, 1, 14), "candidate_action_chunks": teacher_candidates, "candidate_utility": torch.tensor([[1.0, 0.9, -1.0, -2.0]]), } model_output = _proposal_model_output(torch.zeros(1, 4)) model_output["proposal_candidates"] = proposal_candidates weights = LossWeights(action=0.0, proposal_reconstruction=1.0, proposal_success=0.0, proposal_ranking=0.0, proposal_diversity=0.0) losses = compute_total_loss(model_output, batch, weights=weights) assert float(losses["proposal_reconstruction"]) < 1e-6 def test_bag_proposal_reconstruction_anchors_to_fallback_targets(): proposal_candidates = torch.zeros(1, 4, 1, 14) proposal_candidates[0, 0, 0, 0] = 1.0 proposal_candidates[0, 1, 0, 1] = 1.0 proposal_candidates[0, 2, 0, 0] = 1.0 proposal_candidates[0, 3, 0, 1] = 1.0 teacher_candidates = torch.zeros(1, 4, 1, 14) teacher_candidates[0, 0, 0, 2] = 5.0 teacher_candidates[0, 1, 0, 3] = 5.0 teacher_candidates[0, 2, 0, 4] = 5.0 teacher_candidates[0, 3, 0, 5] = 5.0 batch = { "action_chunk": torch.zeros(1, 1, 14), "task_name": ["bag"], "candidate_action_chunks": teacher_candidates, "candidate_utility": torch.tensor([[1.0, 0.9, -1.0, -2.0]]), "proposal_target_action_chunks": proposal_candidates.clone(), } model_output = _proposal_model_output(torch.zeros(1, 4)) model_output["proposal_candidates"] = proposal_candidates weights = LossWeights(action=0.0, proposal_reconstruction=1.0, proposal_success=0.0, proposal_ranking=0.0, proposal_diversity=0.0) losses = compute_total_loss(model_output, batch, weights=weights) assert float(losses["proposal_reconstruction"]) < 1e-6 def test_proposal_mode_loss_prefers_mode_with_highest_utility(): batch = { "action_chunk": torch.zeros(1, 1, 14), "proposal_target_retrieval_success": torch.tensor([[1.0, 0.0, 1.0, 0.0]]), "proposal_target_utility": torch.tensor([[0.9, -0.4, 0.8, -0.3]]), "proposal_target_risk": torch.tensor([[0.0, 0.8, 0.1, 0.7]]), } weights = LossWeights( action=0.0, proposal_reconstruction=0.0, proposal_success=0.0, proposal_ranking=0.0, proposal_mode=1.0, proposal_diversity=0.0, ) aligned = _proposal_model_output(torch.zeros(1, 4)) aligned["proposal_mode_logits"] = torch.tensor([[2.0, -1.0]]) aligned["proposal_mode_assignments"] = torch.tensor([0, 1, 0, 1], dtype=torch.long) reversed_order = _proposal_model_output(torch.zeros(1, 4)) reversed_order["proposal_mode_logits"] = torch.tensor([[-1.0, 2.0]]) reversed_order["proposal_mode_assignments"] = torch.tensor([0, 1, 0, 1], dtype=torch.long) aligned_losses = compute_total_loss(aligned, batch, weights=weights) reversed_losses = compute_total_loss(reversed_order, batch, weights=weights) assert float(aligned_losses["proposal_mode"]) < float(reversed_losses["proposal_mode"]) def test_proposal_mode_loss_can_focus_on_cloth_only(): batch = { "action_chunk": torch.zeros(2, 1, 14), "task_name": ["cloth", "foliage"], "proposal_target_retrieval_success": torch.tensor([[1.0, 0.0, 1.0, 0.0], [1.0, 0.0, 1.0, 0.0]]), "proposal_target_utility": torch.tensor([[0.9, -0.4, 0.8, -0.3], [0.9, -0.4, 0.8, -0.3]]), "proposal_target_risk": torch.tensor([[0.0, 0.8, 0.1, 0.7], [0.0, 0.8, 0.1, 0.7]]), } weights = LossWeights( action=0.0, proposal_reconstruction=0.0, proposal_success=0.0, proposal_ranking=0.0, proposal_mode=1.0, proposal_mode_cloth_only=True, proposal_diversity=0.0, ) aligned = _proposal_model_output(torch.zeros(2, 4)) aligned["proposal_mode_logits"] = torch.tensor([[2.0, -1.0], [2.0, -1.0]]) aligned["proposal_mode_assignments"] = torch.tensor([0, 1, 0, 1], dtype=torch.long) foliage_reversed = _proposal_model_output(torch.zeros(2, 4)) foliage_reversed["proposal_mode_logits"] = torch.tensor([[2.0, -1.0], [-1.0, 2.0]]) foliage_reversed["proposal_mode_assignments"] = torch.tensor([0, 1, 0, 1], dtype=torch.long) aligned_losses = compute_total_loss(aligned, batch, weights=weights) foliage_reversed_losses = compute_total_loss(foliage_reversed, batch, weights=weights) assert torch.isclose(aligned_losses["proposal_mode"], foliage_reversed_losses["proposal_mode"]) def test_proposal_mode_loss_can_focus_on_selected_tasks(): batch = { "action_chunk": torch.zeros(2, 1, 14), "task_name": ["bag", "foliage"], "proposal_target_retrieval_success": torch.tensor([[1.0, 0.0, 1.0, 0.0], [1.0, 0.0, 1.0, 0.0]]), "proposal_target_utility": torch.tensor([[0.9, -0.4, 0.8, -0.3], [0.9, -0.4, 0.8, -0.3]]), "proposal_target_risk": torch.tensor([[0.0, 0.8, 0.1, 0.7], [0.0, 0.8, 0.1, 0.7]]), } weights = LossWeights( action=0.0, proposal_reconstruction=0.0, proposal_success=0.0, proposal_ranking=0.0, proposal_mode=1.0, proposal_mode_task_filter=["bag"], proposal_diversity=0.0, ) aligned = _proposal_model_output(torch.zeros(2, 4)) aligned["proposal_mode_logits"] = torch.tensor([[2.0, -1.0], [2.0, -1.0]]) aligned["proposal_mode_assignments"] = torch.tensor([0, 1, 0, 1], dtype=torch.long) foliage_reversed = _proposal_model_output(torch.zeros(2, 4)) foliage_reversed["proposal_mode_logits"] = torch.tensor([[2.0, -1.0], [-1.0, 2.0]]) foliage_reversed["proposal_mode_assignments"] = torch.tensor([0, 1, 0, 1], dtype=torch.long) aligned_losses = compute_total_loss(aligned, batch, weights=weights) foliage_reversed_losses = compute_total_loss(foliage_reversed, batch, weights=weights) assert torch.isclose(aligned_losses["proposal_mode"], foliage_reversed_losses["proposal_mode"])