Upload folder using huggingface_hub
Browse files
q1_simpleworld_cem/dreamer_model_trainer.py
ADDED
|
@@ -0,0 +1,702 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
os.environ["MUJOCO_GL"] = "egl"
|
| 5 |
+
# Ensure the vendored LIBERO package is importable even if it hasn't been pip-installed.
|
| 6 |
+
# Hydra may change the working directory, so we resolve relative to this file.
|
| 7 |
+
_REPO_ROOT = Path(__file__).resolve().parents[1]
|
| 8 |
+
_LIBERO_ROOT = _REPO_ROOT / "LIBERO"
|
| 9 |
+
if _LIBERO_ROOT.exists():
|
| 10 |
+
sys.path.insert(0, str(_LIBERO_ROOT))
|
| 11 |
+
|
| 12 |
+
import dill
|
| 13 |
+
from omegaconf import DictConfig, OmegaConf
|
| 14 |
+
import hydra
|
| 15 |
+
import torch
|
| 16 |
+
from torchvision import transforms
|
| 17 |
+
import h5py
|
| 18 |
+
import numpy as np
|
| 19 |
+
|
| 20 |
+
# Support both `python hw2/dreamer_model_trainer.py` (cwd=hw2) and
|
| 21 |
+
# `python -m hw2.dreamer_model_trainer` / importing as a package.
|
| 22 |
+
try:
|
| 23 |
+
from .dreamerV3 import DreamerV3
|
| 24 |
+
from .simple_world_model import SimpleWorldModel
|
| 25 |
+
from .planning import CEMPlanner, PolicyPlanner, RandomPlanner
|
| 26 |
+
except ImportError:
|
| 27 |
+
from dreamerV3 import DreamerV3
|
| 28 |
+
from simple_world_model import SimpleWorldModel
|
| 29 |
+
from planning import CEMPlanner, PolicyPlanner, RandomPlanner
|
| 30 |
+
import random
|
| 31 |
+
from collections import deque
|
| 32 |
+
from datasets import load_dataset
|
| 33 |
+
import datasets
|
| 34 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# Factory function to instantiate the correct model
|
| 39 |
+
def create_model(model_type, img_shape, action_dim, device, cfg):
|
| 40 |
+
"""
|
| 41 |
+
Factory function to create a world model based on the specified type.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
model_type: 'dreamer' or 'simple'
|
| 45 |
+
img_shape: Image shape [C, H, W]
|
| 46 |
+
action_dim: Dimensionality of actions
|
| 47 |
+
device: torch device
|
| 48 |
+
cfg: Configuration object
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
model: Instantiated model
|
| 52 |
+
"""
|
| 53 |
+
if model_type.lower() == 'dreamer':
|
| 54 |
+
model = DreamerV3(obs_shape=img_shape,
|
| 55 |
+
action_dim=action_dim, cfg=cfg).to(device)
|
| 56 |
+
elif model_type.lower() == 'simple':
|
| 57 |
+
model = SimpleWorldModel(
|
| 58 |
+
action_dim=action_dim, pose_dim=7, hidden_dim=256, cfg=cfg).to(device)
|
| 59 |
+
else:
|
| 60 |
+
raise ValueError(
|
| 61 |
+
f"Unknown model_type: {model_type}. Choose 'dreamer' or 'simple'.")
|
| 62 |
+
|
| 63 |
+
return model
|
| 64 |
+
|
| 65 |
+
def batch_data(dataset, batch_size, cfg):
|
| 66 |
+
"""
|
| 67 |
+
Utility function to batch data from the dataset with a fixed sequence length.
|
| 68 |
+
Args:
|
| 69 |
+
dataset: Dataset object that returns (images, actions, rewards, dones, poses)
|
| 70 |
+
batch_size: Number of sequences per batch
|
| 71 |
+
sequence_length: Length of each sequence (T)
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
A generator that yields batches of (images, actions, rewards, dones, poses) with shapes:
|
| 75 |
+
- images: (B, T, C, H, W)
|
| 76 |
+
- actions: (B, T, 7)
|
| 77 |
+
- rewards: (B, T)
|
| 78 |
+
- dones: (B, T)
|
| 79 |
+
- poses: (B, T, 7)
|
| 80 |
+
"""
|
| 81 |
+
# Collect sequences for the batch with fixed sequence length
|
| 82 |
+
list_images, list_actions, list_rewards, list_dones, list_poses = [], [], [], [], []
|
| 83 |
+
# padding short trajectories to max_seq_len with zeros
|
| 84 |
+
for img, act, rew, don, pos in dataset:
|
| 85 |
+
list_images += [img[i:i+cfg.policy.sequence_length] for i in range(0, len(img)-cfg.policy.sequence_length+1, cfg.policy.sequence_length)]
|
| 86 |
+
list_actions += [act[i:i+cfg.policy.sequence_length] for i in range(0, len(act)-cfg.policy.sequence_length+1, cfg.policy.sequence_length)]
|
| 87 |
+
list_rewards += [rew[i:i+cfg.policy.sequence_length] for i in range(0, len(rew)-cfg.policy.sequence_length+1, cfg.policy.sequence_length)]
|
| 88 |
+
list_dones += [don[i:i+cfg.policy.sequence_length] for i in range(0, len(don)-cfg.policy.sequence_length+1, cfg.policy.sequence_length)]
|
| 89 |
+
list_poses += [pos[i:i+cfg.policy.sequence_length] for i in range(0, len(pos)-cfg.policy.sequence_length+1, cfg.policy.sequence_length)]
|
| 90 |
+
images = torch.stack(list_images) # (B, T, H, W, C)
|
| 91 |
+
actions = torch.stack(list_actions) # (B, T, action_dim)
|
| 92 |
+
rewards = torch.stack(list_rewards) # (B, T)
|
| 93 |
+
dones = torch.stack(list_dones) # (B, T)
|
| 94 |
+
poses = torch.stack(list_poses) # (B, T, pose_dim)
|
| 95 |
+
images = images.permute(0, 1, 4, 2, 3).to(cfg.device) # (B, T, H, W, C) -> (B, T, C, H, W)
|
| 96 |
+
actions = actions.float().to(cfg.device) # (B, T, action_dim)
|
| 97 |
+
rewards = rewards.float().to(cfg.device) # (B, T)
|
| 98 |
+
dones = dones.float().to(cfg.device) # (B, T)
|
| 99 |
+
poses = poses.float().to(cfg.device) # (B, T, pose_dim)
|
| 100 |
+
# for img, act, rew, don, pos in dataset:
|
| 101 |
+
# list_images.append(img) # (T, H, W, C)
|
| 102 |
+
# list_actions.append(act) # (T, action_dim)
|
| 103 |
+
# list_rewards.append(rew) # (T,)
|
| 104 |
+
# list_dones.append(don) # (T,)
|
| 105 |
+
# list_poses.append(pos) # (T, pose_dim)
|
| 106 |
+
# images = pad_sequence(list_images, batch_first=True, padding_value=0.0).permute(0, 1, 4, 2, 3).to(cfg.device) # (B, T, H, W, C) -> (B, T, C, H, W)
|
| 107 |
+
# actions = pad_sequence(list_actions, batch_first=True, padding_value=0.0).float().to(cfg.device) # (B, T, action_dim)
|
| 108 |
+
# rewards = pad_sequence(list_rewards, batch_first=True, padding_value=0.0).float().to(cfg.device) # (B, T)
|
| 109 |
+
# dones = pad_sequence(list_dones, batch_first=True, padding_value=0.0).float().to(cfg.device) # (B, T)
|
| 110 |
+
# poses = pad_sequence(list_poses, batch_first=True, padding_value=0.0).float().to(cfg.device) # (B, T, pose_dim)
|
| 111 |
+
print(f"[info] Batched data into tensors with shapes: images={images.shape}, actions={actions.shape}, rewards={rewards.shape}, dones={dones.shape}, poses={poses.shape}")
|
| 112 |
+
out_dataset = torch.utils.data.TensorDataset(images, actions, rewards, dones, poses)
|
| 113 |
+
print(f"[info] Created DataLoader with {len(out_dataset)} samples")
|
| 114 |
+
return torch.utils.data.DataLoader(out_dataset, batch_size=batch_size, shuffle=True)
|
| 115 |
+
|
| 116 |
+
class ModelTrainingWrapper:
|
| 117 |
+
"""
|
| 118 |
+
Wrapper to provide unified interface for training different world models.
|
| 119 |
+
Handles differences in forward passes and loss computation between models.
|
| 120 |
+
"""
|
| 121 |
+
|
| 122 |
+
def __init__(self, model, model_type, device):
|
| 123 |
+
self.model = model
|
| 124 |
+
self.model_type = model_type.lower()
|
| 125 |
+
self.device = device
|
| 126 |
+
|
| 127 |
+
def forward_pass(self, images, poses, actions):
|
| 128 |
+
"""
|
| 129 |
+
Unified forward pass that works with both model types.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
images: Image tensor (B, T, H, W, C) or None for simple model
|
| 133 |
+
poses: Pose tensor (B, T, 7)
|
| 134 |
+
actions: Action tensor (B, T, 7)
|
| 135 |
+
|
| 136 |
+
Returns:
|
| 137 |
+
output: Model output (format depends on model type)
|
| 138 |
+
"""
|
| 139 |
+
if self.model_type == 'dreamer':
|
| 140 |
+
# DreamerV3 returns a dict of rollout predictions.
|
| 141 |
+
return self.model(images, actions)
|
| 142 |
+
elif self.model_type == 'simple':
|
| 143 |
+
# SimpleWorldModel expects normalized inputs
|
| 144 |
+
pred_pose_seq, pred_reward_seq = self.model(poses, actions)
|
| 145 |
+
return {
|
| 146 |
+
'pred_poses': pred_pose_seq,
|
| 147 |
+
'pred_rewards': pred_reward_seq
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
def compute_loss(self, model_out, normalized_images, rewards, dones, poses, actions):
|
| 151 |
+
"""
|
| 152 |
+
Compute loss in a way that works for both model types.
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
output: Output from forward_pass
|
| 156 |
+
normalized_images: Image tensor
|
| 157 |
+
rewards: Reward tensor
|
| 158 |
+
dones: Done tensor
|
| 159 |
+
poses: Pose tensor (used for SimpleWorldModel)
|
| 160 |
+
actions: Action tensor (used for SimpleWorldModel)
|
| 161 |
+
pred_coeff, dyn_coeff, rep_coeff: Loss coefficients (used for DreamerV3)
|
| 162 |
+
|
| 163 |
+
Returns:
|
| 164 |
+
losses: Dictionary with loss information
|
| 165 |
+
"""
|
| 166 |
+
if self.model_type == 'dreamer':
|
| 167 |
+
# Use DreamerV3 loss computation
|
| 168 |
+
if not isinstance(model_out, dict):
|
| 169 |
+
raise ValueError(
|
| 170 |
+
f"DreamerV3 forward must return a dict, got {type(model_out)}"
|
| 171 |
+
)
|
| 172 |
+
return self.model.compute_loss(model_out, normalized_images, rewards, dones, self.device)
|
| 173 |
+
elif self.model_type == 'simple':
|
| 174 |
+
# TODO: Part 1.2 - Implement SimpleWorldModel training loss
|
| 175 |
+
# Compute MSE loss between predicted and target poses/rewards
|
| 176 |
+
# Ensure rewards are always (B, T)
|
| 177 |
+
pred_poses = model_out['pred_poses']
|
| 178 |
+
pred_rewards = model_out['pred_rewards']
|
| 179 |
+
if pred_rewards is None:
|
| 180 |
+
raise ValueError("SimpleWorldModel path expected pred_rewards, got None")
|
| 181 |
+
if pred_rewards.dim() == 3 and pred_rewards.shape[-1] == 1:
|
| 182 |
+
pred_rewards = pred_rewards.squeeze(-1)
|
| 183 |
+
# Check shape of pred_poses and pred_rewards
|
| 184 |
+
# print(f"Predicted poses shape: {pred_poses.shape}, Predicted rewards shape: {pred_rewards.shape}")
|
| 185 |
+
if pred_poses.dim() == 2:
|
| 186 |
+
print(
|
| 187 |
+
f"Warning: Predicted poses have shape {pred_poses.shape}, expected (B, T, 7). Check model output formatting.")
|
| 188 |
+
raise ValueError("SimpleWorldModel output must be (B, T, 7); got 2D tensor")
|
| 189 |
+
elif pred_poses.dim() == 3 and pred_poses.shape[2] != 7:
|
| 190 |
+
print(
|
| 191 |
+
f"Warning: Predicted poses have last dimension {pred_poses.shape[2]}, expected 7. Check model output formatting.")
|
| 192 |
+
raise ValueError("SimpleWorldModel pose dim must be 7")
|
| 193 |
+
elif pred_poses.dim() == 3 and pred_poses.shape[2] == 7:
|
| 194 |
+
B, T, _ = pred_poses.shape
|
| 195 |
+
|
| 196 |
+
# Align shapes: predict at times [0..T-2] to match targets [1..T-1]
|
| 197 |
+
pred_pose_seq = pred_poses[:, : T - 1, :]
|
| 198 |
+
tgt_pose_seq = poses[:, 1:, :]
|
| 199 |
+
|
| 200 |
+
# Rewards are (B, T). Use the same alignment.
|
| 201 |
+
pred_rew_seq = pred_rewards
|
| 202 |
+
tgt_rew_seq = rewards
|
| 203 |
+
|
| 204 |
+
loss_dict = self.model.compute_loss(
|
| 205 |
+
pred_pose_seq,
|
| 206 |
+
pred_rew_seq,
|
| 207 |
+
target_pose=tgt_pose_seq,
|
| 208 |
+
target_reward=tgt_rew_seq,
|
| 209 |
+
)
|
| 210 |
+
return loss_dict
|
| 211 |
+
|
| 212 |
+
raise ValueError(f"Unexpected pred_poses shape: {pred_poses.shape}")
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class LIBERODataset(torch.utils.data.Dataset):
|
| 216 |
+
def __init__(self, data_dir, transform=None):
|
| 217 |
+
self.data_dir = data_dir
|
| 218 |
+
self.transform = transform
|
| 219 |
+
|
| 220 |
+
# crawl the data_dir and build the index map for h5py files
|
| 221 |
+
self.index_map = []
|
| 222 |
+
for root, dirs, files in os.walk(self.data_dir):
|
| 223 |
+
for file in files:
|
| 224 |
+
if file.endswith('.hdf5') or file.endswith('.h5'):
|
| 225 |
+
file_path = os.path.join(root, file)
|
| 226 |
+
with h5py.File(file_path, 'r') as f:
|
| 227 |
+
for demo_key in f['data'].keys():
|
| 228 |
+
self.index_map.append((file_path, demo_key))
|
| 229 |
+
|
| 230 |
+
def __len__(self):
|
| 231 |
+
return len(self.index_map)
|
| 232 |
+
|
| 233 |
+
def __getitem__(self, idx):
|
| 234 |
+
# Load your data here
|
| 235 |
+
# data_path = os.path.join(self.data_dir, self.data_files[idx])
|
| 236 |
+
file_path, demo_key = self.index_map[idx]
|
| 237 |
+
# data_list = []
|
| 238 |
+
with h5py.File(file_path, 'r') as f:
|
| 239 |
+
# for demo in f['data'].keys():
|
| 240 |
+
demo = f['data'][demo_key]
|
| 241 |
+
image = torch.from_numpy(
|
| 242 |
+
f['data'][demo_key]['obs']['agentview_rgb'][()])
|
| 243 |
+
action = torch.from_numpy(f['data'][demo_key]['actions'][()])
|
| 244 |
+
dones = torch.from_numpy(f['data'][demo_key]['dones'][()])
|
| 245 |
+
rewards = torch.from_numpy(f['data'][demo_key]['rewards'][()])
|
| 246 |
+
# poses = torch.from_numpy(f['data'][demo_key]['robot_states'][()])
|
| 247 |
+
poses = torch.from_numpy(np.concatenate((f['data'][demo_key]['obs']["ee_pos"],
|
| 248 |
+
f['data'][demo_key]['obs']["ee_ori"][:, :3],
|
| 249 |
+
(f['data'][demo_key]['obs']["gripper_states"][:, :1])), axis=-1))
|
| 250 |
+
# Note: Images are returned in channel-last format (T, H, W, C)
|
| 251 |
+
# Conversion to channel-first (T, C, H, W) happens in the training loop
|
| 252 |
+
# Return the image and label if needed
|
| 253 |
+
return image, action, rewards, dones, poses
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
class CircularBufferDataset(torch.utils.data.Dataset):
|
| 257 |
+
"""Circular buffer dataset that holds up to max_trajectories.
|
| 258 |
+
When full, oldest trajectories are overwritten.
|
| 259 |
+
"""
|
| 260 |
+
|
| 261 |
+
def __init__(self, cfg=None, data_dir=None):
|
| 262 |
+
self.trajectories = []
|
| 263 |
+
self.write_idx = 0
|
| 264 |
+
self._cfg = cfg
|
| 265 |
+
|
| 266 |
+
if data_dir is None:
|
| 267 |
+
data_dir = getattr(cfg, 'data_dir', None)
|
| 268 |
+
if data_dir is None and cfg is not None:
|
| 269 |
+
data_dir = getattr(
|
| 270 |
+
getattr(cfg, 'dataset', None), 'data_dir', None)
|
| 271 |
+
if data_dir is None:
|
| 272 |
+
data_dir = '/network/projects/real-g-grp/libero/targets_clean/'
|
| 273 |
+
|
| 274 |
+
if cfg.dataset.load_dataset:
|
| 275 |
+
dataset = LIBERODatasetLeRobot(
|
| 276 |
+
repo_id=cfg.dataset.to_name,
|
| 277 |
+
transform=transforms.ToTensor(),
|
| 278 |
+
cfg=cfg
|
| 279 |
+
)
|
| 280 |
+
else:
|
| 281 |
+
data_dir = getattr(
|
| 282 |
+
cfg.dataset, 'data_dir', '/network/projects/real-g-grp/libero/targets_clean/')
|
| 283 |
+
dataset = LIBERODataset(data_dir, transform=transforms.ToTensor())
|
| 284 |
+
num_to_load = min(len(dataset), self._cfg.dataset.buffer_size)
|
| 285 |
+
if num_to_load == 0:
|
| 286 |
+
return
|
| 287 |
+
|
| 288 |
+
indices = np.random.choice(
|
| 289 |
+
len(dataset), size=num_to_load, replace=False)
|
| 290 |
+
for idx in range(num_to_load):
|
| 291 |
+
images, actions, rewards, dones, poses = dataset[idx]
|
| 292 |
+
|
| 293 |
+
# dones = np.zeros_like(rewards)
|
| 294 |
+
# dones[-1] = 1
|
| 295 |
+
|
| 296 |
+
self.add_trajectory(
|
| 297 |
+
np.array(images),
|
| 298 |
+
np.array(actions),
|
| 299 |
+
np.array(rewards),
|
| 300 |
+
np.array(dones),
|
| 301 |
+
np.array(poses)
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
def add_trajectory(self, images, actions, rewards, dones, poses):
|
| 305 |
+
"""Add a trajectory to the buffer. Overwrites oldest if full."""
|
| 306 |
+
trajectory = {
|
| 307 |
+
'images': torch.from_numpy(images),
|
| 308 |
+
'actions': torch.from_numpy(actions),
|
| 309 |
+
'rewards': torch.from_numpy(rewards),
|
| 310 |
+
'dones': torch.from_numpy(dones),
|
| 311 |
+
'poses': torch.from_numpy(poses)
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
if len(self.trajectories) < self._cfg.dataset.buffer_size:
|
| 315 |
+
self.trajectories.append(trajectory)
|
| 316 |
+
else:
|
| 317 |
+
# Overwrite oldest trajectory
|
| 318 |
+
self.trajectories[self.write_idx] = trajectory
|
| 319 |
+
self.write_idx = (
|
| 320 |
+
self.write_idx + 1) % self._cfg.dataset.buffer_size
|
| 321 |
+
|
| 322 |
+
def get_trajectory(self, idx):
|
| 323 |
+
trajectory = []
|
| 324 |
+
traj = self.trajectories[idx]
|
| 325 |
+
for i in range(len(traj['images'])):
|
| 326 |
+
step_dict = {
|
| 327 |
+
'observation': traj['images'][i],
|
| 328 |
+
'action': traj['actions'][i],
|
| 329 |
+
'reward': traj['rewards'][i],
|
| 330 |
+
'done': traj['dones'][i],
|
| 331 |
+
'pose': traj['poses'][i]
|
| 332 |
+
}
|
| 333 |
+
trajectory.append(step_dict)
|
| 334 |
+
return trajectory
|
| 335 |
+
|
| 336 |
+
def __len__(self):
|
| 337 |
+
return len(self.trajectories)
|
| 338 |
+
|
| 339 |
+
def __getitem__(self, idx):
|
| 340 |
+
traj = self.trajectories[idx]
|
| 341 |
+
return traj['images'], traj['actions'], traj['rewards'], traj['dones'], traj['poses']
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
class LIBERODatasetLeRobot(torch.utils.data.Dataset):
|
| 345 |
+
|
| 346 |
+
"""A dataset class for loading LIBERO data from the LeRobot repository."""
|
| 347 |
+
|
| 348 |
+
def __init__(self, repo_id, transform=None, cfg=None):
|
| 349 |
+
# super().__init__(repo_id, transform)
|
| 350 |
+
self.repo_id = repo_id
|
| 351 |
+
self.transform = transform
|
| 352 |
+
self._dataset = datasets.load_dataset(repo_id, split='train[:{}]'.format(
|
| 353 |
+
cfg.dataset.buffer_size), keep_in_memory=True)
|
| 354 |
+
|
| 355 |
+
def __len__(self):
|
| 356 |
+
return len(self._dataset)
|
| 357 |
+
|
| 358 |
+
def __getitem__(self, idx):
|
| 359 |
+
# Load trajectory data from LeRobot dataset
|
| 360 |
+
sample = self._dataset[idx]
|
| 361 |
+
|
| 362 |
+
# Extract trajectory components
|
| 363 |
+
images = torch.from_numpy(np.array(sample['img'])).float()
|
| 364 |
+
actions = torch.from_numpy(np.array(sample['action'])).float()
|
| 365 |
+
rewards = torch.from_numpy(np.array(sample['rewards'])).float(
|
| 366 |
+
) if 'rewards' in sample else torch.zeros(len(actions))
|
| 367 |
+
dones = torch.from_numpy(np.array(sample['terminated'])).float(
|
| 368 |
+
) if 'terminated' in sample else torch.zeros(len(actions))
|
| 369 |
+
poses = torch.from_numpy(np.array(sample['poses'])).float(
|
| 370 |
+
) if 'poses' in sample else torch.zeros(len(actions), 7)
|
| 371 |
+
|
| 372 |
+
# Note: Images are returned in channel-last format (T, H, W, C)
|
| 373 |
+
# Conversion to channel-first (T, C, H, W) happens in the training loop
|
| 374 |
+
|
| 375 |
+
return images, actions, rewards, dones, poses
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
# ---------------------------------------------------------------------------
|
| 379 |
+
# Powerful stochastic policy network
|
| 380 |
+
# ---------------------------------------------------------------------------
|
| 381 |
+
class _ResLayer(torch.nn.Module):
|
| 382 |
+
"""Pre-norm residual MLP block: LayerNorm → Linear(d→2d) → SiLU → Linear(2d→d) + skip."""
|
| 383 |
+
def __init__(self, dim: int, dropout: float = 0.0):
|
| 384 |
+
super().__init__()
|
| 385 |
+
self.norm = torch.nn.LayerNorm(dim)
|
| 386 |
+
self.fc1 = torch.nn.Linear(dim, dim * 4)
|
| 387 |
+
self.act = torch.nn.SiLU()
|
| 388 |
+
self.fc2 = torch.nn.Linear(dim * 4, dim)
|
| 389 |
+
self.drop = torch.nn.Dropout(dropout)
|
| 390 |
+
|
| 391 |
+
def forward(self, x):
|
| 392 |
+
return x + self.drop(self.fc2(self.act(self.fc1(self.norm(x)))))
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
class PolicyNet(torch.nn.Module):
|
| 396 |
+
"""Expressive Gaussian policy for both SimpleWorldModel and DreamerV3.
|
| 397 |
+
|
| 398 |
+
Architecture
|
| 399 |
+
────────────
|
| 400 |
+
input_proj : Linear(in_dim → hidden_dim) + LayerNorm + SiLU
|
| 401 |
+
trunk : N × _ResLayer(hidden_dim) (pre-norm residual blocks)
|
| 402 |
+
mean_head : Linear → SiLU → Linear → Tanh → action means in [-1, 1]
|
| 403 |
+
logstd_head : Linear → SiLU → Linear → clamp → log-std in [-5, 2]
|
| 404 |
+
|
| 405 |
+
Forward returns torch.cat([mean, log_std], dim=-1) shape (B, 2*action_dim)
|
| 406 |
+
so it is a drop-in replacement for the old nn.Sequential policy.
|
| 407 |
+
"""
|
| 408 |
+
|
| 409 |
+
LOG_STD_MIN = -5.0
|
| 410 |
+
LOG_STD_MAX = 2.0
|
| 411 |
+
|
| 412 |
+
def __init__(self, in_dim: int, action_dim: int,
|
| 413 |
+
hidden_dim: int = 512, n_layers: int = 4,
|
| 414 |
+
dropout: float = 0.0):
|
| 415 |
+
super().__init__()
|
| 416 |
+
self.action_dim = action_dim
|
| 417 |
+
|
| 418 |
+
# Input projection: lifts any input size into the hidden space
|
| 419 |
+
self.input_proj = torch.nn.Sequential(
|
| 420 |
+
torch.nn.Linear(in_dim, hidden_dim),
|
| 421 |
+
torch.nn.LayerNorm(hidden_dim),
|
| 422 |
+
torch.nn.SiLU(),
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
# Deep residual trunk
|
| 426 |
+
self.trunk = torch.nn.Sequential(
|
| 427 |
+
*[_ResLayer(hidden_dim, dropout=dropout) for _ in range(n_layers)]
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
# Separate heads for mean and log-std → richer uncertainty estimates
|
| 431 |
+
neck_dim = hidden_dim // 2
|
| 432 |
+
self.mean_head = torch.nn.Sequential(
|
| 433 |
+
torch.nn.Linear(hidden_dim, neck_dim),
|
| 434 |
+
torch.nn.SiLU(),
|
| 435 |
+
torch.nn.Linear(neck_dim, action_dim),
|
| 436 |
+
torch.nn.Tanh(), # bounded action means in [-1, 1]
|
| 437 |
+
)
|
| 438 |
+
self.logstd_head = torch.nn.Sequential(
|
| 439 |
+
torch.nn.Linear(hidden_dim, neck_dim),
|
| 440 |
+
torch.nn.SiLU(),
|
| 441 |
+
torch.nn.Linear(neck_dim, action_dim),
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
def forward(self, x):
|
| 445 |
+
h = self.trunk(self.input_proj(x))
|
| 446 |
+
mean = self.mean_head(h) # (B, A) in [-1,1]
|
| 447 |
+
log_std = self.logstd_head(h).clamp(self.LOG_STD_MIN, self.LOG_STD_MAX) # (B, A)
|
| 448 |
+
return torch.cat([mean, log_std], dim=-1) # (B, 2A)
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
@hydra.main(version_base=None, config_path="./conf", config_name="64pix-pose")
|
| 452 |
+
def my_main(cfg: DictConfig):
|
| 453 |
+
# Set device
|
| 454 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 455 |
+
wandb = None
|
| 456 |
+
os.makedirs("checkpoints", exist_ok=True)
|
| 457 |
+
subcheckpoint_dir = os.path.join("checkpoints", f"{cfg.experiment.name}")
|
| 458 |
+
os.makedirs(subcheckpoint_dir, exist_ok=True)
|
| 459 |
+
if not cfg.testing:
|
| 460 |
+
import wandb
|
| 461 |
+
# start a new wandb run to track this script
|
| 462 |
+
wandb.init(
|
| 463 |
+
project=cfg.experiment.project,
|
| 464 |
+
# track hyperparameters and run metadata
|
| 465 |
+
config=OmegaConf.to_container(cfg),
|
| 466 |
+
name=cfg.experiment.name,
|
| 467 |
+
)
|
| 468 |
+
wandb.run.log_code(".")
|
| 469 |
+
|
| 470 |
+
# Get model type from config or default to 'dreamer'
|
| 471 |
+
model_type = getattr(cfg, 'model_type', 'dreamer')
|
| 472 |
+
print(f"[info] Using model type: {model_type}")
|
| 473 |
+
|
| 474 |
+
# Initialize the model using factory
|
| 475 |
+
img_shape = [3, 64, 64]
|
| 476 |
+
model = create_model(model_type, img_shape,
|
| 477 |
+
action_dim=7, device=device, cfg=cfg)
|
| 478 |
+
|
| 479 |
+
# Wrap model for unified training interface
|
| 480 |
+
model_wrapper = ModelTrainingWrapper(model, model_type, device)
|
| 481 |
+
|
| 482 |
+
# Initialize planner (works with both model types through the model interface)
|
| 483 |
+
if cfg.use_policy:
|
| 484 |
+
print("[info] Using policy-based planner (CEMPlanner with policy)")
|
| 485 |
+
import torch.nn as nn
|
| 486 |
+
|
| 487 |
+
# PolicyPlanner expects the policy input to match the planner's state feature:
|
| 488 |
+
# - SimpleWorldModel: encoded pose (dim=7)
|
| 489 |
+
# - DreamerV3: concat([h, z]) with dim = deter_dim + stoch_dim * discrete_dim
|
| 490 |
+
if model_type == 'dreamer':
|
| 491 |
+
policy_in_dim = int(model.deter_dim + model.stoch_dim * model.discrete_dim)
|
| 492 |
+
else:
|
| 493 |
+
policy_in_dim = 7
|
| 494 |
+
|
| 495 |
+
# Stochastic policy: outputs [mean (Tanh-bounded), log_std] concatenated → shape (B, 14).
|
| 496 |
+
# _PolicyNet: deep residual MLP with pre-norm blocks and separate mean/log-std heads.
|
| 497 |
+
policy = PolicyNet(in_dim=policy_in_dim, action_dim=7, hidden_dim=256, n_layers=2, dropout=cfg.policy.dropout)
|
| 498 |
+
policy.to(device)
|
| 499 |
+
planner = PolicyPlanner(
|
| 500 |
+
model,
|
| 501 |
+
policy_model=policy,
|
| 502 |
+
action_dim=7,
|
| 503 |
+
cfg=cfg
|
| 504 |
+
)
|
| 505 |
+
if cfg.planner.type == 'policy_guided_cem':
|
| 506 |
+
# Load pretrained policy model for policy-guided CEM
|
| 507 |
+
print(f"[info] Loading pretrained policy model from {cfg.load_policy}")
|
| 508 |
+
planner.load_policy_model(cfg.load_policy)
|
| 509 |
+
else:
|
| 510 |
+
planner = CEMPlanner(
|
| 511 |
+
model,
|
| 512 |
+
action_dim=7,
|
| 513 |
+
cfg=cfg
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
# Initialize circular buffer dataset
|
| 517 |
+
if cfg.use_random_data:
|
| 518 |
+
print("[info] Using CircularBufferDataset with random data collection")
|
| 519 |
+
dataset = CircularBufferDataset(cfg=cfg)
|
| 520 |
+
print(f"[info] Initialized buffer with {len(dataset)} trajectories")
|
| 521 |
+
else:
|
| 522 |
+
# Use Hugging Face dataset by default for portability; fall back to local HDF5 if requested.
|
| 523 |
+
if cfg.dataset.load_dataset:
|
| 524 |
+
dataset = LIBERODatasetLeRobot(
|
| 525 |
+
repo_id=cfg.dataset.to_name,
|
| 526 |
+
transform=transforms.ToTensor(),
|
| 527 |
+
cfg=cfg
|
| 528 |
+
)
|
| 529 |
+
else:
|
| 530 |
+
data_dir = getattr(
|
| 531 |
+
cfg.dataset, 'data_dir', '/network/projects/real-g-grp/libero/targets_clean/')
|
| 532 |
+
dataset = LIBERODataset(data_dir, transform=transforms.ToTensor())
|
| 533 |
+
|
| 534 |
+
load_world_model = getattr(cfg, 'load_world_model', None)
|
| 535 |
+
if load_world_model is not None:
|
| 536 |
+
planner.load_world_model(load_world_model)
|
| 537 |
+
print(f"[info] Loaded world model weights from {load_world_model}")
|
| 538 |
+
|
| 539 |
+
# Define optimizer and loss function
|
| 540 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
| 541 |
+
|
| 542 |
+
# Add linear learning rate scheduler that decays to 0 over training
|
| 543 |
+
scheduler = torch.optim.lr_scheduler.LinearLR(
|
| 544 |
+
optimizer,
|
| 545 |
+
start_factor=1.0, # Start at full learning rate
|
| 546 |
+
end_factor=0.01, # End at 0 learning rate
|
| 547 |
+
total_iters=cfg.max_iters # Decay over num_epochs
|
| 548 |
+
)
|
| 549 |
+
policy_loss = 0
|
| 550 |
+
|
| 551 |
+
# Training loop
|
| 552 |
+
for epoch in range(cfg.max_iters):
|
| 553 |
+
num_idx = np.arange(len(dataset))
|
| 554 |
+
np.random.shuffle(num_idx)
|
| 555 |
+
loss = 0.0
|
| 556 |
+
policy_loss = 0.0
|
| 557 |
+
batch_counter = 0
|
| 558 |
+
# Accumulate all encoded poses and actions for policy training at the end of the epoch
|
| 559 |
+
if epoch == 0 or ((epoch-1) % cfg.eval_vid_iters == 0):
|
| 560 |
+
print(f"[info] Starting epoch {epoch+1}/{cfg.max_iters} with {len(dataset)} trajectories in dataset")
|
| 561 |
+
# Batch data using the batch_data utility function
|
| 562 |
+
dataloader = batch_data(dataset, batch_size=cfg.batch_size, cfg=cfg)
|
| 563 |
+
|
| 564 |
+
# Process data in batches
|
| 565 |
+
for batch in dataloader:
|
| 566 |
+
images, actions, rewards, dones, poses = batch
|
| 567 |
+
# Normalize poses and actions for SimpleWorldModel
|
| 568 |
+
normalized_poses = model.encode_pose(poses)
|
| 569 |
+
normalized_actions = model.encode_action(actions)
|
| 570 |
+
normalized_images = ((images.float() / 127.5) - 1.0).to(cfg.device) if model_type == 'dreamer' else None
|
| 571 |
+
|
| 572 |
+
# Training world model on the batch
|
| 573 |
+
model.train() # Set model to training mode
|
| 574 |
+
## Call model_wrapper.forward_pass() with appropriate inputs based on model type
|
| 575 |
+
if model_type == 'dreamer':
|
| 576 |
+
if (cfg.use_policy and (cfg.planner.type == 'policy' or cfg.planner.type == 'policy_guided_cem')):
|
| 577 |
+
# PolicyPlanner.update() for Dreamer expects image sequences (B,T,C,H,W)
|
| 578 |
+
# so it can encode them and build RSSM features [h,z] as policy inputs.
|
| 579 |
+
policy_loss = planner.update(normalized_images, normalized_actions)
|
| 580 |
+
model_out = model_wrapper.forward_pass(normalized_images, None, normalized_actions)
|
| 581 |
+
loss_dict = model_wrapper.compute_loss(
|
| 582 |
+
model_out,
|
| 583 |
+
normalized_images,
|
| 584 |
+
rewards,
|
| 585 |
+
dones,
|
| 586 |
+
None,
|
| 587 |
+
None,
|
| 588 |
+
)
|
| 589 |
+
batch_loss = loss_dict['total_loss']
|
| 590 |
+
elif model_type == 'simple':
|
| 591 |
+
if (cfg.use_policy and (cfg.planner.type == 'policy' or cfg.planner.type == 'policy_guided_cem')):
|
| 592 |
+
policy_loss = planner.update(normalized_poses, normalized_actions)
|
| 593 |
+
model_out = model_wrapper.forward_pass(
|
| 594 |
+
None,
|
| 595 |
+
normalized_poses,
|
| 596 |
+
normalized_actions,
|
| 597 |
+
)
|
| 598 |
+
loss_dict = model_wrapper.compute_loss(
|
| 599 |
+
model_out,
|
| 600 |
+
None,
|
| 601 |
+
rewards,
|
| 602 |
+
dones,
|
| 603 |
+
normalized_poses,
|
| 604 |
+
normalized_actions,
|
| 605 |
+
)
|
| 606 |
+
batch_loss = loss_dict['total_loss']
|
| 607 |
+
else:
|
| 608 |
+
raise ValueError(f"Unknown model type: {model_type}")
|
| 609 |
+
optimizer.zero_grad()
|
| 610 |
+
batch_loss.backward()
|
| 611 |
+
# Clip gradients — essential for DreamerV3: without this, prior/posterior logits
|
| 612 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
| 613 |
+
optimizer.step()
|
| 614 |
+
loss = batch_loss.item()
|
| 615 |
+
batch_counter += 1
|
| 616 |
+
# Implement data loading and training step for the batch
|
| 617 |
+
if model_type == 'dreamer':
|
| 618 |
+
# Dreamer: log the components for quick debugging.
|
| 619 |
+
print(
|
| 620 |
+
f"Epoch [{epoch+1}/{cfg.max_iters }], Batch [{batch_counter}/{(len(dataset) + cfg.batch_size - 1) // cfg.batch_size}], "
|
| 621 |
+
f"Loss: {batch_loss.item():.4f}, recon: {loss_dict['recon_loss'].item():.4f}, "
|
| 622 |
+
f"reward: {loss_dict['reward_loss'].item():.4f}, cont: {loss_dict['continue_loss'].item():.4f}, "
|
| 623 |
+
f"dyn: {loss_dict['dyn_loss'].item():.4f}, rep: {loss_dict['rep_loss'].item():.4f}, policy_loss: {policy_loss:.4f}"
|
| 624 |
+
)
|
| 625 |
+
else:
|
| 626 |
+
print(
|
| 627 |
+
f"Epoch [{epoch+1}/{cfg.max_iters }], Batch [{batch_counter}/{(len(dataset) + cfg.batch_size - 1) // cfg.batch_size}], "
|
| 628 |
+
f"Loss: {batch_loss.item():.4f}, policy_loss: {policy_loss:.4f}"
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
# Log training loss to wandb
|
| 632 |
+
if wandb is not None:
|
| 633 |
+
if model_type == 'dreamer':
|
| 634 |
+
log_payload = {
|
| 635 |
+
"train_loss": loss,
|
| 636 |
+
"policy_loss": policy_loss,
|
| 637 |
+
"loss/recon": float(loss_dict['recon_loss'].detach().cpu()),
|
| 638 |
+
"loss/reward": float(loss_dict['reward_loss'].detach().cpu()),
|
| 639 |
+
"loss/continue": float(loss_dict['continue_loss'].detach().cpu()),
|
| 640 |
+
"loss/dyn": float(loss_dict['dyn_loss'].detach().cpu()),
|
| 641 |
+
"loss/rep": float(loss_dict['rep_loss'].detach().cpu()),
|
| 642 |
+
}
|
| 643 |
+
else:
|
| 644 |
+
log_payload = {
|
| 645 |
+
"train_loss": loss,
|
| 646 |
+
"policy_loss": policy_loss,
|
| 647 |
+
"pose_loss": float(loss_dict['pose_loss'].detach().cpu()),
|
| 648 |
+
"reward_loss": float(loss_dict['reward_loss'].detach().cpu())
|
| 649 |
+
}
|
| 650 |
+
# log_payload = {"train_loss": loss, "policy_loss": policy_loss}
|
| 651 |
+
# # If the last computed loss was Dreamer-style, add its components.
|
| 652 |
+
# if 'loss_dict' in locals() and isinstance(locals().get('loss_dict', None), dict):
|
| 653 |
+
# ld = locals()['loss_dict']
|
| 654 |
+
# log_payload.update(
|
| 655 |
+
# {
|
| 656 |
+
# "loss/recon": float(ld['recon_loss'].detach().cpu()),
|
| 657 |
+
# "loss/reward": float(ld['reward_loss'].detach().cpu()),
|
| 658 |
+
# "loss/continue": float(ld['continue_loss'].detach().cpu()),
|
| 659 |
+
# "loss/dyn": float(ld['dyn_loss'].detach().cpu()),
|
| 660 |
+
# "loss/rep": float(ld['rep_loss'].detach().cpu()),
|
| 661 |
+
# }
|
| 662 |
+
# )
|
| 663 |
+
wandb.log(log_payload)
|
| 664 |
+
|
| 665 |
+
# save the model checkpoint
|
| 666 |
+
if epoch % cfg.eval_vid_iters == 0:
|
| 667 |
+
torch.save(model.state_dict(), os.path.join(subcheckpoint_dir, f'model_epoch_{epoch+1}_batch_{batch_counter}.pth'), pickle_module=dill)
|
| 668 |
+
# Save policy model if using policy-based planner
|
| 669 |
+
if cfg.use_policy:
|
| 670 |
+
torch.save(planner.policy_model.state_dict(), os.path.join(subcheckpoint_dir, f'policy.pth'), pickle_module=dill)
|
| 671 |
+
# Evaluate the model using eval_libero from sim_eval
|
| 672 |
+
print("[info] Starting evaluation on LIBERO tasks...")
|
| 673 |
+
# Import lazily so importing this module doesn't require robosuite/LIBERO deps.
|
| 674 |
+
try:
|
| 675 |
+
from .sim_eval import eval_libero
|
| 676 |
+
except ImportError:
|
| 677 |
+
from sim_eval import eval_libero
|
| 678 |
+
data = eval_libero(planner, device, cfg, iter_=epoch, log_dir="./",
|
| 679 |
+
wandb=wandb)
|
| 680 |
+
if cfg.use_random_data:
|
| 681 |
+
# Add new random trajectories to the buffer
|
| 682 |
+
for traj in data['traj']:
|
| 683 |
+
dones = np.zeros_like(traj['rewards'])
|
| 684 |
+
dones[-1] = 1
|
| 685 |
+
# observations need to be changed to channel first
|
| 686 |
+
# (T, 1, H, W, C) -> (T, H, W, C)
|
| 687 |
+
observations = np.array(traj['observations'])
|
| 688 |
+
# (T, H, W, C) -> (T, C, H, W)
|
| 689 |
+
# observations = np.transpose(observations, (0, 3, 1, 2))
|
| 690 |
+
dataset.add_trajectory(observations, np.array(traj['actions']),
|
| 691 |
+
np.array(traj['rewards']), np.array(dones), np.array(traj['poses']))
|
| 692 |
+
print(
|
| 693 |
+
f"[info] Added new random trajectories to buffer. Current buffer size: {len(dataset)}")
|
| 694 |
+
|
| 695 |
+
# Step the learning rate scheduler after each epoch
|
| 696 |
+
scheduler.step()
|
| 697 |
+
print(
|
| 698 |
+
f'Learning rate after epoch {epoch+1}: {scheduler.get_last_lr()[0]:.6f}')
|
| 699 |
+
torch.save(model.state_dict(), os.path.join(subcheckpoint_dir, f'world_model.pth'), pickle_module=dill)
|
| 700 |
+
|
| 701 |
+
if __name__ == '__main__':
|
| 702 |
+
my_main()
|