Upload folder using huggingface_hub
Browse files- wm/config/fulltraj_dit/franka.yaml +56 -0
- wm/dataset/__pycache__/data_config.cpython-39.pyc +0 -0
- wm/dataset/__pycache__/dataset.cpython-39.pyc +0 -0
- wm/dataset/dataset.py +1 -1
- wm/scripts/get_franka_stats.py +53 -0
- wm/test/test_franka_load.py +18 -0
- wm/utils/__pycache__/visualization.cpython-39.pyc +0 -0
wm/config/fulltraj_dit/franka.yaml
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Configuration for Franka (IsaacLab) World Model Training
|
| 2 |
+
# Full Trajectory Generation (Bidirectional)
|
| 3 |
+
|
| 4 |
+
# Dynamics Model Class
|
| 5 |
+
dynamics_class: "Bidirectional_FullTrajectory"
|
| 6 |
+
|
| 7 |
+
# Model identifier for DIT_CLASS_MAP
|
| 8 |
+
model_name: "VideoDiT"
|
| 9 |
+
|
| 10 |
+
# Configuration passed to the DiT model constructor
|
| 11 |
+
model_config:
|
| 12 |
+
in_channels: 16 # Latent channels from Wan VAE
|
| 13 |
+
patch_size: 2
|
| 14 |
+
dim: 1024 # Hidden dimension
|
| 15 |
+
num_layers: 16
|
| 16 |
+
num_heads: 16
|
| 17 |
+
action_dim: 7 # Franka action dimension (6-DoF + Gripper)
|
| 18 |
+
action_compress_rate: 4 # Compresses action sequence to latent sequence
|
| 19 |
+
max_frames: 33 # Franka sequence length (T=33, 1 + 4*8 windows)
|
| 20 |
+
action_dropout_prob: 0.1 # CFG for action conditioning
|
| 21 |
+
temporal_causal: false # Bidirectional temporal attention for fulltraj
|
| 22 |
+
vae_name: "WanVAE"
|
| 23 |
+
vae_config:
|
| 24 |
+
- "/storage/ice-shared/ae8803che/hxue/data/checkpoint/wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth"
|
| 25 |
+
scheduler: "FlowMatch" # Will be instantiated in dynamics class
|
| 26 |
+
training_timesteps: 1000
|
| 27 |
+
|
| 28 |
+
# Dataset Configuration
|
| 29 |
+
dataset:
|
| 30 |
+
name: "franka"
|
| 31 |
+
seq_len: 33 # Matches max_frames (e.g., T=33 for 8 steps in latent space)
|
| 32 |
+
train_test_split: 50 # 50:1 split
|
| 33 |
+
|
| 34 |
+
# Training Hyperparameters
|
| 35 |
+
training:
|
| 36 |
+
batch_size: 4
|
| 37 |
+
learning_rate: 1e-4
|
| 38 |
+
num_epochs: 2000
|
| 39 |
+
grad_clip: 1.0
|
| 40 |
+
checkpoint_freq: 5000 # Numbered checkpoints for eval
|
| 41 |
+
latest_freq: 500 # Only updates latest.pt for resuming
|
| 42 |
+
val_freq: 1000 # Video Logging
|
| 43 |
+
eval_freq: 500 # MSE Rollout
|
| 44 |
+
log_freq: 10 # Steps
|
| 45 |
+
num_workers: 4
|
| 46 |
+
|
| 47 |
+
# Distributed Training
|
| 48 |
+
distributed:
|
| 49 |
+
use_ddp: true
|
| 50 |
+
use_fsdp: false
|
| 51 |
+
|
| 52 |
+
# WandB Configuration
|
| 53 |
+
wandb:
|
| 54 |
+
project: "world_model"
|
| 55 |
+
run_name: "franka_fulltraj_dit_v1"
|
| 56 |
+
api_key: "62da90010e5c8cc94a66361396c57cea8c2c1e21"
|
wm/dataset/__pycache__/data_config.cpython-39.pyc
CHANGED
|
Binary files a/wm/dataset/__pycache__/data_config.cpython-39.pyc and b/wm/dataset/__pycache__/data_config.cpython-39.pyc differ
|
|
|
wm/dataset/__pycache__/dataset.cpython-39.pyc
CHANGED
|
Binary files a/wm/dataset/__pycache__/dataset.cpython-39.pyc and b/wm/dataset/__pycache__/dataset.cpython-39.pyc differ
|
|
|
wm/dataset/dataset.py
CHANGED
|
@@ -100,7 +100,7 @@ class BaseRoboticsDataset(Dataset):
|
|
| 100 |
|
| 101 |
def _get_action_slice(self, entry: Dict[str, Any], start: int, end: int) -> torch.Tensor:
|
| 102 |
"""Extract raw action slice without padding."""
|
| 103 |
-
if self.config.name in ["language_table", "rt1", "dreamer4"]:
|
| 104 |
return entry['actions'][start:end]
|
| 105 |
elif self.config.name == "recon":
|
| 106 |
# RECON commands are linear_velocity and angular_velocity
|
|
|
|
| 100 |
|
| 101 |
def _get_action_slice(self, entry: Dict[str, Any], start: int, end: int) -> torch.Tensor:
|
| 102 |
"""Extract raw action slice without padding."""
|
| 103 |
+
if self.config.name in ["language_table", "rt1", "dreamer4", "pusht", "franka", "lang_table_50k"]:
|
| 104 |
return entry['actions'][start:end]
|
| 105 |
elif self.config.name == "recon":
|
| 106 |
# RECON commands are linear_velocity and angular_velocity
|
wm/scripts/get_franka_stats.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import os
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
metadata_path = "/storage/ice-shared/ae8803che/hxue/data/dataset/franka/metadata.pt"
|
| 7 |
+
if not os.path.exists(metadata_path):
|
| 8 |
+
print(f"Error: {metadata_path} not found.")
|
| 9 |
+
exit(1)
|
| 10 |
+
|
| 11 |
+
metadata = torch.load(metadata_path)
|
| 12 |
+
num_trajectories = len(metadata)
|
| 13 |
+
|
| 14 |
+
lengths = []
|
| 15 |
+
action_dims = set()
|
| 16 |
+
|
| 17 |
+
# Handle both list and dict formats
|
| 18 |
+
if isinstance(metadata, dict):
|
| 19 |
+
iterator = metadata.values()
|
| 20 |
+
else:
|
| 21 |
+
iterator = metadata
|
| 22 |
+
|
| 23 |
+
for info in iterator:
|
| 24 |
+
if 'num_frames' in info:
|
| 25 |
+
lengths.append(info['num_frames'])
|
| 26 |
+
elif 'actions' in info:
|
| 27 |
+
lengths.append(info['actions'].shape[0])
|
| 28 |
+
else:
|
| 29 |
+
print(f"Keys in info: {info.keys()}")
|
| 30 |
+
break
|
| 31 |
+
action_dims.add(info['actions'].shape[-1])
|
| 32 |
+
|
| 33 |
+
avg_len = sum(lengths) / len(lengths)
|
| 34 |
+
median_len = np.median(lengths)
|
| 35 |
+
action_dim = list(action_dims)[0] if len(action_dims) == 1 else str(action_dims)
|
| 36 |
+
|
| 37 |
+
print(f"Trajectories: {num_trajectories}")
|
| 38 |
+
print(f"Action Dim: {action_dim}")
|
| 39 |
+
print(f"Avg. Video Len: {avg_len:.1f}")
|
| 40 |
+
print(f"Median Video Len: {median_len:.1f}")
|
| 41 |
+
|
| 42 |
+
# Generate distribution plot
|
| 43 |
+
plt.figure(figsize=(10, 6))
|
| 44 |
+
plt.hist(lengths, bins=30, color='skyblue', edgecolor='black')
|
| 45 |
+
plt.title(f"Franka Video Length Distribution (N={num_trajectories})")
|
| 46 |
+
plt.xlabel("Number of Frames")
|
| 47 |
+
plt.ylabel("Frequency")
|
| 48 |
+
plt.grid(axis='y', alpha=0.75)
|
| 49 |
+
|
| 50 |
+
save_path = "/storage/ice-shared/ae8803che/hxue/data/world_model/results/stats/franka_dist.png"
|
| 51 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
| 52 |
+
plt.savefig(save_path)
|
| 53 |
+
print(f"Distribution plot saved to {save_path}")
|
wm/test/test_franka_load.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from wm.dataset.dataset import RoboticsDatasetWrapper
|
| 3 |
+
from wm.dataset.data_config import get_config_by_name
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
def test_franka_load():
|
| 7 |
+
dataset = RoboticsDatasetWrapper.get_dataset("franka", seq_len=10)
|
| 8 |
+
|
| 9 |
+
print(f"Dataset size: {len(dataset)}")
|
| 10 |
+
|
| 11 |
+
# Load first sample
|
| 12 |
+
sample = dataset[0]
|
| 13 |
+
|
| 14 |
+
print(f"Video shape: {sample['obs'].shape}") # (T, C, H, W)
|
| 15 |
+
print(f"Actions shape: {sample['action'].shape}") # (T, action_dim)
|
| 16 |
+
|
| 17 |
+
if __name__ == "__main__":
|
| 18 |
+
test_franka_load()
|
wm/utils/__pycache__/visualization.cpython-39.pyc
CHANGED
|
Binary files a/wm/utils/__pycache__/visualization.cpython-39.pyc and b/wm/utils/__pycache__/visualization.cpython-39.pyc differ
|
|
|