diff --git a/code/reveal_vla_bimanual/models/__init__.py b/code/reveal_vla_bimanual/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4f85c00caf39a513251a11d813a0bf17343b2dc2 --- /dev/null +++ b/code/reveal_vla_bimanual/models/__init__.py @@ -0,0 +1,24 @@ +from models.action_decoder import ACTBimanualChunkDecoder, ChunkDecoderConfig +from models.backbones import FrozenVLBackbone, FrozenVLBackboneConfig +from models.multiview_fusion import MultiViewFusion, MultiViewFusionConfig +from models.planner import PlannerConfig, RevealPlanner +from models.policy import BackboneOnlyPolicy, RevealBimanualPolicy +from models.reveal_head import RevealHeadConfig, RevealStateHead +from models.world_model import RevealWM, RevealWMConfig + +__all__ = [ + "ACTBimanualChunkDecoder", + "BackboneOnlyPolicy", + "ChunkDecoderConfig", + "FrozenVLBackbone", + "FrozenVLBackboneConfig", + "MultiViewFusion", + "MultiViewFusionConfig", + "PlannerConfig", + "RevealBimanualPolicy", + "RevealHeadConfig", + "RevealPlanner", + "RevealStateHead", + "RevealWM", + "RevealWMConfig", +] diff --git a/code/reveal_vla_bimanual/models/__pycache__/__init__.cpython-310.pyc b/code/reveal_vla_bimanual/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa8efaaf51a1e5c75903734e4c69f76c16b0840c Binary files /dev/null and b/code/reveal_vla_bimanual/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/code/reveal_vla_bimanual/models/__pycache__/__init__.cpython-311.pyc b/code/reveal_vla_bimanual/models/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2fb3dde09812584bab64b143d2e15f0d91c2e576 Binary files /dev/null and b/code/reveal_vla_bimanual/models/__pycache__/__init__.cpython-311.pyc differ diff --git a/code/reveal_vla_bimanual/models/__pycache__/action_decoder.cpython-310.pyc b/code/reveal_vla_bimanual/models/__pycache__/action_decoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e1c39a20c85e381f939876843fff7f0a1b1d5d0 Binary files /dev/null and b/code/reveal_vla_bimanual/models/__pycache__/action_decoder.cpython-310.pyc differ diff --git a/code/reveal_vla_bimanual/models/__pycache__/action_decoder.cpython-311.pyc b/code/reveal_vla_bimanual/models/__pycache__/action_decoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20319c73218869dc2d75b0d151c4c2e86964d921 Binary files /dev/null and b/code/reveal_vla_bimanual/models/__pycache__/action_decoder.cpython-311.pyc differ diff --git a/code/reveal_vla_bimanual/models/__pycache__/backbones.cpython-310.pyc b/code/reveal_vla_bimanual/models/__pycache__/backbones.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3b09e7fb557764965aa1346d90fd97b45694dd3 Binary files /dev/null and b/code/reveal_vla_bimanual/models/__pycache__/backbones.cpython-310.pyc differ diff --git a/code/reveal_vla_bimanual/models/__pycache__/backbones.cpython-311.pyc b/code/reveal_vla_bimanual/models/__pycache__/backbones.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df0820f700f2c56f6c47d40429e7a00c8dde45a3 Binary files /dev/null and b/code/reveal_vla_bimanual/models/__pycache__/backbones.cpython-311.pyc differ diff --git a/code/reveal_vla_bimanual/models/__pycache__/multiview_fusion.cpython-310.pyc b/code/reveal_vla_bimanual/models/__pycache__/multiview_fusion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dcc7ea68debc8d2ceed7b5d7da3f576da528cb3c Binary files /dev/null and b/code/reveal_vla_bimanual/models/__pycache__/multiview_fusion.cpython-310.pyc differ diff --git a/code/reveal_vla_bimanual/models/__pycache__/multiview_fusion.cpython-311.pyc b/code/reveal_vla_bimanual/models/__pycache__/multiview_fusion.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12405ac30755c8500650b968a49bc45091be9d2b Binary files /dev/null and b/code/reveal_vla_bimanual/models/__pycache__/multiview_fusion.cpython-311.pyc differ diff --git a/code/reveal_vla_bimanual/models/__pycache__/planner.cpython-310.pyc b/code/reveal_vla_bimanual/models/__pycache__/planner.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dedfcd9b138743e3fb8e086443c80a6e1e46cb64 Binary files /dev/null and b/code/reveal_vla_bimanual/models/__pycache__/planner.cpython-310.pyc differ diff --git a/code/reveal_vla_bimanual/models/__pycache__/planner.cpython-311.pyc b/code/reveal_vla_bimanual/models/__pycache__/planner.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f535680dd124c59378d01b261e3bdd13b64d1857 Binary files /dev/null and b/code/reveal_vla_bimanual/models/__pycache__/planner.cpython-311.pyc differ diff --git a/code/reveal_vla_bimanual/models/__pycache__/policy.cpython-310.pyc b/code/reveal_vla_bimanual/models/__pycache__/policy.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04df6721fbf70b6ca53fc82d024822707956bca5 Binary files /dev/null and b/code/reveal_vla_bimanual/models/__pycache__/policy.cpython-310.pyc differ diff --git a/code/reveal_vla_bimanual/models/__pycache__/policy.cpython-311.pyc b/code/reveal_vla_bimanual/models/__pycache__/policy.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6a99bf6abbebfc88f3616113ad451ae00a775ff Binary files /dev/null and b/code/reveal_vla_bimanual/models/__pycache__/policy.cpython-311.pyc differ diff --git a/code/reveal_vla_bimanual/models/__pycache__/reveal_head.cpython-310.pyc b/code/reveal_vla_bimanual/models/__pycache__/reveal_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a975fd599eaa13964deeef470fd58c1c958726e Binary files /dev/null and b/code/reveal_vla_bimanual/models/__pycache__/reveal_head.cpython-310.pyc differ diff --git a/code/reveal_vla_bimanual/models/__pycache__/reveal_head.cpython-311.pyc b/code/reveal_vla_bimanual/models/__pycache__/reveal_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ab75e248e2a50a607a25f08e69666feb1ca467d Binary files /dev/null and b/code/reveal_vla_bimanual/models/__pycache__/reveal_head.cpython-311.pyc differ diff --git a/code/reveal_vla_bimanual/models/__pycache__/world_model.cpython-310.pyc b/code/reveal_vla_bimanual/models/__pycache__/world_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67e86b1b04894cfd3b077b3e494629fad788ffa0 Binary files /dev/null and b/code/reveal_vla_bimanual/models/__pycache__/world_model.cpython-310.pyc differ diff --git a/code/reveal_vla_bimanual/models/__pycache__/world_model.cpython-311.pyc b/code/reveal_vla_bimanual/models/__pycache__/world_model.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d823eead8bb59d340306f4ddaab3c86e2823b8a0 Binary files /dev/null and b/code/reveal_vla_bimanual/models/__pycache__/world_model.cpython-311.pyc differ diff --git a/code/reveal_vla_bimanual/models/action_decoder.py b/code/reveal_vla_bimanual/models/action_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..0ab8e1223ea95307136c96c1e2aba35d60a821e5 --- /dev/null +++ b/code/reveal_vla_bimanual/models/action_decoder.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +from dataclasses import dataclass + +import torch +from torch import Tensor, nn + + +@dataclass +class ChunkDecoderConfig: + hidden_dim: int = 512 + num_heads: int = 8 + num_layers: int = 4 + ff_dim: int = 2048 + dropout: float = 0.1 + chunk_size: int = 8 + action_dim: int = 14 + num_candidates: int = 8 + + +class ACTBimanualChunkDecoder(nn.Module): + def __init__(self, config: ChunkDecoderConfig) -> None: + super().__init__() + self.config = config + decoder_layer = nn.TransformerDecoderLayer( + d_model=config.hidden_dim, + nhead=config.num_heads, + dim_feedforward=config.ff_dim, + dropout=config.dropout, + batch_first=True, + norm_first=True, + ) + self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=config.num_layers) + self.query_embed = nn.Embedding(config.chunk_size, config.hidden_dim) + self.action_mean = nn.Linear(config.hidden_dim, config.action_dim) + self.action_log_std = nn.Linear(config.hidden_dim, config.action_dim) + self.proposal_score = nn.Sequential( + nn.LayerNorm(config.hidden_dim), + nn.Linear(config.hidden_dim, 1), + ) + + def forward(self, scene_tokens: Tensor) -> dict[str, Tensor]: + batch_size = scene_tokens.shape[0] + query = self.query_embed.weight.unsqueeze(0).expand(batch_size, -1, -1) + decoded = self.decoder(query, scene_tokens) + return { + "decoded_tokens": decoded, + "action_mean": self.action_mean(decoded), + "action_log_std": self.action_log_std(decoded).clamp(min=-5.0, max=2.0), + "proposal_score": self.proposal_score(decoded.mean(dim=1)).squeeze(-1), + } + + def sample_candidates(self, action_mean: Tensor, action_log_std: Tensor, num_candidates: int | None = None) -> Tensor: + num_candidates = num_candidates or self.config.num_candidates + if num_candidates <= 1: + return action_mean.unsqueeze(1) + std = action_log_std.exp() + noise = torch.randn( + action_mean.size(0), + num_candidates, + action_mean.size(1), + action_mean.size(2), + device=action_mean.device, + dtype=action_mean.dtype, + ) + candidates = action_mean.unsqueeze(1) + noise * std.unsqueeze(1) + candidates[:, 0] = action_mean + return candidates diff --git a/code/reveal_vla_bimanual/models/backbones.py b/code/reveal_vla_bimanual/models/backbones.py new file mode 100644 index 0000000000000000000000000000000000000000..30ec627dd90a7c2c0b35287ef08c689c2024b86b --- /dev/null +++ b/code/reveal_vla_bimanual/models/backbones.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +from dataclasses import dataclass +import math +from typing import Sequence + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + + +@dataclass +class FrozenVLBackboneConfig: + model_name: str = "openai/clip-vit-base-patch32" + hidden_dim: int = 512 + max_text_tokens: int = 32 + freeze_backbone: bool = True + gradient_checkpointing: bool = True + use_dummy_backbone: bool = False + + +class _DummyTextTokenizer: + def __init__(self, vocab_size: int = 8192, max_length: int = 32) -> None: + self.vocab_size = vocab_size + self.max_length = max_length + + def __call__(self, texts: Sequence[str], device: torch.device) -> dict[str, Tensor]: + token_ids = torch.zeros((len(texts), self.max_length), dtype=torch.long, device=device) + attention_mask = torch.zeros_like(token_ids) + for row, text in enumerate(texts): + encoded = [min(ord(char), self.vocab_size - 1) for char in text[: self.max_length]] + if encoded: + token_ids[row, : len(encoded)] = torch.tensor(encoded, dtype=torch.long, device=device) + attention_mask[row, : len(encoded)] = 1 + return {"input_ids": token_ids, "attention_mask": attention_mask} + + +class FrozenVLBackbone(nn.Module): + def __init__(self, config: FrozenVLBackboneConfig) -> None: + super().__init__() + self.config = config + self.hidden_dim = config.hidden_dim + self.use_dummy_backbone = config.use_dummy_backbone + + if config.use_dummy_backbone: + self.image_patch_size = 16 + self.tokenizer = _DummyTextTokenizer(max_length=config.max_text_tokens) + else: + from transformers import AutoTokenizer, CLIPModel + + clip_model = CLIPModel.from_pretrained(config.model_name) + self.vision_model = clip_model.vision_model + self.text_model = clip_model.text_model + self.visual_projection = clip_model.visual_projection + self.text_projection = clip_model.text_projection + self.tokenizer = AutoTokenizer.from_pretrained(config.model_name) + self.hidden_dim = clip_model.config.projection_dim + if config.gradient_checkpointing: + if hasattr(self.vision_model, "gradient_checkpointing_enable"): + self.vision_model.gradient_checkpointing_enable() + if hasattr(self.text_model, "gradient_checkpointing_enable"): + self.text_model.gradient_checkpointing_enable() + + if config.freeze_backbone: + for parameter in self.parameters(): + parameter.requires_grad = False + + def tokenize_text(self, texts: Sequence[str], device: torch.device) -> dict[str, Tensor]: + if self.use_dummy_backbone: + return self.tokenizer(texts, device=device) + return self.tokenizer( + list(texts), + padding=True, + truncation=True, + max_length=self.config.max_text_tokens, + return_tensors="pt", + ).to(device) + + def encode_images(self, images: Tensor) -> Tensor: + batch_size, num_views, channels, height, width = images.shape + flat_images = images.reshape(batch_size * num_views, channels, height, width) + if self.use_dummy_backbone: + pooled = F.avg_pool2d(flat_images.float(), kernel_size=self.image_patch_size, stride=self.image_patch_size) + patch_tokens = pooled.flatten(2).transpose(1, 2) + grid_h, grid_w = pooled.shape[-2], pooled.shape[-1] + y_coords = torch.linspace(-1.0, 1.0, steps=grid_h, device=images.device) + x_coords = torch.linspace(-1.0, 1.0, steps=grid_w, device=images.device) + grid_y, grid_x = torch.meshgrid(y_coords, x_coords, indexing="ij") + coords = torch.stack([grid_x, grid_y], dim=-1).reshape(1, grid_h * grid_w, 2) + coords = coords.expand(patch_tokens.shape[0], -1, -1) + intensity = patch_tokens.mean(dim=-1, keepdim=True) + base = torch.cat([patch_tokens, intensity, coords], dim=-1) + repeat_factor = math.ceil(self.hidden_dim / base.shape[-1]) + tokens = base.repeat(1, 1, repeat_factor)[..., : self.hidden_dim] + else: + outputs = self.vision_model(pixel_values=flat_images) + tokens = self.visual_projection(outputs.last_hidden_state) + num_tokens = tokens.shape[1] + return tokens.reshape(batch_size, num_views, num_tokens, -1) + + def encode_text(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor: + if self.use_dummy_backbone: + vocab_scale = float(self.tokenizer.vocab_size - 1) + token_values = input_ids.float() / vocab_scale + frequencies = torch.linspace( + 1.0, + 4.0, + steps=max(1, self.hidden_dim // 2), + device=input_ids.device, + dtype=token_values.dtype, + ) + phases = token_values.unsqueeze(-1) * frequencies.view(1, 1, -1) * (2.0 * math.pi) + embeddings = torch.cat([torch.sin(phases), torch.cos(phases)], dim=-1)[..., : self.hidden_dim] + return embeddings * attention_mask.unsqueeze(-1).float() + outputs = self.text_model(input_ids=input_ids, attention_mask=attention_mask) + return self.text_projection(outputs.last_hidden_state) diff --git a/code/reveal_vla_bimanual/models/multiview_fusion.py b/code/reveal_vla_bimanual/models/multiview_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..98b70e99a80efad63301771635a9def96c0a1637 --- /dev/null +++ b/code/reveal_vla_bimanual/models/multiview_fusion.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +from dataclasses import dataclass + +import torch +from torch import Tensor, nn + + +@dataclass +class MultiViewFusionConfig: + hidden_dim: int = 512 + num_cameras: int = 3 + num_layers: int = 4 + num_heads: int = 8 + ff_dim: int = 2048 + dropout: float = 0.1 + proprio_dim: int = 32 + proprio_tokens: int = 1 + + +class MultiViewFusion(nn.Module): + def __init__(self, config: MultiViewFusionConfig) -> None: + super().__init__() + self.config = config + self.camera_embedding = nn.Embedding(config.num_cameras, config.hidden_dim) + encoder_layer = nn.TransformerEncoderLayer( + d_model=config.hidden_dim, + nhead=config.num_heads, + dim_feedforward=config.ff_dim, + dropout=config.dropout, + batch_first=True, + norm_first=True, + ) + self.cross_view_transformer = nn.TransformerEncoder( + encoder_layer, + num_layers=config.num_layers, + ) + self.proprio_adapter = nn.Sequential( + nn.LayerNorm(config.proprio_dim), + nn.Linear(config.proprio_dim, config.hidden_dim * config.proprio_tokens), + nn.GELU(), + ) + + def forward(self, image_tokens: Tensor, proprio: Tensor, language_tokens: Tensor) -> Tensor: + batch_size, num_views, num_tokens, hidden_dim = image_tokens.shape + if num_views != self.config.num_cameras: + raise ValueError(f"Expected {self.config.num_cameras} views, received {num_views}") + + camera_ids = torch.arange(num_views, device=image_tokens.device) + camera_embed = self.camera_embedding(camera_ids).view(1, num_views, 1, hidden_dim) + image_tokens = image_tokens + camera_embed + fused = self.cross_view_transformer(image_tokens.reshape(batch_size, num_views * num_tokens, hidden_dim)) + + proprio_tokens = self.proprio_adapter(proprio).view( + batch_size, self.config.proprio_tokens, hidden_dim + ) + return torch.cat([fused, proprio_tokens, language_tokens], dim=1) diff --git a/code/reveal_vla_bimanual/models/planner.py b/code/reveal_vla_bimanual/models/planner.py new file mode 100644 index 0000000000000000000000000000000000000000..9c9e341edaaee01db3ffe199d737c91b57e328ce --- /dev/null +++ b/code/reveal_vla_bimanual/models/planner.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from dataclasses import dataclass + +import torch +from torch import Tensor + + +@dataclass +class PlannerConfig: + num_candidates: int = 8 + corridor_weight: float = 1.0 + persistence_weight: float = 0.5 + proposal_weight: float = 0.5 + task_progress_weight: float = 0.75 + disturbance_weight: float = 0.75 + reocclusion_weight: float = 0.5 + visibility_weight: float = 0.25 + + +class RevealPlanner: + def __init__(self, config: PlannerConfig) -> None: + self.config = config + + def score_rollouts( + self, + rollout_state: dict[str, Tensor], + proposal_scores: Tensor, + candidate_chunks: Tensor | None = None, + belief_gain: Tensor | None = None, + ) -> Tensor: + corridor_prob = rollout_state["corridor_logits"].sigmoid().amax(dim=-1).mean(dim=(-1, -2)) + persistence = rollout_state["persistence_horizon"].mean(dim=(-1, -2)) + disturbance = rollout_state["disturbance_cost"].mean(dim=-1) + reocclusion_penalty = torch.relu(1.0 - rollout_state["corridor_logits"].sigmoid().amax(dim=-1)).mean(dim=(-1, -2)) + task_progress = proposal_scores.new_zeros(proposal_scores.shape) + if candidate_chunks is not None: + actor_reach = torch.tanh(candidate_chunks[..., 8]).mean(dim=-1) + actor_retrieve = torch.tanh(candidate_chunks[..., 13]).amax(dim=-1) + task_progress = 0.5 * (actor_reach + 1.0) * 0.5 + 0.5 * (actor_retrieve + 1.0) * 0.5 + score = ( + self.config.corridor_weight * corridor_prob + + self.config.persistence_weight * persistence + + self.config.proposal_weight * proposal_scores + + self.config.task_progress_weight * task_progress + - self.config.disturbance_weight * disturbance + - self.config.reocclusion_weight * reocclusion_penalty + ) + if belief_gain is not None: + score = score + self.config.visibility_weight * belief_gain + return score + + def select_best(self, candidate_chunks: Tensor, rollout_state: dict[str, Tensor], proposal_scores: Tensor) -> dict[str, Tensor]: + scores = self.score_rollouts(rollout_state, proposal_scores, candidate_chunks=candidate_chunks) + best_idx = scores.argmax(dim=-1) + batch_indices = torch.arange(candidate_chunks.shape[0], device=candidate_chunks.device) + return { + "scores": scores, + "best_indices": best_idx, + "best_chunk": candidate_chunks[batch_indices, best_idx], + } diff --git a/code/reveal_vla_bimanual/models/policy.py b/code/reveal_vla_bimanual/models/policy.py new file mode 100644 index 0000000000000000000000000000000000000000..0f6604f02e83922abe7828d174871ced24de419b --- /dev/null +++ b/code/reveal_vla_bimanual/models/policy.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Sequence + +import torch +from torch import Tensor, nn + +from models.action_decoder import ACTBimanualChunkDecoder, ChunkDecoderConfig +from models.backbones import FrozenVLBackbone, FrozenVLBackboneConfig +from models.multiview_fusion import MultiViewFusion, MultiViewFusionConfig +from models.planner import PlannerConfig, RevealPlanner +from models.reveal_head import RevealHeadConfig, RevealStateHead +from models.world_model import RevealWM, RevealWMConfig + + +@dataclass +class PolicyConfig: + backbone: FrozenVLBackboneConfig = field(default_factory=FrozenVLBackboneConfig) + fusion: MultiViewFusionConfig = field(default_factory=MultiViewFusionConfig) + decoder: ChunkDecoderConfig = field(default_factory=ChunkDecoderConfig) + reveal_head: RevealHeadConfig = field(default_factory=RevealHeadConfig) + world_model: RevealWMConfig = field(default_factory=RevealWMConfig) + planner: PlannerConfig = field(default_factory=PlannerConfig) + + +class BackboneOnlyPolicy(nn.Module): + def __init__(self, config: PolicyConfig) -> None: + super().__init__() + self.config = config + self.backbone = FrozenVLBackbone(config.backbone) + self.fusion = MultiViewFusion(config.fusion) + self.decoder = ACTBimanualChunkDecoder(config.decoder) + + def _encode_language( + self, + images: Tensor, + texts: Sequence[str] | None = None, + language_tokens: dict[str, Tensor] | None = None, + ) -> Tensor: + if language_tokens is None: + if texts is None: + raise ValueError("Either texts or language_tokens must be provided.") + language_tokens = self.backbone.tokenize_text(texts, device=images.device) + return self.backbone.encode_text( + input_ids=language_tokens["input_ids"], + attention_mask=language_tokens["attention_mask"], + ) + + def encode_scene( + self, + images: Tensor, + proprio: Tensor, + texts: Sequence[str] | None = None, + language_tokens: dict[str, Tensor] | None = None, + ) -> Tensor: + image_tokens = self.backbone.encode_images(images) + text_tokens = self._encode_language(images, texts=texts, language_tokens=language_tokens) + return self.fusion(image_tokens=image_tokens, proprio=proprio, language_tokens=text_tokens) + + def forward( + self, + images: Tensor, + proprio: Tensor, + texts: Sequence[str] | None = None, + language_tokens: dict[str, Tensor] | None = None, + ) -> dict[str, Tensor]: + scene_tokens = self.encode_scene(images, proprio, texts=texts, language_tokens=language_tokens) + decoded = self.decoder(scene_tokens) + decoded["scene_tokens"] = scene_tokens + return decoded + + +class RevealBimanualPolicy(BackboneOnlyPolicy): + def __init__(self, config: PolicyConfig) -> None: + super().__init__(config) + self.reveal_head = RevealStateHead(config.reveal_head) + self.world_model = RevealWM(config.world_model) + self.planner = RevealPlanner(config.planner) + + def forward( + self, + images: Tensor, + proprio: Tensor, + texts: Sequence[str] | None = None, + language_tokens: dict[str, Tensor] | None = None, + plan: bool = True, + support_mode_conditioning: bool = True, + ) -> dict[str, Tensor]: + outputs = super().forward(images, proprio, texts=texts, language_tokens=language_tokens) + reveal_state = self.reveal_head(outputs["scene_tokens"]) + outputs["reveal_state"] = reveal_state + + candidate_chunks = self.decoder.sample_candidates( + outputs["action_mean"], + outputs["action_log_std"], + num_candidates=self.config.decoder.num_candidates, + ) + outputs["candidate_chunks"] = candidate_chunks + + if plan: + batch_size, num_candidates, chunk_size, action_dim = candidate_chunks.shape + flat_chunks = candidate_chunks.view(batch_size * num_candidates, chunk_size, action_dim) + tiled_scene = outputs["scene_tokens"].unsqueeze(1).expand(-1, num_candidates, -1, -1) + tiled_scene = tiled_scene.reshape(batch_size * num_candidates, outputs["scene_tokens"].shape[1], outputs["scene_tokens"].shape[2]) + planning_reveal_state = reveal_state + if not support_mode_conditioning: + planning_reveal_state = dict(reveal_state) + planning_reveal_state["support_mode_logits"] = torch.zeros_like(reveal_state["support_mode_logits"]) + tiled_reveal = { + key: value.unsqueeze(1).expand(-1, num_candidates, *value.shape[1:]).reshape(batch_size * num_candidates, *value.shape[1:]) + for key, value in planning_reveal_state.items() + } + rollout = self.world_model(tiled_scene, tiled_reveal, flat_chunks) + reshaped_rollout = { + key: value.view(batch_size, num_candidates, *value.shape[1:]) for key, value in rollout.items() + } + selected = self.planner.select_best( + candidate_chunks=candidate_chunks, + rollout_state=reshaped_rollout, + proposal_scores=outputs["proposal_score"].unsqueeze(-1).expand(-1, num_candidates), + ) + outputs["planned_rollout"] = reshaped_rollout + outputs["planned_chunk"] = selected["best_chunk"] + outputs["planner_scores"] = selected["scores"] + outputs["best_candidate_indices"] = selected["best_indices"] + return outputs diff --git a/code/reveal_vla_bimanual/models/reveal_head.py b/code/reveal_vla_bimanual/models/reveal_head.py new file mode 100644 index 0000000000000000000000000000000000000000..7c357d6993dd5ea0222df81336a53e1744493ad4 --- /dev/null +++ b/code/reveal_vla_bimanual/models/reveal_head.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from torch import Tensor, nn + + +@dataclass +class RevealHeadConfig: + hidden_dim: int = 512 + num_support_modes: int = 3 + num_approach_templates: int = 32 + rollout_horizon: int = 5 + belief_map_size: int = 32 + predict_belief_map: bool = False + + +class RevealStateHead(nn.Module): + def __init__(self, config: RevealHeadConfig) -> None: + super().__init__() + self.config = config + self.trunk = nn.Sequential( + nn.LayerNorm(config.hidden_dim), + nn.Linear(config.hidden_dim, config.hidden_dim), + nn.GELU(), + ) + self.support_mode = nn.Linear(config.hidden_dim, config.num_support_modes) + self.corridor = nn.Linear( + config.hidden_dim, + config.num_support_modes * config.num_approach_templates, + ) + self.persistence = nn.Linear(config.hidden_dim, config.num_support_modes) + self.disturbance = nn.Linear(config.hidden_dim, 1) + self.belief_map = None + if config.predict_belief_map: + map_side = config.belief_map_size + self.belief_map = nn.Linear(config.hidden_dim, map_side * map_side) + + def forward(self, scene_tokens: Tensor) -> dict[str, Tensor]: + pooled = scene_tokens.mean(dim=1) + hidden = self.trunk(pooled) + output = { + "support_mode_logits": self.support_mode(hidden), + "corridor_logits": self.corridor(hidden).view( + hidden.shape[0], + self.config.num_support_modes, + self.config.num_approach_templates, + ), + "persistence_horizon": self.persistence(hidden), + "disturbance_cost": self.disturbance(hidden).squeeze(-1), + } + if self.belief_map is not None: + side = self.config.belief_map_size + output["belief_map"] = self.belief_map(hidden).view(hidden.shape[0], 1, side, side) + return output diff --git a/code/reveal_vla_bimanual/models/world_model.py b/code/reveal_vla_bimanual/models/world_model.py new file mode 100644 index 0000000000000000000000000000000000000000..208995eb992876eea0e36a9d4d34edcb880408e8 --- /dev/null +++ b/code/reveal_vla_bimanual/models/world_model.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +from dataclasses import dataclass + +import torch +from torch import Tensor, nn + + +@dataclass +class RevealWMConfig: + hidden_dim: int = 512 + action_dim: int = 14 + num_support_modes: int = 3 + num_approach_templates: int = 32 + rollout_horizon: int = 5 + + +class RevealWM(nn.Module): + def __init__(self, config: RevealWMConfig) -> None: + super().__init__() + self.config = config + reveal_dim = ( + config.num_support_modes + + config.num_support_modes * config.num_approach_templates + + config.num_support_modes + + 1 + ) + self.initial = nn.Sequential( + nn.LayerNorm(config.hidden_dim + reveal_dim), + nn.Linear(config.hidden_dim + reveal_dim, config.hidden_dim), + nn.GELU(), + ) + self.action_encoder = nn.Linear(config.action_dim, config.hidden_dim) + self.gru = nn.GRU(config.hidden_dim, config.hidden_dim, batch_first=True) + self.support_mode = nn.Linear(config.hidden_dim, config.num_support_modes) + self.corridor = nn.Linear( + config.hidden_dim, + config.num_support_modes * config.num_approach_templates, + ) + self.persistence = nn.Linear(config.hidden_dim, config.num_support_modes) + self.disturbance = nn.Linear(config.hidden_dim, 1) + + def _flatten_reveal(self, reveal_state: dict[str, Tensor]) -> Tensor: + return torch.cat( + [ + reveal_state["support_mode_logits"], + reveal_state["corridor_logits"].flatten(start_dim=1), + reveal_state["persistence_horizon"], + reveal_state["disturbance_cost"].unsqueeze(-1), + ], + dim=-1, + ) + + def forward(self, scene_tokens: Tensor, reveal_state: dict[str, Tensor], action_chunk: Tensor) -> dict[str, Tensor]: + pooled = scene_tokens.mean(dim=1) + initial_hidden = self.initial(torch.cat([pooled, self._flatten_reveal(reveal_state)], dim=-1)) + encoded_actions = self.action_encoder(action_chunk) + rollout, _ = self.gru(encoded_actions, initial_hidden.unsqueeze(0)) + batch_size, horizon, _ = rollout.shape + return { + "support_mode_logits": self.support_mode(rollout), + "corridor_logits": self.corridor(rollout).view( + batch_size, + horizon, + self.config.num_support_modes, + self.config.num_approach_templates, + ), + "persistence_horizon": self.persistence(rollout), + "disturbance_cost": self.disturbance(rollout).squeeze(-1), + } diff --git a/code/reveal_vla_bimanual/sim_reveal/__init__.py b/code/reveal_vla_bimanual/sim_reveal/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..00cbb06aa8ba74722af418ff162a272d2009bf3a --- /dev/null +++ b/code/reveal_vla_bimanual/sim_reveal/__init__.py @@ -0,0 +1,15 @@ +from sim_reveal.base import RevealProxyConfig, RevealState, SupportMode +from sim_reveal.procedural_envs import ProceduralRevealEnv, available_proxy_names, make_proxy_env +from sim_reveal.proxy_specs import BAG_PROXY, CLOTH_PROXY, FOLIAGE_PROXY + +__all__ = [ + "BAG_PROXY", + "CLOTH_PROXY", + "FOLIAGE_PROXY", + "ProceduralRevealEnv", + "RevealProxyConfig", + "RevealState", + "SupportMode", + "available_proxy_names", + "make_proxy_env", +] diff --git a/code/reveal_vla_bimanual/sim_reveal/__pycache__/__init__.cpython-310.pyc b/code/reveal_vla_bimanual/sim_reveal/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85bae60505e1ddb22d575f64f32b85b22e8b269e Binary files /dev/null and b/code/reveal_vla_bimanual/sim_reveal/__pycache__/__init__.cpython-310.pyc differ diff --git a/code/reveal_vla_bimanual/sim_reveal/__pycache__/__init__.cpython-311.pyc b/code/reveal_vla_bimanual/sim_reveal/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2265eb93e4e6f10f4a8db8799ced43795a21b6b6 Binary files /dev/null and b/code/reveal_vla_bimanual/sim_reveal/__pycache__/__init__.cpython-311.pyc differ diff --git a/code/reveal_vla_bimanual/sim_reveal/__pycache__/base.cpython-310.pyc b/code/reveal_vla_bimanual/sim_reveal/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82bb41cbff5b805485eeca8d4d3f77af4875bb45 Binary files /dev/null and b/code/reveal_vla_bimanual/sim_reveal/__pycache__/base.cpython-310.pyc differ diff --git a/code/reveal_vla_bimanual/sim_reveal/__pycache__/base.cpython-311.pyc b/code/reveal_vla_bimanual/sim_reveal/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce5bdcc7e5cf7cc03b6ec652df5328da709a71a8 Binary files /dev/null and b/code/reveal_vla_bimanual/sim_reveal/__pycache__/base.cpython-311.pyc differ diff --git a/code/reveal_vla_bimanual/sim_reveal/__pycache__/dataset.cpython-310.pyc b/code/reveal_vla_bimanual/sim_reveal/__pycache__/dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e82852acc00756b05bc37b91f89c117d561f220 Binary files /dev/null and b/code/reveal_vla_bimanual/sim_reveal/__pycache__/dataset.cpython-310.pyc differ diff --git a/code/reveal_vla_bimanual/sim_reveal/__pycache__/dataset.cpython-311.pyc b/code/reveal_vla_bimanual/sim_reveal/__pycache__/dataset.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2cdda9a4de4a9ed4854ea35914b88a06cd0f972 Binary files /dev/null and b/code/reveal_vla_bimanual/sim_reveal/__pycache__/dataset.cpython-311.pyc differ diff --git a/code/reveal_vla_bimanual/sim_reveal/__pycache__/generate_dataset.cpython-310.pyc b/code/reveal_vla_bimanual/sim_reveal/__pycache__/generate_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc6a75fda6b11b002d9a216cc814890eaa026bd3 Binary files /dev/null and b/code/reveal_vla_bimanual/sim_reveal/__pycache__/generate_dataset.cpython-310.pyc differ diff --git a/code/reveal_vla_bimanual/sim_reveal/__pycache__/generate_dataset.cpython-311.pyc b/code/reveal_vla_bimanual/sim_reveal/__pycache__/generate_dataset.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af7e59508f7d41a6e08982dc53acd3c23a396547 Binary files /dev/null and b/code/reveal_vla_bimanual/sim_reveal/__pycache__/generate_dataset.cpython-311.pyc differ diff --git a/code/reveal_vla_bimanual/sim_reveal/__pycache__/isaac_smoke.cpython-310.pyc b/code/reveal_vla_bimanual/sim_reveal/__pycache__/isaac_smoke.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..709800afbf80649f6e938453a89800b2f30623a6 Binary files /dev/null and b/code/reveal_vla_bimanual/sim_reveal/__pycache__/isaac_smoke.cpython-310.pyc differ diff --git a/code/reveal_vla_bimanual/sim_reveal/__pycache__/isaac_smoke.cpython-311.pyc b/code/reveal_vla_bimanual/sim_reveal/__pycache__/isaac_smoke.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30963d15e237483b0c4f1f8998653b68d8044e2d Binary files /dev/null and b/code/reveal_vla_bimanual/sim_reveal/__pycache__/isaac_smoke.cpython-311.pyc differ diff --git a/code/reveal_vla_bimanual/sim_reveal/__pycache__/isaac_wrapper.cpython-310.pyc b/code/reveal_vla_bimanual/sim_reveal/__pycache__/isaac_wrapper.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc5f78042310680208e2b75448ccd4441711b35c Binary files /dev/null and b/code/reveal_vla_bimanual/sim_reveal/__pycache__/isaac_wrapper.cpython-310.pyc differ diff --git a/code/reveal_vla_bimanual/sim_reveal/__pycache__/isaac_wrapper.cpython-311.pyc b/code/reveal_vla_bimanual/sim_reveal/__pycache__/isaac_wrapper.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c147c3abf45d9c26719055d9da92d642e5d5f99 Binary files /dev/null and b/code/reveal_vla_bimanual/sim_reveal/__pycache__/isaac_wrapper.cpython-311.pyc differ diff --git a/code/reveal_vla_bimanual/sim_reveal/__pycache__/labels.cpython-311.pyc b/code/reveal_vla_bimanual/sim_reveal/__pycache__/labels.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0f10bdcfd92708560d87f8375025e96375cb048 Binary files /dev/null and b/code/reveal_vla_bimanual/sim_reveal/__pycache__/labels.cpython-311.pyc differ diff --git a/code/reveal_vla_bimanual/sim_reveal/__pycache__/procedural_envs.cpython-310.pyc b/code/reveal_vla_bimanual/sim_reveal/__pycache__/procedural_envs.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0b7ade829db7abc22e87d5ac5d83e2f48def773 Binary files /dev/null and b/code/reveal_vla_bimanual/sim_reveal/__pycache__/procedural_envs.cpython-310.pyc differ diff --git a/code/reveal_vla_bimanual/sim_reveal/__pycache__/procedural_envs.cpython-311.pyc b/code/reveal_vla_bimanual/sim_reveal/__pycache__/procedural_envs.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6dea4ffe5267197b2d25ed424be73d4099c69f2 Binary files /dev/null and b/code/reveal_vla_bimanual/sim_reveal/__pycache__/procedural_envs.cpython-311.pyc differ diff --git a/code/reveal_vla_bimanual/sim_reveal/__pycache__/proxy_specs.cpython-310.pyc b/code/reveal_vla_bimanual/sim_reveal/__pycache__/proxy_specs.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..731cb1d47df0565947f6703bfd844a8029068d03 Binary files /dev/null and b/code/reveal_vla_bimanual/sim_reveal/__pycache__/proxy_specs.cpython-310.pyc differ diff --git a/code/reveal_vla_bimanual/sim_reveal/__pycache__/proxy_specs.cpython-311.pyc b/code/reveal_vla_bimanual/sim_reveal/__pycache__/proxy_specs.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99e1e931f10c9c0468529acee53cbd0dc3f6bca6 Binary files /dev/null and b/code/reveal_vla_bimanual/sim_reveal/__pycache__/proxy_specs.cpython-311.pyc differ diff --git a/code/reveal_vla_bimanual/sim_reveal/__pycache__/teachers.cpython-311.pyc b/code/reveal_vla_bimanual/sim_reveal/__pycache__/teachers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..73cb5310c38b91039539abe952adafb957b9c7f1 Binary files /dev/null and b/code/reveal_vla_bimanual/sim_reveal/__pycache__/teachers.cpython-311.pyc differ diff --git a/code/reveal_vla_bimanual/sim_reveal/base.py b/code/reveal_vla_bimanual/sim_reveal/base.py new file mode 100644 index 0000000000000000000000000000000000000000..442eb618613d233add1d2986b0d368d94d812051 --- /dev/null +++ b/code/reveal_vla_bimanual/sim_reveal/base.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import IntEnum + +import numpy as np + + +class SupportMode(IntEnum): + HOLD = 0 + TRANSFER = 1 + PASSIVE = 2 + + +@dataclass +class RevealState: + support_mode_logits: np.ndarray + corridor_logits: np.ndarray + persistence_horizon: np.ndarray + disturbance_cost: np.ndarray + belief_map: np.ndarray | None = None + + +@dataclass +class RevealProxyConfig: + name: str + num_templates: int = 32 + rollout_horizon: int = 5 + max_steps: int = 80 + disturbance_key: str = "disturbance_cost" + success_key: str = "retrieval_success" + metadata: dict[str, str] = field(default_factory=dict) diff --git a/code/reveal_vla_bimanual/sim_reveal/dataset.py b/code/reveal_vla_bimanual/sim_reveal/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e9244d7f7dccd81041e5438232a62f96eda8590e --- /dev/null +++ b/code/reveal_vla_bimanual/sim_reveal/dataset.py @@ -0,0 +1,137 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any, Sequence + +import torch +from torch import Tensor +from torch.utils.data import Dataset + +from sim_reveal.procedural_envs import available_proxy_names, make_proxy_env, render_views_from_state + + +def collect_teacher_dataset( + proxy_names: Sequence[str] | None = None, + episodes_per_proxy: int = 32, + resolution: int = 96, + seed: int = 0, + chunk_horizon: int = 8, + rollout_horizon: int = 5, +) -> dict[str, Any]: + proxy_names = tuple(proxy_names or available_proxy_names()) + samples: list[dict[str, Any]] = [] + summary: dict[str, dict[str, float]] = {} + + for proxy_offset, proxy_name in enumerate(proxy_names): + proxy_samples = 0 + proxy_success = 0 + for episode_idx in range(episodes_per_proxy): + env = make_proxy_env( + proxy_name=proxy_name, + resolution=resolution, + seed=seed + proxy_offset * 10_000 + episode_idx, + rollout_horizon=rollout_horizon, + ) + _, privileged_state = env.reset(seed=seed + proxy_offset * 10_000 + episode_idx) + while True: + action_chunk, rollout = env.teacher_chunk_and_rollout( + chunk_horizon=chunk_horizon, + rollout_horizon=rollout_horizon, + ) + samples.append( + { + "proxy_name": proxy_name, + "episode_id": episode_idx, + "render_state": env.render_state(privileged_state), + "proprio": env.get_observation(privileged_state)["proprio"].astype("float32"), + "language_goal": env.get_observation(privileged_state)["text"], + "action_chunk": action_chunk.astype("float32"), + "support_mode": int(privileged_state["support_mode"]), + "corridor_feasible": privileged_state["corridor_feasible"].astype("float32"), + "persistence_horizon": privileged_state["persistence_horizon"].astype("float32"), + "disturbance_cost": float(privileged_state["disturbance_cost"]), + "belief_map": privileged_state["belief_map"].astype("float32"), + "rollout_support_mode": rollout["rollout_support_mode"].astype("int64"), + "rollout_corridor_feasible": rollout["rollout_corridor_feasible"].astype("float32"), + "rollout_persistence_horizon": rollout["rollout_persistence_horizon"].astype("float32"), + "rollout_disturbance_cost": rollout["rollout_disturbance_cost"].astype("float32"), + } + ) + proxy_samples += 1 + _, _, terminated, truncated, privileged_state = env.step(env.teacher_action()) + if terminated: + proxy_success += 1 + if terminated or truncated: + break + summary[proxy_name] = { + "episodes": float(episodes_per_proxy), + "samples": float(proxy_samples), + "teacher_success": proxy_success / float(max(1, episodes_per_proxy)), + } + return { + "resolution": resolution, + "chunk_horizon": chunk_horizon, + "rollout_horizon": rollout_horizon, + "samples": samples, + "summary": summary, + } + + +def save_teacher_dataset(output_path: str | Path, dataset_bundle: dict[str, Any]) -> Path: + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + torch.save(dataset_bundle, output_path) + return output_path + + +def load_teacher_dataset(dataset_path: str | Path) -> dict[str, Any]: + return torch.load(Path(dataset_path), map_location="cpu", weights_only=False) + + +class RevealOfflineDataset(Dataset[dict[str, Any]]): + def __init__(self, samples: Sequence[dict[str, Any]], resolution: int = 96) -> None: + self.samples = list(samples) + self.resolution = resolution + + def __len__(self) -> int: + return len(self.samples) + + def __getitem__(self, index: int) -> dict[str, Any]: + sample = self.samples[index] + images = render_views_from_state( + proxy_name=sample["proxy_name"], + render_state=sample["render_state"], + resolution=self.resolution, + ) + stacked = torch.from_numpy( + torch.stack( + [ + torch.from_numpy(images["front"]), + torch.from_numpy(images["wrist_left"]), + torch.from_numpy(images["wrist_right"]), + ], + dim=0, + ).numpy() + ).permute(0, 3, 1, 2).float() / 255.0 + return { + "images": stacked, + "proprio": torch.as_tensor(sample["proprio"], dtype=torch.float32), + "texts": sample["language_goal"], + "action_chunk": torch.as_tensor(sample["action_chunk"], dtype=torch.float32), + "support_mode": torch.as_tensor(sample["support_mode"], dtype=torch.long), + "corridor_feasible": torch.as_tensor(sample["corridor_feasible"], dtype=torch.float32), + "persistence_horizon": torch.as_tensor(sample["persistence_horizon"], dtype=torch.float32), + "disturbance_cost": torch.as_tensor(sample["disturbance_cost"], dtype=torch.float32), + "belief_map": torch.as_tensor(sample["belief_map"], dtype=torch.float32).unsqueeze(0), + "rollout_support_mode": torch.as_tensor(sample["rollout_support_mode"], dtype=torch.long), + "rollout_corridor_feasible": torch.as_tensor(sample["rollout_corridor_feasible"], dtype=torch.float32), + "rollout_persistence_horizon": torch.as_tensor(sample["rollout_persistence_horizon"], dtype=torch.float32), + "rollout_disturbance_cost": torch.as_tensor(sample["rollout_disturbance_cost"], dtype=torch.float32), + "proxy_name": sample["proxy_name"], + "episode_id": sample["episode_id"], + } + + +def dataset_from_bundle(dataset_bundle: dict[str, Any], resolution: int | None = None) -> RevealOfflineDataset: + resolution = resolution or int(dataset_bundle["resolution"]) + return RevealOfflineDataset(dataset_bundle["samples"], resolution=resolution) diff --git a/code/reveal_vla_bimanual/sim_reveal/generate_dataset.py b/code/reveal_vla_bimanual/sim_reveal/generate_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a84b0ff9500396e95bee88f6e9bc3a48da9f400c --- /dev/null +++ b/code/reveal_vla_bimanual/sim_reveal/generate_dataset.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +import argparse +import json +from pathlib import Path + +from sim_reveal.dataset import collect_teacher_dataset, save_teacher_dataset + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--proxies", nargs="*", default=None) + parser.add_argument("--episodes-per-proxy", type=int, default=32) + parser.add_argument("--resolution", type=int, default=96) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--chunk-horizon", type=int, default=8) + parser.add_argument("--rollout-horizon", type=int, default=5) + parser.add_argument("--output-path", default="/workspace/data/reveal_proxy/reveal_proxy_teacher.pt") + args = parser.parse_args() + + dataset_bundle = collect_teacher_dataset( + proxy_names=args.proxies, + episodes_per_proxy=args.episodes_per_proxy, + resolution=args.resolution, + seed=args.seed, + chunk_horizon=args.chunk_horizon, + rollout_horizon=args.rollout_horizon, + ) + output_path = save_teacher_dataset(Path(args.output_path), dataset_bundle) + payload = { + "output_path": str(output_path), + "resolution": dataset_bundle["resolution"], + "num_samples": len(dataset_bundle["samples"]), + "summary": dataset_bundle["summary"], + } + print(json.dumps(payload, indent=2)) + + +if __name__ == "__main__": + main() diff --git a/code/reveal_vla_bimanual/sim_reveal/isaac_smoke.py b/code/reveal_vla_bimanual/sim_reveal/isaac_smoke.py new file mode 100644 index 0000000000000000000000000000000000000000..e87a6197b106606395e9a759b10d142b6206349a --- /dev/null +++ b/code/reveal_vla_bimanual/sim_reveal/isaac_smoke.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +import argparse +import json + +from sim_reveal.isaac_wrapper import IsaacRevealRuntime + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--visible", action="store_true") + args = parser.parse_args() + + runtime = IsaacRevealRuntime(headless=not args.visible) + try: + import isaacsim + + payload = { + "headless": not args.visible, + "isaacsim_version": getattr(isaacsim, "__version__", "unknown"), + "status": "ok", + } + print(json.dumps(payload, indent=2)) + finally: + runtime.close() + + +if __name__ == "__main__": + main() diff --git a/code/reveal_vla_bimanual/sim_reveal/isaac_wrapper.py b/code/reveal_vla_bimanual/sim_reveal/isaac_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..9355aaa273b4a308a1d406daa50e3b3ccfbef0dd --- /dev/null +++ b/code/reveal_vla_bimanual/sim_reveal/isaac_wrapper.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass +class IsaacRevealRuntime: + headless: bool = True + + def __post_init__(self) -> None: + from isaacsim import SimulationApp + + self._simulation_app = SimulationApp({"headless": self.headless}) + + def close(self) -> None: + self._simulation_app.close() diff --git a/code/reveal_vla_bimanual/sim_reveal/labels.py b/code/reveal_vla_bimanual/sim_reveal/labels.py new file mode 100644 index 0000000000000000000000000000000000000000..4d8dec475dd8fba1b877e0375a553af28d09b68f --- /dev/null +++ b/code/reveal_vla_bimanual/sim_reveal/labels.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from typing import Any + +import numpy as np + +from sim_reveal.base import RevealState, SupportMode + + +def privileged_state_to_reveal_labels( + state: dict[str, Any], + num_modes: int = 3, + num_templates: int = 32, + rollout_horizon: int = 5, +) -> RevealState: + support_mode = int(state["support_mode"]) + support_logits = np.full((num_modes,), -4.0, dtype=np.float32) + support_logits[support_mode] = 4.0 + + corridor = np.asarray(state["corridor_feasible"], dtype=np.float32) + if corridor.shape != (num_modes, num_templates): + raise ValueError( + f"Expected corridor_feasible shape {(num_modes, num_templates)}, got {corridor.shape}" + ) + corridor_logits = np.where(corridor > 0.5, 4.0, -4.0).astype(np.float32) + + persistence = np.asarray(state["persistence_horizon"], dtype=np.float32) + if persistence.shape != (num_modes,): + raise ValueError(f"Expected persistence_horizon shape {(num_modes,)}, got {persistence.shape}") + persistence = np.clip(persistence, 0.0, float(rollout_horizon)) + + disturbance = np.asarray([state["disturbance_cost"]], dtype=np.float32) + belief_map = state.get("belief_map") + if belief_map is not None: + belief_map = np.asarray(belief_map, dtype=np.float32) + + return RevealState( + support_mode_logits=support_logits, + corridor_logits=corridor_logits, + persistence_horizon=persistence, + disturbance_cost=disturbance, + belief_map=belief_map, + ) + + +def reocclusion_rate(corridor_open_history: np.ndarray) -> float: + corridor_open_history = np.asarray(corridor_open_history, dtype=np.float32) + if corridor_open_history.ndim != 1: + raise ValueError("corridor_open_history must be 1D.") + if corridor_open_history.size < 2: + return 0.0 + open_then_closed = np.logical_and(corridor_open_history[:-1] > 0.5, corridor_open_history[1:] <= 0.5) + return float(open_then_closed.mean()) + + +def infer_support_mode_from_flags(holding: bool, transferred: bool) -> SupportMode: + if holding: + return SupportMode.HOLD + if transferred: + return SupportMode.TRANSFER + return SupportMode.PASSIVE diff --git a/code/reveal_vla_bimanual/sim_reveal/procedural_envs.py b/code/reveal_vla_bimanual/sim_reveal/procedural_envs.py new file mode 100644 index 0000000000000000000000000000000000000000..5fce8d031e06478ce60ec47570137fd0b20ae21c --- /dev/null +++ b/code/reveal_vla_bimanual/sim_reveal/procedural_envs.py @@ -0,0 +1,545 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +import numpy as np + +from sim_reveal.base import RevealProxyConfig, SupportMode +from sim_reveal.proxy_specs import BAG_PROXY, CLOTH_PROXY, FOLIAGE_PROXY + + +@dataclass(frozen=True) +class ProxyDynamics: + hold_decay: float + transfer_decay: float + passive_decay: float + disturbance_gain: float + settle_rate: float + desired_opening: float + preferred_mode: SupportMode + transfer_support_factor: float + passive_support_factor: float + visibility_bias: float + retrieve_visibility_threshold: float + palette: tuple[float, float, float] + + +PROXY_CONFIGS: dict[str, RevealProxyConfig] = { + FOLIAGE_PROXY.name: FOLIAGE_PROXY, + BAG_PROXY.name: BAG_PROXY, + CLOTH_PROXY.name: CLOTH_PROXY, +} + +PROXY_DYNAMICS: dict[str, ProxyDynamics] = { + FOLIAGE_PROXY.name: ProxyDynamics( + hold_decay=0.02, + transfer_decay=0.07, + passive_decay=0.15, + disturbance_gain=0.06, + settle_rate=0.03, + desired_opening=0.60, + preferred_mode=SupportMode.HOLD, + transfer_support_factor=0.76, + passive_support_factor=0.42, + visibility_bias=0.03, + retrieve_visibility_threshold=0.42, + palette=(0.16, 0.30, 0.12), + ), + BAG_PROXY.name: ProxyDynamics( + hold_decay=0.04, + transfer_decay=0.03, + passive_decay=0.12, + disturbance_gain=0.05, + settle_rate=0.02, + desired_opening=0.68, + preferred_mode=SupportMode.TRANSFER, + transfer_support_factor=0.96, + passive_support_factor=0.55, + visibility_bias=0.06, + retrieve_visibility_threshold=0.48, + palette=(0.26, 0.17, 0.10), + ), + CLOTH_PROXY.name: ProxyDynamics( + hold_decay=0.03, + transfer_decay=0.05, + passive_decay=0.04, + disturbance_gain=0.04, + settle_rate=0.04, + desired_opening=0.50, + preferred_mode=SupportMode.PASSIVE, + transfer_support_factor=0.82, + passive_support_factor=0.90, + visibility_bias=0.08, + retrieve_visibility_threshold=0.38, + palette=(0.24, 0.24, 0.29), + ), +} + +PROXY_GOALS = { + FOLIAGE_PROXY.name: "create a gap in the foliage and retrieve the target", + BAG_PROXY.name: "open the bag mouth and retrieve the target object", + CLOTH_PROXY.name: "lift the top layer enough to retrieve the hidden object", +} + + +def available_proxy_names() -> tuple[str, ...]: + return tuple(PROXY_CONFIGS.keys()) + + +def make_proxy_env( + proxy_name: str, + resolution: int = 96, + seed: int = 0, + num_templates: int = 32, + rollout_horizon: int = 5, + max_steps: int | None = None, +) -> "ProceduralRevealEnv": + return ProceduralRevealEnv( + proxy_name=proxy_name, + resolution=resolution, + seed=seed, + num_templates=num_templates, + rollout_horizon=rollout_horizon, + max_steps=max_steps, + ) + + +class ProceduralRevealEnv: + camera_names = ("front", "wrist_left", "wrist_right") + + def __init__( + self, + proxy_name: str, + resolution: int = 96, + seed: int = 0, + num_templates: int = 32, + rollout_horizon: int = 5, + max_steps: int | None = None, + ) -> None: + if proxy_name not in PROXY_CONFIGS: + raise KeyError(f"Unknown proxy: {proxy_name}") + self.proxy = PROXY_CONFIGS[proxy_name] + self.dynamics = PROXY_DYNAMICS[proxy_name] + self.proxy_name = proxy_name + self.resolution = resolution + self.num_templates = num_templates + self.rollout_horizon = rollout_horizon + self.max_steps = max_steps or self.proxy.max_steps + self.rng = np.random.default_rng(seed) + self.reset(seed=seed) + + def clone_state(self) -> dict[str, Any]: + return { + "step_count": self.step_count, + "opening": self.opening, + "disturbance": self.disturbance, + "target_template": self.target_template, + "target_depth": self.target_depth, + "holding": self.holding, + "transferred": self.transferred, + "retrieved": self.retrieved, + "actor_progress": self.actor_progress, + "last_actor_template": self.last_actor_template, + "visibility_trace": list(self.visibility_trace), + "corridor_trace": list(self.corridor_trace), + } + + def restore_state(self, state: dict[str, Any]) -> None: + self.step_count = int(state["step_count"]) + self.opening = float(state["opening"]) + self.disturbance = float(state["disturbance"]) + self.target_template = int(state["target_template"]) + self.target_depth = float(state["target_depth"]) + self.holding = bool(state["holding"]) + self.transferred = bool(state["transferred"]) + self.retrieved = bool(state["retrieved"]) + self.actor_progress = float(state["actor_progress"]) + self.last_actor_template = int(state["last_actor_template"]) + self.visibility_trace = list(state["visibility_trace"]) + self.corridor_trace = list(state["corridor_trace"]) + + def reset(self, seed: int | None = None) -> tuple[dict[str, Any], dict[str, Any]]: + if seed is not None: + self.rng = np.random.default_rng(seed) + self.step_count = 0 + self.opening = float(self.rng.uniform(0.08, 0.22)) + self.disturbance = float(self.rng.uniform(0.02, 0.12)) + self.target_template = int(self.rng.integers(4, self.num_templates - 4)) + self.target_depth = float(self.rng.uniform(0.15, 0.45)) + self.holding = False + self.transferred = False + self.retrieved = False + self.actor_progress = 0.0 + self.last_actor_template = self.target_template + privileged_state = self.get_privileged_state() + self.visibility_trace = [float(privileged_state["visibility"])] + self.corridor_trace = [float(privileged_state["corridor_feasible"][privileged_state["support_mode"]].any())] + return self.get_observation(privileged_state), privileged_state + + def _normalized_template(self, template_index: int) -> float: + return (template_index / float(self.num_templates - 1)) * 2.0 - 1.0 + + def _current_support_mode(self) -> SupportMode: + if self.holding: + return SupportMode.HOLD + if self.transferred: + return SupportMode.TRANSFER + return SupportMode.PASSIVE + + def _mode_from_action(self, action: np.ndarray) -> SupportMode: + hold_score = (np.tanh(float(action[6])) + 1.0) * 0.5 + transfer_score = (np.tanh(float(action[1])) + 1.0) * 0.5 + passive_score = (np.tanh(float(action[2])) + 1.0) * 0.5 + if hold_score >= max(transfer_score, passive_score): + return SupportMode.HOLD + if transfer_score >= passive_score and self.opening >= 0.32: + return SupportMode.TRANSFER + return SupportMode.PASSIVE + + def _visibility(self, opening: float | None = None, disturbance: float | None = None) -> float: + opening = self.opening if opening is None else float(opening) + disturbance = self.disturbance if disturbance is None else float(disturbance) + visibility = ( + 1.35 * opening + - 0.58 * disturbance + - 0.25 * self.target_depth + + self.dynamics.visibility_bias + ) + return float(np.clip(visibility, 0.0, 1.0)) + + def _mode_factor(self, mode: SupportMode) -> float: + if mode == SupportMode.HOLD: + return 1.0 + if mode == SupportMode.TRANSFER: + return self.dynamics.transfer_support_factor + return self.dynamics.passive_support_factor + + def _mode_decay(self, mode: SupportMode) -> float: + if mode == SupportMode.HOLD: + return self.dynamics.hold_decay + if mode == SupportMode.TRANSFER: + return self.dynamics.transfer_decay + return self.dynamics.passive_decay + + def _corridor_for_mode( + self, + mode: SupportMode, + opening: float | None = None, + disturbance: float | None = None, + ) -> np.ndarray: + opening = self.opening if opening is None else float(opening) + disturbance = self.disturbance if disturbance is None else float(disturbance) + visibility = self._visibility(opening, disturbance) + effective = opening * self._mode_factor(mode) - 0.35 * disturbance - 0.18 * self.target_depth + width = int(np.floor(max(0.0, effective) * 8.0)) + corridor = np.zeros((self.num_templates,), dtype=np.float32) + if visibility < self.dynamics.retrieve_visibility_threshold * 0.7 or width <= 0: + return corridor + low = max(0, self.target_template - width) + high = min(self.num_templates, self.target_template + width + 1) + corridor[low:high] = 1.0 + return corridor + + def _persistence_for_mode(self, mode: SupportMode) -> float: + opening = self.opening + disturbance = self.disturbance + persisted = 0.0 + for _ in range(self.rollout_horizon): + if self._corridor_for_mode(mode, opening, disturbance).any(): + persisted += 1.0 + else: + break + opening = float(np.clip(opening - self._mode_decay(mode) + (0.035 if mode == SupportMode.HOLD else 0.0), 0.0, 1.0)) + disturbance = float(np.clip(disturbance * (1.0 - self.dynamics.settle_rate), 0.0, 1.0)) + return persisted + + def _belief_map(self, visibility: float) -> np.ndarray: + side = 32 + x = np.linspace(0.0, 1.0, side, dtype=np.float32) + y = np.linspace(0.0, 1.0, side, dtype=np.float32) + yy, xx = np.meshgrid(y, x, indexing="ij") + center_x = self.target_template / float(self.num_templates - 1) + center_y = 0.72 - 0.25 * self.target_depth + sigma = 0.08 + 0.05 * (1.0 - visibility) + belief = np.exp(-(((xx - center_x) ** 2) + ((yy - center_y) ** 2)) / (2.0 * sigma**2)) + belief *= visibility + return belief.astype(np.float32) + + def get_privileged_state(self) -> dict[str, Any]: + support_mode = int(self._current_support_mode()) + corridor = np.stack( + [self._corridor_for_mode(mode) for mode in SupportMode], + axis=0, + ) + persistence = np.asarray([self._persistence_for_mode(mode) for mode in SupportMode], dtype=np.float32) + visibility = self._visibility() + disturbance_cost = float(np.clip(self.disturbance + 0.08 * max(0.0, self.opening - self.dynamics.desired_opening), 0.0, 1.0)) + return { + "support_mode": support_mode, + "corridor_feasible": corridor, + "persistence_horizon": persistence, + "disturbance_cost": disturbance_cost, + "belief_map": self._belief_map(visibility), + "visibility": visibility, + "retrieval_success": bool(self.retrieved), + "target_template": self.target_template, + } + + def render_state(self, privileged_state: dict[str, Any] | None = None) -> dict[str, Any]: + privileged_state = privileged_state or self.get_privileged_state() + current_mode = int(privileged_state["support_mode"]) + return { + "opening": float(self.opening), + "disturbance": float(self.disturbance), + "target_template": int(self.target_template), + "support_mode": current_mode, + "visibility": float(privileged_state["visibility"]), + "actor_template": int(self.last_actor_template), + "actor_progress": float(self.actor_progress), + "corridor_current": privileged_state["corridor_feasible"][current_mode].astype(np.float32), + "step_fraction": float(self.step_count / max(1, self.max_steps)), + } + + def _proprio(self, privileged_state: dict[str, Any]) -> np.ndarray: + mode = privileged_state["support_mode"] + features = np.zeros((32,), dtype=np.float32) + features[0] = self.opening + features[1] = self.disturbance + features[2] = privileged_state["visibility"] + features[3 + mode] = 1.0 + features[6] = self.target_template / float(self.num_templates - 1) + features[7] = self.last_actor_template / float(self.num_templates - 1) + features[8] = self.step_count / float(max(1, self.max_steps)) + features[9:12] = privileged_state["persistence_horizon"] / float(self.rollout_horizon) + features[12] = float(privileged_state["corridor_feasible"][mode].any()) + features[13] = float(self.retrieved) + features[14] = self.actor_progress + return features + + def get_observation(self, privileged_state: dict[str, Any] | None = None) -> dict[str, Any]: + privileged_state = privileged_state or self.get_privileged_state() + render_state = self.render_state(privileged_state) + images = render_views_from_state( + proxy_name=self.proxy_name, + render_state=render_state, + resolution=self.resolution, + num_templates=self.num_templates, + ) + return { + "images": np.stack([images[camera] for camera in self.camera_names], axis=0), + "proprio": self._proprio(privileged_state), + "text": PROXY_GOALS[self.proxy_name], + "camera_names": self.camera_names, + "render_state": render_state, + } + + def teacher_action(self) -> np.ndarray: + privileged_state = self.get_privileged_state() + preferred_mode = self.dynamics.preferred_mode + if self.opening < self.dynamics.desired_opening: + chosen_mode = SupportMode.HOLD + open_cmd = 0.95 + elif privileged_state["persistence_horizon"][preferred_mode] >= 2.0: + chosen_mode = preferred_mode + open_cmd = 0.12 + else: + chosen_mode = SupportMode.HOLD + open_cmd = 0.30 + + corridor = privileged_state["corridor_feasible"][int(chosen_mode)] + actor_ready = bool(corridor[self.target_template] > 0.5) + retrieve = ( + actor_ready + and privileged_state["visibility"] >= self.dynamics.retrieve_visibility_threshold + and self.actor_progress >= 0.55 + ) + action = np.zeros((14,), dtype=np.float32) + action[0] = np.float32(open_cmd) + action[1] = np.float32(1.0 if chosen_mode == SupportMode.TRANSFER else -1.0) + action[2] = np.float32(1.0 if chosen_mode == SupportMode.PASSIVE else -1.0) + action[6] = np.float32(1.0 if chosen_mode == SupportMode.HOLD else -1.0) + action[7] = np.float32(self._normalized_template(self.target_template)) + action[8] = np.float32(1.0 if actor_ready else 0.2) + action[13] = np.float32(1.0 if retrieve else -1.0) + return action + + def teacher_chunk_and_rollout( + self, + chunk_horizon: int = 8, + rollout_horizon: int | None = None, + ) -> tuple[np.ndarray, dict[str, np.ndarray]]: + rollout_horizon = rollout_horizon or self.rollout_horizon + snapshot = self.clone_state() + action_chunk: list[np.ndarray] = [] + rollout_support_mode = [] + rollout_corridor = [] + rollout_persistence = [] + rollout_disturbance = [] + for step in range(chunk_horizon): + action = self.teacher_action() + action_chunk.append(action) + _, _, terminated, truncated, privileged_state = self.step(action) + if step < rollout_horizon: + rollout_support_mode.append(privileged_state["support_mode"]) + rollout_corridor.append(privileged_state["corridor_feasible"]) + rollout_persistence.append(privileged_state["persistence_horizon"]) + rollout_disturbance.append(privileged_state["disturbance_cost"]) + if terminated or truncated: + break + while len(action_chunk) < chunk_horizon: + action_chunk.append(np.zeros((14,), dtype=np.float32)) + while len(rollout_support_mode) < rollout_horizon: + rollout_support_mode.append(int(self._current_support_mode())) + rollout_corridor.append(self.get_privileged_state()["corridor_feasible"]) + rollout_persistence.append(self.get_privileged_state()["persistence_horizon"]) + rollout_disturbance.append(self.get_privileged_state()["disturbance_cost"]) + self.restore_state(snapshot) + return np.stack(action_chunk, axis=0).astype(np.float32), { + "rollout_support_mode": np.asarray(rollout_support_mode, dtype=np.int64), + "rollout_corridor_feasible": np.asarray(rollout_corridor, dtype=np.float32), + "rollout_persistence_horizon": np.asarray(rollout_persistence, dtype=np.float32), + "rollout_disturbance_cost": np.asarray(rollout_disturbance, dtype=np.float32), + } + + def step(self, action: np.ndarray) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]: + action = np.asarray(action, dtype=np.float32) + mode = self._mode_from_action(action) + self.holding = mode == SupportMode.HOLD + self.transferred = mode == SupportMode.TRANSFER + open_cmd = float(np.clip(action[0], -1.0, 1.0)) + actor_reach = float((np.tanh(float(action[8])) + 1.0) * 0.5) + retrieve_cmd = float((np.tanh(float(action[13])) + 1.0) * 0.5) + self.last_actor_template = int( + np.clip( + round(((float(np.clip(action[7], -1.0, 1.0)) + 1.0) * 0.5) * (self.num_templates - 1)), + 0, + self.num_templates - 1, + ) + ) + + support_bonus = {SupportMode.HOLD: 0.08, SupportMode.TRANSFER: 0.04, SupportMode.PASSIVE: 0.0}[mode] + closure = self._mode_decay(mode) + self.opening = float( + np.clip( + self.opening + 0.16 * open_cmd + support_bonus - closure - 0.05 * self.disturbance, + 0.0, + 1.0, + ) + ) + self.disturbance = float( + np.clip( + self.disturbance + + self.dynamics.disturbance_gain * abs(open_cmd) + + 0.025 * actor_reach + + 0.05 * max(0.0, self.opening - self.dynamics.desired_opening) + - self.dynamics.settle_rate, + 0.0, + 1.0, + ) + ) + + self.step_count += 1 + privileged_state = self.get_privileged_state() + corridor = privileged_state["corridor_feasible"][privileged_state["support_mode"]] + if corridor[self.last_actor_template] > 0.5 and actor_reach >= 0.55: + persistence_ratio = privileged_state["persistence_horizon"][privileged_state["support_mode"]] / float( + max(1, self.rollout_horizon) + ) + self.actor_progress = float(np.clip(self.actor_progress + 0.55 * persistence_ratio, 0.0, 1.0)) + shock = 0.16 * max(0.0, 0.8 - persistence_ratio) + if shock > 0.0: + self.opening = float(np.clip(self.opening - shock, 0.0, 1.0)) + privileged_state = self.get_privileged_state() + corridor = privileged_state["corridor_feasible"][privileged_state["support_mode"]] + else: + self.actor_progress = float(np.clip(self.actor_progress - 0.20, 0.0, 1.0)) + success = bool( + retrieve_cmd >= 0.55 + and self.actor_progress >= 0.80 + and corridor[self.last_actor_template] > 0.5 + and privileged_state["visibility"] >= self.dynamics.retrieve_visibility_threshold + and self.disturbance < 0.9 + ) + if success: + self.retrieved = True + privileged_state["retrieval_success"] = True + + self.visibility_trace.append(float(privileged_state["visibility"])) + self.corridor_trace.append(float(corridor.any())) + + reward = 1.0 if success else (0.08 * privileged_state["visibility"] - 0.03 * privileged_state["disturbance_cost"]) + terminated = bool(self.retrieved) + truncated = bool(self.step_count >= self.max_steps) + return self.get_observation(privileged_state), float(reward), terminated, truncated, privileged_state + + +def render_views_from_state( + proxy_name: str, + render_state: dict[str, Any], + resolution: int, + num_templates: int = 32, +) -> dict[str, np.ndarray]: + dynamics = PROXY_DYNAMICS[proxy_name] + opening = float(render_state["opening"]) + disturbance = float(render_state["disturbance"]) + target_template = int(render_state["target_template"]) + support_mode = int(render_state["support_mode"]) + visibility = float(render_state["visibility"]) + actor_template = int(render_state["actor_template"]) + actor_progress = float(render_state["actor_progress"]) + corridor_current = np.asarray(render_state["corridor_current"], dtype=np.float32) + step_fraction = float(render_state["step_fraction"]) + + height = width = resolution + base = np.ones((height, width, 3), dtype=np.float32) + base *= np.asarray(dynamics.palette, dtype=np.float32) + + x = np.linspace(0.0, 1.0, width, dtype=np.float32) + y = np.linspace(0.0, 1.0, height, dtype=np.float32) + yy, xx = np.meshgrid(y, x, indexing="ij") + center_x = target_template / float(max(1, num_templates - 1)) + gap_width = 0.04 + 0.18 * opening + gap_mask = np.abs(xx - center_x) <= gap_width + stripe_mask = (np.sin(xx * np.pi * 18.0) > 0.2).astype(np.float32) + + front = base.copy() + front[..., 1] += 0.22 * stripe_mask + front[..., 0] += 0.07 * stripe_mask + front[gap_mask, :] = np.clip(front[gap_mask, :] + np.asarray([0.18, 0.18, 0.18], dtype=np.float32), 0.0, 1.0) + target_mask = ((xx - center_x) ** 2 + (yy - 0.76) ** 2) <= (0.03 + 0.015 * visibility) ** 2 + front[target_mask, 0] = np.clip(front[target_mask, 0] + 0.55 * visibility, 0.0, 1.0) + front[target_mask, 1] *= 0.55 + front[..., 2] = np.clip(front[..., 2] + 0.18 * disturbance + 0.05 * step_fraction, 0.0, 1.0) + + wrist_left = np.full((height, width, 3), 0.12, dtype=np.float32) + open_rows = int(opening * height) + wrist_left[height - open_rows :, : width // 3, 1] = 0.75 + wrist_left[height - int(disturbance * height) :, width // 3 : (2 * width) // 3, 0] = 0.85 + mode_colors = { + SupportMode.HOLD: np.asarray([0.92, 0.82, 0.16], dtype=np.float32), + SupportMode.TRANSFER: np.asarray([0.16, 0.78, 0.92], dtype=np.float32), + SupportMode.PASSIVE: np.asarray([0.86, 0.86, 0.86], dtype=np.float32), + } + wrist_left[:, (2 * width) // 3 :, :] = mode_colors[SupportMode(support_mode)] + + wrist_right = np.full((height, width, 3), 0.08, dtype=np.float32) + template_edges = np.linspace(0, width, num_templates + 1, dtype=np.int32) + for template_idx in range(num_templates): + col_start = template_edges[template_idx] + col_end = template_edges[template_idx + 1] + if corridor_current[template_idx] > 0.5: + wrist_right[:, col_start:col_end, 1] = 0.70 + if template_idx == target_template: + wrist_right[:, col_start:col_end, 0] = 0.78 + if template_idx == actor_template: + wrist_right[:, col_start:col_end, 2] = 0.90 + wrist_right[: max(1, int(visibility * height)), :, :] += 0.10 + wrist_right[height - max(1, int(actor_progress * height)) :, :, 2] += 0.12 + wrist_right = np.clip(wrist_right, 0.0, 1.0) + + return { + "front": (front * 255.0).astype(np.uint8), + "wrist_left": (wrist_left * 255.0).astype(np.uint8), + "wrist_right": (wrist_right * 255.0).astype(np.uint8), + } diff --git a/code/reveal_vla_bimanual/sim_reveal/proxy_specs.py b/code/reveal_vla_bimanual/sim_reveal/proxy_specs.py new file mode 100644 index 0000000000000000000000000000000000000000..3bcf0e768b1ad0ff6bc64a9077d8af7c37f8e732 --- /dev/null +++ b/code/reveal_vla_bimanual/sim_reveal/proxy_specs.py @@ -0,0 +1,29 @@ +from sim_reveal.base import RevealProxyConfig + + +FOLIAGE_PROXY = RevealProxyConfig( + name="foliage_proxy", + disturbance_key="strip_strain", + metadata={ + "teacher": "pull lowest-cost strip cluster, retrieve when corridor exists", + "goal": "maintain a usable corridor with low strip strain and low collateral motion", + }, +) + +BAG_PROXY = RevealProxyConfig( + name="bag_proxy", + disturbance_key="aperture_collapse_and_non_target_motion", + metadata={ + "teacher": "expand mouth with two contacts, retrieve after aperture threshold", + "goal": "open the mouth enough for actor entry while minimizing collapse after release", + }, +) + +CLOTH_PROXY = RevealProxyConfig( + name="cloth_proxy", + disturbance_key="fold_line_deviation", + metadata={ + "teacher": "lift top layer minimally, retrieve once target is exposed", + "goal": "expose the target while penalizing unnecessary fold disruption", + }, +) diff --git a/code/reveal_vla_bimanual/sim_reveal/teachers.py b/code/reveal_vla_bimanual/sim_reveal/teachers.py new file mode 100644 index 0000000000000000000000000000000000000000..3f85af25891c8468e00f853305ac24a0a9f79e4c --- /dev/null +++ b/code/reveal_vla_bimanual/sim_reveal/teachers.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +import numpy as np + + +@dataclass +class TeacherAction: + revealer_action: np.ndarray + actor_action: np.ndarray + + +def foliage_teacher_step(privileged_state: dict[str, Any]) -> TeacherAction: + target_cluster = privileged_state.get("lowest_cost_strip_direction", [0.0, 0.0, 1.0]) + actor_ready = bool(privileged_state.get("corridor_exists", False)) + revealer = np.asarray(target_cluster, dtype=np.float32) + actor = np.asarray(privileged_state.get("retrieve_direction", [0.0, 0.0, 0.0]), dtype=np.float32) + if not actor_ready: + actor = np.zeros_like(actor) + return TeacherAction(revealer_action=revealer, actor_action=actor) + + +def bag_teacher_step(privileged_state: dict[str, Any]) -> TeacherAction: + contact_a = np.asarray(privileged_state.get("expand_contact_a", [1.0, 0.0, 0.0]), dtype=np.float32) + contact_b = np.asarray(privileged_state.get("expand_contact_b", [-1.0, 0.0, 0.0]), dtype=np.float32) + aperture_ready = float(privileged_state.get("aperture", 0.0)) >= float(privileged_state.get("aperture_threshold", 1.0)) + actor = np.asarray(privileged_state.get("retrieve_direction", [0.0, 0.0, 0.0]), dtype=np.float32) + if not aperture_ready: + actor = np.zeros_like(actor) + return TeacherAction(revealer_action=np.concatenate([contact_a, contact_b]), actor_action=actor) + + +def cloth_teacher_step(privileged_state: dict[str, Any]) -> TeacherAction: + lift = np.asarray(privileged_state.get("minimal_lift_direction", [0.0, 0.0, 1.0]), dtype=np.float32) + actor_ready = bool(privileged_state.get("target_exposed", False)) + actor = np.asarray(privileged_state.get("retrieve_direction", [0.0, 0.0, 0.0]), dtype=np.float32) + if not actor_ready: + actor = np.zeros_like(actor) + return TeacherAction(revealer_action=lift, actor_action=actor) diff --git a/code/reveal_vla_bimanual/sim_rlbench/__init__.py b/code/reveal_vla_bimanual/sim_rlbench/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..31bbcc1b0a491931b44f618dc01b676528ef9b43 --- /dev/null +++ b/code/reveal_vla_bimanual/sim_rlbench/__init__.py @@ -0,0 +1,9 @@ +from sim_rlbench.camera_spec import RLBenchThreeCameraSpec, default_three_camera_spec +from sim_rlbench.obs_adapter import CanonicalBimanualObservation, extract_canonical_bimanual_obs + +__all__ = [ + "CanonicalBimanualObservation", + "RLBenchThreeCameraSpec", + "default_three_camera_spec", + "extract_canonical_bimanual_obs", +] diff --git a/code/reveal_vla_bimanual/sim_rlbench/__pycache__/__init__.cpython-310.pyc b/code/reveal_vla_bimanual/sim_rlbench/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44fc04bbcaea185d87e82560f428275ff4825478 Binary files /dev/null and b/code/reveal_vla_bimanual/sim_rlbench/__pycache__/__init__.cpython-310.pyc differ diff --git a/code/reveal_vla_bimanual/sim_rlbench/__pycache__/__init__.cpython-311.pyc b/code/reveal_vla_bimanual/sim_rlbench/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f05298afd93dc6e1d06f291dc27c5e06da5a3c14 Binary files /dev/null and b/code/reveal_vla_bimanual/sim_rlbench/__pycache__/__init__.cpython-311.pyc differ diff --git a/code/reveal_vla_bimanual/sim_rlbench/__pycache__/camera_spec.cpython-310.pyc b/code/reveal_vla_bimanual/sim_rlbench/__pycache__/camera_spec.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e9f011ec0170802f9166532fb3ce75124fe78e7 Binary files /dev/null and b/code/reveal_vla_bimanual/sim_rlbench/__pycache__/camera_spec.cpython-310.pyc differ diff --git a/code/reveal_vla_bimanual/sim_rlbench/__pycache__/camera_spec.cpython-311.pyc b/code/reveal_vla_bimanual/sim_rlbench/__pycache__/camera_spec.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a02192e47678b9520bbd126abe5d0d0524f7697c Binary files /dev/null and b/code/reveal_vla_bimanual/sim_rlbench/__pycache__/camera_spec.cpython-311.pyc differ diff --git a/code/reveal_vla_bimanual/sim_rlbench/__pycache__/dataset_download.cpython-310.pyc b/code/reveal_vla_bimanual/sim_rlbench/__pycache__/dataset_download.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a043b4bed04d4ece7f614026734f699fe924d02a Binary files /dev/null and b/code/reveal_vla_bimanual/sim_rlbench/__pycache__/dataset_download.cpython-310.pyc differ diff --git a/code/reveal_vla_bimanual/sim_rlbench/__pycache__/dataset_download.cpython-311.pyc b/code/reveal_vla_bimanual/sim_rlbench/__pycache__/dataset_download.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba38b8d6f0073b5552805b4c5743c34586f3e3ff Binary files /dev/null and b/code/reveal_vla_bimanual/sim_rlbench/__pycache__/dataset_download.cpython-311.pyc differ diff --git a/code/reveal_vla_bimanual/sim_rlbench/__pycache__/generate_smoke_dataset.cpython-311.pyc b/code/reveal_vla_bimanual/sim_rlbench/__pycache__/generate_smoke_dataset.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3259191e2dcb432b9ab84d0d3b4067bdabbc25b Binary files /dev/null and b/code/reveal_vla_bimanual/sim_rlbench/__pycache__/generate_smoke_dataset.cpython-311.pyc differ diff --git a/code/reveal_vla_bimanual/sim_rlbench/__pycache__/launch_smoke.cpython-310.pyc b/code/reveal_vla_bimanual/sim_rlbench/__pycache__/launch_smoke.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38145fdbd0dca601342a3b17aec83334f52a2c50 Binary files /dev/null and b/code/reveal_vla_bimanual/sim_rlbench/__pycache__/launch_smoke.cpython-310.pyc differ diff --git a/code/reveal_vla_bimanual/sim_rlbench/__pycache__/launch_smoke.cpython-311.pyc b/code/reveal_vla_bimanual/sim_rlbench/__pycache__/launch_smoke.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..881e2c8a479e899f86612e7922157b350e38fb33 Binary files /dev/null and b/code/reveal_vla_bimanual/sim_rlbench/__pycache__/launch_smoke.cpython-311.pyc differ diff --git a/code/reveal_vla_bimanual/sim_rlbench/__pycache__/obs_adapter.cpython-310.pyc b/code/reveal_vla_bimanual/sim_rlbench/__pycache__/obs_adapter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9496f6b4a672f7db221602562b7114852124bf3a Binary files /dev/null and b/code/reveal_vla_bimanual/sim_rlbench/__pycache__/obs_adapter.cpython-310.pyc differ diff --git a/code/reveal_vla_bimanual/sim_rlbench/__pycache__/obs_adapter.cpython-311.pyc b/code/reveal_vla_bimanual/sim_rlbench/__pycache__/obs_adapter.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2346fe07fd1651f972a3fbb9caec65133673224 Binary files /dev/null and b/code/reveal_vla_bimanual/sim_rlbench/__pycache__/obs_adapter.cpython-311.pyc differ diff --git a/code/reveal_vla_bimanual/sim_rlbench/__pycache__/peract2_runner.cpython-310.pyc b/code/reveal_vla_bimanual/sim_rlbench/__pycache__/peract2_runner.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7331730f7cd114de76eed00adcd32b3a672c42a Binary files /dev/null and b/code/reveal_vla_bimanual/sim_rlbench/__pycache__/peract2_runner.cpython-310.pyc differ diff --git a/code/reveal_vla_bimanual/sim_rlbench/__pycache__/peract2_runner.cpython-311.pyc b/code/reveal_vla_bimanual/sim_rlbench/__pycache__/peract2_runner.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4da3a4842ab350320cf289d4a9a83e1699ee4a2 Binary files /dev/null and b/code/reveal_vla_bimanual/sim_rlbench/__pycache__/peract2_runner.cpython-311.pyc differ diff --git a/code/reveal_vla_bimanual/sim_rlbench/__pycache__/smoke_test.cpython-310.pyc b/code/reveal_vla_bimanual/sim_rlbench/__pycache__/smoke_test.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..884bea0ac5a0d0f202df5a404d02218e5094710b Binary files /dev/null and b/code/reveal_vla_bimanual/sim_rlbench/__pycache__/smoke_test.cpython-310.pyc differ diff --git a/code/reveal_vla_bimanual/sim_rlbench/__pycache__/smoke_test.cpython-311.pyc b/code/reveal_vla_bimanual/sim_rlbench/__pycache__/smoke_test.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b7bca5c920afdb744c7edd29183c8fe05b576a3 Binary files /dev/null and b/code/reveal_vla_bimanual/sim_rlbench/__pycache__/smoke_test.cpython-311.pyc differ diff --git a/code/reveal_vla_bimanual/sim_rlbench/__pycache__/task_splits.cpython-310.pyc b/code/reveal_vla_bimanual/sim_rlbench/__pycache__/task_splits.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..191b2138ced2921e51f98b15c75b17bcabac21d0 Binary files /dev/null and b/code/reveal_vla_bimanual/sim_rlbench/__pycache__/task_splits.cpython-310.pyc differ diff --git a/code/reveal_vla_bimanual/sim_rlbench/__pycache__/task_splits.cpython-311.pyc b/code/reveal_vla_bimanual/sim_rlbench/__pycache__/task_splits.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..edb8ce379b02c5e26e59574c5fdf46506b879bae Binary files /dev/null and b/code/reveal_vla_bimanual/sim_rlbench/__pycache__/task_splits.cpython-311.pyc differ diff --git a/code/reveal_vla_bimanual/sim_rlbench/camera_spec.py b/code/reveal_vla_bimanual/sim_rlbench/camera_spec.py new file mode 100644 index 0000000000000000000000000000000000000000..626f8647116b5554d30fbed0ce01298a85b64db8 --- /dev/null +++ b/code/reveal_vla_bimanual/sim_rlbench/camera_spec.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class RLBenchThreeCameraSpec: + cameras: tuple[str, str, str] = ("front", "wrist_left", "wrist_right") + resolution: tuple[int, int] = (224, 224) + + @property + def global_camera(self) -> str: + return self.cameras[0] + + @property + def wrist_cameras(self) -> tuple[str, str]: + return self.cameras[1], self.cameras[2] + + def hydra_overrides(self, prefix: str = "rlbench") -> list[str]: + camera_list = ",".join(self.cameras) + height, width = self.resolution + return [ + f"{prefix}.cameras=[{camera_list}]", + f"{prefix}.camera_resolution=[{height},{width}]", + ] + + +def default_three_camera_spec(resolution: int = 224) -> RLBenchThreeCameraSpec: + return RLBenchThreeCameraSpec(resolution=(resolution, resolution)) diff --git a/code/reveal_vla_bimanual/sim_rlbench/dataset_download.py b/code/reveal_vla_bimanual/sim_rlbench/dataset_download.py new file mode 100644 index 0000000000000000000000000000000000000000..b37a5bfbd51ab8431f0e3185ab3fd0181f054559 --- /dev/null +++ b/code/reveal_vla_bimanual/sim_rlbench/dataset_download.py @@ -0,0 +1,131 @@ +from __future__ import annotations + +import argparse +import hashlib +import subprocess +import sys +from pathlib import Path +from urllib.request import urlopen + +from sim_rlbench.task_splits import PERACT2_BIMANUAL_TASKS + + +def _sha256(path: Path, chunk_size: int = 8 * 1024 * 1024) -> str: + digest = hashlib.sha256() + with path.open("rb") as handle: + while True: + chunk = handle.read(chunk_size) + if not chunk: + break + digest.update(chunk) + return digest.hexdigest() + + +def _download_checksum_table(base_url: str) -> dict[str, str]: + with urlopen(f"{base_url}/SHA256SUM") as response: + lines = response.read().decode("utf-8").splitlines() + + checksums: dict[str, str] = {} + for line in lines: + line = line.strip() + if not line: + continue + checksum, filename = line.split(maxsplit=1) + checksums[filename] = checksum + return checksums + + +def _download_file(url: str, destination: Path) -> None: + destination.parent.mkdir(parents=True, exist_ok=True) + subprocess.run( + ["curl", "-L", "-C", "-", "--fail", "-o", str(destination), url], + check=True, + ) + + +def _ensure_unsquashfs() -> None: + if subprocess.run(["which", "unsquashfs"], check=False, capture_output=True).returncode == 0: + return + raise RuntimeError( + "unsquashfs is required for extraction. Install squashfs-tools first." + ) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--resolution", type=int, default=256, choices=(128, 256)) + parser.add_argument("--splits", nargs="+", default=["train"]) + parser.add_argument("--tasks", nargs="+", default=list(PERACT2_BIMANUAL_TASKS)) + parser.add_argument("--output-root", default="/workspace/data/rlbench2") + parser.add_argument("--extract", action="store_true") + parser.add_argument("--keep-archives", action="store_true") + parser.add_argument("--dry-run", action="store_true") + args = parser.parse_args() + + output_root = Path(args.output_root) + archive_root = output_root / "_archives" / f"image_size_{args.resolution}" + extract_root = output_root + base_url = f"https://dataset.cs.washington.edu/fox/bimanual/image_size_{args.resolution}" + + checksums = _download_checksum_table(base_url) + requested_files = [ + f"{task}.{split}.squashfs" + for task in args.tasks + for split in args.splits + ] + + missing = [filename for filename in requested_files if filename not in checksums] + if missing: + raise RuntimeError(f"Missing checksum entries for: {missing}") + + if args.extract: + _ensure_unsquashfs() + + for filename in requested_files: + archive_path = archive_root / filename + expected_sha = checksums[filename] + url = f"{base_url}/{filename}" + + print(f"[plan] {filename}", flush=True) + print(f" url={url}", flush=True) + print(f" archive={archive_path}", flush=True) + if args.extract: + print(f" extract_root={extract_root}", flush=True) + + if args.dry_run: + continue + + needs_download = True + if archive_path.exists(): + current_sha = _sha256(archive_path) + if current_sha == expected_sha: + needs_download = False + print(f"[skip] checksum ok for {filename}", flush=True) + else: + print(f"[redo] checksum mismatch for {filename}", flush=True) + + if needs_download: + _download_file(url, archive_path) + current_sha = _sha256(archive_path) + if current_sha != expected_sha: + raise RuntimeError(f"Checksum mismatch after download: {filename}") + print(f"[done] downloaded {filename}", flush=True) + + if args.extract: + subprocess.run( + ["unsquashfs", "-f", "-d", str(extract_root), str(archive_path)], + check=True, + ) + print(f"[done] extracted {filename}", flush=True) + if not args.keep_archives: + archive_path.unlink() + print(f"[done] removed archive {filename}", flush=True) + + print("[done] dataset stage complete", flush=True) + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + sys.exit(130) diff --git a/code/reveal_vla_bimanual/sim_rlbench/generate_smoke_dataset.py b/code/reveal_vla_bimanual/sim_rlbench/generate_smoke_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..097505273fd630cbf94455763f2743cdf323f3c0 --- /dev/null +++ b/code/reveal_vla_bimanual/sim_rlbench/generate_smoke_dataset.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +import argparse +import pickle +from pathlib import Path + +import numpy as np +from PIL import Image +from pyrep.const import RenderMode +from rlbench.action_modes.action_mode import BimanualMoveArmThenGripper +from rlbench.action_modes.arm_action_modes import BimanualJointPosition +from rlbench.action_modes.gripper_action_modes import BimanualDiscrete +from rlbench.backend.const import ( + DEPTH_SCALE, + EPISODE_FOLDER, + EPISODES_FOLDER, + LOW_DIM_PICKLE, + VARIATION_DESCRIPTIONS, + VARIATION_NUMBER, + VARIATIONS_ALL_FOLDER, +) +from rlbench.backend.utils import float_array_to_rgb_image, task_file_to_task_class +from rlbench.environment import Environment +from rlbench.observation_config import CameraConfig, ObservationConfig + +from sim_rlbench.camera_spec import default_three_camera_spec + + +def _save_demo(demo, episode_path: Path, cameras: list[str]) -> None: + data_types = ("rgb", "depth", "mask") + for obs_idx, obs in enumerate(demo): + for camera_name in cameras: + for dtype in data_types: + output_dir = episode_path / f"{camera_name}_{dtype}" + output_dir.mkdir(parents=True, exist_ok=True) + payload = obs.perception_data.get(f"{camera_name}_{dtype}") + if payload is None: + continue + if dtype == "rgb": + image = Image.fromarray(payload) + elif dtype == "depth": + image = float_array_to_rgb_image(payload, scale_factor=DEPTH_SCALE) + elif dtype == "mask": + image = Image.fromarray((payload * 255).astype(np.uint8)) + else: + raise ValueError(dtype) + image.save(output_dir / f"{dtype}_{obs_idx:04d}.png") + obs.perception_data.clear() + + with (episode_path / LOW_DIM_PICKLE).open("wb") as handle: + pickle.dump(demo, handle) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--task", default="bimanual_lift_ball") + parser.add_argument("--episodes", type=int, default=1) + parser.add_argument("--resolution", type=int, default=224) + parser.add_argument("--output-root", default="/workspace/data/rlbench2_smoke") + args = parser.parse_args() + + spec = default_three_camera_spec(args.resolution) + camera_config = CameraConfig( + rgb=True, + depth=True, + point_cloud=False, + mask=True, + image_size=list(spec.resolution), + render_mode=RenderMode.OPENGL, + masks_as_one_channel=False, + depth_in_meters=False, + ) + obs_config = ObservationConfig( + camera_configs={camera_name: camera_config for camera_name in spec.cameras}, + joint_forces=False, + joint_positions=True, + joint_velocities=True, + task_low_dim_state=False, + gripper_touch_forces=False, + gripper_pose=True, + gripper_open=True, + gripper_matrix=True, + gripper_joint_positions=True, + robot_name="bimanual", + ) + + task_class = task_file_to_task_class(args.task, bimanual=True) + env = Environment( + action_mode=BimanualMoveArmThenGripper(BimanualJointPosition(), BimanualDiscrete()), + obs_config=obs_config, + robot_setup="dual_panda", + headless=True, + ) + output_root = Path(args.output_root) + episodes_root = output_root / args.task / VARIATIONS_ALL_FOLDER / EPISODES_FOLDER + episodes_root.mkdir(parents=True, exist_ok=True) + + try: + env.launch() + task_env = env.get_task(task_class) + variation_count = task_env.variation_count() + rng = np.random.default_rng(0) + + for episode_idx in range(args.episodes): + task_env = env.get_task(task_class) + variation = int(rng.integers(variation_count)) + task_env.set_variation(variation) + descriptions, _ = task_env.reset() + (demo,) = task_env.get_demos(amount=1, live_demos=True) + episode_path = episodes_root / (EPISODE_FOLDER % episode_idx) + episode_path.mkdir(parents=True, exist_ok=True) + _save_demo(demo, episode_path, list(spec.cameras)) + with (episode_path / VARIATION_NUMBER).open("wb") as handle: + pickle.dump(variation, handle) + with (episode_path / VARIATION_DESCRIPTIONS).open("wb") as handle: + pickle.dump(descriptions, handle) + print( + f"[done] wrote {args.task} episode {episode_idx} variation {variation} to {episode_path}", + flush=True, + ) + finally: + env.shutdown() + + +if __name__ == "__main__": + main() diff --git a/code/reveal_vla_bimanual/sim_rlbench/obs_adapter.py b/code/reveal_vla_bimanual/sim_rlbench/obs_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..9342089c10f2082fce71acafeab159e60e2b73ad --- /dev/null +++ b/code/reveal_vla_bimanual/sim_rlbench/obs_adapter.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +import numpy as np + +from sim_rlbench.camera_spec import RLBenchThreeCameraSpec + + +@dataclass +class CanonicalBimanualObservation: + rgb: dict[str, np.ndarray] + proprio: np.ndarray + language_goal: str + camera_intrinsics: dict[str, np.ndarray] + camera_extrinsics: dict[str, np.ndarray] + point_cloud: dict[str, np.ndarray] | None = None + + def as_batch_dict(self) -> dict[str, Any]: + batch = { + "rgb": self.rgb, + "proprio": self.proprio.astype(np.float32), + "language_goal": self.language_goal, + "camera_intrinsics": self.camera_intrinsics, + "camera_extrinsics": self.camera_extrinsics, + } + if self.point_cloud is not None: + batch["point_cloud"] = self.point_cloud + return batch + + +def _camera_rgb(obs: Any, camera_name: str) -> np.ndarray: + value = obs.perception_data[f"{camera_name}_rgb"] + return np.asarray(value, dtype=np.uint8) + + +def _camera_point_cloud(obs: Any, camera_name: str) -> np.ndarray: + value = obs.perception_data[f"{camera_name}_point_cloud"] + return np.asarray(value, dtype=np.float32) + + +def _bimanual_proprio(obs: Any, timestep: int | None = None, episode_length: int | None = None) -> np.ndarray: + right = np.asarray(obs.get_low_dim_data(obs.right), dtype=np.float32) + left = np.asarray(obs.get_low_dim_data(obs.left), dtype=np.float32) + proprio = np.concatenate([right, left], axis=0) + if timestep is not None and episode_length and episode_length > 1: + time_feature = np.array( + [(1.0 - (timestep / float(episode_length - 1))) * 2.0 - 1.0], + dtype=np.float32, + ) + proprio = np.concatenate([proprio, time_feature], axis=0) + return proprio + + +def extract_canonical_bimanual_obs( + obs: Any, + language_goal: str, + camera_spec: RLBenchThreeCameraSpec | None = None, + include_point_cloud: bool = False, + timestep: int | None = None, + episode_length: int | None = None, +) -> CanonicalBimanualObservation: + camera_spec = camera_spec or RLBenchThreeCameraSpec() + rgb = {camera: _camera_rgb(obs, camera) for camera in camera_spec.cameras} + intrinsics = { + camera: np.asarray(obs.misc[f"{camera}_camera_intrinsics"], dtype=np.float32) + for camera in camera_spec.cameras + } + extrinsics = { + camera: np.asarray(obs.misc[f"{camera}_camera_extrinsics"], dtype=np.float32) + for camera in camera_spec.cameras + } + point_cloud = None + if include_point_cloud: + point_cloud = { + camera: _camera_point_cloud(obs, camera) for camera in camera_spec.cameras + } + return CanonicalBimanualObservation( + rgb=rgb, + proprio=_bimanual_proprio(obs, timestep=timestep, episode_length=episode_length), + language_goal=language_goal, + camera_intrinsics=intrinsics, + camera_extrinsics=extrinsics, + point_cloud=point_cloud, + ) diff --git a/code/reveal_vla_bimanual/sim_rlbench/peract2_runner.py b/code/reveal_vla_bimanual/sim_rlbench/peract2_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..b7545a0d5cd77722581489beeb0e86cb042b570e --- /dev/null +++ b/code/reveal_vla_bimanual/sim_rlbench/peract2_runner.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +import os +import subprocess +import sys +from dataclasses import dataclass, field +from pathlib import Path + +from sim_rlbench.camera_spec import RLBenchThreeCameraSpec, default_three_camera_spec +from sim_rlbench.task_splits import PERACT2_BIMANUAL_TASKS + + +def _hydra_list(values: tuple[str, ...] | list[str]) -> str: + return "[" + ",".join(values) + "]" + + +def _default_nvidia_shim_root() -> Path | None: + try: + output = subprocess.check_output( + ["nvidia-smi", "--query-gpu=driver_version", "--format=csv,noheader"], + text=True, + ).strip() + except (FileNotFoundError, subprocess.SubprocessError): + return None + if not output: + return None + major = output.split(".", maxsplit=1)[0] + candidate = Path(f"/workspace/system_shims/nvidia{major}") + return candidate if candidate.exists() else None + + +@dataclass +class BenchmarkRunSpec: + upstream_root: Path = Path("/workspace/third_party/peract_bimanual") + demo_path: Path = Path("/workspace/data/rlbench2") + replay_path: Path = Path("/workspace/replays/rlbench2") + logdir: Path = Path("/workspace/logs/rlbench2") + method: str = "BIMANUAL_PERACT" + tasks: tuple[str, ...] = field(default_factory=lambda: PERACT2_BIMANUAL_TASKS) + demos: int = 100 + training_iterations: int = 40000 + seed: int = 0 + gpu: int = 0 + display: str = ":99" + coppeliasim_root: Path = Path("/workspace/assets/coppeliasim_v4_1_0") + camera_spec: RLBenchThreeCameraSpec = field(default_factory=default_three_camera_spec) + + def common_overrides(self) -> list[str]: + task_name = "multi_3cam" if len(self.tasks) > 1 else self.tasks[0] + overrides = [ + f"method={self.method}", + f"rlbench.task_name={task_name}", + f"rlbench.tasks={_hydra_list(list(self.tasks))}", + f"rlbench.demos={self.demos}", + f"rlbench.demo_path={self.demo_path}", + f"replay.path={self.replay_path}", + f"framework.logdir={self.logdir}", + f"framework.training_iterations={self.training_iterations}", + f"framework.gpu={self.gpu}", + f"framework.env_gpu={self.gpu}", + f"framework.start_seed={self.seed}", + "ddp.num_devices=1", + ] + overrides.extend(self.camera_spec.hydra_overrides()) + return overrides + + def train_command(self, python_executable: str | None = None) -> list[str]: + python_executable = python_executable or sys.executable + return [python_executable, "train.py", *self.common_overrides()] + + def eval_command( + self, + checkpoint_root: Path, + episodes: int = 25, + save_videos: bool = False, + python_executable: str | None = None, + ) -> list[str]: + python_executable = python_executable or sys.executable + return [ + python_executable, + "eval.py", + f"method={self.method}", + f"rlbench.demo_path={self.demo_path}", + f"framework.logdir={checkpoint_root}", + f"eval_episodes={episodes}", + f"cinematic_recorder.enabled={str(save_videos)}", + *self.camera_spec.hydra_overrides(), + ] + + def env(self) -> dict[str, str]: + env = os.environ.copy() + env.setdefault("PYTHONUNBUFFERED", "1") + env.setdefault("DISPLAY", self.display) + env.setdefault("QT_QPA_PLATFORM", "xcb") + if self.coppeliasim_root.exists(): + env.setdefault("COPPELIASIM_ROOT", str(self.coppeliasim_root)) + env.setdefault("QT_QPA_PLATFORM_PLUGIN_PATH", str(self.coppeliasim_root)) + + ld_paths: list[str] = [] + shim_root = _default_nvidia_shim_root() + if shim_root is not None: + ld_paths.extend( + [ + str(shim_root / "usr/lib/x86_64-linux-gnu"), + str(shim_root / "usr/lib/x86_64-linux-gnu/nvidia"), + ] + ) + if self.coppeliasim_root.exists(): + ld_paths.append(str(self.coppeliasim_root)) + existing_ld = env.get("LD_LIBRARY_PATH") + if existing_ld: + ld_paths.append(existing_ld) + if ld_paths: + env["LD_LIBRARY_PATH"] = ":".join(ld_paths) + return env + + def run_train(self) -> subprocess.CompletedProcess[bytes]: + return subprocess.run( + self.train_command(), + cwd=self.upstream_root, + env=self.env(), + check=True, + ) diff --git a/code/reveal_vla_bimanual/sim_rlbench/smoke_test.py b/code/reveal_vla_bimanual/sim_rlbench/smoke_test.py new file mode 100644 index 0000000000000000000000000000000000000000..fd57a836889cc022c9785c9b264fb9ba8397a770 --- /dev/null +++ b/code/reveal_vla_bimanual/sim_rlbench/smoke_test.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import argparse +import json +from pathlib import Path + +from sim_rlbench.camera_spec import default_three_camera_spec +from sim_rlbench.peract2_runner import BenchmarkRunSpec + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--demo-path", default="/workspace/data/rlbench2") + parser.add_argument("--upstream-root", default="/workspace/third_party/peract_bimanual") + parser.add_argument("--print-train-command", action="store_true") + args = parser.parse_args() + + spec = default_three_camera_spec(224) + payload = { + "camera_names": list(spec.cameras), + "resolution": list(spec.resolution), + "global_camera": spec.global_camera, + } + + import_status = {} + for module_name in ("rlbench", "pyrep", "yarr"): + try: + __import__(module_name) + import_status[module_name] = "ok" + except Exception as exc: # pragma: no cover - smoke path + import_status[module_name] = f"error: {exc}" + + payload["imports"] = import_status + print(json.dumps(payload, indent=2)) + + if args.print_train_command: + run_spec = BenchmarkRunSpec( + upstream_root=Path(args.upstream_root), + demo_path=Path(args.demo_path), + camera_spec=spec, + ) + print(" ".join(run_spec.train_command())) + + +if __name__ == "__main__": + main() diff --git a/code/reveal_vla_bimanual/sim_rlbench/task_splits.py b/code/reveal_vla_bimanual/sim_rlbench/task_splits.py new file mode 100644 index 0000000000000000000000000000000000000000..4526c494da939599862ca114cef53e176b925d46 --- /dev/null +++ b/code/reveal_vla_bimanual/sim_rlbench/task_splits.py @@ -0,0 +1,18 @@ +PERACT2_BIMANUAL_TASKS: tuple[str, ...] = ( + "bimanual_push_box", + "bimanual_lift_ball", + "bimanual_dual_push_buttons", + "bimanual_pick_plate", + "bimanual_put_item_in_drawer", + "bimanual_put_bottle_in_fridge", + "bimanual_handover_item", + "bimanual_pick_laptop", + "bimanual_straighten_rope", + "bimanual_sweep_to_dustpan", + "bimanual_lift_tray", + "bimanual_handover_item_easy", + "bimanual_take_tray_out_of_oven", +) + + +ANYBIMANUAL_OPTIONAL_TASKS: tuple[str, ...] = PERACT2_BIMANUAL_TASKS diff --git a/code/reveal_vla_bimanual/train/__pycache__/__init__.cpython-310.pyc b/code/reveal_vla_bimanual/train/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1524e0379a5c8e15813a3d873dc9005496091d0c Binary files /dev/null and b/code/reveal_vla_bimanual/train/__pycache__/__init__.cpython-310.pyc differ diff --git a/code/reveal_vla_bimanual/train/__pycache__/__init__.cpython-311.pyc b/code/reveal_vla_bimanual/train/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf6eba138d570e48199427459481d843ae3035df Binary files /dev/null and b/code/reveal_vla_bimanual/train/__pycache__/__init__.cpython-311.pyc differ diff --git a/code/reveal_vla_bimanual/train/__pycache__/losses.cpython-310.pyc b/code/reveal_vla_bimanual/train/__pycache__/losses.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa72fc9c325967ef3d797dab673fd30e412bbc34 Binary files /dev/null and b/code/reveal_vla_bimanual/train/__pycache__/losses.cpython-310.pyc differ diff --git a/code/reveal_vla_bimanual/train/__pycache__/losses.cpython-311.pyc b/code/reveal_vla_bimanual/train/__pycache__/losses.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b3511e6d1b2c3c7eb94a5adb0c103c560f61304 Binary files /dev/null and b/code/reveal_vla_bimanual/train/__pycache__/losses.cpython-311.pyc differ diff --git a/code/reveal_vla_bimanual/train/__pycache__/run_experiment.cpython-310.pyc b/code/reveal_vla_bimanual/train/__pycache__/run_experiment.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fb9c9974821c826bfe9ff9414b91d821f25bc04 Binary files /dev/null and b/code/reveal_vla_bimanual/train/__pycache__/run_experiment.cpython-310.pyc differ diff --git a/code/reveal_vla_bimanual/train/__pycache__/run_experiment.cpython-311.pyc b/code/reveal_vla_bimanual/train/__pycache__/run_experiment.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d23fc1661217dce4f1c55eb18c351ba8745c2338 Binary files /dev/null and b/code/reveal_vla_bimanual/train/__pycache__/run_experiment.cpython-311.pyc differ diff --git a/code/reveal_vla_bimanual/train/__pycache__/run_rlbench_experiment.cpython-310.pyc b/code/reveal_vla_bimanual/train/__pycache__/run_rlbench_experiment.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee6c4541669fe741e3fd82f315136ded2aca5405 Binary files /dev/null and b/code/reveal_vla_bimanual/train/__pycache__/run_rlbench_experiment.cpython-310.pyc differ diff --git a/code/reveal_vla_bimanual/train/__pycache__/trainer.cpython-310.pyc b/code/reveal_vla_bimanual/train/__pycache__/trainer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0f571eee234935a4feeaa1ee06affda87eb3b90 Binary files /dev/null and b/code/reveal_vla_bimanual/train/__pycache__/trainer.cpython-310.pyc differ diff --git a/code/reveal_vla_bimanual/train/__pycache__/trainer.cpython-311.pyc b/code/reveal_vla_bimanual/train/__pycache__/trainer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e22b9b1d79888917f07fcf19c20daf311aa7bead Binary files /dev/null and b/code/reveal_vla_bimanual/train/__pycache__/trainer.cpython-311.pyc differ diff --git a/code/reveal_vla_bimanual/train/configs/data/reveal_proxies.yaml b/code/reveal_vla_bimanual/train/configs/data/reveal_proxies.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cab33bea8ad5a44f936ff4d01d8ddc0ab0a4361c --- /dev/null +++ b/code/reveal_vla_bimanual/train/configs/data/reveal_proxies.yaml @@ -0,0 +1,9 @@ +dataset: + name: reveal_proxies + proxies: [foliage_proxy, bag_proxy, cloth_proxy] + image_resolution: 224 + chunk_size: 8 + action_dim: 14 + num_support_modes: 3 + num_approach_templates: 32 + rollout_horizon: 5 diff --git a/code/reveal_vla_bimanual/train/configs/data/rlbench_3cam.yaml b/code/reveal_vla_bimanual/train/configs/data/rlbench_3cam.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d6a2a4c1ec742dccc6b4bb28dbc3b5178e08187b --- /dev/null +++ b/code/reveal_vla_bimanual/train/configs/data/rlbench_3cam.yaml @@ -0,0 +1,8 @@ +dataset: + name: rlbench_3cam + cameras: [front, wrist_left, wrist_right] + image_resolution: 224 + chunk_size: 8 + action_dim: 14 + include_point_cloud_for_peract2: true + task_split: peract2 diff --git a/code/reveal_vla_bimanual/train/configs/model/backbone_only.yaml b/code/reveal_vla_bimanual/train/configs/model/backbone_only.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b0ef785494d7633e52f4d362f2aeb6804225f532 --- /dev/null +++ b/code/reveal_vla_bimanual/train/configs/model/backbone_only.yaml @@ -0,0 +1,26 @@ +policy: + backbone: + model_name: openai/clip-vit-base-patch32 + hidden_dim: 512 + max_text_tokens: 32 + freeze_backbone: true + gradient_checkpointing: true + use_dummy_backbone: false + fusion: + hidden_dim: 512 + num_cameras: 3 + num_layers: 4 + num_heads: 8 + ff_dim: 2048 + dropout: 0.1 + proprio_dim: 32 + proprio_tokens: 1 + decoder: + hidden_dim: 512 + num_heads: 8 + num_layers: 4 + ff_dim: 2048 + dropout: 0.1 + chunk_size: 8 + action_dim: 14 + num_candidates: 8 diff --git a/code/reveal_vla_bimanual/train/configs/model/reveal_state.yaml b/code/reveal_vla_bimanual/train/configs/model/reveal_state.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6d93f14efe43a8e05402e9ac128d743d0f775918 --- /dev/null +++ b/code/reveal_vla_bimanual/train/configs/model/reveal_state.yaml @@ -0,0 +1,47 @@ +policy: + backbone: + model_name: openai/clip-vit-base-patch32 + hidden_dim: 512 + max_text_tokens: 32 + freeze_backbone: true + gradient_checkpointing: true + use_dummy_backbone: false + fusion: + hidden_dim: 512 + num_cameras: 3 + num_layers: 4 + num_heads: 8 + ff_dim: 2048 + dropout: 0.1 + proprio_dim: 32 + proprio_tokens: 1 + decoder: + hidden_dim: 512 + num_heads: 8 + num_layers: 4 + ff_dim: 2048 + dropout: 0.1 + chunk_size: 8 + action_dim: 14 + num_candidates: 8 + reveal_head: + hidden_dim: 512 + num_support_modes: 3 + num_approach_templates: 32 + rollout_horizon: 5 + belief_map_size: 32 + predict_belief_map: false + world_model: + hidden_dim: 512 + action_dim: 14 + num_support_modes: 3 + num_approach_templates: 32 + rollout_horizon: 5 + planner: + num_candidates: 8 + corridor_weight: 1.0 + persistence_weight: 0.5 + proposal_weight: 0.5 + disturbance_weight: 0.75 + reocclusion_weight: 0.5 + visibility_weight: 0.25 diff --git a/code/reveal_vla_bimanual/train/configs/proxy_backbone_only.yaml b/code/reveal_vla_bimanual/train/configs/proxy_backbone_only.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8b67aa367dc1847f8ffddb2c31d2634905440a59 --- /dev/null +++ b/code/reveal_vla_bimanual/train/configs/proxy_backbone_only.yaml @@ -0,0 +1,87 @@ +experiment_name: proxy_backbone_only +output_dir: /workspace/outputs/reveal_runs +device: cuda +seed: 7 + +data: + proxies: [foliage_proxy, bag_proxy, cloth_proxy] + resolution: 96 + train_episodes_per_proxy: 48 + val_episodes_per_proxy: 16 + train_dataset_path: /workspace/data/reveal_proxy/proxy_train_v2.pt + val_dataset_path: /workspace/data/reveal_proxy/proxy_val_v2.pt + rebuild_dataset: true + chunk_horizon: 8 + rollout_horizon: 5 + seed: 7 + +optim: + epochs: 8 + batch_size: 16 + num_workers: 0 + lr: 0.001 + weight_decay: 0.0001 + +trainer: + policy_type: backbone_only + use_bf16: true + grad_clip_norm: 1.0 + freeze_backbone: true + gradient_checkpointing: false + +policy: + backbone: + model_name: openai/clip-vit-base-patch32 + hidden_dim: 128 + max_text_tokens: 32 + freeze_backbone: true + gradient_checkpointing: false + use_dummy_backbone: true + fusion: + hidden_dim: 128 + num_cameras: 3 + num_layers: 2 + num_heads: 4 + ff_dim: 256 + dropout: 0.1 + proprio_dim: 32 + proprio_tokens: 1 + decoder: + hidden_dim: 128 + num_heads: 4 + num_layers: 2 + ff_dim: 256 + dropout: 0.1 + chunk_size: 8 + action_dim: 14 + num_candidates: 8 + reveal_head: + hidden_dim: 128 + num_support_modes: 3 + num_approach_templates: 32 + rollout_horizon: 5 + belief_map_size: 32 + predict_belief_map: true + world_model: + hidden_dim: 128 + action_dim: 14 + num_support_modes: 3 + num_approach_templates: 32 + rollout_horizon: 5 + planner: + num_candidates: 8 + corridor_weight: 1.0 + persistence_weight: 0.5 + proposal_weight: 0.5 + disturbance_weight: 0.75 + reocclusion_weight: 0.5 + visibility_weight: 0.25 + +loss_weights: + action: 1.0 + support_mode: 0.1 + corridor: 0.1 + persistence: 0.05 + disturbance: 0.05 + world_model: 0.1 + belief: 0.05 diff --git a/code/reveal_vla_bimanual/train/configs/proxy_backbone_only_clip.yaml b/code/reveal_vla_bimanual/train/configs/proxy_backbone_only_clip.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c9be7b0554012a5f63070edf882fe65c367e2515 --- /dev/null +++ b/code/reveal_vla_bimanual/train/configs/proxy_backbone_only_clip.yaml @@ -0,0 +1,90 @@ +experiment_name: proxy_backbone_only_clip +output_dir: /workspace/outputs/reveal_runs +device: cuda +seed: 7 + +data: + proxies: [foliage_proxy, bag_proxy, cloth_proxy] + resolution: 224 + train_episodes_per_proxy: 48 + val_episodes_per_proxy: 16 + train_dataset_path: /workspace/data/reveal_proxy/proxy_train_clip224.pt + val_dataset_path: /workspace/data/reveal_proxy/proxy_val_clip224.pt + rebuild_dataset: true + chunk_horizon: 8 + rollout_horizon: 5 + seed: 7 + +optim: + epochs: 4 + batch_size: 2 + num_workers: 0 + lr: 0.0003 + weight_decay: 0.0001 + +trainer: + policy_type: backbone_only + use_bf16: true + grad_clip_norm: 1.0 + freeze_backbone: true + gradient_checkpointing: false + plan_during_train: false + plan_during_eval: false + support_mode_conditioning: true + +policy: + backbone: + model_name: openai/clip-vit-base-patch32 + hidden_dim: 512 + max_text_tokens: 32 + freeze_backbone: true + gradient_checkpointing: false + use_dummy_backbone: false + fusion: + hidden_dim: 512 + num_cameras: 3 + num_layers: 4 + num_heads: 8 + ff_dim: 2048 + dropout: 0.1 + proprio_dim: 32 + proprio_tokens: 1 + decoder: + hidden_dim: 512 + num_heads: 8 + num_layers: 4 + ff_dim: 2048 + dropout: 0.1 + chunk_size: 8 + action_dim: 14 + num_candidates: 8 + reveal_head: + hidden_dim: 512 + num_support_modes: 3 + num_approach_templates: 32 + rollout_horizon: 5 + belief_map_size: 32 + predict_belief_map: true + world_model: + hidden_dim: 512 + action_dim: 14 + num_support_modes: 3 + num_approach_templates: 32 + rollout_horizon: 5 + planner: + num_candidates: 8 + corridor_weight: 1.0 + persistence_weight: 0.5 + proposal_weight: 0.5 + disturbance_weight: 0.75 + reocclusion_weight: 0.5 + visibility_weight: 0.25 + +loss_weights: + action: 1.0 + support_mode: 0.1 + corridor: 0.1 + persistence: 0.05 + disturbance: 0.05 + world_model: 0.1 + belief: 0.05 diff --git a/code/reveal_vla_bimanual/train/configs/proxy_reveal_state.yaml b/code/reveal_vla_bimanual/train/configs/proxy_reveal_state.yaml new file mode 100644 index 0000000000000000000000000000000000000000..eb79c1dba1905c1e65cfc585b4bc9d1659e9fd36 --- /dev/null +++ b/code/reveal_vla_bimanual/train/configs/proxy_reveal_state.yaml @@ -0,0 +1,87 @@ +experiment_name: proxy_reveal_state +output_dir: /workspace/outputs/reveal_runs +device: cuda +seed: 7 + +data: + proxies: [foliage_proxy, bag_proxy, cloth_proxy] + resolution: 96 + train_episodes_per_proxy: 48 + val_episodes_per_proxy: 16 + train_dataset_path: /workspace/data/reveal_proxy/proxy_train_v2.pt + val_dataset_path: /workspace/data/reveal_proxy/proxy_val_v2.pt + rebuild_dataset: false + chunk_horizon: 8 + rollout_horizon: 5 + seed: 7 + +optim: + epochs: 8 + batch_size: 16 + num_workers: 0 + lr: 0.001 + weight_decay: 0.0001 + +trainer: + policy_type: reveal_state + use_bf16: true + grad_clip_norm: 1.0 + freeze_backbone: true + gradient_checkpointing: false + +policy: + backbone: + model_name: openai/clip-vit-base-patch32 + hidden_dim: 128 + max_text_tokens: 32 + freeze_backbone: true + gradient_checkpointing: false + use_dummy_backbone: true + fusion: + hidden_dim: 128 + num_cameras: 3 + num_layers: 2 + num_heads: 4 + ff_dim: 256 + dropout: 0.1 + proprio_dim: 32 + proprio_tokens: 1 + decoder: + hidden_dim: 128 + num_heads: 4 + num_layers: 2 + ff_dim: 256 + dropout: 0.1 + chunk_size: 8 + action_dim: 14 + num_candidates: 8 + reveal_head: + hidden_dim: 128 + num_support_modes: 3 + num_approach_templates: 32 + rollout_horizon: 5 + belief_map_size: 32 + predict_belief_map: true + world_model: + hidden_dim: 128 + action_dim: 14 + num_support_modes: 3 + num_approach_templates: 32 + rollout_horizon: 5 + planner: + num_candidates: 8 + corridor_weight: 1.0 + persistence_weight: 0.65 + proposal_weight: 0.35 + disturbance_weight: 0.8 + reocclusion_weight: 0.6 + visibility_weight: 0.35 + +loss_weights: + action: 1.0 + support_mode: 0.15 + corridor: 0.2 + persistence: 0.1 + disturbance: 0.1 + world_model: 0.2 + belief: 0.05 diff --git a/code/reveal_vla_bimanual/train/configs/proxy_reveal_state_clip.yaml b/code/reveal_vla_bimanual/train/configs/proxy_reveal_state_clip.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2aee557e132dea5d368366f26a797c197f47a09e --- /dev/null +++ b/code/reveal_vla_bimanual/train/configs/proxy_reveal_state_clip.yaml @@ -0,0 +1,90 @@ +experiment_name: proxy_reveal_state_clip +output_dir: /workspace/outputs/reveal_runs +device: cuda +seed: 7 + +data: + proxies: [foliage_proxy, bag_proxy, cloth_proxy] + resolution: 224 + train_episodes_per_proxy: 48 + val_episodes_per_proxy: 16 + train_dataset_path: /workspace/data/reveal_proxy/proxy_train_clip224.pt + val_dataset_path: /workspace/data/reveal_proxy/proxy_val_clip224.pt + rebuild_dataset: false + chunk_horizon: 8 + rollout_horizon: 5 + seed: 7 + +optim: + epochs: 4 + batch_size: 2 + num_workers: 0 + lr: 0.0003 + weight_decay: 0.0001 + +trainer: + policy_type: reveal_state + use_bf16: true + grad_clip_norm: 1.0 + freeze_backbone: true + gradient_checkpointing: false + plan_during_train: true + plan_during_eval: true + support_mode_conditioning: true + +policy: + backbone: + model_name: openai/clip-vit-base-patch32 + hidden_dim: 512 + max_text_tokens: 32 + freeze_backbone: true + gradient_checkpointing: false + use_dummy_backbone: false + fusion: + hidden_dim: 512 + num_cameras: 3 + num_layers: 4 + num_heads: 8 + ff_dim: 2048 + dropout: 0.1 + proprio_dim: 32 + proprio_tokens: 1 + decoder: + hidden_dim: 512 + num_heads: 8 + num_layers: 4 + ff_dim: 2048 + dropout: 0.1 + chunk_size: 8 + action_dim: 14 + num_candidates: 8 + reveal_head: + hidden_dim: 512 + num_support_modes: 3 + num_approach_templates: 32 + rollout_horizon: 5 + belief_map_size: 32 + predict_belief_map: true + world_model: + hidden_dim: 512 + action_dim: 14 + num_support_modes: 3 + num_approach_templates: 32 + rollout_horizon: 5 + planner: + num_candidates: 8 + corridor_weight: 1.0 + persistence_weight: 0.65 + proposal_weight: 0.35 + disturbance_weight: 0.8 + reocclusion_weight: 0.6 + visibility_weight: 0.35 + +loss_weights: + action: 1.0 + support_mode: 0.15 + corridor: 0.2 + persistence: 0.1 + disturbance: 0.1 + world_model: 0.2 + belief: 0.05 diff --git a/code/reveal_vla_bimanual/train/configs/rlbench_subset3_backbone_only_clip.yaml b/code/reveal_vla_bimanual/train/configs/rlbench_subset3_backbone_only_clip.yaml new file mode 100644 index 0000000000000000000000000000000000000000..38bcc370da1cad88ecd18b8399190d92d13d0e86 --- /dev/null +++ b/code/reveal_vla_bimanual/train/configs/rlbench_subset3_backbone_only_clip.yaml @@ -0,0 +1,89 @@ +experiment_name: rlbench_subset3_backbone_only_clip +output_dir: /workspace/outputs/rlbench_custom +device: cuda +seed: 7 +init_checkpoint: /workspace/outputs/reveal_runs/proxy_backbone_only_clip/checkpoint_best.pt +init_strict: false + +data: + dataset_root: /workspace/data/rlbench2 + tasks: [bimanual_lift_ball, bimanual_push_box, bimanual_dual_push_buttons] + train_episodes: [0] + val_episodes: [1] + resolution: 224 + chunk_horizon: 8 + proprio_dim: 32 + +optim: + epochs: 2 + batch_size: 2 + num_workers: 0 + lr: 0.0002 + weight_decay: 0.0001 + +trainer: + policy_type: backbone_only + use_bf16: true + grad_clip_norm: 1.0 + freeze_backbone: true + gradient_checkpointing: false + plan_during_train: false + plan_during_eval: false + support_mode_conditioning: true + +policy: + backbone: + model_name: openai/clip-vit-base-patch32 + hidden_dim: 512 + max_text_tokens: 32 + freeze_backbone: true + gradient_checkpointing: false + use_dummy_backbone: false + fusion: + hidden_dim: 512 + num_cameras: 3 + num_layers: 4 + num_heads: 8 + ff_dim: 2048 + dropout: 0.1 + proprio_dim: 32 + proprio_tokens: 1 + decoder: + hidden_dim: 512 + num_heads: 8 + num_layers: 4 + ff_dim: 2048 + dropout: 0.1 + chunk_size: 8 + action_dim: 14 + num_candidates: 8 + reveal_head: + hidden_dim: 512 + num_support_modes: 3 + num_approach_templates: 32 + rollout_horizon: 5 + belief_map_size: 32 + predict_belief_map: true + world_model: + hidden_dim: 512 + action_dim: 14 + num_support_modes: 3 + num_approach_templates: 32 + rollout_horizon: 5 + planner: + num_candidates: 8 + corridor_weight: 1.0 + persistence_weight: 0.5 + proposal_weight: 0.5 + disturbance_weight: 0.75 + reocclusion_weight: 0.5 + visibility_weight: 0.25 + +loss_weights: + action: 1.0 + support_mode: 0.1 + corridor: 0.1 + persistence: 0.05 + disturbance: 0.05 + world_model: 0.1 + belief: 0.05 diff --git a/code/reveal_vla_bimanual/train/configs/rlbench_subset3_backbone_only_dummy.yaml b/code/reveal_vla_bimanual/train/configs/rlbench_subset3_backbone_only_dummy.yaml new file mode 100644 index 0000000000000000000000000000000000000000..31ad915f0716a77a9be572803b8431295cacd7fa --- /dev/null +++ b/code/reveal_vla_bimanual/train/configs/rlbench_subset3_backbone_only_dummy.yaml @@ -0,0 +1,89 @@ +experiment_name: rlbench_subset3_backbone_only_dummy +output_dir: /workspace/outputs/rlbench_custom +device: cuda +seed: 7 +init_checkpoint: /workspace/outputs/reveal_runs/proxy_backbone_only/checkpoint_best.pt +init_strict: false + +data: + dataset_root: /workspace/data/rlbench2 + tasks: [bimanual_lift_ball, bimanual_push_box, bimanual_dual_push_buttons] + train_episodes: [0] + val_episodes: [1] + resolution: 224 + chunk_horizon: 8 + proprio_dim: 32 + +optim: + epochs: 2 + batch_size: 4 + num_workers: 0 + lr: 0.0005 + weight_decay: 0.0001 + +trainer: + policy_type: backbone_only + use_bf16: true + grad_clip_norm: 1.0 + freeze_backbone: true + gradient_checkpointing: false + plan_during_train: false + plan_during_eval: false + support_mode_conditioning: true + +policy: + backbone: + model_name: openai/clip-vit-base-patch32 + hidden_dim: 128 + max_text_tokens: 32 + freeze_backbone: true + gradient_checkpointing: false + use_dummy_backbone: true + fusion: + hidden_dim: 128 + num_cameras: 3 + num_layers: 2 + num_heads: 4 + ff_dim: 256 + dropout: 0.1 + proprio_dim: 32 + proprio_tokens: 1 + decoder: + hidden_dim: 128 + num_heads: 4 + num_layers: 2 + ff_dim: 256 + dropout: 0.1 + chunk_size: 8 + action_dim: 14 + num_candidates: 8 + reveal_head: + hidden_dim: 128 + num_support_modes: 3 + num_approach_templates: 32 + rollout_horizon: 5 + belief_map_size: 32 + predict_belief_map: true + world_model: + hidden_dim: 128 + action_dim: 14 + num_support_modes: 3 + num_approach_templates: 32 + rollout_horizon: 5 + planner: + num_candidates: 8 + corridor_weight: 1.0 + persistence_weight: 0.5 + proposal_weight: 0.5 + disturbance_weight: 0.75 + reocclusion_weight: 0.5 + visibility_weight: 0.25 + +loss_weights: + action: 1.0 + support_mode: 0.1 + corridor: 0.1 + persistence: 0.05 + disturbance: 0.05 + world_model: 0.1 + belief: 0.05 diff --git a/code/reveal_vla_bimanual/train/configs/rlbench_subset3_reveal_state_clip.yaml b/code/reveal_vla_bimanual/train/configs/rlbench_subset3_reveal_state_clip.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7cd4f11b6689c83cf0d540db1a15ff88842dcdf9 --- /dev/null +++ b/code/reveal_vla_bimanual/train/configs/rlbench_subset3_reveal_state_clip.yaml @@ -0,0 +1,89 @@ +experiment_name: rlbench_subset3_reveal_state_clip +output_dir: /workspace/outputs/rlbench_custom +device: cuda +seed: 7 +init_checkpoint: /workspace/outputs/reveal_runs/proxy_reveal_state_clip/checkpoint_best.pt +init_strict: false + +data: + dataset_root: /workspace/data/rlbench2 + tasks: [bimanual_lift_ball, bimanual_push_box, bimanual_dual_push_buttons] + train_episodes: [0] + val_episodes: [1] + resolution: 224 + chunk_horizon: 8 + proprio_dim: 32 + +optim: + epochs: 2 + batch_size: 2 + num_workers: 0 + lr: 0.0002 + weight_decay: 0.0001 + +trainer: + policy_type: reveal_state + use_bf16: true + grad_clip_norm: 1.0 + freeze_backbone: true + gradient_checkpointing: false + plan_during_train: false + plan_during_eval: false + support_mode_conditioning: true + +policy: + backbone: + model_name: openai/clip-vit-base-patch32 + hidden_dim: 512 + max_text_tokens: 32 + freeze_backbone: true + gradient_checkpointing: false + use_dummy_backbone: false + fusion: + hidden_dim: 512 + num_cameras: 3 + num_layers: 4 + num_heads: 8 + ff_dim: 2048 + dropout: 0.1 + proprio_dim: 32 + proprio_tokens: 1 + decoder: + hidden_dim: 512 + num_heads: 8 + num_layers: 4 + ff_dim: 2048 + dropout: 0.1 + chunk_size: 8 + action_dim: 14 + num_candidates: 8 + reveal_head: + hidden_dim: 512 + num_support_modes: 3 + num_approach_templates: 32 + rollout_horizon: 5 + belief_map_size: 32 + predict_belief_map: true + world_model: + hidden_dim: 512 + action_dim: 14 + num_support_modes: 3 + num_approach_templates: 32 + rollout_horizon: 5 + planner: + num_candidates: 8 + corridor_weight: 1.0 + persistence_weight: 0.65 + proposal_weight: 0.35 + disturbance_weight: 0.8 + reocclusion_weight: 0.6 + visibility_weight: 0.35 + +loss_weights: + action: 1.0 + support_mode: 0.15 + corridor: 0.2 + persistence: 0.1 + disturbance: 0.1 + world_model: 0.2 + belief: 0.05 diff --git a/code/reveal_vla_bimanual/train/configs/rlbench_subset3_reveal_state_dummy.yaml b/code/reveal_vla_bimanual/train/configs/rlbench_subset3_reveal_state_dummy.yaml new file mode 100644 index 0000000000000000000000000000000000000000..02eb05b5f28024b4b9b46e6d53b8f139b0d74414 --- /dev/null +++ b/code/reveal_vla_bimanual/train/configs/rlbench_subset3_reveal_state_dummy.yaml @@ -0,0 +1,89 @@ +experiment_name: rlbench_subset3_reveal_state_dummy +output_dir: /workspace/outputs/rlbench_custom +device: cuda +seed: 7 +init_checkpoint: /workspace/outputs/reveal_runs/proxy_reveal_state/checkpoint_best.pt +init_strict: false + +data: + dataset_root: /workspace/data/rlbench2 + tasks: [bimanual_lift_ball, bimanual_push_box, bimanual_dual_push_buttons] + train_episodes: [0] + val_episodes: [1] + resolution: 224 + chunk_horizon: 8 + proprio_dim: 32 + +optim: + epochs: 2 + batch_size: 4 + num_workers: 0 + lr: 0.0005 + weight_decay: 0.0001 + +trainer: + policy_type: reveal_state + use_bf16: true + grad_clip_norm: 1.0 + freeze_backbone: true + gradient_checkpointing: false + plan_during_train: false + plan_during_eval: false + support_mode_conditioning: true + +policy: + backbone: + model_name: openai/clip-vit-base-patch32 + hidden_dim: 128 + max_text_tokens: 32 + freeze_backbone: true + gradient_checkpointing: false + use_dummy_backbone: true + fusion: + hidden_dim: 128 + num_cameras: 3 + num_layers: 2 + num_heads: 4 + ff_dim: 256 + dropout: 0.1 + proprio_dim: 32 + proprio_tokens: 1 + decoder: + hidden_dim: 128 + num_heads: 4 + num_layers: 2 + ff_dim: 256 + dropout: 0.1 + chunk_size: 8 + action_dim: 14 + num_candidates: 8 + reveal_head: + hidden_dim: 128 + num_support_modes: 3 + num_approach_templates: 32 + rollout_horizon: 5 + belief_map_size: 32 + predict_belief_map: true + world_model: + hidden_dim: 128 + action_dim: 14 + num_support_modes: 3 + num_approach_templates: 32 + rollout_horizon: 5 + planner: + num_candidates: 8 + corridor_weight: 1.0 + persistence_weight: 0.65 + proposal_weight: 0.35 + disturbance_weight: 0.8 + reocclusion_weight: 0.6 + visibility_weight: 0.35 + +loss_weights: + action: 1.0 + support_mode: 0.15 + corridor: 0.2 + persistence: 0.1 + disturbance: 0.1 + world_model: 0.2 + belief: 0.05 diff --git a/code/upstream_local_patches/YARR/yarr/runners/_independent_env_runner.py b/code/upstream_local_patches/YARR/yarr/runners/_independent_env_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..521c8cd8e565a0c072fdcf0bc85af9d38a3e54cf --- /dev/null +++ b/code/upstream_local_patches/YARR/yarr/runners/_independent_env_runner.py @@ -0,0 +1,303 @@ +import copy +import logging +import os +import time +import pandas as pd + +from multiprocessing import Process, Manager +from multiprocessing import get_start_method, set_start_method +from typing import Any + +import numpy as np +import torch +from yarr.agents.agent import Agent +from yarr.agents.agent import ScalarSummary +from yarr.agents.agent import Summary +from yarr.envs.env import Env +from yarr.utils.rollout_generator import RolloutGenerator +from yarr.utils.log_writer import LogWriter +from yarr.utils.process_str import change_case +from yarr.utils.video_utils import CircleCameraMotion, TaskRecorder + +from pyrep.objects.dummy import Dummy +from pyrep.objects.vision_sensor import VisionSensor + +from yarr.runners._env_runner import _EnvRunner + + +class _IndependentEnvRunner(_EnvRunner): + + def __init__(self, + train_env: Env, + eval_env: Env, + agent: Agent, + timesteps: int, + train_envs: int, + eval_envs: int, + rollout_episodes: int, + eval_episodes: int, + training_iterations: int, + eval_from_eps_number: int, + episode_length: int, + kill_signal: Any, + step_signal: Any, + num_eval_episodes_signal: Any, + eval_epochs_signal: Any, + eval_report_signal: Any, + log_freq: int, + rollout_generator: RolloutGenerator, + save_load_lock, + current_replay_ratio, + target_replay_ratio, + weightsdir: str = None, + logdir: str = None, + env_device: torch.device = None, + previous_loaded_weight_folder: str = '', + num_eval_runs: int = 1, + ): + + super().__init__(train_env, eval_env, agent, timesteps, + train_envs, eval_envs, rollout_episodes, eval_episodes, + training_iterations, eval_from_eps_number, episode_length, + kill_signal, step_signal, num_eval_episodes_signal, + eval_epochs_signal, eval_report_signal, log_freq, + rollout_generator, save_load_lock, current_replay_ratio, + target_replay_ratio, weightsdir, logdir, env_device, + previous_loaded_weight_folder, num_eval_runs) + + def _load_save(self): + if self._weightsdir is None: + logging.info("'weightsdir' was None, so not loading weights.") + return + while True: + weight_folders = [] + with self._save_load_lock: + if os.path.exists(self._weightsdir): + weight_folders = os.listdir(self._weightsdir) + if len(weight_folders) > 0: + weight_folders = sorted(map(int, weight_folders)) + # only load if there has been a new weight saving + if self._previous_loaded_weight_folder != weight_folders[-1]: + self._previous_loaded_weight_folder = weight_folders[-1] + d = os.path.join(self._weightsdir, str(weight_folders[-1])) + try: + self._agent.load_weights(d) + except FileNotFoundError: + # rare case when agent hasn't finished writing. + time.sleep(1) + self._agent.load_weights(d) + logging.info('Agent %s: Loaded weights: %s' % (self._name, d)) + self._new_weights = True + else: + self._new_weights = False + break + logging.info('Waiting for weights to become available.') + time.sleep(1) + + def _get_task_name(self): + if hasattr(self._eval_env, '_task_class'): + eval_task_name = change_case(self._eval_env._task_class.__name__) + multi_task = False + elif hasattr(self._eval_env, '_task_classes'): + if self._eval_env.active_task_id != -1: + task_id = (self._eval_env.active_task_id) % len(self._eval_env._task_classes) + eval_task_name = change_case(self._eval_env._task_classes[task_id].__name__) + else: + eval_task_name = '' + multi_task = True + else: + raise Exception('Neither task_class nor task_classes found in eval env') + return eval_task_name, multi_task + + def _run_eval_independent(self, name: str, + stats_accumulator, + weight, + writer_lock, + eval=True, + device_idx=0, + save_metrics=True, + cinematic_recorder_cfg=None): + + self._name = name + self._save_metrics = save_metrics + self._is_test_set = type(weight) == dict + + self._agent = copy.deepcopy(self._agent) + + device = torch.device('cuda:%d' % device_idx) if torch.cuda.device_count() > 1 else torch.device('cuda:0') + with writer_lock: # hack to prevent multiple CLIP downloads ... argh should use a separate lock + self._agent.build(training=False, device=device) + + logging.info('%s: Launching env.' % name) + np.random.seed() + + logging.info('Agent information:') + logging.info(self._agent) + + env = self._eval_env + env.eval = eval + env.launch() + + # initialize cinematic recorder if specified + rec_cfg = cinematic_recorder_cfg + if rec_cfg.enabled: + cam_placeholder = Dummy('cam_cinematic_placeholder') + cam = VisionSensor.create(rec_cfg.camera_resolution) + cam.set_pose(cam_placeholder.get_pose()) + cam.set_parent(cam_placeholder) + + cam_motion = CircleCameraMotion(cam, Dummy('cam_cinematic_base'), rec_cfg.rotate_speed) + tr = TaskRecorder(env, cam_motion, fps=rec_cfg.fps) + + env.env._action_mode.arm_action_mode.set_callable_each_step(tr.take_snap) + + if not os.path.exists(self._weightsdir): + raise Exception('No weights directory found.') + + # to save or not to save evaluation metrics (set as False for recording videos) + if self._save_metrics: + csv_file = 'eval_data.csv' if not self._is_test_set else 'test_data.csv' + writer = LogWriter(self._logdir, True, True, + env_csv=csv_file) + + # one weight for all tasks (used for validation) + if type(weight) == int: + logging.info('Evaluating weight %s' % weight) + weight_path = os.path.join(self._weightsdir, str(weight)) + seed_path = self._weightsdir.replace('/weights', '') + self._agent.load_weights(weight_path) + weight_name = str(weight) + + new_transitions = {'train_envs': 0, 'eval_envs': 0} + total_transitions = {'train_envs': 0, 'eval_envs': 0} + current_task_id = -1 + + for n_eval in range(self._num_eval_runs): + if rec_cfg.enabled: + tr._cam_motion.save_pose() + + # best weight for each task (used for test evaluation) + if type(weight) == dict: + task_name = list(weight.keys())[n_eval] + task_weight = weight[task_name] + weight_path = os.path.join(self._weightsdir, str(task_weight)) + seed_path = self._weightsdir.replace('/weights', '') + self._agent.load_weights(weight_path) + weight_name = str(task_weight) + print('Evaluating weight %s for %s' % (weight_name, task_name)) + + # evaluate on N tasks * M episodes per task = total eval episodes + for ep in range(self._eval_episodes): + eval_demo_seed = ep + self._eval_from_eps_number + logging.info('%s: Starting episode %d, seed %d.' % (name, ep, eval_demo_seed)) + + # the current task gets reset after every M episodes + episode_rollout = [] + generator = self._rollout_generator.generator( + self._step_signal, env, self._agent, + self._episode_length, self._timesteps, + eval, eval_demo_seed=eval_demo_seed, + record_enabled=rec_cfg.enabled) + try: + for replay_transition in generator: + while True: + if self._kill_signal.value: + env.shutdown() + return + if (eval or self._target_replay_ratio is None or + self._step_signal.value <= 0 or ( + self._current_replay_ratio.value > + self._target_replay_ratio)): + break + time.sleep(1) + logging.debug( + 'Agent. Waiting for replay_ratio %f to be more than %f' % + (self._current_replay_ratio.value, self._target_replay_ratio)) + + with self.write_lock: + if len(self.agent_summaries) == 0: + # Only store new summaries if the previous ones + # have been popped by the main env runner. + for s in self._agent.act_summaries(): + self.agent_summaries.append(s) + episode_rollout.append(replay_transition) + except StopIteration as e: + continue + except Exception as e: + env.shutdown() + raise e + + with self.write_lock: + for transition in episode_rollout: + self.stored_transitions.append((name, transition, eval)) + + new_transitions['eval_envs'] += 1 + total_transitions['eval_envs'] += 1 + stats_accumulator.step(transition, eval) + current_task_id = transition.info['active_task_id'] + + self._num_eval_episodes_signal.value += 1 + + task_name, _ = self._get_task_name() + reward = episode_rollout[-1].reward + lang_goal = env._lang_goal + print(f"Evaluating {task_name} | Episode {ep} | Score: {reward} | Lang Goal: {lang_goal}") + + # save recording + if rec_cfg.enabled: + success = reward > 0.99 + record_file = os.path.join(seed_path, 'videos', + '%s_w%s_s%s_%s.mp4' % (task_name, + weight_name, + eval_demo_seed, + 'succ' if success else 'fail')) + + lang_goal = self._eval_env._lang_goal + + tr.save(record_file, lang_goal, reward) + tr._cam_motion.restore_pose() + + # report summaries + summaries = [] + summaries.extend(stats_accumulator.pop()) + + eval_task_name, multi_task = self._get_task_name() + + if eval_task_name and multi_task: + for s in summaries: + if 'eval' in s.name: + s.name = '%s/%s' % (s.name, eval_task_name) + + if len(summaries) > 0: + if multi_task: + task_score = [s.value for s in summaries if f'eval_envs/return/{eval_task_name}' in s.name][0] + else: + task_score = [s.value for s in summaries if f'eval_envs/return' in s.name][0] + else: + task_score = float(reward) + summary_name = ( + f'eval_envs/return/{eval_task_name}' + if multi_task + else 'eval_envs/return' + ) + summaries.append(ScalarSummary(summary_name, task_score)) + + print(f"Finished {eval_task_name} | Final Score: {task_score}\n") + + if self._save_metrics: + with writer_lock: + writer.add_summaries(weight_name, summaries) + + self._new_transitions = {'train_envs': 0, 'eval_envs': 0} + self.agent_summaries[:] = [] + self.stored_transitions[:] = [] + + if self._save_metrics: + with writer_lock: + writer.end_iteration() + + logging.info('Finished evaluation.') + env.shutdown() + + def kill(self): + self._kill_signal.value = True diff --git a/code/upstream_local_patches/peract_bimanual/agents/bimanual_peract/qattention_peract_bc_agent.py b/code/upstream_local_patches/peract_bimanual/agents/bimanual_peract/qattention_peract_bc_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..192b0d50546b071df7aa9022e56e6e36f5945894 --- /dev/null +++ b/code/upstream_local_patches/peract_bimanual/agents/bimanual_peract/qattention_peract_bc_agent.py @@ -0,0 +1,1066 @@ +import copy +import logging +import os +from typing import List + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import transforms +from pytorch3d import transforms as torch3d_tf +from yarr.agents.agent import ( + Agent, + ActResult, + ScalarSummary, + HistogramSummary, + ImageSummary, + Summary, +) + +from helpers import utils +from helpers.utils import visualise_voxel, stack_on_channel +from voxel.voxel_grid import VoxelGrid +from voxel.augmentation import apply_se3_augmentation +from einops import rearrange +from helpers.clip.core.clip import build_model, load_clip + +import transformers +from helpers.optim.lamb import Lamb + +from torch.nn.parallel import DistributedDataParallel as DDP + +NAME = "QAttentionAgent" + + +class QFunction(nn.Module): + def __init__( + self, + perceiver_encoder: nn.Module, + voxelizer: VoxelGrid, + bounds_offset: float, + rotation_resolution: float, + device, + training, + ): + super(QFunction, self).__init__() + self._rotation_resolution = rotation_resolution + self._voxelizer = voxelizer + self._bounds_offset = bounds_offset + self._qnet = perceiver_encoder.to(device) + + # distributed training + if training: + self._qnet = DDP(self._qnet, device_ids=[device]) + + def _argmax_3d(self, tensor_orig): + b, c, d, h, w = tensor_orig.shape # c will be one + idxs = tensor_orig.view(b, c, -1).argmax(-1) + indices = torch.cat([((idxs // h) // d), (idxs // h) % w, idxs % w], 1) + return indices + + def choose_highest_action(self, q_trans, q_rot_grip, q_collision): + coords = self._argmax_3d(q_trans) + rot_and_grip_indicies = None + ignore_collision = None + if q_rot_grip is not None: + q_rot = torch.stack( + torch.split( + q_rot_grip[:, :-2], int(360 // self._rotation_resolution), dim=1 + ), + dim=1, + ) + rot_and_grip_indicies = torch.cat( + [ + q_rot[:, 0:1].argmax(-1), + q_rot[:, 1:2].argmax(-1), + q_rot[:, 2:3].argmax(-1), + q_rot_grip[:, -2:].argmax(-1, keepdim=True), + ], + -1, + ) + ignore_collision = q_collision[:, -2:].argmax(-1, keepdim=True) + return coords, rot_and_grip_indicies, ignore_collision + + def forward( + self, + rgb_pcd, + proprio, + pcd, + lang_goal_emb, + lang_token_embs, + bounds=None, + prev_bounds=None, + prev_layer_voxel_grid=None, + ): + # rgb_pcd will be list of list (list of [rgb, pcd]) + b = rgb_pcd[0][0].shape[0] + pcd_flat = torch.cat([p.permute(0, 2, 3, 1).reshape(b, -1, 3) for p in pcd], 1) + + # flatten RGBs and Pointclouds + rgb = [rp[0] for rp in rgb_pcd] + feat_size = rgb[0].shape[1] + flat_imag_features = torch.cat( + [p.permute(0, 2, 3, 1).reshape(b, -1, feat_size) for p in rgb], 1 + ) + + # construct voxel grid + voxel_grid = self._voxelizer.coords_to_bounding_voxel_grid( + pcd_flat, coord_features=flat_imag_features, coord_bounds=bounds + ) + + # swap to channels fist + voxel_grid = voxel_grid.permute(0, 4, 1, 2, 3).detach() + + # batch bounds if necessary + if bounds.shape[0] != b: + bounds = bounds.repeat(b, 1) + + # forward pass + split_pred = self._qnet( + voxel_grid, + proprio, + lang_goal_emb, + lang_token_embs, + prev_layer_voxel_grid, + bounds, + prev_bounds, + ) + + return split_pred, voxel_grid + + +class QAttentionPerActBCAgent(Agent): + def __init__( + self, + layer: int, + coordinate_bounds: list, + perceiver_encoder: nn.Module, + camera_names: list, + batch_size: int, + voxel_size: int, + bounds_offset: float, + voxel_feature_size: int, + image_crop_size: int, + num_rotation_classes: int, + rotation_resolution: float, + lr: float = 0.0001, + lr_scheduler: bool = False, + training_iterations: int = 100000, + num_warmup_steps: int = 20000, + trans_loss_weight: float = 1.0, + rot_loss_weight: float = 1.0, + grip_loss_weight: float = 1.0, + collision_loss_weight: float = 1.0, + include_low_dim_state: bool = False, + image_resolution: list = None, + lambda_weight_l2: float = 0.0, + transform_augmentation: bool = True, + transform_augmentation_xyz: list = [0.0, 0.0, 0.0], + transform_augmentation_rpy: list = [0.0, 0.0, 180.0], + transform_augmentation_rot_resolution: int = 5, + optimizer_type: str = "adam", + num_devices: int = 1, + ): + self._layer = layer + self._coordinate_bounds = coordinate_bounds + self._perceiver_encoder = perceiver_encoder + self._voxel_feature_size = voxel_feature_size + self._bounds_offset = bounds_offset + self._image_crop_size = image_crop_size + self._lr = lr + self._lr_scheduler = lr_scheduler + self._training_iterations = training_iterations + self._num_warmup_steps = num_warmup_steps + self._trans_loss_weight = trans_loss_weight + self._rot_loss_weight = rot_loss_weight + self._grip_loss_weight = grip_loss_weight + self._collision_loss_weight = collision_loss_weight + self._include_low_dim_state = include_low_dim_state + self._image_resolution = image_resolution or [128, 128] + self._voxel_size = voxel_size + self._camera_names = camera_names + self._num_cameras = len(camera_names) + self._batch_size = batch_size + self._lambda_weight_l2 = lambda_weight_l2 + self._transform_augmentation = transform_augmentation + self._transform_augmentation_xyz = torch.from_numpy( + np.array(transform_augmentation_xyz) + ) + self._transform_augmentation_rpy = transform_augmentation_rpy + self._transform_augmentation_rot_resolution = ( + transform_augmentation_rot_resolution + ) + self._optimizer_type = optimizer_type + self._num_devices = num_devices + self._num_rotation_classes = num_rotation_classes + self._rotation_resolution = rotation_resolution + + self._cross_entropy_loss = nn.CrossEntropyLoss(reduction="none") + self._name = NAME + "_layer" + str(self._layer) + + def build(self, training: bool, device: torch.device = None): + self._training = training + + if device is None: + device = torch.device("cpu") + + self._device = device + + self._voxelizer = VoxelGrid( + coord_bounds=self._coordinate_bounds, + voxel_size=self._voxel_size, + device=device, + batch_size=self._batch_size if training else 1, + feature_size=self._voxel_feature_size, + max_num_coords=np.prod(self._image_resolution) * self._num_cameras, + ) + + self._q = ( + QFunction( + self._perceiver_encoder, + self._voxelizer, + self._bounds_offset, + self._rotation_resolution, + device, + training, + ) + .to(device) + .train(training) + ) + + grid_for_crop = ( + torch.arange(0, self._image_crop_size, device=device) + .unsqueeze(0) + .repeat(self._image_crop_size, 1) + .unsqueeze(-1) + ) + self._grid_for_crop = torch.cat( + [grid_for_crop.transpose(1, 0), grid_for_crop], dim=2 + ).unsqueeze(0) + + self._coordinate_bounds = torch.tensor( + self._coordinate_bounds, device=device + ).unsqueeze(0) + + if self._training: + # optimizer + if self._optimizer_type == "lamb": + self._optimizer = Lamb( + self._q.parameters(), + lr=self._lr, + weight_decay=self._lambda_weight_l2, + betas=(0.9, 0.999), + adam=False, + ) + elif self._optimizer_type == "adam": + self._optimizer = torch.optim.Adam( + self._q.parameters(), + lr=self._lr, + weight_decay=self._lambda_weight_l2, + ) + else: + raise Exception("Unknown optimizer type") + + # learning rate scheduler + if self._lr_scheduler: + self._scheduler = ( + transformers.get_cosine_with_hard_restarts_schedule_with_warmup( + self._optimizer, + num_warmup_steps=self._num_warmup_steps, + num_training_steps=self._training_iterations, + num_cycles=self._training_iterations // 10000, + ) + ) + + # one-hot zero tensors + self._action_trans_one_hot_zeros = torch.zeros( + ( + self._batch_size, + 1, + self._voxel_size, + self._voxel_size, + self._voxel_size, + ), + dtype=int, + device=device, + ) + self._action_rot_x_one_hot_zeros = torch.zeros( + (self._batch_size, self._num_rotation_classes), dtype=int, device=device + ) + self._action_rot_y_one_hot_zeros = torch.zeros( + (self._batch_size, self._num_rotation_classes), dtype=int, device=device + ) + self._action_rot_z_one_hot_zeros = torch.zeros( + (self._batch_size, self._num_rotation_classes), dtype=int, device=device + ) + self._action_grip_one_hot_zeros = torch.zeros( + (self._batch_size, 2), dtype=int, device=device + ) + self._action_ignore_collisions_one_hot_zeros = torch.zeros( + (self._batch_size, 2), dtype=int, device=device + ) + + # print total params + logging.info( + "# Q Params: %d" + % sum( + p.numel() + for name, p in self._q.named_parameters() + if p.requires_grad and "clip" not in name + ) + ) + else: + for param in self._q.parameters(): + param.requires_grad = False + + # load CLIP for encoding language goals during evaluation + model, _ = load_clip("RN50", jit=False) + self._clip_rn50 = build_model(model.state_dict()) + self._clip_rn50 = self._clip_rn50.float().to(device) + self._clip_rn50.eval() + del model + + self._voxelizer.to(device) + self._q.to(device) + + def _extract_crop(self, pixel_action, observation): + # Pixel action will now be (B, 2) + # observation = stack_on_channel(observation) + h = observation.shape[-1] + top_left_corner = torch.clamp( + pixel_action - self._image_crop_size // 2, 0, h - self._image_crop_size + ) + grid = self._grid_for_crop + top_left_corner.unsqueeze(1) + grid = ((grid / float(h)) * 2.0) - 1.0 # between -1 and 1 + # Used for cropping the images across a batch + # swap fro y x, to x, y + grid = torch.cat((grid[:, :, :, 1:2], grid[:, :, :, 0:1]), dim=-1) + crop = F.grid_sample(observation, grid, mode="nearest", align_corners=True) + return crop + + def _preprocess_inputs(self, replay_sample): + obs = [] + pcds = [] + self._crop_summary = [] + for n in self._camera_names: + rgb = replay_sample["%s_rgb" % n] + pcd = replay_sample["%s_point_cloud" % n] + + obs.append([rgb, pcd]) + pcds.append(pcd) + return obs, pcds + + def _act_preprocess_inputs(self, observation): + obs, pcds = [], [] + for n in self._camera_names: + rgb = observation["%s_rgb" % n] + pcd = observation["%s_point_cloud" % n] + + obs.append([rgb, pcd]) + pcds.append(pcd) + return obs, pcds + + def _get_value_from_voxel_index(self, q, voxel_idx): + b, c, d, h, w = q.shape + q_trans_flat = q.view(b, c, d * h * w) + flat_indicies = ( + voxel_idx[:, 0] * d * h + voxel_idx[:, 1] * h + voxel_idx[:, 2] + )[:, None].int() + highest_idxs = flat_indicies.unsqueeze(-1).repeat(1, c, 1) + chosen_voxel_values = q_trans_flat.gather(2, highest_idxs)[ + ..., 0 + ] # (B, trans + rot + grip) + return chosen_voxel_values + + def _get_value_from_rot_and_grip(self, rot_grip_q, rot_and_grip_idx): + q_rot = torch.stack( + torch.split( + rot_grip_q[:, :-2], int(360 // self._rotation_resolution), dim=1 + ), + dim=1, + ) # B, 3, 72 + q_grip = rot_grip_q[:, -2:] + rot_and_grip_values = torch.cat( + [ + q_rot[:, 0].gather(1, rot_and_grip_idx[:, 0:1]), + q_rot[:, 1].gather(1, rot_and_grip_idx[:, 1:2]), + q_rot[:, 2].gather(1, rot_and_grip_idx[:, 2:3]), + q_grip.gather(1, rot_and_grip_idx[:, 3:4]), + ], + -1, + ) + return rot_and_grip_values + + def _celoss(self, pred, labels): + return self._cross_entropy_loss(pred, labels.argmax(-1)) + + def _softmax_q_trans(self, q): + q_shape = q.shape + return F.softmax(q.reshape(q_shape[0], -1), dim=1).reshape(q_shape) + + def _softmax_q_rot_grip(self, q_rot_grip): + q_rot_x_flat = q_rot_grip[ + :, 0 * self._num_rotation_classes : 1 * self._num_rotation_classes + ] + q_rot_y_flat = q_rot_grip[ + :, 1 * self._num_rotation_classes : 2 * self._num_rotation_classes + ] + q_rot_z_flat = q_rot_grip[ + :, 2 * self._num_rotation_classes : 3 * self._num_rotation_classes + ] + q_grip_flat = q_rot_grip[:, 3 * self._num_rotation_classes :] + + q_rot_x_flat_softmax = F.softmax(q_rot_x_flat, dim=1) + q_rot_y_flat_softmax = F.softmax(q_rot_y_flat, dim=1) + q_rot_z_flat_softmax = F.softmax(q_rot_z_flat, dim=1) + q_grip_flat_softmax = F.softmax(q_grip_flat, dim=1) + + return torch.cat( + [ + q_rot_x_flat_softmax, + q_rot_y_flat_softmax, + q_rot_z_flat_softmax, + q_grip_flat_softmax, + ], + dim=1, + ) + + def _softmax_ignore_collision(self, q_collision): + q_collision_softmax = F.softmax(q_collision, dim=1) + return q_collision_softmax + + def update(self, step: int, replay_sample: dict) -> dict: + right_action_trans = replay_sample["right_trans_action_indicies"][ + :, self._layer * 3 : self._layer * 3 + 3 + ].int() + right_action_rot_grip = replay_sample["right_rot_grip_action_indicies"].int() + right_action_gripper_pose = replay_sample["right_gripper_pose"] + right_action_ignore_collisions = replay_sample["right_ignore_collisions"].int() + + left_action_trans = replay_sample["left_trans_action_indicies"][ + :, self._layer * 3 : self._layer * 3 + 3 + ].int() + left_action_rot_grip = replay_sample["left_rot_grip_action_indicies"].int() + left_action_gripper_pose = replay_sample["left_gripper_pose"] + left_action_ignore_collisions = replay_sample["left_ignore_collisions"].int() + + lang_goal_emb = replay_sample["lang_goal_emb"].float() + lang_token_embs = replay_sample["lang_token_embs"].float() + prev_layer_voxel_grid = replay_sample.get("prev_layer_voxel_grid", None) + prev_layer_bounds = replay_sample.get("prev_layer_bounds", None) + device = self._device + + bounds = self._coordinate_bounds.to(device) + if self._layer > 0: + right_cp = replay_sample[ + "right_attention_coordinate_layer_%d" % (self._layer - 1) + ] + + left_cp = replay_sample[ + "left_attention_coordinate_layer_%d" % (self._layer - 1) + ] + + right_bounds = torch.cat( + [right_cp - self._bounds_offset, right_cp + self._bounds_offset], dim=1 + ) + left_bounds = torch.cat( + [left_cp - self._bounds_offset, left_cp + self._bounds_offset], dim=1 + ) + + else: + right_bounds = bounds + left_bounds = bounds + + right_proprio = None + left_proprio = None + if self._include_low_dim_state: + right_proprio = replay_sample["right_low_dim_state"] + left_proprio = replay_sample["left_low_dim_state"] + + # ..TODO:: + # Can we add the coordinates of both robots? + # + + obs, pcd = self._preprocess_inputs(replay_sample) + + # batch size + bs = pcd[0].shape[0] + + # We can move the point cloud w.r.t to the other robot's cooridinate system + # similar to apply_se3_augmentation + # + + # SE(3) augmentation of point clouds and actions + if self._transform_augmentation: + from voxel import augmentation + + ( + right_action_trans, + right_action_rot_grip, + left_action_trans, + left_action_rot_grip, + pcd, + ) = augmentation.bimanual_apply_se3_augmentation( + pcd, + right_action_gripper_pose, + right_action_trans, + right_action_rot_grip, + left_action_gripper_pose, + left_action_trans, + left_action_rot_grip, + bounds, + self._layer, + self._transform_augmentation_xyz, + self._transform_augmentation_rpy, + self._transform_augmentation_rot_resolution, + self._voxel_size, + self._rotation_resolution, + self._device, + ) + else: + right_action_trans = right_action_trans.int() + left_action_trans = left_action_trans.int() + + proprio = torch.cat((right_proprio, left_proprio), dim=1) + + right_action = ( + right_action_trans, + right_action_rot_grip, + right_action_ignore_collisions, + ) + left_action = ( + left_action_trans, + left_action_rot_grip, + left_action_ignore_collisions, + ) + # forward pass + q, voxel_grid = self._q( + obs, + proprio, + pcd, + lang_goal_emb, + lang_token_embs, + bounds, + prev_layer_bounds, + prev_layer_voxel_grid, + ) + + ( + right_q_trans, + right_q_rot_grip, + right_q_collision, + left_q_trans, + left_q_rot_grip, + left_q_collision, + ) = q + + # argmax to choose best action + ( + right_coords, + right_rot_and_grip_indicies, + right_ignore_collision_indicies, + ) = self._q.choose_highest_action( + right_q_trans, right_q_rot_grip, right_q_collision + ) + + ( + left_coords, + left_rot_and_grip_indicies, + left_ignore_collision_indicies, + ) = self._q.choose_highest_action( + left_q_trans, left_q_rot_grip, left_q_collision + ) + + ( + right_q_trans_loss, + right_q_rot_loss, + right_q_grip_loss, + right_q_collision_loss, + ) = (0.0, 0.0, 0.0, 0.0) + left_q_trans_loss, left_q_rot_loss, left_q_grip_loss, left_q_collision_loss = ( + 0.0, + 0.0, + 0.0, + 0.0, + ) + + # translation one-hot + right_action_trans_one_hot = self._action_trans_one_hot_zeros.clone().detach() + left_action_trans_one_hot = self._action_trans_one_hot_zeros.clone().detach() + for b in range(bs): + right_gt_coord = right_action_trans[b, :].int() + right_action_trans_one_hot[ + b, :, right_gt_coord[0], right_gt_coord[1], right_gt_coord[2] + ] = 1 + left_gt_coord = left_action_trans[b, :].int() + left_action_trans_one_hot[ + b, :, left_gt_coord[0], left_gt_coord[1], left_gt_coord[2] + ] = 1 + + # translation loss + right_q_trans_flat = right_q_trans.view(bs, -1) + right_action_trans_one_hot_flat = right_action_trans_one_hot.view(bs, -1) + right_q_trans_loss = self._celoss( + right_q_trans_flat, right_action_trans_one_hot_flat + ) + left_q_trans_flat = left_q_trans.view(bs, -1) + left_action_trans_one_hot_flat = left_action_trans_one_hot.view(bs, -1) + left_q_trans_loss = self._celoss( + left_q_trans_flat, left_action_trans_one_hot_flat + ) + + q_trans_loss = right_q_trans_loss + left_q_trans_loss + + with_rot_and_grip = ( + len(right_rot_and_grip_indicies) > 0 and len(left_rot_and_grip_indicies) > 0 + ) + if with_rot_and_grip: + # rotation, gripper, and collision one-hots + right_action_rot_x_one_hot = self._action_rot_x_one_hot_zeros.clone() + right_action_rot_y_one_hot = self._action_rot_y_one_hot_zeros.clone() + right_action_rot_z_one_hot = self._action_rot_z_one_hot_zeros.clone() + right_action_grip_one_hot = self._action_grip_one_hot_zeros.clone() + right_action_ignore_collisions_one_hot = ( + self._action_ignore_collisions_one_hot_zeros.clone() + ) + + left_action_rot_x_one_hot = self._action_rot_x_one_hot_zeros.clone() + left_action_rot_y_one_hot = self._action_rot_y_one_hot_zeros.clone() + left_action_rot_z_one_hot = self._action_rot_z_one_hot_zeros.clone() + left_action_grip_one_hot = self._action_grip_one_hot_zeros.clone() + left_action_ignore_collisions_one_hot = ( + self._action_ignore_collisions_one_hot_zeros.clone() + ) + + for b in range(bs): + right_gt_rot_grip = right_action_rot_grip[b, :].int() + right_action_rot_x_one_hot[b, right_gt_rot_grip[0]] = 1 + right_action_rot_y_one_hot[b, right_gt_rot_grip[1]] = 1 + right_action_rot_z_one_hot[b, right_gt_rot_grip[2]] = 1 + right_action_grip_one_hot[b, right_gt_rot_grip[3]] = 1 + + right_gt_ignore_collisions = right_action_ignore_collisions[b, :].int() + right_action_ignore_collisions_one_hot[ + b, right_gt_ignore_collisions[0] + ] = 1 + + left_gt_rot_grip = left_action_rot_grip[b, :].int() + left_action_rot_x_one_hot[b, left_gt_rot_grip[0]] = 1 + left_action_rot_y_one_hot[b, left_gt_rot_grip[1]] = 1 + left_action_rot_z_one_hot[b, left_gt_rot_grip[2]] = 1 + left_action_grip_one_hot[b, left_gt_rot_grip[3]] = 1 + + left_gt_ignore_collisions = left_action_ignore_collisions[b, :].int() + left_action_ignore_collisions_one_hot[ + b, left_gt_ignore_collisions[0] + ] = 1 + + # flatten predictions + right_q_rot_x_flat = right_q_rot_grip[ + :, 0 * self._num_rotation_classes : 1 * self._num_rotation_classes + ] + right_q_rot_y_flat = right_q_rot_grip[ + :, 1 * self._num_rotation_classes : 2 * self._num_rotation_classes + ] + right_q_rot_z_flat = right_q_rot_grip[ + :, 2 * self._num_rotation_classes : 3 * self._num_rotation_classes + ] + right_q_grip_flat = right_q_rot_grip[:, 3 * self._num_rotation_classes :] + right_q_ignore_collisions_flat = right_q_collision + + left_q_rot_x_flat = left_q_rot_grip[ + :, 0 * self._num_rotation_classes : 1 * self._num_rotation_classes + ] + left_q_rot_y_flat = left_q_rot_grip[ + :, 1 * self._num_rotation_classes : 2 * self._num_rotation_classes + ] + left_q_rot_z_flat = left_q_rot_grip[ + :, 2 * self._num_rotation_classes : 3 * self._num_rotation_classes + ] + left_q_grip_flat = left_q_rot_grip[:, 3 * self._num_rotation_classes :] + left_q_ignore_collisions_flat = left_q_collision + + # rotation loss + right_q_rot_loss += self._celoss( + right_q_rot_x_flat, right_action_rot_x_one_hot + ) + right_q_rot_loss += self._celoss( + right_q_rot_y_flat, right_action_rot_y_one_hot + ) + right_q_rot_loss += self._celoss( + right_q_rot_z_flat, right_action_rot_z_one_hot + ) + + left_q_rot_loss += self._celoss( + left_q_rot_x_flat, left_action_rot_x_one_hot + ) + left_q_rot_loss += self._celoss( + left_q_rot_y_flat, left_action_rot_y_one_hot + ) + left_q_rot_loss += self._celoss( + left_q_rot_z_flat, left_action_rot_z_one_hot + ) + + # gripper loss + right_q_grip_loss += self._celoss( + right_q_grip_flat, right_action_grip_one_hot + ) + left_q_grip_loss += self._celoss(left_q_grip_flat, left_action_grip_one_hot) + + # collision loss + right_q_collision_loss += self._celoss( + right_q_ignore_collisions_flat, right_action_ignore_collisions_one_hot + ) + left_q_collision_loss += self._celoss( + left_q_ignore_collisions_flat, left_action_ignore_collisions_one_hot + ) + + q_trans_loss = right_q_trans_loss + left_q_trans_loss + q_rot_loss = right_q_rot_loss + left_q_rot_loss + q_grip_loss = right_q_grip_loss + left_q_grip_loss + q_collision_loss = right_q_collision_loss + left_q_collision_loss + + combined_losses = ( + (q_trans_loss * self._trans_loss_weight) + + (q_rot_loss * self._rot_loss_weight) + + (q_grip_loss * self._grip_loss_weight) + + (q_collision_loss * self._collision_loss_weight) + ) + total_loss = combined_losses.mean() + + self._optimizer.zero_grad() + total_loss.backward() + self._optimizer.step() + + self._summaries = { + "losses/total_loss": total_loss, + "losses/trans_loss": q_trans_loss.mean(), + "losses/rot_loss": q_rot_loss.mean() if with_rot_and_grip else 0.0, + "losses/grip_loss": q_grip_loss.mean() if with_rot_and_grip else 0.0, + "losses/right/trans_loss": q_trans_loss.mean(), + "losses/right/rot_loss": q_rot_loss.mean() if with_rot_and_grip else 0.0, + "losses/right/grip_loss": q_grip_loss.mean() if with_rot_and_grip else 0.0, + "losses/right/collision_loss": q_collision_loss.mean() + if with_rot_and_grip + else 0.0, + "losses/left/trans_loss": q_trans_loss.mean(), + "losses/left/rot_loss": q_rot_loss.mean() if with_rot_and_grip else 0.0, + "losses/left/grip_loss": q_grip_loss.mean() if with_rot_and_grip else 0.0, + "losses/left/collision_loss": q_collision_loss.mean() + if with_rot_and_grip + else 0.0, + "losses/collision_loss": q_collision_loss.mean() + if with_rot_and_grip + else 0.0, + } + + if self._lr_scheduler: + self._scheduler.step() + self._summaries["learning_rate"] = self._scheduler.get_last_lr()[0] + + self._vis_voxel_grid = voxel_grid[0] + self._right_vis_translation_qvalue = self._softmax_q_trans(right_q_trans[0]) + self._right_vis_max_coordinate = right_coords[0] + self._right_vis_gt_coordinate = right_action_trans[0] + + self._left_vis_translation_qvalue = self._softmax_q_trans(left_q_trans[0]) + self._left_vis_max_coordinate = left_coords[0] + self._left_vis_gt_coordinate = left_action_trans[0] + + # Note: PerAct doesn't use multi-layer voxel grids like C2FARM + # stack prev_layer_voxel_grid(s) from previous layers into a list + if prev_layer_voxel_grid is None: + prev_layer_voxel_grid = [voxel_grid] + else: + prev_layer_voxel_grid = prev_layer_voxel_grid + [voxel_grid] + + # stack prev_layer_bound(s) from previous layers into a list + if prev_layer_bounds is None: + prev_layer_bounds = [self._coordinate_bounds.repeat(bs, 1)] + else: + prev_layer_bounds = prev_layer_bounds + [bounds] + + return { + "total_loss": total_loss, + "prev_layer_voxel_grid": prev_layer_voxel_grid, + "prev_layer_bounds": prev_layer_bounds, + } + + def act(self, step: int, observation: dict, deterministic=False) -> ActResult: + deterministic = True + bounds = self._coordinate_bounds + prev_layer_voxel_grid = observation.get("prev_layer_voxel_grid", None) + prev_layer_bounds = observation.get("prev_layer_bounds", None) + lang_goal_tokens = observation.get("lang_goal_tokens", None).long() + + # extract CLIP language embs + with torch.no_grad(): + lang_goal_tokens = lang_goal_tokens.to(device=self._device) + ( + lang_goal_emb, + lang_token_embs, + ) = self._clip_rn50.encode_text_with_embeddings(lang_goal_tokens[0]) + + # voxelization resolution + res = (bounds[:, 3:] - bounds[:, :3]) / self._voxel_size + max_rot_index = int(360 // self._rotation_resolution) + right_proprio = None + left_proprio = None + + if self._include_low_dim_state: + right_proprio = observation["right_low_dim_state"] + left_proprio = observation["left_low_dim_state"] + right_proprio = right_proprio[0].to(self._device) + left_proprio = left_proprio[0].to(self._device) + + obs, pcd = self._act_preprocess_inputs(observation) + + # correct batch size and device + obs = [[o[0][0].to(self._device), o[1][0].to(self._device)] for o in obs] + + pcd = [p[0].to(self._device) for p in pcd] + lang_goal_emb = lang_goal_emb.to(self._device) + lang_token_embs = lang_token_embs.to(self._device) + bounds = torch.as_tensor(bounds, device=self._device) + prev_layer_voxel_grid = ( + prev_layer_voxel_grid.to(self._device) + if prev_layer_voxel_grid is not None + else None + ) + prev_layer_bounds = ( + prev_layer_bounds.to(self._device) + if prev_layer_bounds is not None + else None + ) + + proprio = torch.cat((right_proprio, left_proprio), dim=1) + + # inference + ( + right_q_trans, + right_q_rot_grip, + right_q_ignore_collisions, + left_q_trans, + left_q_rot_grip, + left_q_ignore_collisions, + ), vox_grid = self._q( + obs, + proprio, + pcd, + lang_goal_emb, + lang_token_embs, + bounds, + prev_layer_bounds, + prev_layer_voxel_grid, + ) + + # softmax Q predictions + right_q_trans = self._softmax_q_trans(right_q_trans) + left_q_trans = self._softmax_q_trans(left_q_trans) + + if right_q_rot_grip is not None: + right_q_rot_grip = self._softmax_q_rot_grip(right_q_rot_grip) + + if left_q_rot_grip is not None: + left_q_rot_grip = self._softmax_q_rot_grip(left_q_rot_grip) + + if right_q_ignore_collisions is not None: + right_q_ignore_collisions = self._softmax_ignore_collision( + right_q_ignore_collisions + ) + + if left_q_ignore_collisions is not None: + left_q_ignore_collisions = self._softmax_ignore_collision( + left_q_ignore_collisions + ) + + # argmax Q predictions + ( + right_coords, + right_rot_and_grip_indicies, + right_ignore_collisions, + ) = self._q.choose_highest_action( + right_q_trans, right_q_rot_grip, right_q_ignore_collisions + ) + ( + left_coords, + left_rot_and_grip_indicies, + left_ignore_collisions, + ) = self._q.choose_highest_action( + left_q_trans, left_q_rot_grip, left_q_ignore_collisions + ) + + if right_q_rot_grip is not None: + right_rot_grip_action = right_rot_and_grip_indicies + if right_q_ignore_collisions is not None: + right_ignore_collisions_action = right_ignore_collisions.int() + + if left_q_rot_grip is not None: + left_rot_grip_action = left_rot_and_grip_indicies + if left_q_ignore_collisions is not None: + left_ignore_collisions_action = left_ignore_collisions.int() + + right_coords = right_coords.int() + left_coords = left_coords.int() + + right_attention_coordinate = bounds[:, :3] + res * right_coords + res / 2 + left_attention_coordinate = bounds[:, :3] + res * left_coords + res / 2 + + # stack prev_layer_voxel_grid(s) into a list + # NOTE: PerAct doesn't used multi-layer voxel grids like C2FARM + if prev_layer_voxel_grid is None: + prev_layer_voxel_grid = [vox_grid] + else: + prev_layer_voxel_grid = prev_layer_voxel_grid + [vox_grid] + + if prev_layer_bounds is None: + prev_layer_bounds = [bounds] + else: + prev_layer_bounds = prev_layer_bounds + [bounds] + + observation_elements = { + "right_attention_coordinate": right_attention_coordinate, + "left_attention_coordinate": left_attention_coordinate, + "prev_layer_voxel_grid": prev_layer_voxel_grid, + "prev_layer_bounds": prev_layer_bounds, + } + info = { + "voxel_grid_depth%d" % self._layer: vox_grid, + "right_q_depth%d" % self._layer: right_q_trans, + "right_voxel_idx_depth%d" % self._layer: right_coords, + "left_q_depth%d" % self._layer: left_q_trans, + "left_voxel_idx_depth%d" % self._layer: left_coords, + } + self._act_voxel_grid = vox_grid[0] + self._right_act_max_coordinate = right_coords[0] + self._right_act_qvalues = right_q_trans[0].detach() + self._left_act_max_coordinate = left_coords[0] + self._left_act_qvalues = left_q_trans[0].detach() + + action = ( + right_coords, + right_rot_grip_action, + right_ignore_collisions, + left_coords, + left_rot_grip_action, + left_ignore_collisions, + ) + + return ActResult(action, observation_elements=observation_elements, info=info) + + def update_summaries(self) -> List[Summary]: + voxel_grid = self._vis_voxel_grid.detach().cpu().numpy() + summaries = [] + summaries.append( + ImageSummary( + "%s/right_update_qattention" % self._name, + transforms.ToTensor()( + visualise_voxel( + voxel_grid, + self._right_vis_translation_qvalue.detach().cpu().numpy(), + self._right_vis_max_coordinate.detach().cpu().numpy(), + self._right_vis_gt_coordinate.detach().cpu().numpy(), + ) + ), + ) + ) + summaries.append( + ImageSummary( + "%s/left_update_qattention" % self._name, + transforms.ToTensor()( + visualise_voxel( + voxel_grid, + self._left_vis_translation_qvalue.detach().cpu().numpy(), + self._left_vis_max_coordinate.detach().cpu().numpy(), + self._left_vis_gt_coordinate.detach().cpu().numpy(), + ) + ), + ) + ) + for n, v in self._summaries.items(): + summaries.append(ScalarSummary("%s/%s" % (self._name, n), v)) + + for name, crop in self._crop_summary: + crops = (torch.cat(torch.split(crop, 3, dim=1), dim=3) + 1.0) / 2.0 + summaries.extend([ImageSummary("%s/crops/%s" % (self._name, name), crops)]) + + for tag, param in self._q.named_parameters(): + # assert not torch.isnan(param.grad.abs() <= 1.0).all() + summaries.append( + HistogramSummary("%s/gradient/%s" % (self._name, tag), param.grad) + ) + summaries.append( + HistogramSummary("%s/weight/%s" % (self._name, tag), param.data) + ) + + return summaries + + def act_summaries(self) -> List[Summary]: + voxel_grid = self._act_voxel_grid.cpu().numpy() + right_q_attention = self._right_act_qvalues.cpu().numpy() + right_highlight_coordinate = self._right_act_max_coordinate.cpu().numpy() + left_q_attention = self._left_act_qvalues.cpu().numpy() + left_highlight_coordinate = self._left_act_max_coordinate.cpu().numpy() + try: + right_visualization = visualise_voxel( + voxel_grid, right_q_attention, right_highlight_coordinate + ) + left_visualization = visualise_voxel( + voxel_grid, left_q_attention, left_highlight_coordinate + ) + except Exception as exc: + logging.warning("Skipping act_summaries voxel render: %s", exc) + return [] + + return [ + ImageSummary( + f"{self._name}/right_act_Qattention", + transforms.ToTensor()(right_visualization), + ), + ImageSummary( + f"{self._name}/left_act_Qattention", + transforms.ToTensor()(left_visualization), + ), + ] + + def load_weights(self, savedir: str): + device = ( + self._device + if not self._training + else torch.device("cuda:%d" % self._device) + ) + weight_file = os.path.join(savedir, "%s.pt" % self._name) + state_dict = torch.load(weight_file, map_location=device) + + # load only keys that are in the current model + merged_state_dict = self._q.state_dict() + for k, v in state_dict.items(): + if not self._training: + k = k.replace("_qnet.module", "_qnet") + if k in merged_state_dict: + merged_state_dict[k] = v + else: + if "_voxelizer" not in k: + logging.warning("key %s not found in checkpoint" % k) + if not self._training: + # reshape voxelizer weights + b = merged_state_dict["_voxelizer._ones_max_coords"].shape[0] + merged_state_dict["_voxelizer._ones_max_coords"] = merged_state_dict[ + "_voxelizer._ones_max_coords" + ][0:1] + flat_shape = merged_state_dict["_voxelizer._flat_output"].shape[0] + merged_state_dict["_voxelizer._flat_output"] = merged_state_dict[ + "_voxelizer._flat_output" + ][0 : flat_shape // b] + merged_state_dict["_voxelizer._tiled_batch_indices"] = merged_state_dict[ + "_voxelizer._tiled_batch_indices" + ][0:1] + merged_state_dict["_voxelizer._index_grid"] = merged_state_dict[ + "_voxelizer._index_grid" + ][0:1] + self._q.load_state_dict(merged_state_dict) + print("loaded weights from %s" % weight_file) + + def save_weights(self, savedir: str): + torch.save(self._q.state_dict(), os.path.join(savedir, "%s.pt" % self._name))