Spaces:
Runtime error
Runtime error
| from functools import partial | |
| import jax | |
| import jax.numpy as jnp | |
| import numpy as np | |
| from jax2d import joint | |
| from jax2d.engine import select_shape | |
| from jax2d.maths import rmat | |
| from jax2d.sim_state import RigidBody | |
| from jaxgl.maths import dist_from_line | |
| from jaxgl.renderer import clear_screen, make_renderer | |
| from jaxgl.shaders import ( | |
| fragment_shader_quad, | |
| fragment_shader_edged_quad, | |
| make_fragment_shader_texture, | |
| nearest_neighbour, | |
| make_fragment_shader_quad_textured, | |
| ) | |
| from kinetix.render.renderer_symbolic_common import ( | |
| make_circle_features, | |
| make_joint_features, | |
| make_polygon_features, | |
| make_thruster_features, | |
| ) | |
| from kinetix.environment.env_state import StaticEnvParams, EnvParams, EnvState | |
| from flax import struct | |
| def make_render_symbolic(params, static_params: StaticEnvParams): | |
| def render_symbolic(state): | |
| n_polys = static_params.num_polygons | |
| nshapes = n_polys + static_params.num_circles | |
| polygon_features, polygon_mask = make_polygon_features(state, params, static_params) | |
| mask_to_ignore_walls_ceiling = np.ones(static_params.num_polygons, dtype=bool) | |
| mask_to_ignore_walls_ceiling[np.array([1, 2, 3])] = False | |
| polygon_features = polygon_features[mask_to_ignore_walls_ceiling] | |
| polygon_mask = polygon_mask[mask_to_ignore_walls_ceiling] | |
| circle_features, circle_mask = make_circle_features(state, params, static_params) | |
| joint_features, joint_idxs, joint_mask = make_joint_features(state, params, static_params) | |
| thruster_features, thruster_idxs, thruster_mask = make_thruster_features(state, params, static_params) | |
| two_J = joint_features.shape[0] | |
| J = two_J // 2 # for symbolic only have the one | |
| joint_features = jnp.concatenate( | |
| [ | |
| joint_features[:J], # shape (2 * J, K) | |
| jax.nn.one_hot(joint_idxs[:J, 0], nshapes), # shape (2 * J, N) | |
| jax.nn.one_hot(joint_idxs[:J, 1], nshapes), # shape (2 * J, N) | |
| ], | |
| axis=1, | |
| ) | |
| thruster_features = jnp.concatenate( | |
| [ | |
| thruster_features, | |
| jax.nn.one_hot(thruster_idxs, nshapes), | |
| ], | |
| axis=1, | |
| ) | |
| polygon_features = jnp.where(polygon_mask[:, None], polygon_features, 0.0).flatten() | |
| circle_features = jnp.where(circle_mask[:, None], circle_features, 0.0).flatten() | |
| joint_features = jnp.where(joint_mask[:J, None], joint_features, 0.0).flatten() | |
| thruster_features = jnp.where(thruster_mask[:, None], thruster_features, 0.0).flatten() | |
| def _get_manifold_features(manifold): | |
| collision_mask_features = jnp.concatenate( | |
| [ | |
| manifold.normal, | |
| jnp.expand_dims(manifold.penetration, axis=-1), | |
| manifold.collision_point, | |
| jnp.expand_dims(manifold.acc_impulse_normal, axis=-1), | |
| jnp.expand_dims(manifold.acc_impulse_tangent, axis=-1), | |
| ], | |
| axis=-1, | |
| ) | |
| return (collision_mask_features * manifold.active[..., None]).flatten() | |
| obs = jnp.concatenate( | |
| [ | |
| polygon_features, | |
| circle_features, | |
| joint_features, | |
| thruster_features, | |
| jnp.array([state.gravity[1]]) / 10, | |
| # _get_manifold_features(state.acc_cc_manifolds), | |
| # _get_manifold_features(state.acc_cr_manifolds), | |
| # _get_manifold_features(state.acc_rr_manifolds), | |
| ], | |
| axis=0, | |
| ) | |
| obs = jnp.clip(obs, a_min=-10.0, a_max=10.0) | |
| obs = jnp.nan_to_num(obs) | |
| return obs | |
| return render_symbolic | |