Spaces:
Running
Running
| <html lang="en"> | |
| <head> | |
| <meta charset="UTF-8"> | |
| <meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
| <title>LeRobot Diffusion Policy - PushT-v0</title> | |
| <script src="https://cdn.jsdelivr.net/npm/@tailwindcss/browser@4"></script> | |
| <link rel="preconnect" href="https://fonts.googleapis.com"> | |
| <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin> | |
| <link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&family=JetBrains+Mono:wght@400;500&display=swap" rel="stylesheet"> | |
| <link href="https://cdnjs.cloudflare.com/ajax/libs/prism/1.29.0/themes/prism-tomorrow.min.css" rel="stylesheet" /> | |
| <style> | |
| body { | |
| font-family: 'Inter', sans-serif; | |
| } | |
| pre, code { | |
| font-family: 'JetBrains Mono', monospace ; | |
| } | |
| </style> | |
| </head> | |
| <body class="bg-slate-950 text-slate-100 min-h-screen selection:bg-indigo-500 selection:text-white"> | |
| <header class="border-b border-slate-800 bg-slate-900/50 backdrop-blur sticky top-0 z-50"> | |
| <div class="max-w-7xl mx-auto px-4 sm:px-6 lg:px-8 h-16 flex items-center justify-between"> | |
| <div class="flex items-center space-x-3"> | |
| <span class="flex h-3 w-3 relative"> | |
| <span class="animate-ping absolute inline-flex h-full w-full rounded-full bg-emerald-400 opacity-75"></span> | |
| <span class="relative inline-flex rounded-full h-3 w-3 bg-emerald-500"></span> | |
| </span> | |
| <h1 class="text-lg font-semibold tracking-tight text-white">LeRobot PushT Evaluation</h1> | |
| </div> | |
| <div class="flex items-center space-x-2"> | |
| <span class="px-2.5 py-1 text-xs font-medium rounded-md bg-indigo-500/10 text-indigo-400 border border-indigo-500/20">Diffusion Policy</span> | |
| <span class="px-2.5 py-1 text-xs font-medium rounded-md bg-slate-800 text-slate-400 border border-slate-700">Gymnasium</span> | |
| </div> | |
| </div> | |
| </header> | |
| <main class="max-w-7xl mx-auto px-4 sm:px-6 lg:px-8 py-10 space-y-12"> | |
| <div class="grid grid-cols-1 lg:grid-cols-12 gap-8 items-start"> | |
| <div class="lg:col-span-5 space-y-4"> | |
| <div class="bg-slate-900 border border-slate-800 rounded-xl overflow-hidden shadow-2xl p-4"> | |
| <div class="flex items-center justify-between mb-3 px-1"> | |
| <span class="text-xs font-medium uppercase tracking-wider text-slate-400">Policy Rollout (300 Steps)</span> | |
| <span class="text-xs text-slate-500">pusht_policy.mp4</span> | |
| </div> | |
| <div class="relative aspect-square rounded-lg overflow-hidden bg-white border border-slate-700"> | |
| <video class="w-full h-full object-contain" autoplay loop muted controls playsinline> | |
| <source src="pusht_policy.mp4" type="video/mp4"> | |
| Your browser does not support the video tag. | |
| </video> | |
| </div> | |
| </div> | |
| </div> | |
| <div class="lg:col-span-7 space-y-6 lg:pt-2"> | |
| <div> | |
| <h2 class="text-3xl font-bold tracking-tight text-white mb-3">Autonomous Multi-Modal Manipulation</h2> | |
| <p class="text-slate-400 leading-relaxed"> | |
| This Space showcases a trained <strong>Diffusion Policy</strong> operating within the <code>PushT-v0</code> simulation environment using Hugging Face's <strong>LeRobot</strong> ecosystem. The agent learns multi-modal trajectories to effectively guide the gray T-shaped block completely into the target green silhouette zone. | |
| </p> | |
| </div> | |
| <div class="grid grid-cols-1 sm:grid-cols-2 gap-4"> | |
| <div class="p-4 bg-slate-900/60 border border-slate-800 rounded-xl"> | |
| <h3 class="text-xs font-semibold text-slate-400 uppercase tracking-wider mb-2">Observation Space</h3> | |
| <p class="text-sm font-medium text-slate-200">Pixels & Agent Position</p> | |
| <p class="text-xs text-slate-500 mt-1">Image Shape: (3, 384, 384)</p> | |
| </div> | |
| <div class="p-4 bg-slate-900/60 border border-slate-800 rounded-xl"> | |
| <h3 class="text-xs font-semibold text-slate-400 uppercase tracking-wider mb-2">Action Space</h3> | |
| <p class="text-sm font-medium text-slate-200">2D Continuous Control</p> | |
| <p class="text-xs text-slate-500 mt-1">End-effector delta position</p> | |
| </div> | |
| <div class="p-4 bg-slate-900/60 border border-slate-800 rounded-xl"> | |
| <h3 class="text-xs font-semibold text-slate-400 uppercase tracking-wider mb-2">Model Source</h3> | |
| <p class="text-sm font-medium text-slate-200">lerobot/diffusion_pusht</p> | |
| <p class="text-xs text-slate-500 mt-1">Pre-trained Checkpoint via HF Hub</p> | |
| </div> | |
| <div class="p-4 bg-slate-900/60 border border-slate-800 rounded-xl"> | |
| <h3 class="text-xs font-semibold text-slate-400 uppercase tracking-wider mb-2">Pipeline Optimization</h3> | |
| <p class="text-sm font-medium text-emerald-400">Dynamic Buffer Patching</p> | |
| <p class="text-xs text-slate-500 mt-1">Overrides state dict pos_grid mismatch</p> | |
| </div> | |
| </div> | |
| <div class="p-4 bg-amber-500/5 border border-amber-500/20 rounded-xl flex space-x-3"> | |
| <svg class="h-5 w-5 text-amber-500 shrink-0 mt-0.5" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor"> | |
| <path stroke-linecap="round" stroke-linejoin="round" d="M12 9v3.75m-9.303 3.376c-.866 1.5.217 3.374 1.948 3.374h14.71c1.73 0 2.813-1.874 1.948-3.374L13.949 3.378c-.866-1.5-3.032-1.5-3.898 0L2.697 16.126zM12 15.75h.007v.008H12v-.008z" /> | |
| </svg> | |
| <div class="text-xs text-slate-400 leading-relaxed"> | |
| <strong class="text-slate-200 font-medium">Engineering Note:</strong> The default environment instantiation sets the visual input size to 384×384 pixels, forcing the model's position grid (<code class="text-amber-400">pos_grid</code>) to shape <code class="text-amber-400">[144, 2]</code>. To preserve checkpoint compatibility, the execution engine explicitly overwrites and re-registers the pre-trained token configuration <code class="text-slate-300">[9, 2]</code> buffer layout right after model initialization. | |
| </div> | |
| </div> | |
| </div> | |
| </div> | |
| <hr class="border-slate-800" /> | |
| <div class="space-y-4"> | |
| <div> | |
| <h2 class="text-xl font-bold text-white tracking-tight">Deployment & Rollout Script</h2> | |
| <p class="text-sm text-slate-400 mt-1">The clean implementation pipeline used to compile dataset statistics, patch the architecture shapes, step through environment dynamics, and generate the rollout asset.</p> | |
| </div> | |
| <div class="relative bg-slate-900 border border-slate-800 rounded-xl overflow-hidden shadow-xl max-h-[600px] overflow-y-auto"> | |
| <div class="sticky top-0 bg-slate-900 border-b border-slate-800 px-4 py-2 flex items-center justify-between text-xs text-slate-400 z-10"> | |
| <span class="font-mono">run_pusht.py</span> | |
| <button id="copyBtn" class="hover:text-white transition flex items-center space-x-1 cursor-pointer"> | |
| <svg class="h-3.5 w-3.5" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor"> | |
| <path stroke-linecap="round" stroke-linejoin="round" d="M15.75 17.25v3.375c0 .621-.504 1.125-1.125 1.125h-9.75a1.125 1.125 0 0 1-1.125-1.125V7.875c0-.621.504-1.125 1.125-1.125H6.75a9.06 9.06 0 0 1 1.5.124m7.5 10.376A8.965 8.965 0 0 0 12 12.75a8.965 8.965 0 0 0-3.75 4.625M18 4.75V3.375c0-.621-.504-1.125-1.125-1.125h-9.75a1.125 1.125 0 0 0-1.125 1.125V4.75m12.75 0V19.5a1.125 1.125 0 0 1-1.125 1.125H18M9 4.75v1.5a1.125 1.125 0 0 0 1.125 1.125h3.75A1.125 1.125 0 0 0 15 6.25v-1.5M9 4.75h6" /> | |
| </svg> | |
| <span>Copy Code</span> | |
| </button> | |
| </div> | |
| <pre class="m-0 p-4 bg-slate-900 text-sm"><code class="language-python" id="codeBlock">import os | |
| import gymnasium as gym | |
| import gym_pusht | |
| import torch | |
| import imageio | |
| from huggingface_hub import hf_hub_download | |
| import safetensors.torch | |
| from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy | |
| from lerobot.configs.policies import PreTrainedConfig | |
| from lerobot.policies.factory import make_pre_post_processors | |
| from lerobot.envs.utils import preprocess_observation | |
| def main(): | |
| # 1. Download checkpoint and load config | |
| print("Downloading config from lerobot/diffusion_pusht...") | |
| cfg = PreTrainedConfig.from_pretrained('lerobot/diffusion_pusht') | |
| # We override the observation.image feature shape to (3, 384, 384) to match the environment defaults, | |
| # which instantiates the model's pos_grid as [144, 2] instead of [9, 2] (checkpoint size). | |
| cfg.input_features['observation.image'].shape = (3, 384, 384) | |
| # Build the DiffusionPolicy | |
| print("Building DiffusionPolicy...") | |
| policy = DiffusionPolicy(cfg) | |
| print("Initial pos_grid shape in model:", policy.diffusion.rgb_encoder.pool.pos_grid.shape) | |
| # Load weights with strict=False | |
| print("Downloading and loading safetensors model weights...") | |
| model_file = hf_hub_download(repo_id='lerobot/diffusion_pusht', filename='model.safetensors') | |
| state_dict = safetensors.torch.load_file(model_file) | |
| policy.load_state_dict(state_dict, strict=False) | |
| # 2. Patch the pos_grid shape mismatch so inference works | |
| print("Patching the pos_grid shape mismatch...") | |
| checkpoint_pos_grid = state_dict['diffusion.rgb_encoder.pool.pos_grid'] | |
| policy.diffusion.rgb_encoder.pool.register_buffer('pos_grid', checkpoint_pos_grid) | |
| print("Patched pos_grid shape in model:", policy.diffusion.rgb_encoder.pool.pos_grid.shape) | |
| # Move policy to correct device and set to eval mode | |
| policy.to(cfg.device) | |
| policy.eval() | |
| # 3. Create preprocessor / postprocessor with the extracted dataset stats | |
| print("Creating preprocessor and postprocessor...") | |
| dataset_stats = { | |
| 'observation.image': { | |
| 'mean': state_dict['normalize_inputs.buffer_observation_image.mean'], | |
| 'std': state_dict['normalize_inputs.buffer_observation_image.std'], | |
| }, | |
| 'observation.state': { | |
| 'max': state_dict['normalize_inputs.buffer_observation_state.max'], | |
| 'min': state_dict['normalize_inputs.buffer_observation_state.min'], | |
| }, | |
| 'action': { | |
| 'max': state_dict['normalize_targets.buffer_action.max'], | |
| 'min': state_dict['normalize_targets.buffer_action.min'], | |
| } | |
| } | |
| preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_stats) | |
| # 4. Instantiate the gym environment | |
| print("Creating PushT environment...") | |
| env = gym.make('gym_pusht/PushT-v0', render_mode='rgb_array', obs_type='pixels_agent_pos') | |
| # Reset env and cache initial frame | |
| policy.reset() | |
| obs, info = env.reset() | |
| frames = [env.render()] | |
| # Run rollout for 300 steps | |
| print("Running 300 steps rollout...") | |
| for step in range(300): | |
| # Format observations to LeRobot format | |
| obs_t = preprocess_observation(obs) | |
| obs_t = preprocessor(obs_t) | |
| # Select action | |
| with torch.no_grad(): | |
| action = policy.select_action(obs_t) | |
| action = postprocessor(action) | |
| # Extract numpy action and apply to env (drop batch dimension) | |
| action_numpy = action.to("cpu").numpy()[0] | |
| obs, reward, terminated, truncated, info = env.step(action_numpy) | |
| # Render frame | |
| frame = env.render() | |
| frames.append(frame) | |
| if terminated or truncated: | |
| obs, info = env.reset() | |
| # Close env | |
| env.close() | |
| # 5. Save the frames as pusht_policy.mp4 | |
| print("Saving video to pusht_policy.mp4...") | |
| imageio.mimsave("pusht_policy.mp4", frames, fps=10) | |
| print("Done! Video saved successfully.") | |
| if __name__ == "__main__": | |
| main()</code></pre> | |
| </div> | |
| </div> | |
| </main> | |
| <footer class="text-center py-8 text-xs text-slate-600 border-t border-slate-900 mt-12"> | |
| Powered by LeRobot, Gymnasium, and Hugging Face Static Spaces. | |
| </footer> | |
| <script src="https://cdnjs.cloudflare.com/ajax/libs/prism/1.29.0/components/prism-core.min.js"></script> | |
| <script src="https://cdnjs.cloudflare.com/ajax/libs/prism/1.29.0/plugins/autoloader/prism-autoloader.min.js"></script> | |
| <script> | |
| document.getElementById('copyBtn').addEventListener('click', () => { | |
| const code = document.getElementById('codeBlock').innerText; | |
| navigator.clipboard.writeText(code).then(() => { | |
| const btnSpan = document.querySelector('#copyBtn span'); | |
| btnSpan.textContent = 'Copied!'; | |
| setTimeout(() => { btnSpan.textContent = 'Copy Code'; }, 2000); | |
| }); | |
| }); | |
| </script> | |
| </body> | |
| </html> |