Spaces:
Runtime error
Runtime error
| from cmath import rect | |
| from functools import partial | |
| import jax | |
| import jax.numpy as jnp | |
| from flax import struct | |
| from jax2d.engine import get_pairwise_interaction_indices | |
| from kinetix.environment.env_state import EnvState | |
| from kinetix.render.renderer_symbolic_common import ( | |
| make_circle_features, | |
| make_joint_features, | |
| make_polygon_features, | |
| make_thruster_features, | |
| make_unified_shape_features, | |
| ) | |
| class EntityObservation: | |
| circles: jnp.ndarray | |
| polygons: jnp.ndarray | |
| joints: jnp.ndarray | |
| thrusters: jnp.ndarray | |
| circle_mask: jnp.ndarray | |
| polygon_mask: jnp.ndarray | |
| joint_mask: jnp.ndarray | |
| thruster_mask: jnp.ndarray | |
| attention_mask: jnp.ndarray | |
| # collision_mask: jnp.ndarray | |
| joint_indexes: jnp.ndarray | |
| thruster_indexes: jnp.ndarray | |
| def make_render_entities(params, static_params): | |
| _, _, _, circle_circle_pairs, circle_rect_pairs, rect_rect_pairs = get_pairwise_interaction_indices(static_params) | |
| circle_rect_pairs = circle_rect_pairs.at[:, 0].add(static_params.num_polygons) | |
| circle_circle_pairs = circle_circle_pairs + static_params.num_polygons | |
| def render_entities(state: EnvState): | |
| state = jax.tree_util.tree_map(lambda x: jnp.nan_to_num(x), state) | |
| joint_features, joint_indexes, joint_mask = make_joint_features(state, params, static_params) | |
| thruster_features, thruster_indexes, thruster_mask = make_thruster_features(state, params, static_params) | |
| poly_nodes, poly_mask = make_polygon_features(state, params, static_params) | |
| circle_nodes, circle_mask = make_circle_features(state, params, static_params) | |
| def _add_grav(nodes): | |
| return jnp.concatenate( | |
| [nodes, jnp.zeros((nodes.shape[0], 1)) + state.gravity[1] / 10], axis=-1 | |
| ) # add gravity to each shape's embedding | |
| poly_nodes = _add_grav(poly_nodes) | |
| circle_nodes = _add_grav(circle_nodes) | |
| # Shape of something like (NPoly + NCircle + 2 * NJoint + NThruster ) | |
| mask_flat_shapes = jnp.concatenate([poly_mask, circle_mask], axis=0) | |
| num_shapes = static_params.num_polygons + static_params.num_circles | |
| def make_n_squared_mask(val): | |
| # val has shape N of bools. | |
| N = val.shape[0] | |
| A = jnp.eye(N, N, dtype=bool) # also have things attend to themselves | |
| # Make the shapes fully connected | |
| full_mask = A.at[:num_shapes, :num_shapes].set(jnp.ones((num_shapes, num_shapes), dtype=bool)) | |
| one_hop_connected = jnp.zeros((N, N), dtype=bool) | |
| one_hop_connected = one_hop_connected.at[joint_indexes[:, 0], joint_indexes[:, 1]].set(True) | |
| one_hop_connected = one_hop_connected.at[0, 0].set(False) # invalid joints have indices of (0, 0) | |
| multi_hop_connected = jnp.logical_not(state.collision_matrix) | |
| collision_mask = state.collision_matrix | |
| # where val is false, we want to mask out the row and column. | |
| full_mask = full_mask & (val[:, None]) & (val[None, :]) | |
| collision_mask = collision_mask & (val[:, None]) & (val[None, :]) | |
| multi_hop_connected = multi_hop_connected & (val[:, None]) & (val[None, :]) | |
| one_hop_connected = one_hop_connected & (val[:, None]) & (val[None, :]) | |
| collision_manifold_mask = jnp.zeros_like(collision_mask) | |
| def _set(collision_manifold_mask, pairs, active): | |
| return collision_manifold_mask.at[ | |
| pairs[:, 0], | |
| pairs[:, 1], | |
| ].set(active) | |
| collision_manifold_mask = _set( | |
| collision_manifold_mask, | |
| rect_rect_pairs, | |
| jnp.logical_or(state.acc_rr_manifolds.active[..., 0], state.acc_rr_manifolds.active[..., 1]), | |
| ) | |
| collision_manifold_mask = _set(collision_manifold_mask, circle_rect_pairs, state.acc_cr_manifolds.active) | |
| collision_manifold_mask = _set(collision_manifold_mask, circle_circle_pairs, state.acc_cc_manifolds.active) | |
| collision_manifold_mask = collision_manifold_mask & (val[:, None]) & (val[None, :]) | |
| return jnp.concatenate( | |
| [full_mask[None], multi_hop_connected[None], one_hop_connected[None], collision_manifold_mask[None]], | |
| axis=0, | |
| ) | |
| mask_n_squared = make_n_squared_mask(mask_flat_shapes) | |
| return EntityObservation( | |
| circles=circle_nodes, | |
| polygons=poly_nodes, | |
| joints=joint_features, | |
| thrusters=thruster_features, | |
| circle_mask=circle_mask, | |
| polygon_mask=poly_mask, | |
| joint_mask=joint_mask, | |
| thruster_mask=thruster_mask, | |
| attention_mask=mask_n_squared, | |
| joint_indexes=joint_indexes, | |
| thruster_indexes=thruster_indexes, | |
| ) | |
| return render_entities | |