"""Craftax environment construction and trajectory data structures.""" from __future__ import annotations import jax.numpy as jnp from craftax.craftax_env import make_craftax_env_from_name from typing import NamedTuple from Craftax_Baselines.wrappers import ( AutoResetEnvWrapper, BatchEnvWrapper, LogWrapper, OptimisticResetVecEnvWrapper, ) class Transition(NamedTuple): done: jnp.ndarray action: jnp.ndarray reward: jnp.ndarray obs: jnp.ndarray info: dict def make_env(config: dict, num_envs: int): """Build a wrapped Craftax environment. Args: config: Upper-cased config dict with ``ENV_NAME``, ``USE_OPTIMISTIC_RESETS``, ``OPTIMISTIC_RESET_RATIO``. num_envs: Number of parallel environments. Returns: Tuple of ``(env, env_params)``. """ env = make_craftax_env_from_name(config["ENV_NAME"], not config["USE_OPTIMISTIC_RESETS"]) env_params = env.default_params env = LogWrapper(env) if config["USE_OPTIMISTIC_RESETS"]: env = OptimisticResetVecEnvWrapper( env, num_envs=num_envs, reset_ratio=min(config["OPTIMISTIC_RESET_RATIO"], num_envs), ) else: env = AutoResetEnvWrapper(env) env = BatchEnvWrapper(env, num_envs=num_envs) return env, env_params