t1an commited on
Commit
ad9aba4
·
verified ·
1 Parent(s): a6bd9fb

Upload folder using huggingface_hub

Browse files
wm/config/fulltraj_dit/franka.yaml ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration for Franka (IsaacLab) World Model Training
2
+ # Full Trajectory Generation (Bidirectional)
3
+
4
+ # Dynamics Model Class
5
+ dynamics_class: "Bidirectional_FullTrajectory"
6
+
7
+ # Model identifier for DIT_CLASS_MAP
8
+ model_name: "VideoDiT"
9
+
10
+ # Configuration passed to the DiT model constructor
11
+ model_config:
12
+ in_channels: 16 # Latent channels from Wan VAE
13
+ patch_size: 2
14
+ dim: 1024 # Hidden dimension
15
+ num_layers: 16
16
+ num_heads: 16
17
+ action_dim: 7 # Franka action dimension (6-DoF + Gripper)
18
+ action_compress_rate: 4 # Compresses action sequence to latent sequence
19
+ max_frames: 33 # Franka sequence length (T=33, 1 + 4*8 windows)
20
+ action_dropout_prob: 0.1 # CFG for action conditioning
21
+ temporal_causal: false # Bidirectional temporal attention for fulltraj
22
+ vae_name: "WanVAE"
23
+ vae_config:
24
+ - "/storage/ice-shared/ae8803che/hxue/data/checkpoint/wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth"
25
+ scheduler: "FlowMatch" # Will be instantiated in dynamics class
26
+ training_timesteps: 1000
27
+
28
+ # Dataset Configuration
29
+ dataset:
30
+ name: "franka"
31
+ seq_len: 33 # Matches max_frames (e.g., T=33 for 8 steps in latent space)
32
+ train_test_split: 50 # 50:1 split
33
+
34
+ # Training Hyperparameters
35
+ training:
36
+ batch_size: 4
37
+ learning_rate: 1e-4
38
+ num_epochs: 2000
39
+ grad_clip: 1.0
40
+ checkpoint_freq: 5000 # Numbered checkpoints for eval
41
+ latest_freq: 500 # Only updates latest.pt for resuming
42
+ val_freq: 1000 # Video Logging
43
+ eval_freq: 500 # MSE Rollout
44
+ log_freq: 10 # Steps
45
+ num_workers: 4
46
+
47
+ # Distributed Training
48
+ distributed:
49
+ use_ddp: true
50
+ use_fsdp: false
51
+
52
+ # WandB Configuration
53
+ wandb:
54
+ project: "world_model"
55
+ run_name: "franka_fulltraj_dit_v1"
56
+ api_key: "62da90010e5c8cc94a66361396c57cea8c2c1e21"
wm/dataset/__pycache__/data_config.cpython-39.pyc CHANGED
Binary files a/wm/dataset/__pycache__/data_config.cpython-39.pyc and b/wm/dataset/__pycache__/data_config.cpython-39.pyc differ
 
wm/dataset/__pycache__/dataset.cpython-39.pyc CHANGED
Binary files a/wm/dataset/__pycache__/dataset.cpython-39.pyc and b/wm/dataset/__pycache__/dataset.cpython-39.pyc differ
 
wm/dataset/dataset.py CHANGED
@@ -100,7 +100,7 @@ class BaseRoboticsDataset(Dataset):
100
 
101
  def _get_action_slice(self, entry: Dict[str, Any], start: int, end: int) -> torch.Tensor:
102
  """Extract raw action slice without padding."""
103
- if self.config.name in ["language_table", "rt1", "dreamer4"]:
104
  return entry['actions'][start:end]
105
  elif self.config.name == "recon":
106
  # RECON commands are linear_velocity and angular_velocity
 
100
 
101
  def _get_action_slice(self, entry: Dict[str, Any], start: int, end: int) -> torch.Tensor:
102
  """Extract raw action slice without padding."""
103
+ if self.config.name in ["language_table", "rt1", "dreamer4", "pusht", "franka", "lang_table_50k"]:
104
  return entry['actions'][start:end]
105
  elif self.config.name == "recon":
106
  # RECON commands are linear_velocity and angular_velocity
wm/scripts/get_franka_stats.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+
6
+ metadata_path = "/storage/ice-shared/ae8803che/hxue/data/dataset/franka/metadata.pt"
7
+ if not os.path.exists(metadata_path):
8
+ print(f"Error: {metadata_path} not found.")
9
+ exit(1)
10
+
11
+ metadata = torch.load(metadata_path)
12
+ num_trajectories = len(metadata)
13
+
14
+ lengths = []
15
+ action_dims = set()
16
+
17
+ # Handle both list and dict formats
18
+ if isinstance(metadata, dict):
19
+ iterator = metadata.values()
20
+ else:
21
+ iterator = metadata
22
+
23
+ for info in iterator:
24
+ if 'num_frames' in info:
25
+ lengths.append(info['num_frames'])
26
+ elif 'actions' in info:
27
+ lengths.append(info['actions'].shape[0])
28
+ else:
29
+ print(f"Keys in info: {info.keys()}")
30
+ break
31
+ action_dims.add(info['actions'].shape[-1])
32
+
33
+ avg_len = sum(lengths) / len(lengths)
34
+ median_len = np.median(lengths)
35
+ action_dim = list(action_dims)[0] if len(action_dims) == 1 else str(action_dims)
36
+
37
+ print(f"Trajectories: {num_trajectories}")
38
+ print(f"Action Dim: {action_dim}")
39
+ print(f"Avg. Video Len: {avg_len:.1f}")
40
+ print(f"Median Video Len: {median_len:.1f}")
41
+
42
+ # Generate distribution plot
43
+ plt.figure(figsize=(10, 6))
44
+ plt.hist(lengths, bins=30, color='skyblue', edgecolor='black')
45
+ plt.title(f"Franka Video Length Distribution (N={num_trajectories})")
46
+ plt.xlabel("Number of Frames")
47
+ plt.ylabel("Frequency")
48
+ plt.grid(axis='y', alpha=0.75)
49
+
50
+ save_path = "/storage/ice-shared/ae8803che/hxue/data/world_model/results/stats/franka_dist.png"
51
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
52
+ plt.savefig(save_path)
53
+ print(f"Distribution plot saved to {save_path}")
wm/test/test_franka_load.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from wm.dataset.dataset import RoboticsDatasetWrapper
3
+ from wm.dataset.data_config import get_config_by_name
4
+ import os
5
+
6
+ def test_franka_load():
7
+ dataset = RoboticsDatasetWrapper.get_dataset("franka", seq_len=10)
8
+
9
+ print(f"Dataset size: {len(dataset)}")
10
+
11
+ # Load first sample
12
+ sample = dataset[0]
13
+
14
+ print(f"Video shape: {sample['obs'].shape}") # (T, C, H, W)
15
+ print(f"Actions shape: {sample['action'].shape}") # (T, action_dim)
16
+
17
+ if __name__ == "__main__":
18
+ test_franka_load()
wm/utils/__pycache__/visualization.cpython-39.pyc CHANGED
Binary files a/wm/utils/__pycache__/visualization.cpython-39.pyc and b/wm/utils/__pycache__/visualization.cpython-39.pyc differ