Fix null-rollout world-model ablation
Browse files- .gitignore +4 -0
- README.md +2 -0
- code/reveal_vla_bimanual/models/policy.py +52 -17
- code/reveal_vla_bimanual/train/losses.py +1 -1
- results/phase_tracking.md +2 -2
- tests/test_policy_topk_cascade.py +43 -0
.gitignore
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
.pytest_cache/
|
| 3 |
+
.cache/
|
| 4 |
+
*.pyc
|
README.md
CHANGED
|
@@ -116,6 +116,8 @@ Bundle uploaded from the `/workspace` runpod session dated `2026-03-25 UTC`.
|
|
| 116 |
|
| 117 |
Full artifact roots are indexed in `MODEL_INDEX.md`.
|
| 118 |
|
|
|
|
|
|
|
| 119 |
## Raw Training Summaries
|
| 120 |
|
| 121 |
| Run | Mean train time (s) | Mean peak GPU memory (MB) |
|
|
|
|
| 116 |
|
| 117 |
Full artifact roots are indexed in `MODEL_INDEX.md`.
|
| 118 |
|
| 119 |
+
Note: the stored `stage2 dummy no_world_model` row above was produced before the `2026-03-25` null-rollout ablation fix in `ElasticRevealBimanualPolicy`. The raw artifact is retained unchanged, but it should be rerun before using it as a fair world-model comparison.
|
| 120 |
+
|
| 121 |
## Raw Training Summaries
|
| 122 |
|
| 123 |
| Run | Mean train time (s) | Mean peak GPU memory (MB) |
|
code/reveal_vla_bimanual/models/policy.py
CHANGED
|
@@ -485,6 +485,28 @@ class ElasticRevealBimanualPolicy(BackboneOnlyPolicy):
|
|
| 485 |
tiled[key] = self._tile_tensor(value, num_candidates)
|
| 486 |
return tiled
|
| 487 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 488 |
def forward(
|
| 489 |
self,
|
| 490 |
images: Tensor,
|
|
@@ -570,6 +592,7 @@ class ElasticRevealBimanualPolicy(BackboneOnlyPolicy):
|
|
| 570 |
"reveal_state": elastic_state,
|
| 571 |
"view_summaries": scene_output["view_summaries"],
|
| 572 |
"geometry_summaries": scene_output["geometry_summaries"],
|
|
|
|
| 573 |
}
|
| 574 |
|
| 575 |
candidate_chunks = candidate_chunks_override
|
|
@@ -607,34 +630,45 @@ class ElasticRevealBimanualPolicy(BackboneOnlyPolicy):
|
|
| 607 |
batch_size = candidate_chunks.shape[0]
|
| 608 |
batch_indices = torch.arange(batch_size, device=candidate_chunks.device).unsqueeze(-1)
|
| 609 |
topk_candidates = candidate_chunks[batch_indices, shortlist_indices]
|
|
|
|
| 610 |
outputs["planner_topk_candidates"] = topk_candidates
|
| 611 |
if proposal_logits is not None:
|
| 612 |
topk_proposal_logits = proposal_logits.gather(1, shortlist_indices)
|
| 613 |
else:
|
| 614 |
topk_proposal_logits = None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 615 |
|
| 616 |
if not use_world_model:
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
|
| 621 |
-
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
|
| 625 |
-
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 629 |
return outputs
|
| 630 |
|
| 631 |
-
num_topk = topk_candidates.shape[1]
|
| 632 |
flat_chunks = topk_candidates.view(batch_size * num_topk, topk_candidates.shape[2], topk_candidates.shape[3])
|
| 633 |
tiled_scene = self._tile_tensor(scene_tokens, num_topk)
|
| 634 |
-
planning_state = elastic_state
|
| 635 |
-
if not support_mode_conditioning:
|
| 636 |
-
planning_state = dict(elastic_state)
|
| 637 |
-
planning_state["support_mode_logits"] = torch.zeros_like(elastic_state["support_mode_logits"])
|
| 638 |
tiled_state = self._tile_state(planning_state, num_topk)
|
| 639 |
rollout = self.world_model(
|
| 640 |
scene_tokens=tiled_scene,
|
|
@@ -664,4 +698,5 @@ class ElasticRevealBimanualPolicy(BackboneOnlyPolicy):
|
|
| 664 |
outputs["utility_residual"] = selected["utility_residual"]
|
| 665 |
outputs["utility_total"] = selected["utility_total"]
|
| 666 |
outputs["ranking_diagnostics"] = selected["ranking_diagnostics"]
|
|
|
|
| 667 |
return outputs
|
|
|
|
| 485 |
tiled[key] = self._tile_tensor(value, num_candidates)
|
| 486 |
return tiled
|
| 487 |
|
| 488 |
+
def _detach_state(self, state: dict[str, Tensor]) -> dict[str, Tensor]:
|
| 489 |
+
detached: dict[str, Tensor] = {}
|
| 490 |
+
for key, value in state.items():
|
| 491 |
+
detached[key] = value.detach() if isinstance(value, Tensor) else value
|
| 492 |
+
return detached
|
| 493 |
+
|
| 494 |
+
def _repeat_rollout_tensor(self, value: Tensor, num_candidates: int, horizon: int) -> Tensor:
|
| 495 |
+
value = value.detach()
|
| 496 |
+
return value.unsqueeze(1).unsqueeze(2).expand(-1, num_candidates, horizon, *value.shape[1:])
|
| 497 |
+
|
| 498 |
+
def _identity_rollout(
|
| 499 |
+
self,
|
| 500 |
+
interaction_state: dict[str, Tensor],
|
| 501 |
+
num_candidates: int,
|
| 502 |
+
) -> dict[str, Tensor]:
|
| 503 |
+
horizon = max(1, self.config.world_model.rollout_horizon)
|
| 504 |
+
rollout: dict[str, Tensor] = {}
|
| 505 |
+
for key, value in interaction_state.items():
|
| 506 |
+
if isinstance(value, Tensor):
|
| 507 |
+
rollout[key] = self._repeat_rollout_tensor(value, num_candidates, horizon)
|
| 508 |
+
return rollout
|
| 509 |
+
|
| 510 |
def forward(
|
| 511 |
self,
|
| 512 |
images: Tensor,
|
|
|
|
| 592 |
"reveal_state": elastic_state,
|
| 593 |
"view_summaries": scene_output["view_summaries"],
|
| 594 |
"geometry_summaries": scene_output["geometry_summaries"],
|
| 595 |
+
"rollout_source": "none",
|
| 596 |
}
|
| 597 |
|
| 598 |
candidate_chunks = candidate_chunks_override
|
|
|
|
| 630 |
batch_size = candidate_chunks.shape[0]
|
| 631 |
batch_indices = torch.arange(batch_size, device=candidate_chunks.device).unsqueeze(-1)
|
| 632 |
topk_candidates = candidate_chunks[batch_indices, shortlist_indices]
|
| 633 |
+
num_topk = topk_candidates.shape[1]
|
| 634 |
outputs["planner_topk_candidates"] = topk_candidates
|
| 635 |
if proposal_logits is not None:
|
| 636 |
topk_proposal_logits = proposal_logits.gather(1, shortlist_indices)
|
| 637 |
else:
|
| 638 |
topk_proposal_logits = None
|
| 639 |
+
planning_state = elastic_state
|
| 640 |
+
if not support_mode_conditioning:
|
| 641 |
+
planning_state = dict(elastic_state)
|
| 642 |
+
planning_state["support_mode_logits"] = torch.zeros_like(elastic_state["support_mode_logits"])
|
| 643 |
|
| 644 |
if not use_world_model:
|
| 645 |
+
detached_state = self._detach_state(planning_state)
|
| 646 |
+
identity_rollout = self._identity_rollout(
|
| 647 |
+
interaction_state=detached_state,
|
| 648 |
+
num_candidates=num_topk,
|
| 649 |
+
)
|
| 650 |
+
selected = self.planner.select_best(
|
| 651 |
+
initial_state=detached_state,
|
| 652 |
+
candidate_chunks=topk_candidates,
|
| 653 |
+
rollout_state=identity_rollout,
|
| 654 |
+
proposal_logits=topk_proposal_logits,
|
| 655 |
+
candidate_indices=shortlist_indices,
|
| 656 |
+
)
|
| 657 |
+
outputs["planned_rollout"] = identity_rollout
|
| 658 |
+
outputs["planned_chunk"] = selected["best_chunk"]
|
| 659 |
+
outputs["planner_success_logits"] = selected["success_logits"]
|
| 660 |
+
outputs["planner_risk_values"] = selected["risk_values"]
|
| 661 |
+
outputs["planner_scores"] = selected["utility_total"]
|
| 662 |
+
outputs["best_candidate_indices"] = selected["best_indices"]
|
| 663 |
+
outputs["utility_structured"] = selected["utility_structured"]
|
| 664 |
+
outputs["utility_residual"] = selected["utility_residual"]
|
| 665 |
+
outputs["utility_total"] = selected["utility_total"]
|
| 666 |
+
outputs["ranking_diagnostics"] = selected["ranking_diagnostics"]
|
| 667 |
+
outputs["rollout_source"] = "identity"
|
| 668 |
return outputs
|
| 669 |
|
|
|
|
| 670 |
flat_chunks = topk_candidates.view(batch_size * num_topk, topk_candidates.shape[2], topk_candidates.shape[3])
|
| 671 |
tiled_scene = self._tile_tensor(scene_tokens, num_topk)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 672 |
tiled_state = self._tile_state(planning_state, num_topk)
|
| 673 |
rollout = self.world_model(
|
| 674 |
scene_tokens=tiled_scene,
|
|
|
|
| 698 |
outputs["utility_residual"] = selected["utility_residual"]
|
| 699 |
outputs["utility_total"] = selected["utility_total"]
|
| 700 |
outputs["ranking_diagnostics"] = selected["ranking_diagnostics"]
|
| 701 |
+
outputs["rollout_source"] = "learned"
|
| 702 |
return outputs
|
code/reveal_vla_bimanual/train/losses.py
CHANGED
|
@@ -303,7 +303,7 @@ def compute_total_loss(
|
|
| 303 |
+ 0.01 * reveal_losses["uncertainty"]
|
| 304 |
)
|
| 305 |
|
| 306 |
-
if model_output.get("planned_rollout") and (
|
| 307 |
"candidate_rollout_support_mode" in batch or "rollout_support_mode" in batch
|
| 308 |
):
|
| 309 |
if "candidate_rollout_support_mode" in batch:
|
|
|
|
| 303 |
+ 0.01 * reveal_losses["uncertainty"]
|
| 304 |
)
|
| 305 |
|
| 306 |
+
if model_output.get("planned_rollout") and model_output.get("rollout_source", "learned") == "learned" and (
|
| 307 |
"candidate_rollout_support_mode" in batch or "rollout_support_mode" in batch
|
| 308 |
):
|
| 309 |
if "candidate_rollout_support_mode" in batch:
|
results/phase_tracking.md
CHANGED
|
@@ -83,10 +83,10 @@ Date closed: `2026-03-25 UTC`
|
|
| 83 |
- `short_history`: `0.5463` mean success, delta `0.0000`
|
| 84 |
- Gate decisions:
|
| 85 |
- hard success gate `>= 0.60`: fail
|
| 86 |
-
- `no_world_model` must hurt:
|
| 87 |
- full memory must stop losing to short history: hard gate passes narrowly because full equals short-history; preferred gate fails because full does not beat short-history
|
| 88 |
- state metrics should improve over phase 1: fail, reocclusion rate increased (`0.0000 -> 0.0121`), persistence MAE worsened (`1.9553 -> 2.2358`), and calibration worsened
|
| 89 |
-
- Takeaway: the expanded state/memory path did not validate on the dummy proxy benchmark. Planner classification improved, but
|
| 90 |
|
| 91 |
## Phase 3
|
| 92 |
|
|
|
|
| 83 |
- `short_history`: `0.5463` mean success, delta `0.0000`
|
| 84 |
- Gate decisions:
|
| 85 |
- hard success gate `>= 0.60`: fail
|
| 86 |
+
- `no_world_model` must hurt: not interpretable from the stored artifact alone; the recorded `no_world_model` run predates the `2026-03-25` null-rollout ablation fix and should be rerun for a fair comparison
|
| 87 |
- full memory must stop losing to short history: hard gate passes narrowly because full equals short-history; preferred gate fails because full does not beat short-history
|
| 88 |
- state metrics should improve over phase 1: fail, reocclusion rate increased (`0.0000 -> 0.0121`), persistence MAE worsened (`1.9553 -> 2.2358`), and calibration worsened
|
| 89 |
+
- Takeaway: the expanded state/memory path did not validate on the dummy proxy benchmark. Planner classification improved, but the world-model ablation needs a post-fix rerun before it can be interpreted fairly.
|
| 90 |
|
| 91 |
## Phase 3
|
| 92 |
|
tests/test_policy_topk_cascade.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
from train.trainer import build_policy
|
| 2 |
|
| 3 |
|
|
@@ -23,3 +25,44 @@ def test_policy_topk_cascade(tiny_policy_config, tiny_trainer_config, tiny_batch
|
|
| 23 |
assert output["planner_topk_indices"].shape[1] == config.planner.top_k
|
| 24 |
assert output["planned_rollout"]["target_belief_field"].shape[1] == config.planner.top_k
|
| 25 |
assert (output["best_candidate_indices"] < config.decoder.num_candidates).all()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
from train.trainer import build_policy
|
| 4 |
|
| 5 |
|
|
|
|
| 25 |
assert output["planner_topk_indices"].shape[1] == config.planner.top_k
|
| 26 |
assert output["planned_rollout"]["target_belief_field"].shape[1] == config.planner.top_k
|
| 27 |
assert (output["best_candidate_indices"] < config.decoder.num_candidates).all()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def test_policy_null_rollout_ablation_keeps_planner_interface(
|
| 31 |
+
tiny_policy_config,
|
| 32 |
+
tiny_trainer_config,
|
| 33 |
+
tiny_batch,
|
| 34 |
+
):
|
| 35 |
+
config = tiny_policy_config(num_candidates=4, top_k=2)
|
| 36 |
+
batch = tiny_batch(chunk_size=config.decoder.chunk_size)
|
| 37 |
+
policy = build_policy(config, tiny_trainer_config(policy_type="elastic_reveal"))
|
| 38 |
+
output = policy(
|
| 39 |
+
images=batch["images"],
|
| 40 |
+
depths=batch["depths"],
|
| 41 |
+
depth_valid=batch["depth_valid"],
|
| 42 |
+
camera_intrinsics=batch["camera_intrinsics"],
|
| 43 |
+
camera_extrinsics=batch["camera_extrinsics"],
|
| 44 |
+
proprio=batch["proprio"],
|
| 45 |
+
texts=batch["texts"],
|
| 46 |
+
history_images=batch["history_images"],
|
| 47 |
+
history_depths=batch["history_depths"],
|
| 48 |
+
history_depth_valid=batch["history_depth_valid"],
|
| 49 |
+
history_proprio=batch["history_proprio"],
|
| 50 |
+
history_actions=batch["history_actions"],
|
| 51 |
+
plan=True,
|
| 52 |
+
use_world_model=False,
|
| 53 |
+
use_planner=True,
|
| 54 |
+
)
|
| 55 |
+
rollout = output["planned_rollout"]
|
| 56 |
+
current_state = output["interaction_state"]
|
| 57 |
+
assert output["rollout_source"] == "identity"
|
| 58 |
+
assert output["planner_topk_indices"].shape[1] == config.planner.top_k
|
| 59 |
+
assert rollout["target_belief_field"].shape[1] == config.planner.top_k
|
| 60 |
+
repeated_belief = current_state["target_belief_field"].detach().unsqueeze(1).unsqueeze(2).expand_as(
|
| 61 |
+
rollout["target_belief_field"]
|
| 62 |
+
)
|
| 63 |
+
repeated_phase = current_state["phase_logits"].detach().unsqueeze(1).unsqueeze(2).expand_as(
|
| 64 |
+
rollout["phase_logits"]
|
| 65 |
+
)
|
| 66 |
+
assert output["utility_total"].shape == (batch["images"].shape[0], config.planner.top_k)
|
| 67 |
+
assert torch.allclose(rollout["target_belief_field"], repeated_belief)
|
| 68 |
+
assert torch.allclose(rollout["phase_logits"], repeated_phase)
|