Spaces:
Runtime error
Runtime error
| import jax | |
| import jax.numpy as jnp | |
| import jax.random | |
| from jax2d.engine import PhysicsEngine | |
| from matplotlib import pyplot as plt | |
| from kinetix.environment.env import make_kinetix_env_from_args | |
| from kinetix.environment.env_state import StaticEnvParams, EnvParams | |
| from kinetix.environment.ued.distributions import sample_kinetix_level | |
| from kinetix.environment.ued.ued_state import UEDParams | |
| from kinetix.render.renderer_pixels import make_render_pixels | |
| def main(): | |
| # Use default parameters | |
| env_params = EnvParams() | |
| static_env_params = StaticEnvParams() | |
| ued_params = UEDParams() | |
| # Create the environment | |
| env = make_kinetix_env_from_args( | |
| obs_type="pixels", action_type="continuous", reset_type="replay", static_env_params=static_env_params | |
| ) | |
| # Sample a random level | |
| rng = jax.random.PRNGKey(0) | |
| rng, _rng = jax.random.split(rng) | |
| level = sample_kinetix_level(_rng, env.physics_engine, env_params, static_env_params, ued_params) | |
| # Reset the environment state to this level | |
| rng, _rng = jax.random.split(rng) | |
| obs, env_state = env.reset_to_level(_rng, level, env_params) | |
| # Take a step in the environment | |
| rng, _rng = jax.random.split(rng) | |
| action = env.action_space(env_params).sample(_rng) | |
| rng, _rng = jax.random.split(rng) | |
| obs, env_state, reward, done, info = env.step(_rng, env_state, action, env_params) | |
| # Render environment | |
| renderer = make_render_pixels(env_params, static_env_params) | |
| # There are a lot of wrappers | |
| pixels = renderer(env_state.env_state.env_state.env_state) | |
| plt.imshow(pixels.astype(jnp.uint8).transpose(1, 0, 2)[::-1]) | |
| plt.show() | |
| if __name__ == "__main__": | |
| main() | |