Spaces:
Running
Running
| import os | |
| import gymnasium as gym | |
| import gym_pusht | |
| import torch | |
| import imageio | |
| from huggingface_hub import hf_hub_download | |
| import safetensors.torch | |
| from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy | |
| from lerobot.configs.policies import PreTrainedConfig | |
| from lerobot.policies.factory import make_pre_post_processors | |
| from lerobot.envs.utils import preprocess_observation | |
| def main(): | |
| # 1. Download checkpoint and load config | |
| print("Downloading config from lerobot/diffusion_pusht...") | |
| cfg = PreTrainedConfig.from_pretrained('lerobot/diffusion_pusht') | |
| # We override the observation.image feature shape to (3, 384, 384) to match the environment defaults, | |
| # which instantiates the model's pos_grid as [144, 2] instead of [9, 2] (checkpoint size). | |
| cfg.input_features['observation.image'].shape = (3, 384, 384) | |
| # Build the DiffusionPolicy | |
| print("Building DiffusionPolicy...") | |
| policy = DiffusionPolicy(cfg) | |
| print("Initial pos_grid shape in model:", policy.diffusion.rgb_encoder.pool.pos_grid.shape) | |
| # Load weights with strict=False | |
| print("Downloading and loading safetensors model weights...") | |
| model_file = hf_hub_download(repo_id='lerobot/diffusion_pusht', filename='model.safetensors') | |
| state_dict = safetensors.torch.load_file(model_file) | |
| policy.load_state_dict(state_dict, strict=False) | |
| # 2. Patch the pos_grid shape mismatch so inference works | |
| print("Patching the pos_grid shape mismatch...") | |
| checkpoint_pos_grid = state_dict['diffusion.rgb_encoder.pool.pos_grid'] | |
| policy.diffusion.rgb_encoder.pool.register_buffer('pos_grid', checkpoint_pos_grid) | |
| print("Patched pos_grid shape in model:", policy.diffusion.rgb_encoder.pool.pos_grid.shape) | |
| # Move policy to correct device and set to eval mode | |
| policy.to(cfg.device) | |
| policy.eval() | |
| # 3. Create preprocessor / postprocessor with the extracted dataset stats | |
| print("Creating preprocessor and postprocessor...") | |
| dataset_stats = { | |
| 'observation.image': { | |
| 'mean': state_dict['normalize_inputs.buffer_observation_image.mean'], | |
| 'std': state_dict['normalize_inputs.buffer_observation_image.std'], | |
| }, | |
| 'observation.state': { | |
| 'max': state_dict['normalize_inputs.buffer_observation_state.max'], | |
| 'min': state_dict['normalize_inputs.buffer_observation_state.min'], | |
| }, | |
| 'action': { | |
| 'max': state_dict['normalize_targets.buffer_action.max'], | |
| 'min': state_dict['normalize_targets.buffer_action.min'], | |
| } | |
| } | |
| preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_stats) | |
| # 4. Instantiate the gym environment | |
| print("Creating PushT environment...") | |
| env = gym.make('gym_pusht/PushT-v0', render_mode='rgb_array', obs_type='pixels_agent_pos') | |
| # Reset env and cache initial frame | |
| policy.reset() | |
| obs, info = env.reset() | |
| frames = [env.render()] | |
| # Run rollout for 300 steps | |
| print("Running 300 steps rollout...") | |
| for step in range(300): | |
| # Format observations to LeRobot format | |
| obs_t = preprocess_observation(obs) | |
| obs_t = preprocessor(obs_t) | |
| # Select action | |
| with torch.no_grad(): | |
| action = policy.select_action(obs_t) | |
| action = postprocessor(action) | |
| # Extract numpy action and apply to env (drop batch dimension) | |
| action_numpy = action.to("cpu").numpy()[0] | |
| obs, reward, terminated, truncated, info = env.step(action_numpy) | |
| # Render frame | |
| frame = env.render() | |
| frames.append(frame) | |
| if terminated or truncated: | |
| obs, info = env.reset() | |
| # Close env | |
| env.close() | |
| # 5. Save the frames as pusht_policy.mp4 | |
| print("Saving video to pusht_policy.mp4...") | |
| imageio.mimsave("pusht_policy.mp4", frames, fps=10) | |
| print("Done! Video saved successfully.") | |
| if __name__ == "__main__": | |
| main() | |