{ "cells": [ { "cell_type": "markdown", "id": "6adc68e0-a943-44ab-9af5-4bc62cc19f34", "metadata": { "editable": true, "id": "6adc68e0-a943-44ab-9af5-4bc62cc19f34", "tags": [] }, "source": [ "![MuJoCo banner](https://raw.githubusercontent.com/google-deepmind/mujoco/main/banner.png)\n", "\n", "#

Rollout Tutorial

\n", "\n", "This notebook provides a tutorial for [**MuJoCo** physics](https://github.com/google-deepmind/mujoco#readme), using the native Python bindings.\n", "\n", "This notebook describes the `rollout` module included in the MuJoCo Python library. It performs simulation \"rollouts\" with an underlying C++ function. The rollouts can be multithreaded.\n", "\n", "Below, the usage of each argument is explained with examples. Then some examples for advanced use cases are provided. Finally, `rollout` is benchmarked against pure python and MJX.\n", "\n", "Note the benchmarks were designed to run on >16 thread CPU and an RTX 4090 or A100. They do not run in a reasonable amount of time on a typical free colab runtime.\n", "\n", "" ] }, { "cell_type": "markdown", "id": "5d8a6604-0948-4a42-a48d-249c7f0c462b", "metadata": { "editable": true, "id": "5d8a6604-0948-4a42-a48d-249c7f0c462b", "tags": [] }, "source": [ "# All Imports" ] }, { "cell_type": "code", "execution_count": 0, "id": "0f9fbad1-59d0-40ac-b2b6-99f37313670f", "metadata": { "editable": true, "id": "0f9fbad1-59d0-40ac-b2b6-99f37313670f", "tags": [ "hide-input" ] }, "outputs": [], "source": [ "!pip install mujoco\n", "!pip install mujoco_mjx\n", "!pip install brax\n", "\n", "# Set up GPU rendering.\n", "#from google.colab import files\n", "import distutils.util\n", "import os\n", "import subprocess\n", "if subprocess.run('nvidia-smi').returncode:\n", " raise RuntimeError(\n", " 'Cannot communicate with GPU. '\n", " 'Make sure you are using a GPU Colab runtime. '\n", " 'Go to the Runtime menu and select Choose runtime type.')\n", "\n", "# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.\n", "# This is usually installed as part of an Nvidia driver package, but the Colab\n", "# kernel doesn't install its driver via APT, and as a result the ICD is missing.\n", "# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)\n", "NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'\n", "if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):\n", " with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:\n", " f.write(\"\"\"{\n", " \"file_format_version\" : \"1.0.0\",\n", " \"ICD\" : {\n", " \"library_path\" : \"libEGL_nvidia.so.0\"\n", " }\n", "}\n", "\"\"\")\n", "\n", "# Configure MuJoCo to use the EGL rendering backend (requires GPU)\n", "print('Setting environment variable to use GPU rendering:')\n", "%env MUJOCO_GL=egl\n", "\n", "# Check if installation was successful.\n", "try:\n", " print('Checking that the installation succeeded:')\n", " import mujoco\n", " from mujoco import rollout\n", " from mujoco import mjx\n", " mujoco.MjModel.from_xml_string('')\n", "except Exception as e:\n", " raise e from RuntimeError(\n", " 'Something went wrong during installation. Check the shell output above '\n", " 'for more information.\\n'\n", " 'If using a hosted Colab runtime, make sure you enable GPU acceleration '\n", " 'by going to the Runtime menu and selecting \"Choose runtime type\".')\n", "\n", "print('Installation successful.')\n", "\n", "# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs\n", "xla_flags = os.environ.get('XLA_FLAGS', '')\n", "xla_flags += ' --xla_gpu_triton_gemm_any=True'\n", "os.environ['XLA_FLAGS'] = xla_flags\n", "\n", "# Other imports and helper functions\n", "import copy\n", "import time\n", "from multiprocessing import cpu_count\n", "import threading\n", "import numpy as np\n", "import jax\n", "import jax.numpy as jp\n", "\n", "# Graphics and plotting.\n", "print('Installing mediapy:')\n", "!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)\n", "!pip install -q mediapy\n", "import mediapy as media\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", "\n", "# More legible printing from numpy.\n", "np.set_printoptions(precision=3, suppress=True, linewidth=100)\n", "\n", "# Set the number of threads to the number of cpu's that the multiprocessing module reports\n", "nthread = cpu_count()\n", "\n", "# Get MuJoCo's standard humanoid and humanoid_100 models.\n", "print('Getting MuJoCo humanoid XML description from GitHub:')\n", "!git clone https://github.com/google-deepmind/mujoco\n", "humanoid_path = 'mujoco/model/humanoid/humanoid.xml'\n", "humanoid100_path = 'mujoco/model/humanoid/humanoid100.xml'\n", "print('Getting hopper XML description from GitHub:')\n", "!git clone https://github.com/google-deepmind/dm_control\n", "hopper_path ='dm_control/dm_control/suite/hopper.xml'\n", "\n", "# clear installation printouts\n", "from IPython.display import clear_output\n", "clear_output()" ] }, { "cell_type": "markdown", "id": "fc69d0f4", "metadata": { "id": "fc69d0f4" }, "source": [ "# Helper Functions" ] }, { "cell_type": "code", "execution_count": 0, "id": "082482c7", "metadata": { "editable": true, "id": "082482c7", "tags": [ "hide-input" ] }, "outputs": [], "source": [ "def get_state(model, data, nbatch=1):\n", " full_physics = mujoco.mjtState.mjSTATE_FULLPHYSICS\n", " state = np.zeros((mujoco.mj_stateSize(model, full_physics),))\n", " mujoco.mj_getState(model, data, state, full_physics)\n", " return np.tile(state, (nbatch, 1))\n", "\n", "def xy_grid(nbatch, ncols=10, spacing=0.05):\n", " nrows = nbatch // ncols\n", " assert nbatch == nrows * ncols\n", " xmax = (nrows-1)*spacing/2\n", " rows = np.linspace(-xmax, xmax, nrows)\n", " ymax = (ncols-1)*spacing/2\n", " cols = np.linspace(-ymax, ymax, ncols)\n", " x, y = np.meshgrid(rows, cols)\n", " return np.stack((x.flatten(), y.flatten())).T\n", "\n", "def benchmark(f, x_list=[None], ntiming=1, f_init=None):\n", " x_times_list = []\n", " for x in x_list:\n", " times = []\n", " for i in range(ntiming):\n", " if f_init is not None:\n", " x_init = f_init(x)\n", "\n", " start = time.perf_counter()\n", " if f_init is not None:\n", " f(x, x_init)\n", " else:\n", " f(x)\n", " end = time.perf_counter()\n", " times.append(end - start)\n", "\n", " x_times_list.append(np.mean(times))\n", " return np.array(x_times_list)\n", "\n", "def render_many(model, data, state, framerate, camera=-1, shape=(480, 640),\n", " transparent=False, light_pos=None):\n", " nbatch = state.shape[0]\n", "\n", " if not isinstance(model, mujoco.MjModel):\n", " model = list(model)\n", "\n", " if isinstance(model, list) and len(model) == 1:\n", " model = model * nbatch\n", " elif isinstance(model, list):\n", " assert len(model) == nbatch\n", " else:\n", " model = [model] * nbatch\n", "\n", " # Visual options\n", " vopt = mujoco.MjvOption()\n", " vopt.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = transparent\n", " pert = mujoco.MjvPerturb() # Empty MjvPerturb object\n", " catmask = mujoco.mjtCatBit.mjCAT_DYNAMIC\n", "\n", " # Simulate and render.\n", " frames = []\n", " with mujoco.Renderer(model[0], *shape) as renderer:\n", " for i in range(state.shape[1]):\n", " if len(frames) < i * model[0].opt.timestep * framerate:\n", " for j in range(state.shape[0]):\n", " mujoco.mj_setState(model[j], data, state[j, i, :],\n", " mujoco.mjtState.mjSTATE_FULLPHYSICS)\n", " mujoco.mj_forward(model[j], data)\n", "\n", " # Use first model to make the scene, add subsequent models\n", " if j == 0:\n", " renderer.update_scene(data, camera, scene_option=vopt)\n", " else:\n", " mujoco.mjv_addGeoms(model[j], data, vopt, pert, catmask, renderer.scene)\n", "\n", " # Add light, if requested\n", " if light_pos is not None:\n", " light = renderer.scene.lights[renderer.scene.nlight]\n", " light.ambient = [0, 0, 0]\n", " light.attenuation = [1, 0, 0]\n", " light.castshadow = 1\n", " light.cutoff = 45\n", " light.diffuse = [0.8, 0.8, 0.8]\n", " light.dir = [0, 0, -1]\n", " light.type = mujoco.mjtLightType.mjLIGHT_SPOT\n", " light.exponent = 10\n", " light.headlight = 0\n", " light.specular = [0.3, 0.3, 0.3]\n", " light.pos = light_pos\n", " renderer.scene.nlight += 1\n", "\n", " # Render and add the frame.\n", " pixels = renderer.render()\n", " frames.append(pixels)\n", " return frames" ] }, { "cell_type": "markdown", "id": "c0570c2c", "metadata": { "id": "c0570c2c" }, "source": [ "# Using `rollout`\n", "\n", "The `rollout.rollout` function in the `mujoco` Python library runs batches of simulations for a fixed number steps. It can run in single or multi-threaded modes. The speedup over pure Python is significant because `rollout` users can easily enable the usage of a lightweight threadpool.\n", "\n", "Below we load the \"tippe top\", \"humanoid\", and \"humanoid100\" models which will be used in the following usage examples and benchmarks.\n", "\n", "The tippe top is copied from the [tutorial notebook](https://colab.research.google.com/github/google-deepmind/mujoco/blob/main/python/tutorial.ipynb). The humanoid and humanoid100 models are distributed with MuJoCo." ] }, { "cell_type": "code", "execution_count": 0, "id": "849b93e5", "metadata": { "id": "849b93e5" }, "outputs": [], "source": [ "#@title Benchmarked models\n", "tippe_top = \"\"\"\n", "\n", " \n", "\"\"\"\n", "\n", "# Create and initialize top model\n", "top_model = mujoco.MjModel.from_xml_string(tippe_top)\n", "top_data = mujoco.MjData(top_model)\n", "# Set to the state to a spinning top (keyframe 0)\n", "mujoco.mj_resetDataKeyframe(top_model, top_data, 0)\n", "top_state = get_state(top_model, top_data)\n", "\n", "# Create and initialize humanoid model\n", "humanoid_model = mujoco.MjModel.from_xml_path(humanoid_path)\n", "humanoid_data = mujoco.MjData(humanoid_model)\n", "humanoid_data.qvel[2] = 4 # Make the humanoid jump\n", "humanoid_state = get_state(humanoid_model, humanoid_data)\n", "\n", "# Create and initialize humanoid100 model\n", "humanoid100_model = mujoco.MjModel.from_xml_path(humanoid100_path)\n", "humanoid100_data = mujoco.MjData(humanoid100_model)\n", "h100_state = get_state(humanoid100_model, humanoid100_data)\n", "\n", "start = time.time()\n", "top_nstep = int(6 / top_model.opt.timestep)\n", "top_state, _ = rollout.rollout(top_model, top_data, top_state, nstep=top_nstep)\n", "\n", "humanoid_nstep = int(3 / humanoid_model.opt.timestep)\n", "humanoid_state, _ = rollout.rollout(humanoid_model, humanoid_data,\n", " humanoid_state, nstep=humanoid_nstep)\n", "\n", "humanoid100_nstep = int(3 / humanoid100_model.opt.timestep)\n", "h100_state, _ = rollout.rollout(humanoid100_model, humanoid100_data,\n", " h100_state, nstep=humanoid100_nstep)\n", "end = time.time()\n", "\n", "start_render = time.time()\n", "top_frames = render_many(top_model, top_data, top_state, framerate=60, shape=(240, 320))\n", "humanoid_frames = render_many(humanoid_model, humanoid_data, humanoid_state, framerate=120, shape=(240, 320))\n", "humanoid100_frames = render_many(humanoid100_model, humanoid100_data, h100_state, framerate=120, shape=(240, 320))\n", "\n", "# humanoid and humanoid100 are shown at half speed\n", "media.show_video(np.concatenate((top_frames, humanoid_frames, humanoid100_frames), axis=2), fps=60)\n", "end_render = time.time()\n", "\n", "print(f'Rollout took {end-start:.1f} seconds')\n", "print(f'Rendering took {end_render-start_render:.1f} seconds')" ] }, { "cell_type": "markdown", "id": "55d171f7-541b-4441-aa18-da86d6716410", "metadata": { "id": "55d171f7-541b-4441-aa18-da86d6716410" }, "source": [ "## Usage\n", "\n", "It is helpful to read `rollout`'s docstring before beginning. The main takeaways are that `rollout` runs `nbatch` rollouts for `nstep` steps. Each `MjModel` can be different but should be the same up to parameter values. Passing multiple `MjData` enables multithreading, one thread per `MjData`.\n", "Further documentation can be found [here](https://mujoco.readthedocs.io/en/latest/python.html#rollout).\n", "\n", "Next we give usage examples of the most common arguments. The more advanced arguments are discussed in the \"Advanced Usage\" section." ] }, { "cell_type": "code", "execution_count": 0, "id": "9cd2f94a-11df-4247-986c-5a56af69a1f5", "metadata": { "id": "9cd2f94a-11df-4247-986c-5a56af69a1f5" }, "outputs": [], "source": [ "print(rollout.rollout.__doc__)" ] }, { "cell_type": "markdown", "id": "b6f7a094-8352-4b07-99ee-5278e3036cd5", "metadata": { "id": "b6f7a094-8352-4b07-99ee-5278e3036cd5", "tags": [] }, "source": [ "### Example: different initial states\n", "`rollout` is designed to run `nbatch` rollouts in parallel for `nstep` steps. Let's simulate 100 tippe tops with different initial rotation speeds.\n", "\n", "**Note:** Using multithreading with rollout is enabled by passing one MjData per thread, as is done below." ] }, { "cell_type": "code", "execution_count": 0, "id": "849af5f2-9de1-4cb9-bc3a-c9b7acf0e3fe", "metadata": { "id": "849af5f2-9de1-4cb9-bc3a-c9b7acf0e3fe" }, "outputs": [], "source": [ "nbatch = 100 # Simulate this many tops\n", "\n", "# Get nbatch initial states and scale the initial speed of the tippe top using the batch index\n", "top_data = mujoco.MjData(top_model)\n", "mujoco.mj_resetDataKeyframe(top_model, top_data, 0)\n", "initial_states = get_state(top_model, top_data, nbatch)\n", "initial_states[:, -1] *= np.linspace(0.5, 1.5, num=nbatch)\n", "\n", "# Run the rollout\n", "start = time.time()\n", "top_datas = [copy.copy(top_data) for _ in range(nthread)] # 1 MjData per thread\n", "state, sensordata = rollout.rollout(top_model, top_datas, initial_states,\n", " nstep=int(top_nstep*1.5))\n", "end = time.time()\n", "\n", "# Use state to render all the tops at once\n", "start_render = time.time()\n", "framerate = 60\n", "frames = render_many(top_model, top_data, state, framerate, transparent=True)\n", "media.show_video(frames, fps=framerate)\n", "end_render = time.time()\n", "\n", "print(f'Rollout time {end-start:.1f} seconds')\n", "print(f'Rendering time {end_render-start_render:.1f} seconds')" ] }, { "cell_type": "markdown", "id": "aa2cf151-bf9a-4a23-b7fe-6a766979d93f", "metadata": { "id": "aa2cf151-bf9a-4a23-b7fe-6a766979d93f" }, "source": [ "Our model has an angular velocity sensor the middle of the top. Let's plot the response using the `sensordata` array that rollout returns." ] }, { "cell_type": "code", "execution_count": 0, "id": "957b8566-da31-410b-b385-e78241c5247a", "metadata": { "id": "957b8566-da31-410b-b385-e78241c5247a" }, "outputs": [], "source": [ "plt.figure(figsize=(12, 8))\n", "plt.subplot(3,1,1)\n", "for i in range(nbatch): plt.plot(sensordata[i, :, 0])\n", "plt.subplot(3,1,2)\n", "for i in range(nbatch): plt.plot(sensordata[i, :, 1])\n", "plt.subplot(3,1,3)\n", "for i in range(nbatch): plt.plot(sensordata[i, :, 2])\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "58044bc1-f98c-4bbf-a703-40ba075552a0", "metadata": { "id": "58044bc1-f98c-4bbf-a703-40ba075552a0" }, "source": [ "### Example: different models\n", "100 gray tops is kind of boring. It would be better if they were colorful and different sizes!\n", "\n", "`rollout` supports using different models for each rollout, so long as their dimensions are the same (i.e., floating point parameters can be different). Let's simulate 100 tippe tops with the same initial condition, but different sizes and colors.\n", "\n", "**Note:** Strictly speaking, the models must have the same number of states, controls, degrees of freedom, and sensor outputs. The most common use case is multiple models of the same thing, with different parameter values." ] }, { "cell_type": "code", "execution_count": 0, "id": "7c39e79e-8942-4fea-b306-ea0cb3c826e2", "metadata": { "id": "7c39e79e-8942-4fea-b306-ea0cb3c826e2" }, "outputs": [], "source": [ "# Make 100 tippe tops with different colors and sizes\n", "nbatch = 100\n", "spec = mujoco.MjSpec.from_string(tippe_top)\n", "spec.lights[0].pos[2] = 2\n", "models = []\n", "for i in range(nbatch):\n", " for geom in spec.geoms:\n", " if geom.name in ['ball', 'stem', 'ballast']:\n", " geom.rgba[:3] = np.random.rand(3)\n", " if geom.name == 'stem':\n", " stem_geom = geom\n", " if geom.name == 'ball':\n", " ball_geom = geom\n", "\n", " # Save original geom size\n", " stem_geom_size = np.copy(stem_geom.size)\n", " ball_geom_size = np.copy(ball_geom.size)\n", "\n", " # Scale geoms and compile model\n", " size_scale = 0.4*np.random.rand(1) + 0.75\n", " stem_geom.size *= size_scale\n", " ball_geom.size *= size_scale\n", " models.append(spec.compile())\n", "\n", " # Restore original geom size\n", " stem_geom.size = stem_geom_size\n", " ball_geom.size = ball_geom_size\n", "\n", "# Set the initial states of all the tops, placing them on a grid\n", "top_data = mujoco.MjData(top_model)\n", "mujoco.mj_resetDataKeyframe(top_model, top_data, 0)\n", "initial_states = get_state(top_model, top_data, nbatch)\n", "# index 0 is time, so x and y qpos values are at 1 and 2\n", "initial_states[:, 1:3] = xy_grid(nbatch, ncols=10, spacing=.05)\n", "\n", "\n", "# Run the rollout\n", "start = time.time()\n", "top_datas = [copy.copy(top_data) for _ in range(nthread)]\n", "nstep = int(9 / top_model.opt.timestep)\n", "state, sensordata = rollout.rollout(models, top_datas, initial_states,\n", " nstep=nstep)\n", "end = time.time()\n", "\n", "# Render video\n", "start_render = time.time()\n", "framerate = 60\n", "cam = mujoco.MjvCamera()\n", "mujoco.mjv_defaultCamera(cam)\n", "cam.distance = 0.2\n", "cam.azimuth = 135\n", "cam.elevation = -25\n", "cam.lookat = [.2, -.2, 0.07]\n", "models[0].vis.global_.fovy = 60\n", "frames = render_many(models, top_data, state, framerate, camera=cam)\n", "media.show_video(frames, fps=framerate)\n", "end_render = time.time()\n", "\n", "print(f'Rollout time {end-start:.1f} seconds')\n", "print(f'Rendering time {end_render-start_render:.1f} seconds')" ] }, { "cell_type": "markdown", "id": "cf485c08-72be-4169-89b6-9d93df8ebbe3", "metadata": { "id": "cf485c08-72be-4169-89b6-9d93df8ebbe3" }, "source": [ "Because the models are now different, the measurements of the gyro sensor are not consistent even though the initial state for each rollout was the same." ] }, { "cell_type": "code", "execution_count": 0, "id": "b8a5d3d4-24e7-41a1-b3bd-7b63c1812b03", "metadata": { "id": "b8a5d3d4-24e7-41a1-b3bd-7b63c1812b03" }, "outputs": [], "source": [ "plt.figure(figsize=(12, 8))\n", "plt.subplot(3,1,1)\n", "for i in range(nbatch): plt.plot(sensordata[i, :, 0])\n", "plt.subplot(3,1,2)\n", "for i in range(nbatch): plt.plot(sensordata[i, :, 1])\n", "plt.subplot(3,1,3)\n", "for i in range(nbatch): plt.plot(sensordata[i, :, 2])\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "3841a669-6cd1-427e-a629-20a10a6e3a34", "metadata": { "id": "3841a669-6cd1-427e-a629-20a10a6e3a34" }, "source": [ "### Example: control inputs\n", "Open loop controls can be passed to `rollout` via the `control` argument. If passed, `nstep` no longer needs to be specified as it can be inferred from the size of `control`.\n", "\n", "Below we simulate 100 of the flailing humanoids from the [tutorial notebook](https://colab.research.google.com/github/google-deepmind/mujoco/blob/main/python/tutorial.ipynb). Each humanoid uses a different control signal." ] }, { "cell_type": "code", "execution_count": 0, "id": "2a184873-8d24-45da-b444-8d21f5dcd733", "metadata": { "id": "2a184873-8d24-45da-b444-8d21f5dcd733" }, "outputs": [], "source": [ "# Episode parameters.\n", "duration = 3 # (seconds)\n", "framerate = 120 # (Hz)\n", "\n", "# Generate 100 different control sequences\n", "nbatch = 100\n", "nstep = int(duration / humanoid_model.opt.timestep)\n", "times = np.linspace(0.0, duration, nstep)\n", "ctrl_phase = 2 * np.pi * np.random.rand(nbatch, 1, humanoid_model.nu)\n", "control = np.sin((2 * np.pi * times).reshape(nstep, 1) + ctrl_phase)\n", "\n", "# Make initial states\n", "humanoid_data = mujoco.MjData(humanoid_model)\n", "humanoid_data.qvel[2] = 4 # Make the humanoid jump\n", "initial_states = get_state(humanoid_model, humanoid_data, nbatch)\n", "# index 0 is time, so x and y qpos values are at 1 and 2\n", "initial_states[:, 1:3] = xy_grid(nbatch, ncols=10, spacing=1.0)\n", "\n", "\n", "# Run the rollout\n", "start = time.time()\n", "humanoid_datas = [copy.copy(humanoid_data) for _ in range(nthread)]\n", "state, _ = rollout.rollout(humanoid_model, humanoid_datas,\n", " initial_states, control)\n", "end = time.time()\n", "\n", "# Render the rollout\n", "start_render = time.time()\n", "framerate = 120\n", "cam = mujoco.MjvCamera()\n", "mujoco.mjv_defaultCamera(cam)\n", "cam.distance = 10\n", "cam.azimuth = 45\n", "cam.elevation = -15\n", "cam.lookat = [0, 0, 0]\n", "humanoid_model.vis.global_.fovy = 60\n", "frames = render_many(humanoid_model, humanoid_data, state, framerate,\n", " camera=cam, light_pos=[0, 0, 10])\n", "media.show_video(frames, fps=framerate/2) # Show the video at half speed\n", "end_render = time.time()\n", "\n", "print(f'Rollout time {end-start:.1f} seconds')\n", "print(f'Render time {end_render-start_render:.1f} seconds')" ] }, { "cell_type": "markdown", "id": "4d89e4fb-7711-4d23-8ff3-eb5030fa8bf7", "metadata": { "id": "4d89e4fb-7711-4d23-8ff3-eb5030fa8bf7" }, "source": [ "`rollout`'s `control_spec` argument can be used to indicate `control` contains values for actuators, generalized forces, cartesian forces, mocap poses, and/or the activation/deactivation of equality constraints. Internally, this is managed through [mj_setState](https://mujoco.readthedocs.io/en/stable/APIreference/APIfunctions.html#mj-setstate) and `control_spec` corresponds to `mj_setState`'s `spec` argument.\n", "\n", "Let's try applying cartesian forces in addition to the control inputs. This will make the humanoids look like they are being dragged while waving their limbs." ] }, { "cell_type": "code", "execution_count": 0, "id": "4b02bb61-912d-47de-a956-aadfcd4c5cd5", "metadata": { "id": "4b02bb61-912d-47de-a956-aadfcd4c5cd5" }, "outputs": [], "source": [ "xfrc_size = mujoco.mj_stateSize(humanoid_model, mujoco.mjtState.mjSTATE_XFRC_APPLIED)\n", "xfrc = np.zeros((nbatch, nstep, xfrc_size))\n", "head_id = humanoid_model.body('head').id\n", "\n", "# Apply a constant but different force to each model\n", "force = np.random.normal(scale=150.0, size=(nbatch, 1, 3))\n", "force[:,:,2] = 150 # Fixed vertical force\n", "xfrc[:, :, 3*head_id:3*head_id+3] = force\n", "\n", "control_xfrc = np.concatenate((control, xfrc), axis=2)\n", "control_spec = mujoco.mjtState.mjSTATE_XFRC_APPLIED.value\n", "\n", "start = time.time()\n", "state, _ = rollout.rollout(humanoid_model, humanoid_datas,\n", " initial_states, xfrc, control_spec=control_spec)\n", "end = time.time()\n", "\n", "start_render = time.time()\n", "frames = render_many(humanoid_model, humanoid_data, state, framerate,\n", " camera=cam, light_pos=[0, 0, 10])\n", "media.show_video(frames, fps=framerate/2) # Show the video at half speed\n", "end_render = time.time()\n", "\n", "print(f'Rollout time {end-start:.1f} seconds')\n", "print(f'Render time {end_render-start_render:.1f} seconds')" ] }, { "cell_type": "markdown", "id": "0961c3ec-a691-4875-9a55-227a3d29c472", "metadata": { "id": "0961c3ec-a691-4875-9a55-227a3d29c472" }, "source": [ "# Advanced usage" ] }, { "cell_type": "markdown", "id": "VfYIyXWcLKfg", "metadata": { "id": "VfYIyXWcLKfg" }, "source": [ "## skip_checks\n", "\n", "By default rollout performs many checks on the dimensions of its arguments. This it allows it to infer dimensions such as `nbatch` and `nstep`, tile arguments that were not fully specified, and allocate the returned `state` and `sensordata` arrays.\n", "\n", "However, these check take time, particularly if `state` and `sensordata` are large or if there are many models and `nstep` is low. So advanced users may want to use the `skip_checks=True` argument in order to achieve additional performance.\n", "\n", "If used, certain arguments become non-optional, and all signals must be fully defined (no implicit tiling). In particular:\n", "* `model` must be a list of length `nbatch`\n", "* `data` must be a list of length `nthread`\n", "* `nstep` must be specified\n", "* `initial_state` must be an array of shape `nbatch x nstate`\n", "* `control` is optional, but if passed must be an array of shape `nbatch x nstep x ncontrol`\n", "* `state` is optional, but must be passed if state is to be returned and must be of shape `nbatch x nstep x nstate`\n", "* `sensordata` is optional, but must be passed if sensor data is to be returned and must be of shape `nbatch x nstep x nsensordata`\n", "\n", "As an extreme example, we pass 10,000 humanoid models to `rollout` and simulate 1 step each with and without checks." ] }, { "cell_type": "code", "execution_count": 0, "id": "d02cc8e8-63cd-4852-ab3c-364a18025a95", "metadata": { "id": "d02cc8e8-63cd-4852-ab3c-364a18025a95" }, "outputs": [], "source": [ "nbatch = 1000\n", "nstep = [1, 10, 100, 500]\n", "ntiming = 5\n", "\n", "top_data = mujoco.MjData(top_model)\n", "mujoco.mj_resetDataKeyframe(top_model, top_data, 0)\n", "top_datas = [copy.copy(top_data) for _ in range(nthread)]\n", "initial_state = get_state(top_model, top_data)\n", "initial_state_tiled = get_state(top_model, top_data, nbatch)\n", "\n", "# Note: state, sensordata array automatically allocated and return\n", "def rollout_with_checks(nstep):\n", " state, sensordata = rollout.rollout([top_model]*nbatch, top_datas, initial_state, nstep=nstep)\n", "\n", "# Note: state, sensordata arrays have to be preallocated\n", "state = None\n", "sensordata = None\n", "def rollout_skip_checks(nstep):\n", " # Note initial state must be tiled\n", " rollout.rollout([top_model]*nbatch, top_datas, initial_state_tiled, nstep=nstep,\n", " state=state, sensordata=sensordata, skip_checks=True)\n", "\n", "t_with_checks = benchmark(lambda x: rollout_with_checks(x), nstep, ntiming=ntiming)\n", "t_skip_checks = benchmark(lambda x: rollout_skip_checks(x), nstep, ntiming=ntiming)\n", "\n", "steps_per_second = (nbatch * np.array(nstep)) / np.array(t_with_checks)\n", "steps_per_second_skip_checks = (nbatch * np.array(nstep)) / np.array(t_skip_checks)\n", "\n", "plt.loglog(nstep, steps_per_second, label='with checks')\n", "plt.loglog(nstep, steps_per_second_skip_checks, label='skip checks')\n", "plt.ylabel('steps per second')\n", "plt.xlabel('nstep')\n", "ticker = matplotlib.ticker.FuncFormatter(lambda x, p: format(int(x), ','))\n", "plt.gca().yaxis.set_minor_formatter(ticker)\n", "plt.legend()\n", "plt.grid(True, which=\"both\", axis=\"both\")" ] }, { "cell_type": "markdown", "id": "92627030-4726-4689-be8b-f1ba75905104", "metadata": { "id": "92627030-4726-4689-be8b-f1ba75905104" }, "source": [ "As expected, as `nstep` increases, the benefits of using skip checks fades quickly. However, at low nstep and high batch sizes, it can make a significant difference.\n", "\n", "Notice that the version with checks can use the non-tiled `initial_state`, however the skip checks version must used the tiled version, `initial_state_tiled`." ] }, { "cell_type": "markdown", "id": "d32a77b5-24bd-4d17-80ac-15cc4d03731c", "metadata": { "id": "d32a77b5-24bd-4d17-80ac-15cc4d03731c" }, "source": [ "## Reusing threadpools (`Rollout` class)\n", "\n", "The `rollout` module provided the class `Rollout` in addition to the method `rollout`. The class `Rollout` is designed allow safe reuse of the internally managed thread pool.\n", "\n", "Reuse can speed things up considerably when rollouts are short. Let's find out how the speedup changes for the tippe top model by rolling it out with increasing numbers of steps." ] }, { "cell_type": "code", "execution_count": 0, "id": "dd05bbdf-f389-4e4e-b389-d47fe976cb49", "metadata": { "id": "dd05bbdf-f389-4e4e-b389-d47fe976cb49" }, "outputs": [], "source": [ "nbatch = 100\n", "nsteps = [2**i for i in [2, 3, 4, 5, 6, 7]]\n", "ntiming = 5\n", "\n", "top_data = mujoco.MjData(top_model)\n", "mujoco.mj_resetDataKeyframe(top_model, top_data, 0)\n", "top_datas = [copy.copy(top_data) for _ in range(nthread)]\n", "\n", "initial_states = get_state(top_model, top_data, nbatch)\n", "\n", "def rollout_method(nstep):\n", " for i in range(20):\n", " rollout.rollout(top_model, top_datas, initial_states, nstep=nstep)\n", "\n", "def rollout_class(nstep):\n", " with rollout.Rollout(nthread=nthread) as rollout_:\n", " for i in range(20):\n", " rollout_.rollout(top_model, top_datas, initial_states, nstep=nstep)\n", "\n", "t_method = benchmark(lambda x: rollout_method(x), nsteps, ntiming)\n", "t_class = benchmark(lambda x: rollout_class(x), nsteps, ntiming)\n", "\n", "plt.loglog(nsteps, nbatch * np.array(nsteps) / t_method, label='recreating threadpools')\n", "plt.loglog(nsteps, nbatch * np.array(nsteps) / t_class, label='reusing threadpool')\n", "plt.xlabel('nstep')\n", "plt.ylabel('steps per second')\n", "ticker = matplotlib.ticker.FuncFormatter(lambda x, p: format(int(x), ','))\n", "plt.gca().yaxis.set_minor_formatter(ticker)\n", "plt.legend()\n", "plt.grid(True, which=\"both\", axis=\"both\")" ] }, { "cell_type": "markdown", "id": "9b3e14a1-71f3-430d-a3c2-1aadcf6c2671", "metadata": { "id": "9b3e14a1-71f3-430d-a3c2-1aadcf6c2671" }, "source": [ "## Reusing threadpools (`rollout` method)\n", "\n", "`rollout` will create and reuse a persistent threadpool by passing `persistent_pool=True`. However there are some caveats.\n", "\n", "First, because `rollout` is a function and does not know when the user is done calling it, the threadpool pool needs to be shutdown manually like this:" ] }, { "cell_type": "code", "execution_count": 0, "id": "b6aa6801", "metadata": { "id": "b6aa6801" }, "outputs": [], "source": [ "nbatch = 1000\n", "nstep = 1\n", "\n", "top_data = mujoco.MjData(top_model)\n", "mujoco.mj_resetDataKeyframe(top_model, top_data, 0)\n", "top_datas = [copy.copy(top_data) for _ in range(nthread)]\n", "\n", "initial_states = get_state(top_model, top_data, nbatch)\n", "\n", "rollout.rollout(top_model, top_datas, initial_states, nstep=nstep, persistent_pool=True) # Creates a pool\n", "rollout.rollout(top_model, top_datas, initial_states, nstep=nstep, persistent_pool=True) # Reuses the previously created pool\n", "rollout.shutdown_persistent_pool() # Shutdown the pool manually when finished" ] }, { "cell_type": "markdown", "id": "144378d3", "metadata": { "id": "144378d3" }, "source": [ "Second, if `rollout` reuses the same threadpool between calls, it is no longer safe to call `rollout` from multiple threads. For example the following is not allowed (the offending lines are commented out to avoid crashing the interpreter):" ] }, { "cell_type": "code", "execution_count": 0, "id": "7f46a6d8", "metadata": { "id": "7f46a6d8" }, "outputs": [], "source": [ "thread1 = threading.Thread(target=lambda: rollout.rollout(top_model, top_datas, initial_states, nstep=nstep, persistent_pool=True))\n", "thread2 = threading.Thread(target=lambda: rollout.rollout(top_model, top_datas, initial_states, nstep=nstep, persistent_pool=True))\n", "\n", "thread1.start()\n", "#thread2.start() # Do not do this! rollout will be using the same persistent threadpool from two threads and may crash the interpreter\n", "thread1.join()\n", "#thread2.join()\n", "rollout.shutdown_persistent_pool()" ] }, { "cell_type": "markdown", "id": "78c1f864-5238-4e27-a7ec-d03c45484d9a", "metadata": { "id": "78c1f864-5238-4e27-a7ec-d03c45484d9a" }, "source": [ "## chunk_size\n", "\n", "To minimize communication overhead, `rollout` distributes rollouts to threads in groups of rollouts called chunks. By default, `max(1, 0.1 * (nbatch / nthread))` rollouts are assigned to each chunk. While this chunking rule works well for most workloads it is not always optimal, especially when doing short rollouts with small models.\n", "\n", "Below we plot the steps per second versus chunk_size when running 1000 hoppers for 1 step each. In his case, the default chunk_size turns out to be quite a bit slower than using an increased chunk size." ] }, { "cell_type": "code", "execution_count": 0, "id": "a1be8f93", "metadata": { "id": "a1be8f93" }, "outputs": [], "source": [ "nbatch = 100\n", "nstep = 1\n", "ntiming = 20\n", "\n", "# Load model\n", "hopper_model = mujoco.MjModel.from_xml_path(hopper_path)\n", "hopper_data = mujoco.MjData(hopper_model)\n", "hopper_datas = [copy.copy(hopper_data) for _ in range(nthread)]\n", "\n", "# Get initial states\n", "initial_states = get_state(hopper_model, hopper_data, nbatch)\n", "\n", "def rollout_chunk_size(chunk_size=None):\n", " rollout.rollout(hopper_model, hopper_datas, initial_states, nstep=nstep, chunk_size=chunk_size)\n", "\n", "# Rollout with different chunk sizes\n", "default_chunk_size = int(max(1.0, 0.1 * nbatch / nthread))\n", "chunk_sizes = sorted([1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, default_chunk_size])\n", "t_chunk_size = benchmark(lambda x: rollout_chunk_size(x), chunk_sizes, ntiming=ntiming)\n", "\n", "# Get optimal chunk size\n", "steps_per_second = nbatch * nstep / t_chunk_size\n", "default_index = [i for i, c in enumerate(chunk_sizes) if c == default_chunk_size][0]\n", "optimal_index = np.argmax(steps_per_second)\n", "plt.loglog(chunk_sizes, steps_per_second, color='b')\n", "plt.plot(chunk_sizes[default_index], steps_per_second[default_index], marker='o', color='r', label='default chunk size')\n", "plt.plot(chunk_sizes[optimal_index], steps_per_second[optimal_index], marker='o', color='g', label='optimal chunk size')\n", "plt.ylabel('steps per second')\n", "plt.xlabel('chunk size')\n", "ticker = matplotlib.ticker.FuncFormatter(lambda x, p: format(int(x), ','))\n", "plt.gca().yaxis.set_minor_formatter(ticker)\n", "plt.legend()\n", "plt.grid(True, which=\"both\", axis=\"both\")\n", "\n", "print(f'default chunk size: {default_chunk_size} \\t steps per second: {steps_per_second[default_index]:0.1f}')\n", "print(f'optimal chunk size: {chunk_sizes[optimal_index]} \\t steps per second: {steps_per_second[optimal_index]:0.1f}')" ] }, { "cell_type": "markdown", "id": "84c746e0-5d2e-47dc-ac76-f2b6f790b7c7", "metadata": { "id": "84c746e0-5d2e-47dc-ac76-f2b6f790b7c7" }, "source": [ "## Warmstarting\n", "\n", "The `initial_warmstart` parameter can be used to warmstart the constraint solver as described in the [computation chapter](https://mujoco.readthedocs.io/en/stable/computation/index.html#warmstart-acceleration) of the documentation. This can be useful when rolling out models in chunks of steps. Without warmstarting, chaotic systems involving multi-body contact may diverge.\n", "\n", "Below we demonstrate this with the tippe top model where the contact solver was changed to CG. This makes the contact force calculation a less repeatable than if the default, Newton's method, were used and allows demonstrating the benefits of warmstarting.\n", "\n", "The simulation is run three times. Once with a 6000 step rollout, once with 100 chunks of 60 steps with warmstarting, and once more in 100 chunks of 60 steps without warmstarting." ] }, { "cell_type": "code", "execution_count": 0, "id": "d4d9f660-f83c-432e-a579-124a7ecab4fb", "metadata": { "id": "d4d9f660-f83c-432e-a579-124a7ecab4fb" }, "outputs": [], "source": [ "top_model_cg = copy.copy(top_model)\n", "\n", "# Change to CG solver, the Newton solver converges too well for\n", "# warmstarting to have an appreciable effect\n", "top_model_cg.opt.solver = mujoco.mjtSolver.mjSOL_CG\n", "\n", "chunks = 100\n", "steps_per_chunk = 60\n", "nstep = steps_per_chunk*chunks\n", "\n", "# Get initial states\n", "top_data_cg = mujoco.MjData(top_model_cg)\n", "mujoco.mj_resetDataKeyframe(top_model_cg, top_data_cg, 0)\n", "initial_state = get_state(top_model_cg, top_data_cg)\n", "\n", "start = time.time()\n", "# Rollout with nstep steps\n", "state_all, _ = rollout.rollout(top_model_cg, top_data_cg, initial_state, nstep=nstep)\n", "\n", "# Rollout in chunks with warmstarting\n", "state_chunks = []\n", "state_chunk, _ = rollout.rollout(top_model_cg, top_data_cg, initial_state, nstep=steps_per_chunk)\n", "state_chunks.append(state_chunk)\n", "for _ in range(chunks-1):\n", " state_chunk, _ = rollout.rollout(top_model_cg, top_data_cg, state_chunks[-1][0, -1, :],\n", " nstep=steps_per_chunk, initial_warmstart=top_data_cg.qacc_warmstart)\n", " state_chunks.append(state_chunk)\n", "state_all_chunked_warmstart = np.concatenate(state_chunks, axis=1)\n", "\n", "# Rollout in chunks without warmstarting\n", "state_chunks = []\n", "state_chunk, _ = rollout.rollout(top_model_cg, top_data_cg, initial_state, nstep=steps_per_chunk)\n", "state_chunks.append(state_chunk)\n", "first_warmstart = None\n", "for i in range(chunks-1):\n", " state_chunk, _ = rollout.rollout(top_model_cg, top_data_cg, state_chunks[-1][0, -1, :], nstep=steps_per_chunk)\n", " state_chunks.append(state_chunk)\n", "state_all_chunked = np.concatenate(state_chunks, axis=1)\n", "end = time.time()\n", "\n", "# Render the rollouts\n", "start_render = time.time()\n", "framerate = 60\n", "state_render = np.concatenate((state_all, state_all_chunked, state_all_chunked_warmstart), axis=0)\n", "camera = 'distant'\n", "frames1 = render_many(top_model_cg, top_data_cg, state_all, framerate, shape=(240, 320), transparent=False, camera=camera)\n", "frames2 = render_many(top_model_cg, top_data_cg, state_all_chunked_warmstart, framerate, shape=(240, 320), transparent=False, camera=camera)\n", "frames3 = render_many(top_model_cg, top_data_cg, state_all_chunked, framerate, shape=(240, 320), transparent=False, camera=camera)\n", "media.show_video(np.concatenate((frames1, frames2, frames3), axis=2))\n", "end_render = time.time()\n", "\n", "print(f'Rollout took {end-start:.1f} seconds')\n", "print(f'Rendering took {end_render-start_render:.1f} seconds')" ] }, { "cell_type": "markdown", "id": "7c2cf4fa", "metadata": { "id": "7c2cf4fa" }, "source": [ "As expected, the middle animation (with warmstarting) matches the continuous rollout on the left. However, the model that did not use warmstarting diverged." ] }, { "cell_type": "markdown", "id": "7944637f", "metadata": { "id": "7944637f" }, "source": [ "# Benchmarks\n", "\n", "The `rollout.rollout` function in the `mujoco` Python library runs batches of simulations for a fixed number steps. It can run in single or multi-threaded modes. The speedup over pure Python is significant because `rollout` can be easily configured to use multithreading.\n", "\n", "To show the speedup, we will run benchmarks with the \"tippe top\", \"humanoid\", and \"humanoid100\" models.\n", "\n", "## Python rollouts versus `rollout`\n", "\n", "The benchmark runs the three models with varying batch and step counts.\n", "\n", "The Python code for nbatch rollouts of nstep steps is:" ] }, { "cell_type": "code", "execution_count": 0, "id": "cb6355dd", "metadata": { "id": "cb6355dd" }, "outputs": [], "source": [ "def python_rollout(model, data, nbatch, nstep):\n", " for i in range(nbatch):\n", " for i in range(nstep):\n", " mujoco.mj_step(model, data)" ] }, { "cell_type": "markdown", "id": "6fe4a78b", "metadata": { "id": "6fe4a78b" }, "source": [ "To run nbatch rollouts with `rollout`, we need to make an array of nbatch initial states to start the rollouts from.\n", "\n", "\n", "Additionally, to use `rollout`'s parallelism, we must pass one `MjData` per thread.\n", "\n", "The resulting `rollout` call parameterized by `nbatch`, `nstep`, and `nthread` is:" ] }, { "cell_type": "code", "execution_count": 0, "id": "74f143e2", "metadata": { "id": "74f143e2" }, "outputs": [], "source": [ "def nthread_rollout(model, data, nbatch, nstep, nthread, rollout_):\n", " rollout_.rollout([model]*nbatch,\n", " [copy.copy(data) for _ in range(nthread)], # Create one MjData per thread\n", " np.tile(get_state(model, data), (nbatch, 1)), # Tile the initial condition nbatch times\n", " nstep=nstep,\n", " skip_checks=True)" ] }, { "cell_type": "markdown", "id": "b75dc44c", "metadata": { "id": "b75dc44c" }, "source": [ "Next, we benchmark the Python loop and `rollout` in both single threaded and multithreaded modes. The three benchmarks take about 2.5 minutes in total to run in total on an AMD 5800X3D." ] }, { "cell_type": "code", "execution_count": 0, "id": "0301e3ee", "metadata": { "cellView": "form", "id": "0301e3ee" }, "outputs": [], "source": [ "#@title Benchmarking utilities\n", "\n", "top_model = mujoco.MjModel.from_xml_string(tippe_top)\n", "def init_top(model):\n", " data = mujoco.MjData(model)\n", " # Set to the state to a spinning top (keyframe 0)\n", " mujoco.mj_resetDataKeyframe(model, data, 0)\n", " return data\n", "\n", "# Create and initialize humanoid model\n", "# Step for 2 seconds to get a stable set of contacts to benchmark\n", "humanoid_model = mujoco.MjModel.from_xml_path(humanoid_path)\n", "humanoid_data = mujoco.MjData(humanoid_model)\n", "humanoid_data.qvel[2] = 4 # Make the humanoid jump\n", "while humanoid_data.time < 2.0:\n", " mujoco.mj_step(humanoid_model, humanoid_data)\n", "humanoid_initial_state = get_state(humanoid_model, humanoid_data)\n", "def init_humanoid(model):\n", " data = mujoco.MjData(model)\n", " mujoco.mj_setState(model, data, humanoid_initial_state.flatten(),\n", " mujoco.mjtState.mjSTATE_FULLPHYSICS)\n", " return data\n", "\n", "# Create and initialize humanoid100 model\n", "# Step for 4 seconds to get a stable set of contacts to benchmark\n", "humanoid100_model = mujoco.MjModel.from_xml_path(humanoid100_path)\n", "humanoid100_data = mujoco.MjData(humanoid100_model)\n", "while humanoid100_data.time < 4.0:\n", " mujoco.mj_step(humanoid100_model, humanoid100_data)\n", "humanoid100_initial_state = get_state(humanoid100_model, humanoid100_data)\n", "def init_humanoid100(model):\n", " data = mujoco.MjData(model)\n", " mujoco.mj_setState(model, data, humanoid100_initial_state.flatten(),\n", " mujoco.mjtState.mjSTATE_FULLPHYSICS)\n", " return data\n", "\n", "def benchmark_rollout(model, init_model, nbatch, nstep, nominal_nbatch, nominal_nstep, ntiming=1):\n", " print('Benchmarking pure python', end='\\r')\n", " start = time.time()\n", " t_python_nbatch = benchmark(lambda x, data: python_rollout(model, data, x, nominal_nstep), nbatch, ntiming,\n", " f_init=lambda x: init_model(model))\n", " t_python_nstep = benchmark(lambda x, data: python_rollout(model, data, nominal_nbatch, x), nstep, ntiming,\n", " f_init=lambda x: init_model(model))\n", " end = time.time()\n", " print(f'Benchmarking pure python took {end-start:0.1f} seconds')\n", "\n", " print('Benchmarking single threaded rollout', end='\\r')\n", " with rollout.Rollout(nthread=0) as rollout_:\n", " start = time.time()\n", " t_rollout_single_nbatch = benchmark(lambda x, data: nthread_rollout(model, data, x, nominal_nstep, nthread=1, rollout_=rollout_),\n", " nbatch, ntiming,\n", " f_init=lambda x: init_model(model))\n", " t_rollout_single_nstep = benchmark(lambda x, data: nthread_rollout(model, data, nominal_nbatch, x, nthread=1, rollout_=rollout_),\n", " nstep, ntiming, f_init=lambda x: init_model(model))\n", " end = time.time()\n", " print(f'Benchmarking single threaded rollout took {end-start:0.1f} seconds')\n", "\n", " print(f'Benchmarking multithreaded rollout using {nthread} threads', end='\\r')\n", " with rollout.Rollout(nthread=nthread) as rollout_:\n", " start = time.time()\n", " t_rollout_multi_nbatch = benchmark(lambda x, data: nthread_rollout(model, data, x, nominal_nstep, nthread, rollout_=rollout_),\n", " nbatch, ntiming, f_init=lambda x: init_model(model))\n", " t_rollout_multi_nstep = benchmark(lambda x, data: nthread_rollout(model, data, nominal_nbatch, x, nthread, rollout_=rollout_),\n", " nstep, ntiming, f_init=lambda x: init_model(model))\n", " end = time.time()\n", " print(f'Benchmarking multithreaded rollout using {nthread} threads took {end-start:0.1f} seconds')\n", "\n", " return (t_python_nbatch, t_rollout_single_nbatch, t_rollout_multi_nbatch,\n", " t_python_nstep, t_rollout_single_nstep, t_rollout_multi_nstep)\n", "\n", "def plot_benchmark(results, nbatch, nstep, nominal_nbatch, nominal_nstep, title):\n", " (t_python_nbatch, t_rollout_single_nbatch, t_rollout_multi_nbatch,\n", " t_python_nstep, t_rollout_single_nstep, t_rollout_multi_nstep) = results\n", "\n", " width = 0.25\n", " x = np.array([i for i in range(len(nbatch))])\n", "\n", " ticker = matplotlib.ticker.EngFormatter(unit='')\n", "\n", " fig, (ax1, ax2) = plt.subplots(1, 2, sharey=True)\n", " steps_per_t = np.array(nbatch) * nominal_nstep\n", " steps_per_t_python = steps_per_t / t_python_nbatch\n", " steps_per_t_single = steps_per_t / t_rollout_single_nbatch\n", " steps_per_t_multi = steps_per_t / t_rollout_multi_nbatch\n", " ax1.bar(x + 0*width, steps_per_t_python, width=width, label='python')\n", " ax1.bar(x + 1*width, steps_per_t_single, width=width, label='rollout single threaded')\n", " ax1.bar(x + 2*width, steps_per_t_multi, width=width, label='rollout multithreaded')\n", " ax1.set_xticks(x + width, nbatch)\n", " ax1.yaxis.set_major_formatter(ticker)\n", " ax1.grid()\n", " ax1.set_axisbelow(True)\n", " ax1.set_xlabel('nbatch')\n", " ax1.set_ylabel('steps per second')\n", " ax1.set_title(f'nbatch varied, nstep = {nominal_nstep}')\n", "\n", " x = np.array([i for i in range(len(nstep))])\n", " steps_per_t = np.array(nstep) * nominal_nbatch\n", " steps_per_t_python = steps_per_t / t_python_nstep\n", " steps_per_t_single = steps_per_t / t_rollout_single_nstep\n", " steps_per_t_multi = steps_per_t / t_rollout_multi_nstep\n", " ax2.bar(x + 0*width, steps_per_t_python, width=width, label='python')\n", " ax2.bar(x + 1*width, steps_per_t_single, width=width, label='rollout single threaded')\n", " ax2.bar(x + 2*width, steps_per_t_multi, width=width, label='rollout multithreaded')\n", " ax2.set_xticks(x + width, nstep)\n", " ax2.yaxis.set_major_formatter(ticker)\n", " ax2.grid()\n", " ax2.set_axisbelow(True)\n", " ax2.set_xlabel('nstep')\n", " ax2.set_title(f'nstep varied, nbatch = {nominal_nbatch}')\n", "\n", " ax1.legend(loc=(0.03, 0.8))\n", " fig.set_size_inches(10, 5)\n", " plt.suptitle(title)\n", " plt.tight_layout()" ] }, { "cell_type": "markdown", "id": "08fb0c12", "metadata": { "id": "08fb0c12" }, "source": [ "### Tippe Top Benchmark" ] }, { "cell_type": "code", "execution_count": 0, "id": "f7e54830", "metadata": { "id": "f7e54830" }, "outputs": [], "source": [ "nominal_nbatch = 256 # Batch size to use when testing different nstep\n", "nominal_nstep = 5 # Step count to use when testing different nbatch\n", "nbatch = [1, 256, 2048, 8192]\n", "nstep = [1, 10, 100, 1000]\n", "\n", "top_benchmark_results = benchmark_rollout(top_model, init_top,\n", " nbatch, nstep,\n", " nominal_nbatch, nominal_nstep)\n", "plot_benchmark(top_benchmark_results, nbatch, nstep,\n", " nominal_nbatch, nominal_nstep,\n", " title='Tippe Top')" ] }, { "cell_type": "markdown", "id": "edefb26e", "metadata": { "id": "edefb26e" }, "source": [ "### Humanoid Benchmark" ] }, { "cell_type": "code", "execution_count": 0, "id": "c9e58c6c", "metadata": { "id": "c9e58c6c" }, "outputs": [], "source": [ "nominal_nbatch = 256 # Batch size to use when testing different nstep\n", "nominal_nstep = 5 # Step count to use when testing different nbatch\n", "nbatch = [1, 256, 2048, 8192] # Batch sizes to benchmark\n", "nstep = [1, 10, 100, 1000] # Step counts to benchmark\n", "\n", "humanoid_benchmark_results = benchmark_rollout(humanoid_model, init_humanoid,\n", " nbatch, nstep,\n", " nominal_nbatch, nominal_nstep)\n", "plot_benchmark(humanoid_benchmark_results, nbatch, nstep,\n", " nominal_nbatch, nominal_nstep,\n", " title='Humanoid')" ] }, { "cell_type": "markdown", "id": "468903bb", "metadata": { "id": "468903bb" }, "source": [ "### Humanoid100 Benchmark" ] }, { "cell_type": "code", "execution_count": 0, "id": "83d775d4", "metadata": { "id": "83d775d4" }, "outputs": [], "source": [ "nominal_nbatch = 128 # Batch size to use when testing different nstep\n", "nominal_nstep = 5 # Step count to use when testing different nbatch\n", "nbatch = [1, 64, 128, 256] # Batch sizes to benchmark\n", "nstep = [1, 10, 100, 1000] # Step counts to benchmark\n", "\n", "humanoid100_benchmark_results = benchmark_rollout(\n", " humanoid100_model,\n", " init_humanoid100,\n", " nbatch,\n", " nstep,\n", " nominal_nbatch,\n", " nominal_nstep,\n", ")\n", "plot_benchmark(humanoid100_benchmark_results, nbatch, nstep,\n", " nominal_nbatch, nominal_nstep,\n", " title='Humanoid100')" ] }, { "cell_type": "markdown", "id": "d1133084", "metadata": { "id": "d1133084" }, "source": [ "# MJX versus `rollout`" ] }, { "cell_type": "markdown", "id": "c1638f2d", "metadata": { "id": "c1638f2d" }, "source": [ "Next we will benchmark `rollout` and MJX using the tippe top and humanoid models (humanoid100 is not supported by MJX).\n", "\n", "The next two benchmarks take about 16.5 minutes total on an AMD 5800X3D and an NVIDIA 4090. Most of the time is spent JIT compiling the MJX functions. The JIT functions are cached so that subsequent runs of the benchmark run much faster.\n", "\n", "**Note:** MJX is most useful when coupled with something else that runs best on a GPU, like a neural network. Without any such additional workload, CPU based simulation will sometimes be faster, especially when using less than state-of-the-art GPUs." ] }, { "cell_type": "code", "execution_count": 0, "id": "7c86d157", "metadata": { "cellView": "form", "id": "7c86d157" }, "outputs": [], "source": [ "#@title MJX helper functions\n", "def init_mjx_batch(model, init_model, nbatch, nstep, skip_jit=False):\n", " data = init_model(model)\n", "\n", " # Make MJX versions of model and data\n", " mjx_model = mjx.put_model(model)\n", " mjx_data = mjx.put_data(model, data)\n", "\n", " batch = jax.vmap(lambda x: mjx_data)(jp.array(list(range(nbatch))))\n", " jax.block_until_ready(batch)\n", "\n", " if not skip_jit:\n", " start = time.time()\n", " jit_step = jax.vmap(mjx.step, in_axes=(None, 0))\n", " def unroll(d, _):\n", " d = jit_step(mjx_model, d)\n", " return d, None\n", " jit_unroll = jax.jit(lambda d: jax.lax.scan(unroll, d, None, length=nstep, unroll=4)[0])\n", " jit_unroll = jit_unroll.lower(batch).compile()\n", " end = time.time()\n", " jit_time = end - start\n", " else:\n", " jit_unroll = None\n", " jit_time = 0.0\n", "\n", " return mjx_model, mjx_data, jit_unroll, batch, jit_time\n", "\n", "def mjx_rollout(batch, jit_unroll):\n", " batch = jit_unroll(batch)\n", " jax.block_until_ready(batch)\n", "\n", "def benchmark_mjx(model, init_model, nbatch, nstep, nominal_nbatch, nominal_nstep, ntiming=1, jit_unroll_cache=None):\n", " print(f'Benchmarking multithreaded rollout using {nthread} threads', end=\"\\r\")\n", " with rollout.Rollout(nthread=nthread) as rollout_:\n", " start = time.time()\n", " t_rollout_multi_nbatch = benchmark(lambda x, data: nthread_rollout(model, data, x, nominal_nstep, nthread, rollout_),\n", " nbatch, ntiming, f_init=lambda x: init_model(model))\n", " t_rollout_multi_nstep = benchmark(lambda x, data: nthread_rollout(model, data, nominal_nbatch, x, nthread, rollout_),\n", " nstep, ntiming, f_init=lambda x: init_model(model))\n", " end = time.time()\n", " print(f'Benchmarking multithreaded rollout using {nthread} threads took {end-start:0.1f} seconds')\n", "\n", " print('Running JIT for MJX', end='\\r')\n", " total_jit = 0.0\n", " if jit_unroll_cache is None:\n", " jit_unroll_cache = {}\n", " if f'nbatch_{nominal_nstep}' not in jit_unroll_cache:\n", " jit_unroll_cache[f'nbatch_{nominal_nstep}'] = {}\n", " if f'nstep_{nominal_nbatch}' not in jit_unroll_cache:\n", " jit_unroll_cache[f'nstep_{nominal_nbatch}'] = {}\n", " for n in nbatch:\n", " if n not in jit_unroll_cache[f'nbatch_{nominal_nstep}']:\n", " _, _, jit_unroll_cache[f'nbatch_{nominal_nstep}'][n], _, jit_time = init_mjx_batch(model, init_model, n, nominal_nstep)\n", " total_jit += jit_time\n", " for n in nstep:\n", " if n not in jit_unroll_cache[f'nstep_{nominal_nbatch}']:\n", " _, _, jit_unroll_cache[f'nstep_{nominal_nbatch}'][n], _, jit_time = init_mjx_batch(model, init_model, nominal_nbatch, n)\n", " total_jit += jit_time\n", " print(f'Running JIT for MJX took {total_jit:0.1f} seconds')\n", "\n", " print('Benchmarking MJX', end='\\r')\n", " start = time.time()\n", " t_mjx_nbatch = benchmark(lambda x, x_init: mjx_rollout(x_init[3], jit_unroll_cache[f'nbatch_{nominal_nstep}'][x]),\n", " nbatch, ntiming, f_init=lambda x: init_mjx_batch(model, init_model, x, nominal_nstep, skip_jit=True))\n", " t_mjx_nstep = benchmark(lambda x, x_init: mjx_rollout(x_init[3], jit_unroll_cache[f'nstep_{nominal_nbatch}'][x]),\n", " nstep, ntiming, f_init=lambda x: init_mjx_batch(model, init_model, nominal_nbatch, x, skip_jit=True))\n", " end = time.time()\n", " print(f'Benchmarking MJX took {end-start:0.1f} seconds')\n", "\n", " return t_rollout_multi_nbatch, t_rollout_multi_nstep, t_mjx_nbatch, t_mjx_nstep\n", "\n", "def plot_mjx_benchmark(results, nbatch, nstep, nominal_nbatch, nominal_nstep, title):\n", " t_rollout_multi_nbatch, t_rollout_multi_nstep, t_mjx_nbatch, t_mjx_nstep = results\n", "\n", " width = 0.333\n", " x = np.array([i for i in range(len(nbatch))])\n", "\n", " ticker = matplotlib.ticker.EngFormatter(unit='')\n", "\n", " fig, (ax1, ax2) = plt.subplots(1, 2, sharey=True)\n", " steps_per_t = np.array(nbatch) * nominal_nstep\n", " steps_per_t_mjx = steps_per_t / t_mjx_nbatch\n", " steps_per_t_multi = steps_per_t / t_rollout_multi_nbatch\n", " ax1.bar(x + 0*width, steps_per_t_mjx, width=width, label='mjx')\n", " ax1.bar(x + 1*width, steps_per_t_multi, width=width, label='rollout multithreaded')\n", " ax1.set_xticks(x + width / 2, nbatch)\n", " ax1.yaxis.set_major_formatter(ticker)\n", " ax1.grid()\n", " ax1.set_xlabel('nbatch')\n", " ax1.set_ylabel('steps per second')\n", " ax1.set_title(f'nbatch varied, nstep = {nominal_nstep}')\n", "\n", " x = np.array([i for i in range(len(nstep))])\n", " steps_per_t = np.array(nstep) * nominal_nbatch\n", " steps_per_t_mjx = steps_per_t / t_mjx_nstep\n", " steps_per_t_multi = steps_per_t / t_rollout_multi_nstep\n", " ax2.bar(x + 0*width, steps_per_t_mjx, width=width, label='mjx')\n", " ax2.bar(x + 1*width, steps_per_t_multi, width=width, label='rollout multithreaded')\n", " ax2.set_xticks(x + width / 2, nstep)\n", " ax2.yaxis.set_major_formatter(ticker)\n", " ax2.grid()\n", " ax2.set_xlabel('nstep')\n", " ax2.set_title(f'nstep varied, nbatch = {nominal_nbatch}')\n", "\n", " ax2.legend(loc=(1.04, 0.0))\n", " fig.set_size_inches(10, 4)\n", " plt.suptitle(title)\n", " plt.tight_layout()\n", "\n", "# Caches for jit_step functions, they take a long time to compile\n", "top_jit_unroll_cache = {}\n", "humanoid_jit_unroll_cache = {}" ] }, { "cell_type": "markdown", "id": "a2dafd2e", "metadata": { "id": "a2dafd2e" }, "source": [ "### MJX Tippe Top Benchmark" ] }, { "cell_type": "code", "execution_count": 0, "id": "98c580b0", "metadata": { "id": "98c580b0" }, "outputs": [], "source": [ "nominal_nbatch = 16384 # Batch size to use when testing different nstep\n", "nominal_nstep = 5 # Step count to use when testing different nbatch\n", "nbatch = [4096, 16384, 65536, 131072] # Batch sizes to benchmark\n", "nstep = [1, 10, 100, 200] # Step counts to benchmark\n", "\n", "mjx_top_results = benchmark_mjx(top_model, init_top, nbatch, nstep, nominal_nbatch, nominal_nstep,\n", " jit_unroll_cache=top_jit_unroll_cache)\n", "plot_mjx_benchmark(mjx_top_results, nbatch, nstep, nominal_nbatch, nominal_nstep, title='MJX Tippe Top')" ] }, { "cell_type": "markdown", "id": "205da5da", "metadata": { "id": "205da5da" }, "source": [ "### MJX Humanoid Benchmark" ] }, { "cell_type": "code", "execution_count": 0, "id": "53166ae1", "metadata": { "id": "53166ae1" }, "outputs": [], "source": [ "nominal_nbatch = 4096 # Batch size to use when testing different nstep\n", "nominal_nstep = 5 # Step count to use when testing different nbatch\n", "nbatch = [1024, 4096, 16384, 32768] # Batch sizes to benchmark\n", "nstep = [1, 10, 100, 200] # Step counts to benchmark\n", "\n", "mjx_humanoid_results = benchmark_mjx(humanoid_model, init_humanoid, nbatch, nstep, nominal_nbatch, nominal_nstep,\n", " jit_unroll_cache=humanoid_jit_unroll_cache)\n", "plot_mjx_benchmark(mjx_humanoid_results, nbatch, nstep, nominal_nbatch, nominal_nstep, title='MJX Humanoid')" ] }, { "cell_type": "markdown", "id": "fb2caa72", "metadata": { "id": "fb2caa72" }, "source": [ "### MJX Multiple Humanoids in one model\n", "\n", "The MJX [documentation](https://mujoco.readthedocs.io/en/stable/mjx.html#mjx-the-sharp-bits) contains a chart comparing the speed of native MuJoCo vs MJX on a variety of devices.\n", "\n", "Here we will produce a similar plot to compare MJX and with `rollout`. On a 5800X3D and 4090 the benchmark takes about 16.5 minutes to run.\n", "\n", "**Note:** These results are not directly comparable since with the plot in the documentation because, in particular, the batch size was reduced from 8192 to 4096 in order to fit the batch on a 4090." ] }, { "cell_type": "code", "execution_count": 0, "id": "3d6be608", "metadata": { "id": "3d6be608" }, "outputs": [], "source": [ "max_humanoids = 10\n", "nbatch = 8192 // 2 # The original benchmark ran with a batch size of 8192, but on a 4090 we can only fit about 4096 humanoids\n", "nstep = 200\n", "\n", "jit_step = jax.vmap(mjx.step, in_axes=(None, 0))\n", "t_rollout = []\n", "t_mjx = []\n", "for i in range(1, max_humanoids+1):\n", " print(f'Running benchmark on {i} humanoids')\n", " nhumanoid_model = mujoco.MjModel.from_xml_path(\n", " f'mujoco/mjx/mujoco/mjx/test_data/humanoid/{i:02d}_humanoids.xml'\n", " )\n", " nhumanoid_data = mujoco.MjData(nhumanoid_model)\n", "\n", " mjx_model = mjx.put_model(nhumanoid_model)\n", " mjx_data = mjx.put_data(nhumanoid_model, nhumanoid_data)\n", " batch = jax.vmap(lambda x: mjx_data)(jp.array(list(range(nbatch))))\n", " jax.block_until_ready(batch)\n", "\n", " with rollout.Rollout(nthread=nthread) as rollout_:\n", " initial_state = get_state(nhumanoid_model, nhumanoid_data, nbatch)\n", " start = time.perf_counter()\n", " rollout_.rollout([nhumanoid_model]*nbatch,\n", " [copy.copy(nhumanoid_data) for _ in range(nthread)],\n", " initial_state=initial_state,\n", " nstep=nstep, skip_checks=True)\n", " end = time.perf_counter()\n", " t_rollout.append(end-start)\n", "\n", " # Trigger JIT for model/batch so as not to include JIT time in benchmarking information\n", " def unroll(d, _):\n", " d = jit_step(mjx_model, d)\n", " return d, None\n", " jit_unroll = jax.jit(lambda d: jax.lax.scan(unroll, d, None, length=nstep, unroll=4)[0])\n", " jit_unroll = jit_unroll.lower(batch).compile()\n", "\n", " start = time.perf_counter()\n", " jit_unroll(batch)\n", " jax.block_until_ready(batch)\n", " end = time.perf_counter()\n", " t_mjx.append(end-start)" ] }, { "cell_type": "code", "execution_count": 0, "id": "b6c5fc2e", "metadata": { "id": "b6c5fc2e" }, "outputs": [], "source": [ "#@title Plot MJX nhumanoid benchmark\n", "\n", "def plot_mjx_nhumanoid_benchmark(t_rollout, t_mjx, nbatch, nstep, max_humanoids):\n", " nhumanoids = [i for i in range(1, max_humanoids+1)]\n", "\n", " width = 0.333\n", " x = np.array([i for i in range(len(nhumanoids))])\n", "\n", " ticker = matplotlib.ticker.EngFormatter(unit='')\n", "\n", " fig, ax1 = plt.subplots(1, 1, sharey=True)\n", " steps_per_t = nbatch * nstep\n", " steps_per_t_mjx = steps_per_t / np.array(t_mjx)\n", " steps_per_t_multi = steps_per_t / np.array(t_rollout)\n", " ax1.bar(x + 0*width, steps_per_t_mjx, width=width, label='mjx')\n", " ax1.bar(x + 1*width, steps_per_t_multi, width=width, label='rollout multithreaded')\n", " ax1.set_xticks(x + width / 2, nhumanoids)\n", " ax1.yaxis.set_major_formatter(ticker)\n", " ax1.set_yscale('log')\n", " ax1.grid()\n", " ax1.set_xlabel('number of humanoids')\n", " ax1.set_ylabel('steps per second')\n", " ax1.set_title(f'nhumanoids varied, nbatch = {nbatch}, nstep = {nstep}')\n", "\n", " ax1.legend(loc=(1.04, 0.0))\n", " fig.set_size_inches(8, 4)\n", " plt.tight_layout()\n", "\n", "plot_mjx_nhumanoid_benchmark(t_rollout, t_mjx, nbatch, nstep, max_humanoids)" ] }, { "cell_type": "code", "execution_count": 0, "id": "UW0aoKXK7ALd", "metadata": { "id": "UW0aoKXK7ALd" }, "outputs": [], "source": [] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [ "fc69d0f4", "d32a77b5-24bd-4d17-80ac-15cc4d03731c", "9b3e14a1-71f3-430d-a3c2-1aadcf6c2671", "78c1f864-5238-4e27-a7ec-d03c45484d9a", "84c746e0-5d2e-47dc-ac76-f2b6f790b7c7", "a2dafd2e", "205da5da" ], "gpuType": "A100", "private_outputs": true, "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }