NMThuan032k commited on
Commit
bfc4a29
·
verified ·
1 Parent(s): 47b6579

Upload folder using huggingface_hub

Browse files
q1_simpleworld_cem/dreamer_model_trainer.py ADDED
@@ -0,0 +1,702 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ from pathlib import Path
4
+ os.environ["MUJOCO_GL"] = "egl"
5
+ # Ensure the vendored LIBERO package is importable even if it hasn't been pip-installed.
6
+ # Hydra may change the working directory, so we resolve relative to this file.
7
+ _REPO_ROOT = Path(__file__).resolve().parents[1]
8
+ _LIBERO_ROOT = _REPO_ROOT / "LIBERO"
9
+ if _LIBERO_ROOT.exists():
10
+ sys.path.insert(0, str(_LIBERO_ROOT))
11
+
12
+ import dill
13
+ from omegaconf import DictConfig, OmegaConf
14
+ import hydra
15
+ import torch
16
+ from torchvision import transforms
17
+ import h5py
18
+ import numpy as np
19
+
20
+ # Support both `python hw2/dreamer_model_trainer.py` (cwd=hw2) and
21
+ # `python -m hw2.dreamer_model_trainer` / importing as a package.
22
+ try:
23
+ from .dreamerV3 import DreamerV3
24
+ from .simple_world_model import SimpleWorldModel
25
+ from .planning import CEMPlanner, PolicyPlanner, RandomPlanner
26
+ except ImportError:
27
+ from dreamerV3 import DreamerV3
28
+ from simple_world_model import SimpleWorldModel
29
+ from planning import CEMPlanner, PolicyPlanner, RandomPlanner
30
+ import random
31
+ from collections import deque
32
+ from datasets import load_dataset
33
+ import datasets
34
+ from torch.nn.utils.rnn import pad_sequence
35
+
36
+
37
+
38
+ # Factory function to instantiate the correct model
39
+ def create_model(model_type, img_shape, action_dim, device, cfg):
40
+ """
41
+ Factory function to create a world model based on the specified type.
42
+
43
+ Args:
44
+ model_type: 'dreamer' or 'simple'
45
+ img_shape: Image shape [C, H, W]
46
+ action_dim: Dimensionality of actions
47
+ device: torch device
48
+ cfg: Configuration object
49
+
50
+ Returns:
51
+ model: Instantiated model
52
+ """
53
+ if model_type.lower() == 'dreamer':
54
+ model = DreamerV3(obs_shape=img_shape,
55
+ action_dim=action_dim, cfg=cfg).to(device)
56
+ elif model_type.lower() == 'simple':
57
+ model = SimpleWorldModel(
58
+ action_dim=action_dim, pose_dim=7, hidden_dim=256, cfg=cfg).to(device)
59
+ else:
60
+ raise ValueError(
61
+ f"Unknown model_type: {model_type}. Choose 'dreamer' or 'simple'.")
62
+
63
+ return model
64
+
65
+ def batch_data(dataset, batch_size, cfg):
66
+ """
67
+ Utility function to batch data from the dataset with a fixed sequence length.
68
+ Args:
69
+ dataset: Dataset object that returns (images, actions, rewards, dones, poses)
70
+ batch_size: Number of sequences per batch
71
+ sequence_length: Length of each sequence (T)
72
+
73
+ Returns:
74
+ A generator that yields batches of (images, actions, rewards, dones, poses) with shapes:
75
+ - images: (B, T, C, H, W)
76
+ - actions: (B, T, 7)
77
+ - rewards: (B, T)
78
+ - dones: (B, T)
79
+ - poses: (B, T, 7)
80
+ """
81
+ # Collect sequences for the batch with fixed sequence length
82
+ list_images, list_actions, list_rewards, list_dones, list_poses = [], [], [], [], []
83
+ # padding short trajectories to max_seq_len with zeros
84
+ for img, act, rew, don, pos in dataset:
85
+ list_images += [img[i:i+cfg.policy.sequence_length] for i in range(0, len(img)-cfg.policy.sequence_length+1, cfg.policy.sequence_length)]
86
+ list_actions += [act[i:i+cfg.policy.sequence_length] for i in range(0, len(act)-cfg.policy.sequence_length+1, cfg.policy.sequence_length)]
87
+ list_rewards += [rew[i:i+cfg.policy.sequence_length] for i in range(0, len(rew)-cfg.policy.sequence_length+1, cfg.policy.sequence_length)]
88
+ list_dones += [don[i:i+cfg.policy.sequence_length] for i in range(0, len(don)-cfg.policy.sequence_length+1, cfg.policy.sequence_length)]
89
+ list_poses += [pos[i:i+cfg.policy.sequence_length] for i in range(0, len(pos)-cfg.policy.sequence_length+1, cfg.policy.sequence_length)]
90
+ images = torch.stack(list_images) # (B, T, H, W, C)
91
+ actions = torch.stack(list_actions) # (B, T, action_dim)
92
+ rewards = torch.stack(list_rewards) # (B, T)
93
+ dones = torch.stack(list_dones) # (B, T)
94
+ poses = torch.stack(list_poses) # (B, T, pose_dim)
95
+ images = images.permute(0, 1, 4, 2, 3).to(cfg.device) # (B, T, H, W, C) -> (B, T, C, H, W)
96
+ actions = actions.float().to(cfg.device) # (B, T, action_dim)
97
+ rewards = rewards.float().to(cfg.device) # (B, T)
98
+ dones = dones.float().to(cfg.device) # (B, T)
99
+ poses = poses.float().to(cfg.device) # (B, T, pose_dim)
100
+ # for img, act, rew, don, pos in dataset:
101
+ # list_images.append(img) # (T, H, W, C)
102
+ # list_actions.append(act) # (T, action_dim)
103
+ # list_rewards.append(rew) # (T,)
104
+ # list_dones.append(don) # (T,)
105
+ # list_poses.append(pos) # (T, pose_dim)
106
+ # images = pad_sequence(list_images, batch_first=True, padding_value=0.0).permute(0, 1, 4, 2, 3).to(cfg.device) # (B, T, H, W, C) -> (B, T, C, H, W)
107
+ # actions = pad_sequence(list_actions, batch_first=True, padding_value=0.0).float().to(cfg.device) # (B, T, action_dim)
108
+ # rewards = pad_sequence(list_rewards, batch_first=True, padding_value=0.0).float().to(cfg.device) # (B, T)
109
+ # dones = pad_sequence(list_dones, batch_first=True, padding_value=0.0).float().to(cfg.device) # (B, T)
110
+ # poses = pad_sequence(list_poses, batch_first=True, padding_value=0.0).float().to(cfg.device) # (B, T, pose_dim)
111
+ print(f"[info] Batched data into tensors with shapes: images={images.shape}, actions={actions.shape}, rewards={rewards.shape}, dones={dones.shape}, poses={poses.shape}")
112
+ out_dataset = torch.utils.data.TensorDataset(images, actions, rewards, dones, poses)
113
+ print(f"[info] Created DataLoader with {len(out_dataset)} samples")
114
+ return torch.utils.data.DataLoader(out_dataset, batch_size=batch_size, shuffle=True)
115
+
116
+ class ModelTrainingWrapper:
117
+ """
118
+ Wrapper to provide unified interface for training different world models.
119
+ Handles differences in forward passes and loss computation between models.
120
+ """
121
+
122
+ def __init__(self, model, model_type, device):
123
+ self.model = model
124
+ self.model_type = model_type.lower()
125
+ self.device = device
126
+
127
+ def forward_pass(self, images, poses, actions):
128
+ """
129
+ Unified forward pass that works with both model types.
130
+
131
+ Args:
132
+ images: Image tensor (B, T, H, W, C) or None for simple model
133
+ poses: Pose tensor (B, T, 7)
134
+ actions: Action tensor (B, T, 7)
135
+
136
+ Returns:
137
+ output: Model output (format depends on model type)
138
+ """
139
+ if self.model_type == 'dreamer':
140
+ # DreamerV3 returns a dict of rollout predictions.
141
+ return self.model(images, actions)
142
+ elif self.model_type == 'simple':
143
+ # SimpleWorldModel expects normalized inputs
144
+ pred_pose_seq, pred_reward_seq = self.model(poses, actions)
145
+ return {
146
+ 'pred_poses': pred_pose_seq,
147
+ 'pred_rewards': pred_reward_seq
148
+ }
149
+
150
+ def compute_loss(self, model_out, normalized_images, rewards, dones, poses, actions):
151
+ """
152
+ Compute loss in a way that works for both model types.
153
+
154
+ Args:
155
+ output: Output from forward_pass
156
+ normalized_images: Image tensor
157
+ rewards: Reward tensor
158
+ dones: Done tensor
159
+ poses: Pose tensor (used for SimpleWorldModel)
160
+ actions: Action tensor (used for SimpleWorldModel)
161
+ pred_coeff, dyn_coeff, rep_coeff: Loss coefficients (used for DreamerV3)
162
+
163
+ Returns:
164
+ losses: Dictionary with loss information
165
+ """
166
+ if self.model_type == 'dreamer':
167
+ # Use DreamerV3 loss computation
168
+ if not isinstance(model_out, dict):
169
+ raise ValueError(
170
+ f"DreamerV3 forward must return a dict, got {type(model_out)}"
171
+ )
172
+ return self.model.compute_loss(model_out, normalized_images, rewards, dones, self.device)
173
+ elif self.model_type == 'simple':
174
+ # TODO: Part 1.2 - Implement SimpleWorldModel training loss
175
+ # Compute MSE loss between predicted and target poses/rewards
176
+ # Ensure rewards are always (B, T)
177
+ pred_poses = model_out['pred_poses']
178
+ pred_rewards = model_out['pred_rewards']
179
+ if pred_rewards is None:
180
+ raise ValueError("SimpleWorldModel path expected pred_rewards, got None")
181
+ if pred_rewards.dim() == 3 and pred_rewards.shape[-1] == 1:
182
+ pred_rewards = pred_rewards.squeeze(-1)
183
+ # Check shape of pred_poses and pred_rewards
184
+ # print(f"Predicted poses shape: {pred_poses.shape}, Predicted rewards shape: {pred_rewards.shape}")
185
+ if pred_poses.dim() == 2:
186
+ print(
187
+ f"Warning: Predicted poses have shape {pred_poses.shape}, expected (B, T, 7). Check model output formatting.")
188
+ raise ValueError("SimpleWorldModel output must be (B, T, 7); got 2D tensor")
189
+ elif pred_poses.dim() == 3 and pred_poses.shape[2] != 7:
190
+ print(
191
+ f"Warning: Predicted poses have last dimension {pred_poses.shape[2]}, expected 7. Check model output formatting.")
192
+ raise ValueError("SimpleWorldModel pose dim must be 7")
193
+ elif pred_poses.dim() == 3 and pred_poses.shape[2] == 7:
194
+ B, T, _ = pred_poses.shape
195
+
196
+ # Align shapes: predict at times [0..T-2] to match targets [1..T-1]
197
+ pred_pose_seq = pred_poses[:, : T - 1, :]
198
+ tgt_pose_seq = poses[:, 1:, :]
199
+
200
+ # Rewards are (B, T). Use the same alignment.
201
+ pred_rew_seq = pred_rewards
202
+ tgt_rew_seq = rewards
203
+
204
+ loss_dict = self.model.compute_loss(
205
+ pred_pose_seq,
206
+ pred_rew_seq,
207
+ target_pose=tgt_pose_seq,
208
+ target_reward=tgt_rew_seq,
209
+ )
210
+ return loss_dict
211
+
212
+ raise ValueError(f"Unexpected pred_poses shape: {pred_poses.shape}")
213
+
214
+
215
+ class LIBERODataset(torch.utils.data.Dataset):
216
+ def __init__(self, data_dir, transform=None):
217
+ self.data_dir = data_dir
218
+ self.transform = transform
219
+
220
+ # crawl the data_dir and build the index map for h5py files
221
+ self.index_map = []
222
+ for root, dirs, files in os.walk(self.data_dir):
223
+ for file in files:
224
+ if file.endswith('.hdf5') or file.endswith('.h5'):
225
+ file_path = os.path.join(root, file)
226
+ with h5py.File(file_path, 'r') as f:
227
+ for demo_key in f['data'].keys():
228
+ self.index_map.append((file_path, demo_key))
229
+
230
+ def __len__(self):
231
+ return len(self.index_map)
232
+
233
+ def __getitem__(self, idx):
234
+ # Load your data here
235
+ # data_path = os.path.join(self.data_dir, self.data_files[idx])
236
+ file_path, demo_key = self.index_map[idx]
237
+ # data_list = []
238
+ with h5py.File(file_path, 'r') as f:
239
+ # for demo in f['data'].keys():
240
+ demo = f['data'][demo_key]
241
+ image = torch.from_numpy(
242
+ f['data'][demo_key]['obs']['agentview_rgb'][()])
243
+ action = torch.from_numpy(f['data'][demo_key]['actions'][()])
244
+ dones = torch.from_numpy(f['data'][demo_key]['dones'][()])
245
+ rewards = torch.from_numpy(f['data'][demo_key]['rewards'][()])
246
+ # poses = torch.from_numpy(f['data'][demo_key]['robot_states'][()])
247
+ poses = torch.from_numpy(np.concatenate((f['data'][demo_key]['obs']["ee_pos"],
248
+ f['data'][demo_key]['obs']["ee_ori"][:, :3],
249
+ (f['data'][demo_key]['obs']["gripper_states"][:, :1])), axis=-1))
250
+ # Note: Images are returned in channel-last format (T, H, W, C)
251
+ # Conversion to channel-first (T, C, H, W) happens in the training loop
252
+ # Return the image and label if needed
253
+ return image, action, rewards, dones, poses
254
+
255
+
256
+ class CircularBufferDataset(torch.utils.data.Dataset):
257
+ """Circular buffer dataset that holds up to max_trajectories.
258
+ When full, oldest trajectories are overwritten.
259
+ """
260
+
261
+ def __init__(self, cfg=None, data_dir=None):
262
+ self.trajectories = []
263
+ self.write_idx = 0
264
+ self._cfg = cfg
265
+
266
+ if data_dir is None:
267
+ data_dir = getattr(cfg, 'data_dir', None)
268
+ if data_dir is None and cfg is not None:
269
+ data_dir = getattr(
270
+ getattr(cfg, 'dataset', None), 'data_dir', None)
271
+ if data_dir is None:
272
+ data_dir = '/network/projects/real-g-grp/libero/targets_clean/'
273
+
274
+ if cfg.dataset.load_dataset:
275
+ dataset = LIBERODatasetLeRobot(
276
+ repo_id=cfg.dataset.to_name,
277
+ transform=transforms.ToTensor(),
278
+ cfg=cfg
279
+ )
280
+ else:
281
+ data_dir = getattr(
282
+ cfg.dataset, 'data_dir', '/network/projects/real-g-grp/libero/targets_clean/')
283
+ dataset = LIBERODataset(data_dir, transform=transforms.ToTensor())
284
+ num_to_load = min(len(dataset), self._cfg.dataset.buffer_size)
285
+ if num_to_load == 0:
286
+ return
287
+
288
+ indices = np.random.choice(
289
+ len(dataset), size=num_to_load, replace=False)
290
+ for idx in range(num_to_load):
291
+ images, actions, rewards, dones, poses = dataset[idx]
292
+
293
+ # dones = np.zeros_like(rewards)
294
+ # dones[-1] = 1
295
+
296
+ self.add_trajectory(
297
+ np.array(images),
298
+ np.array(actions),
299
+ np.array(rewards),
300
+ np.array(dones),
301
+ np.array(poses)
302
+ )
303
+
304
+ def add_trajectory(self, images, actions, rewards, dones, poses):
305
+ """Add a trajectory to the buffer. Overwrites oldest if full."""
306
+ trajectory = {
307
+ 'images': torch.from_numpy(images),
308
+ 'actions': torch.from_numpy(actions),
309
+ 'rewards': torch.from_numpy(rewards),
310
+ 'dones': torch.from_numpy(dones),
311
+ 'poses': torch.from_numpy(poses)
312
+ }
313
+
314
+ if len(self.trajectories) < self._cfg.dataset.buffer_size:
315
+ self.trajectories.append(trajectory)
316
+ else:
317
+ # Overwrite oldest trajectory
318
+ self.trajectories[self.write_idx] = trajectory
319
+ self.write_idx = (
320
+ self.write_idx + 1) % self._cfg.dataset.buffer_size
321
+
322
+ def get_trajectory(self, idx):
323
+ trajectory = []
324
+ traj = self.trajectories[idx]
325
+ for i in range(len(traj['images'])):
326
+ step_dict = {
327
+ 'observation': traj['images'][i],
328
+ 'action': traj['actions'][i],
329
+ 'reward': traj['rewards'][i],
330
+ 'done': traj['dones'][i],
331
+ 'pose': traj['poses'][i]
332
+ }
333
+ trajectory.append(step_dict)
334
+ return trajectory
335
+
336
+ def __len__(self):
337
+ return len(self.trajectories)
338
+
339
+ def __getitem__(self, idx):
340
+ traj = self.trajectories[idx]
341
+ return traj['images'], traj['actions'], traj['rewards'], traj['dones'], traj['poses']
342
+
343
+
344
+ class LIBERODatasetLeRobot(torch.utils.data.Dataset):
345
+
346
+ """A dataset class for loading LIBERO data from the LeRobot repository."""
347
+
348
+ def __init__(self, repo_id, transform=None, cfg=None):
349
+ # super().__init__(repo_id, transform)
350
+ self.repo_id = repo_id
351
+ self.transform = transform
352
+ self._dataset = datasets.load_dataset(repo_id, split='train[:{}]'.format(
353
+ cfg.dataset.buffer_size), keep_in_memory=True)
354
+
355
+ def __len__(self):
356
+ return len(self._dataset)
357
+
358
+ def __getitem__(self, idx):
359
+ # Load trajectory data from LeRobot dataset
360
+ sample = self._dataset[idx]
361
+
362
+ # Extract trajectory components
363
+ images = torch.from_numpy(np.array(sample['img'])).float()
364
+ actions = torch.from_numpy(np.array(sample['action'])).float()
365
+ rewards = torch.from_numpy(np.array(sample['rewards'])).float(
366
+ ) if 'rewards' in sample else torch.zeros(len(actions))
367
+ dones = torch.from_numpy(np.array(sample['terminated'])).float(
368
+ ) if 'terminated' in sample else torch.zeros(len(actions))
369
+ poses = torch.from_numpy(np.array(sample['poses'])).float(
370
+ ) if 'poses' in sample else torch.zeros(len(actions), 7)
371
+
372
+ # Note: Images are returned in channel-last format (T, H, W, C)
373
+ # Conversion to channel-first (T, C, H, W) happens in the training loop
374
+
375
+ return images, actions, rewards, dones, poses
376
+
377
+
378
+ # ---------------------------------------------------------------------------
379
+ # Powerful stochastic policy network
380
+ # ---------------------------------------------------------------------------
381
+ class _ResLayer(torch.nn.Module):
382
+ """Pre-norm residual MLP block: LayerNorm → Linear(d→2d) → SiLU → Linear(2d→d) + skip."""
383
+ def __init__(self, dim: int, dropout: float = 0.0):
384
+ super().__init__()
385
+ self.norm = torch.nn.LayerNorm(dim)
386
+ self.fc1 = torch.nn.Linear(dim, dim * 4)
387
+ self.act = torch.nn.SiLU()
388
+ self.fc2 = torch.nn.Linear(dim * 4, dim)
389
+ self.drop = torch.nn.Dropout(dropout)
390
+
391
+ def forward(self, x):
392
+ return x + self.drop(self.fc2(self.act(self.fc1(self.norm(x)))))
393
+
394
+
395
+ class PolicyNet(torch.nn.Module):
396
+ """Expressive Gaussian policy for both SimpleWorldModel and DreamerV3.
397
+
398
+ Architecture
399
+ ────────────
400
+ input_proj : Linear(in_dim → hidden_dim) + LayerNorm + SiLU
401
+ trunk : N × _ResLayer(hidden_dim) (pre-norm residual blocks)
402
+ mean_head : Linear → SiLU → Linear → Tanh → action means in [-1, 1]
403
+ logstd_head : Linear → SiLU → Linear → clamp → log-std in [-5, 2]
404
+
405
+ Forward returns torch.cat([mean, log_std], dim=-1) shape (B, 2*action_dim)
406
+ so it is a drop-in replacement for the old nn.Sequential policy.
407
+ """
408
+
409
+ LOG_STD_MIN = -5.0
410
+ LOG_STD_MAX = 2.0
411
+
412
+ def __init__(self, in_dim: int, action_dim: int,
413
+ hidden_dim: int = 512, n_layers: int = 4,
414
+ dropout: float = 0.0):
415
+ super().__init__()
416
+ self.action_dim = action_dim
417
+
418
+ # Input projection: lifts any input size into the hidden space
419
+ self.input_proj = torch.nn.Sequential(
420
+ torch.nn.Linear(in_dim, hidden_dim),
421
+ torch.nn.LayerNorm(hidden_dim),
422
+ torch.nn.SiLU(),
423
+ )
424
+
425
+ # Deep residual trunk
426
+ self.trunk = torch.nn.Sequential(
427
+ *[_ResLayer(hidden_dim, dropout=dropout) for _ in range(n_layers)]
428
+ )
429
+
430
+ # Separate heads for mean and log-std → richer uncertainty estimates
431
+ neck_dim = hidden_dim // 2
432
+ self.mean_head = torch.nn.Sequential(
433
+ torch.nn.Linear(hidden_dim, neck_dim),
434
+ torch.nn.SiLU(),
435
+ torch.nn.Linear(neck_dim, action_dim),
436
+ torch.nn.Tanh(), # bounded action means in [-1, 1]
437
+ )
438
+ self.logstd_head = torch.nn.Sequential(
439
+ torch.nn.Linear(hidden_dim, neck_dim),
440
+ torch.nn.SiLU(),
441
+ torch.nn.Linear(neck_dim, action_dim),
442
+ )
443
+
444
+ def forward(self, x):
445
+ h = self.trunk(self.input_proj(x))
446
+ mean = self.mean_head(h) # (B, A) in [-1,1]
447
+ log_std = self.logstd_head(h).clamp(self.LOG_STD_MIN, self.LOG_STD_MAX) # (B, A)
448
+ return torch.cat([mean, log_std], dim=-1) # (B, 2A)
449
+
450
+
451
+ @hydra.main(version_base=None, config_path="./conf", config_name="64pix-pose")
452
+ def my_main(cfg: DictConfig):
453
+ # Set device
454
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
455
+ wandb = None
456
+ os.makedirs("checkpoints", exist_ok=True)
457
+ subcheckpoint_dir = os.path.join("checkpoints", f"{cfg.experiment.name}")
458
+ os.makedirs(subcheckpoint_dir, exist_ok=True)
459
+ if not cfg.testing:
460
+ import wandb
461
+ # start a new wandb run to track this script
462
+ wandb.init(
463
+ project=cfg.experiment.project,
464
+ # track hyperparameters and run metadata
465
+ config=OmegaConf.to_container(cfg),
466
+ name=cfg.experiment.name,
467
+ )
468
+ wandb.run.log_code(".")
469
+
470
+ # Get model type from config or default to 'dreamer'
471
+ model_type = getattr(cfg, 'model_type', 'dreamer')
472
+ print(f"[info] Using model type: {model_type}")
473
+
474
+ # Initialize the model using factory
475
+ img_shape = [3, 64, 64]
476
+ model = create_model(model_type, img_shape,
477
+ action_dim=7, device=device, cfg=cfg)
478
+
479
+ # Wrap model for unified training interface
480
+ model_wrapper = ModelTrainingWrapper(model, model_type, device)
481
+
482
+ # Initialize planner (works with both model types through the model interface)
483
+ if cfg.use_policy:
484
+ print("[info] Using policy-based planner (CEMPlanner with policy)")
485
+ import torch.nn as nn
486
+
487
+ # PolicyPlanner expects the policy input to match the planner's state feature:
488
+ # - SimpleWorldModel: encoded pose (dim=7)
489
+ # - DreamerV3: concat([h, z]) with dim = deter_dim + stoch_dim * discrete_dim
490
+ if model_type == 'dreamer':
491
+ policy_in_dim = int(model.deter_dim + model.stoch_dim * model.discrete_dim)
492
+ else:
493
+ policy_in_dim = 7
494
+
495
+ # Stochastic policy: outputs [mean (Tanh-bounded), log_std] concatenated → shape (B, 14).
496
+ # _PolicyNet: deep residual MLP with pre-norm blocks and separate mean/log-std heads.
497
+ policy = PolicyNet(in_dim=policy_in_dim, action_dim=7, hidden_dim=256, n_layers=2, dropout=cfg.policy.dropout)
498
+ policy.to(device)
499
+ planner = PolicyPlanner(
500
+ model,
501
+ policy_model=policy,
502
+ action_dim=7,
503
+ cfg=cfg
504
+ )
505
+ if cfg.planner.type == 'policy_guided_cem':
506
+ # Load pretrained policy model for policy-guided CEM
507
+ print(f"[info] Loading pretrained policy model from {cfg.load_policy}")
508
+ planner.load_policy_model(cfg.load_policy)
509
+ else:
510
+ planner = CEMPlanner(
511
+ model,
512
+ action_dim=7,
513
+ cfg=cfg
514
+ )
515
+
516
+ # Initialize circular buffer dataset
517
+ if cfg.use_random_data:
518
+ print("[info] Using CircularBufferDataset with random data collection")
519
+ dataset = CircularBufferDataset(cfg=cfg)
520
+ print(f"[info] Initialized buffer with {len(dataset)} trajectories")
521
+ else:
522
+ # Use Hugging Face dataset by default for portability; fall back to local HDF5 if requested.
523
+ if cfg.dataset.load_dataset:
524
+ dataset = LIBERODatasetLeRobot(
525
+ repo_id=cfg.dataset.to_name,
526
+ transform=transforms.ToTensor(),
527
+ cfg=cfg
528
+ )
529
+ else:
530
+ data_dir = getattr(
531
+ cfg.dataset, 'data_dir', '/network/projects/real-g-grp/libero/targets_clean/')
532
+ dataset = LIBERODataset(data_dir, transform=transforms.ToTensor())
533
+
534
+ load_world_model = getattr(cfg, 'load_world_model', None)
535
+ if load_world_model is not None:
536
+ planner.load_world_model(load_world_model)
537
+ print(f"[info] Loaded world model weights from {load_world_model}")
538
+
539
+ # Define optimizer and loss function
540
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
541
+
542
+ # Add linear learning rate scheduler that decays to 0 over training
543
+ scheduler = torch.optim.lr_scheduler.LinearLR(
544
+ optimizer,
545
+ start_factor=1.0, # Start at full learning rate
546
+ end_factor=0.01, # End at 0 learning rate
547
+ total_iters=cfg.max_iters # Decay over num_epochs
548
+ )
549
+ policy_loss = 0
550
+
551
+ # Training loop
552
+ for epoch in range(cfg.max_iters):
553
+ num_idx = np.arange(len(dataset))
554
+ np.random.shuffle(num_idx)
555
+ loss = 0.0
556
+ policy_loss = 0.0
557
+ batch_counter = 0
558
+ # Accumulate all encoded poses and actions for policy training at the end of the epoch
559
+ if epoch == 0 or ((epoch-1) % cfg.eval_vid_iters == 0):
560
+ print(f"[info] Starting epoch {epoch+1}/{cfg.max_iters} with {len(dataset)} trajectories in dataset")
561
+ # Batch data using the batch_data utility function
562
+ dataloader = batch_data(dataset, batch_size=cfg.batch_size, cfg=cfg)
563
+
564
+ # Process data in batches
565
+ for batch in dataloader:
566
+ images, actions, rewards, dones, poses = batch
567
+ # Normalize poses and actions for SimpleWorldModel
568
+ normalized_poses = model.encode_pose(poses)
569
+ normalized_actions = model.encode_action(actions)
570
+ normalized_images = ((images.float() / 127.5) - 1.0).to(cfg.device) if model_type == 'dreamer' else None
571
+
572
+ # Training world model on the batch
573
+ model.train() # Set model to training mode
574
+ ## Call model_wrapper.forward_pass() with appropriate inputs based on model type
575
+ if model_type == 'dreamer':
576
+ if (cfg.use_policy and (cfg.planner.type == 'policy' or cfg.planner.type == 'policy_guided_cem')):
577
+ # PolicyPlanner.update() for Dreamer expects image sequences (B,T,C,H,W)
578
+ # so it can encode them and build RSSM features [h,z] as policy inputs.
579
+ policy_loss = planner.update(normalized_images, normalized_actions)
580
+ model_out = model_wrapper.forward_pass(normalized_images, None, normalized_actions)
581
+ loss_dict = model_wrapper.compute_loss(
582
+ model_out,
583
+ normalized_images,
584
+ rewards,
585
+ dones,
586
+ None,
587
+ None,
588
+ )
589
+ batch_loss = loss_dict['total_loss']
590
+ elif model_type == 'simple':
591
+ if (cfg.use_policy and (cfg.planner.type == 'policy' or cfg.planner.type == 'policy_guided_cem')):
592
+ policy_loss = planner.update(normalized_poses, normalized_actions)
593
+ model_out = model_wrapper.forward_pass(
594
+ None,
595
+ normalized_poses,
596
+ normalized_actions,
597
+ )
598
+ loss_dict = model_wrapper.compute_loss(
599
+ model_out,
600
+ None,
601
+ rewards,
602
+ dones,
603
+ normalized_poses,
604
+ normalized_actions,
605
+ )
606
+ batch_loss = loss_dict['total_loss']
607
+ else:
608
+ raise ValueError(f"Unknown model type: {model_type}")
609
+ optimizer.zero_grad()
610
+ batch_loss.backward()
611
+ # Clip gradients — essential for DreamerV3: without this, prior/posterior logits
612
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
613
+ optimizer.step()
614
+ loss = batch_loss.item()
615
+ batch_counter += 1
616
+ # Implement data loading and training step for the batch
617
+ if model_type == 'dreamer':
618
+ # Dreamer: log the components for quick debugging.
619
+ print(
620
+ f"Epoch [{epoch+1}/{cfg.max_iters }], Batch [{batch_counter}/{(len(dataset) + cfg.batch_size - 1) // cfg.batch_size}], "
621
+ f"Loss: {batch_loss.item():.4f}, recon: {loss_dict['recon_loss'].item():.4f}, "
622
+ f"reward: {loss_dict['reward_loss'].item():.4f}, cont: {loss_dict['continue_loss'].item():.4f}, "
623
+ f"dyn: {loss_dict['dyn_loss'].item():.4f}, rep: {loss_dict['rep_loss'].item():.4f}, policy_loss: {policy_loss:.4f}"
624
+ )
625
+ else:
626
+ print(
627
+ f"Epoch [{epoch+1}/{cfg.max_iters }], Batch [{batch_counter}/{(len(dataset) + cfg.batch_size - 1) // cfg.batch_size}], "
628
+ f"Loss: {batch_loss.item():.4f}, policy_loss: {policy_loss:.4f}"
629
+ )
630
+
631
+ # Log training loss to wandb
632
+ if wandb is not None:
633
+ if model_type == 'dreamer':
634
+ log_payload = {
635
+ "train_loss": loss,
636
+ "policy_loss": policy_loss,
637
+ "loss/recon": float(loss_dict['recon_loss'].detach().cpu()),
638
+ "loss/reward": float(loss_dict['reward_loss'].detach().cpu()),
639
+ "loss/continue": float(loss_dict['continue_loss'].detach().cpu()),
640
+ "loss/dyn": float(loss_dict['dyn_loss'].detach().cpu()),
641
+ "loss/rep": float(loss_dict['rep_loss'].detach().cpu()),
642
+ }
643
+ else:
644
+ log_payload = {
645
+ "train_loss": loss,
646
+ "policy_loss": policy_loss,
647
+ "pose_loss": float(loss_dict['pose_loss'].detach().cpu()),
648
+ "reward_loss": float(loss_dict['reward_loss'].detach().cpu())
649
+ }
650
+ # log_payload = {"train_loss": loss, "policy_loss": policy_loss}
651
+ # # If the last computed loss was Dreamer-style, add its components.
652
+ # if 'loss_dict' in locals() and isinstance(locals().get('loss_dict', None), dict):
653
+ # ld = locals()['loss_dict']
654
+ # log_payload.update(
655
+ # {
656
+ # "loss/recon": float(ld['recon_loss'].detach().cpu()),
657
+ # "loss/reward": float(ld['reward_loss'].detach().cpu()),
658
+ # "loss/continue": float(ld['continue_loss'].detach().cpu()),
659
+ # "loss/dyn": float(ld['dyn_loss'].detach().cpu()),
660
+ # "loss/rep": float(ld['rep_loss'].detach().cpu()),
661
+ # }
662
+ # )
663
+ wandb.log(log_payload)
664
+
665
+ # save the model checkpoint
666
+ if epoch % cfg.eval_vid_iters == 0:
667
+ torch.save(model.state_dict(), os.path.join(subcheckpoint_dir, f'model_epoch_{epoch+1}_batch_{batch_counter}.pth'), pickle_module=dill)
668
+ # Save policy model if using policy-based planner
669
+ if cfg.use_policy:
670
+ torch.save(planner.policy_model.state_dict(), os.path.join(subcheckpoint_dir, f'policy.pth'), pickle_module=dill)
671
+ # Evaluate the model using eval_libero from sim_eval
672
+ print("[info] Starting evaluation on LIBERO tasks...")
673
+ # Import lazily so importing this module doesn't require robosuite/LIBERO deps.
674
+ try:
675
+ from .sim_eval import eval_libero
676
+ except ImportError:
677
+ from sim_eval import eval_libero
678
+ data = eval_libero(planner, device, cfg, iter_=epoch, log_dir="./",
679
+ wandb=wandb)
680
+ if cfg.use_random_data:
681
+ # Add new random trajectories to the buffer
682
+ for traj in data['traj']:
683
+ dones = np.zeros_like(traj['rewards'])
684
+ dones[-1] = 1
685
+ # observations need to be changed to channel first
686
+ # (T, 1, H, W, C) -> (T, H, W, C)
687
+ observations = np.array(traj['observations'])
688
+ # (T, H, W, C) -> (T, C, H, W)
689
+ # observations = np.transpose(observations, (0, 3, 1, 2))
690
+ dataset.add_trajectory(observations, np.array(traj['actions']),
691
+ np.array(traj['rewards']), np.array(dones), np.array(traj['poses']))
692
+ print(
693
+ f"[info] Added new random trajectories to buffer. Current buffer size: {len(dataset)}")
694
+
695
+ # Step the learning rate scheduler after each epoch
696
+ scheduler.step()
697
+ print(
698
+ f'Learning rate after epoch {epoch+1}: {scheduler.get_last_lr()[0]:.6f}')
699
+ torch.save(model.state_dict(), os.path.join(subcheckpoint_dir, f'world_model.pth'), pickle_module=dill)
700
+
701
+ if __name__ == '__main__':
702
+ my_main()