lsnu commited on
Commit
a0b57b7
·
verified ·
1 Parent(s): 380eb78

Fix null-rollout world-model ablation

Browse files
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ __pycache__/
2
+ .pytest_cache/
3
+ .cache/
4
+ *.pyc
README.md CHANGED
@@ -116,6 +116,8 @@ Bundle uploaded from the `/workspace` runpod session dated `2026-03-25 UTC`.
116
 
117
  Full artifact roots are indexed in `MODEL_INDEX.md`.
118
 
 
 
119
  ## Raw Training Summaries
120
 
121
  | Run | Mean train time (s) | Mean peak GPU memory (MB) |
 
116
 
117
  Full artifact roots are indexed in `MODEL_INDEX.md`.
118
 
119
+ Note: the stored `stage2 dummy no_world_model` row above was produced before the `2026-03-25` null-rollout ablation fix in `ElasticRevealBimanualPolicy`. The raw artifact is retained unchanged, but it should be rerun before using it as a fair world-model comparison.
120
+
121
  ## Raw Training Summaries
122
 
123
  | Run | Mean train time (s) | Mean peak GPU memory (MB) |
code/reveal_vla_bimanual/models/policy.py CHANGED
@@ -485,6 +485,28 @@ class ElasticRevealBimanualPolicy(BackboneOnlyPolicy):
485
  tiled[key] = self._tile_tensor(value, num_candidates)
486
  return tiled
487
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
488
  def forward(
489
  self,
490
  images: Tensor,
@@ -570,6 +592,7 @@ class ElasticRevealBimanualPolicy(BackboneOnlyPolicy):
570
  "reveal_state": elastic_state,
571
  "view_summaries": scene_output["view_summaries"],
572
  "geometry_summaries": scene_output["geometry_summaries"],
 
573
  }
574
 
575
  candidate_chunks = candidate_chunks_override
@@ -607,34 +630,45 @@ class ElasticRevealBimanualPolicy(BackboneOnlyPolicy):
607
  batch_size = candidate_chunks.shape[0]
608
  batch_indices = torch.arange(batch_size, device=candidate_chunks.device).unsqueeze(-1)
609
  topk_candidates = candidate_chunks[batch_indices, shortlist_indices]
 
610
  outputs["planner_topk_candidates"] = topk_candidates
611
  if proposal_logits is not None:
612
  topk_proposal_logits = proposal_logits.gather(1, shortlist_indices)
613
  else:
614
  topk_proposal_logits = None
 
 
 
 
615
 
616
  if not use_world_model:
617
- score_source = topk_proposal_logits if topk_proposal_logits is not None else -topk_candidates.square().mean(dim=(-1, -2))
618
- best_local = score_source.argmax(dim=-1)
619
- best_indices = shortlist_indices[torch.arange(batch_size, device=best_local.device), best_local]
620
- outputs["planned_chunk"] = candidate_chunks[torch.arange(batch_size, device=best_local.device), best_indices]
621
- outputs["planned_rollout"] = {}
622
- outputs["planner_success_logits"] = torch.zeros_like(score_source)
623
- outputs["planner_risk_values"] = torch.zeros_like(score_source)
624
- outputs["planner_scores"] = score_source
625
- outputs["best_candidate_indices"] = best_indices
626
- outputs["utility_structured"] = score_source
627
- outputs["utility_residual"] = torch.zeros_like(score_source)
628
- outputs["utility_total"] = score_source
 
 
 
 
 
 
 
 
 
 
 
629
  return outputs
630
 
631
- num_topk = topk_candidates.shape[1]
632
  flat_chunks = topk_candidates.view(batch_size * num_topk, topk_candidates.shape[2], topk_candidates.shape[3])
633
  tiled_scene = self._tile_tensor(scene_tokens, num_topk)
634
- planning_state = elastic_state
635
- if not support_mode_conditioning:
636
- planning_state = dict(elastic_state)
637
- planning_state["support_mode_logits"] = torch.zeros_like(elastic_state["support_mode_logits"])
638
  tiled_state = self._tile_state(planning_state, num_topk)
639
  rollout = self.world_model(
640
  scene_tokens=tiled_scene,
@@ -664,4 +698,5 @@ class ElasticRevealBimanualPolicy(BackboneOnlyPolicy):
664
  outputs["utility_residual"] = selected["utility_residual"]
665
  outputs["utility_total"] = selected["utility_total"]
666
  outputs["ranking_diagnostics"] = selected["ranking_diagnostics"]
 
667
  return outputs
 
485
  tiled[key] = self._tile_tensor(value, num_candidates)
486
  return tiled
487
 
488
+ def _detach_state(self, state: dict[str, Tensor]) -> dict[str, Tensor]:
489
+ detached: dict[str, Tensor] = {}
490
+ for key, value in state.items():
491
+ detached[key] = value.detach() if isinstance(value, Tensor) else value
492
+ return detached
493
+
494
+ def _repeat_rollout_tensor(self, value: Tensor, num_candidates: int, horizon: int) -> Tensor:
495
+ value = value.detach()
496
+ return value.unsqueeze(1).unsqueeze(2).expand(-1, num_candidates, horizon, *value.shape[1:])
497
+
498
+ def _identity_rollout(
499
+ self,
500
+ interaction_state: dict[str, Tensor],
501
+ num_candidates: int,
502
+ ) -> dict[str, Tensor]:
503
+ horizon = max(1, self.config.world_model.rollout_horizon)
504
+ rollout: dict[str, Tensor] = {}
505
+ for key, value in interaction_state.items():
506
+ if isinstance(value, Tensor):
507
+ rollout[key] = self._repeat_rollout_tensor(value, num_candidates, horizon)
508
+ return rollout
509
+
510
  def forward(
511
  self,
512
  images: Tensor,
 
592
  "reveal_state": elastic_state,
593
  "view_summaries": scene_output["view_summaries"],
594
  "geometry_summaries": scene_output["geometry_summaries"],
595
+ "rollout_source": "none",
596
  }
597
 
598
  candidate_chunks = candidate_chunks_override
 
630
  batch_size = candidate_chunks.shape[0]
631
  batch_indices = torch.arange(batch_size, device=candidate_chunks.device).unsqueeze(-1)
632
  topk_candidates = candidate_chunks[batch_indices, shortlist_indices]
633
+ num_topk = topk_candidates.shape[1]
634
  outputs["planner_topk_candidates"] = topk_candidates
635
  if proposal_logits is not None:
636
  topk_proposal_logits = proposal_logits.gather(1, shortlist_indices)
637
  else:
638
  topk_proposal_logits = None
639
+ planning_state = elastic_state
640
+ if not support_mode_conditioning:
641
+ planning_state = dict(elastic_state)
642
+ planning_state["support_mode_logits"] = torch.zeros_like(elastic_state["support_mode_logits"])
643
 
644
  if not use_world_model:
645
+ detached_state = self._detach_state(planning_state)
646
+ identity_rollout = self._identity_rollout(
647
+ interaction_state=detached_state,
648
+ num_candidates=num_topk,
649
+ )
650
+ selected = self.planner.select_best(
651
+ initial_state=detached_state,
652
+ candidate_chunks=topk_candidates,
653
+ rollout_state=identity_rollout,
654
+ proposal_logits=topk_proposal_logits,
655
+ candidate_indices=shortlist_indices,
656
+ )
657
+ outputs["planned_rollout"] = identity_rollout
658
+ outputs["planned_chunk"] = selected["best_chunk"]
659
+ outputs["planner_success_logits"] = selected["success_logits"]
660
+ outputs["planner_risk_values"] = selected["risk_values"]
661
+ outputs["planner_scores"] = selected["utility_total"]
662
+ outputs["best_candidate_indices"] = selected["best_indices"]
663
+ outputs["utility_structured"] = selected["utility_structured"]
664
+ outputs["utility_residual"] = selected["utility_residual"]
665
+ outputs["utility_total"] = selected["utility_total"]
666
+ outputs["ranking_diagnostics"] = selected["ranking_diagnostics"]
667
+ outputs["rollout_source"] = "identity"
668
  return outputs
669
 
 
670
  flat_chunks = topk_candidates.view(batch_size * num_topk, topk_candidates.shape[2], topk_candidates.shape[3])
671
  tiled_scene = self._tile_tensor(scene_tokens, num_topk)
 
 
 
 
672
  tiled_state = self._tile_state(planning_state, num_topk)
673
  rollout = self.world_model(
674
  scene_tokens=tiled_scene,
 
698
  outputs["utility_residual"] = selected["utility_residual"]
699
  outputs["utility_total"] = selected["utility_total"]
700
  outputs["ranking_diagnostics"] = selected["ranking_diagnostics"]
701
+ outputs["rollout_source"] = "learned"
702
  return outputs
code/reveal_vla_bimanual/train/losses.py CHANGED
@@ -303,7 +303,7 @@ def compute_total_loss(
303
  + 0.01 * reveal_losses["uncertainty"]
304
  )
305
 
306
- if model_output.get("planned_rollout") and (
307
  "candidate_rollout_support_mode" in batch or "rollout_support_mode" in batch
308
  ):
309
  if "candidate_rollout_support_mode" in batch:
 
303
  + 0.01 * reveal_losses["uncertainty"]
304
  )
305
 
306
+ if model_output.get("planned_rollout") and model_output.get("rollout_source", "learned") == "learned" and (
307
  "candidate_rollout_support_mode" in batch or "rollout_support_mode" in batch
308
  ):
309
  if "candidate_rollout_support_mode" in batch:
results/phase_tracking.md CHANGED
@@ -83,10 +83,10 @@ Date closed: `2026-03-25 UTC`
83
  - `short_history`: `0.5463` mean success, delta `0.0000`
84
  - Gate decisions:
85
  - hard success gate `>= 0.60`: fail
86
- - `no_world_model` must hurt: fail, no success drop and no persuasive secondary metric degradation
87
  - full memory must stop losing to short history: hard gate passes narrowly because full equals short-history; preferred gate fails because full does not beat short-history
88
  - state metrics should improve over phase 1: fail, reocclusion rate increased (`0.0000 -> 0.0121`), persistence MAE worsened (`1.9553 -> 2.2358`), and calibration worsened
89
- - Takeaway: the expanded state/memory path did not validate on the dummy proxy benchmark. Planner classification improved, but task success and state quality did not.
90
 
91
  ## Phase 3
92
 
 
83
  - `short_history`: `0.5463` mean success, delta `0.0000`
84
  - Gate decisions:
85
  - hard success gate `>= 0.60`: fail
86
+ - `no_world_model` must hurt: not interpretable from the stored artifact alone; the recorded `no_world_model` run predates the `2026-03-25` null-rollout ablation fix and should be rerun for a fair comparison
87
  - full memory must stop losing to short history: hard gate passes narrowly because full equals short-history; preferred gate fails because full does not beat short-history
88
  - state metrics should improve over phase 1: fail, reocclusion rate increased (`0.0000 -> 0.0121`), persistence MAE worsened (`1.9553 -> 2.2358`), and calibration worsened
89
+ - Takeaway: the expanded state/memory path did not validate on the dummy proxy benchmark. Planner classification improved, but the world-model ablation needs a post-fix rerun before it can be interpreted fairly.
90
 
91
  ## Phase 3
92
 
tests/test_policy_topk_cascade.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  from train.trainer import build_policy
2
 
3
 
@@ -23,3 +25,44 @@ def test_policy_topk_cascade(tiny_policy_config, tiny_trainer_config, tiny_batch
23
  assert output["planner_topk_indices"].shape[1] == config.planner.top_k
24
  assert output["planned_rollout"]["target_belief_field"].shape[1] == config.planner.top_k
25
  assert (output["best_candidate_indices"] < config.decoder.num_candidates).all()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
  from train.trainer import build_policy
4
 
5
 
 
25
  assert output["planner_topk_indices"].shape[1] == config.planner.top_k
26
  assert output["planned_rollout"]["target_belief_field"].shape[1] == config.planner.top_k
27
  assert (output["best_candidate_indices"] < config.decoder.num_candidates).all()
28
+
29
+
30
+ def test_policy_null_rollout_ablation_keeps_planner_interface(
31
+ tiny_policy_config,
32
+ tiny_trainer_config,
33
+ tiny_batch,
34
+ ):
35
+ config = tiny_policy_config(num_candidates=4, top_k=2)
36
+ batch = tiny_batch(chunk_size=config.decoder.chunk_size)
37
+ policy = build_policy(config, tiny_trainer_config(policy_type="elastic_reveal"))
38
+ output = policy(
39
+ images=batch["images"],
40
+ depths=batch["depths"],
41
+ depth_valid=batch["depth_valid"],
42
+ camera_intrinsics=batch["camera_intrinsics"],
43
+ camera_extrinsics=batch["camera_extrinsics"],
44
+ proprio=batch["proprio"],
45
+ texts=batch["texts"],
46
+ history_images=batch["history_images"],
47
+ history_depths=batch["history_depths"],
48
+ history_depth_valid=batch["history_depth_valid"],
49
+ history_proprio=batch["history_proprio"],
50
+ history_actions=batch["history_actions"],
51
+ plan=True,
52
+ use_world_model=False,
53
+ use_planner=True,
54
+ )
55
+ rollout = output["planned_rollout"]
56
+ current_state = output["interaction_state"]
57
+ assert output["rollout_source"] == "identity"
58
+ assert output["planner_topk_indices"].shape[1] == config.planner.top_k
59
+ assert rollout["target_belief_field"].shape[1] == config.planner.top_k
60
+ repeated_belief = current_state["target_belief_field"].detach().unsqueeze(1).unsqueeze(2).expand_as(
61
+ rollout["target_belief_field"]
62
+ )
63
+ repeated_phase = current_state["phase_logits"].detach().unsqueeze(1).unsqueeze(2).expand_as(
64
+ rollout["phase_logits"]
65
+ )
66
+ assert output["utility_total"].shape == (batch["images"].shape[0], config.planner.top_k)
67
+ assert torch.allclose(rollout["target_belief_field"], repeated_belief)
68
+ assert torch.allclose(rollout["phase_logits"], repeated_phase)