import os import gymnasium as gym import gym_pusht import torch import imageio from huggingface_hub import hf_hub_download import safetensors.torch from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.configs.policies import PreTrainedConfig from lerobot.policies.factory import make_pre_post_processors from lerobot.envs.utils import preprocess_observation def main(): # 1. Download checkpoint and load config print("Downloading config from lerobot/diffusion_pusht...") cfg = PreTrainedConfig.from_pretrained('lerobot/diffusion_pusht') # We override the observation.image feature shape to (3, 384, 384) to match the environment defaults, # which instantiates the model's pos_grid as [144, 2] instead of [9, 2] (checkpoint size). cfg.input_features['observation.image'].shape = (3, 384, 384) # Build the DiffusionPolicy print("Building DiffusionPolicy...") policy = DiffusionPolicy(cfg) print("Initial pos_grid shape in model:", policy.diffusion.rgb_encoder.pool.pos_grid.shape) # Load weights with strict=False print("Downloading and loading safetensors model weights...") model_file = hf_hub_download(repo_id='lerobot/diffusion_pusht', filename='model.safetensors') state_dict = safetensors.torch.load_file(model_file) policy.load_state_dict(state_dict, strict=False) # 2. Patch the pos_grid shape mismatch so inference works print("Patching the pos_grid shape mismatch...") checkpoint_pos_grid = state_dict['diffusion.rgb_encoder.pool.pos_grid'] policy.diffusion.rgb_encoder.pool.register_buffer('pos_grid', checkpoint_pos_grid) print("Patched pos_grid shape in model:", policy.diffusion.rgb_encoder.pool.pos_grid.shape) # Move policy to correct device and set to eval mode policy.to(cfg.device) policy.eval() # 3. Create preprocessor / postprocessor with the extracted dataset stats print("Creating preprocessor and postprocessor...") dataset_stats = { 'observation.image': { 'mean': state_dict['normalize_inputs.buffer_observation_image.mean'], 'std': state_dict['normalize_inputs.buffer_observation_image.std'], }, 'observation.state': { 'max': state_dict['normalize_inputs.buffer_observation_state.max'], 'min': state_dict['normalize_inputs.buffer_observation_state.min'], }, 'action': { 'max': state_dict['normalize_targets.buffer_action.max'], 'min': state_dict['normalize_targets.buffer_action.min'], } } preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_stats) # 4. Instantiate the gym environment print("Creating PushT environment...") env = gym.make('gym_pusht/PushT-v0', render_mode='rgb_array', obs_type='pixels_agent_pos') # Reset env and cache initial frame policy.reset() obs, info = env.reset() frames = [env.render()] # Run rollout for 300 steps print("Running 300 steps rollout...") for step in range(300): # Format observations to LeRobot format obs_t = preprocess_observation(obs) obs_t = preprocessor(obs_t) # Select action with torch.no_grad(): action = policy.select_action(obs_t) action = postprocessor(action) # Extract numpy action and apply to env (drop batch dimension) action_numpy = action.to("cpu").numpy()[0] obs, reward, terminated, truncated, info = env.step(action_numpy) # Render frame frame = env.render() frames.append(frame) if terminated or truncated: obs, info = env.reset() # Close env env.close() # 5. Save the frames as pusht_policy.mp4 print("Saving video to pusht_policy.mp4...") imageio.mimsave("pusht_policy.mp4", frames, fps=10) print("Done! Video saved successfully.") if __name__ == "__main__": main()