| """Evaluation: run a trained diffusion planner with MPC + historical inpainting.""" |
|
|
| from __future__ import annotations |
|
|
| import time |
| from typing import Any |
|
|
| import jax |
| import jax.numpy as jnp |
| import numpy as np |
| import wandb |
| from craftax.craftax_env import make_craftax_env_from_name |
| from craftax.craftax.constants import Achievement as FullCraftaxAchievements |
| from craftax.craftax_classic.constants import Achievement as ClassicAchievements |
|
|
| from src.diffusion.sampling import sample_plan_inpainting |
| from .model import build_model, load_checkpoint, make_apply_fns |
|
|
|
|
| |
| |
| |
|
|
| def run_inference(config: dict[str, Any]) -> None: |
| env_name = config["ENV_NAME"] |
| env = make_craftax_env_from_name(env_name, auto_reset=True) |
| env_params = env.default_params |
| num_actions = env.action_space(env_params).n |
| obs_dim = env.observation_space(env_params).shape[0] |
| config["NUM_ACTIONS"] = num_actions |
|
|
| num_envs = config.get("EVAL_NUM_ENVS", 32) |
| plan_horizon = config["PLAN_HORIZON"] |
| diffusion_steps = config.get("DIFFUSION_STEPS_EVAL", 10) |
| temperature = config.get("TEMPERATURE", 0.5) |
| top_p = config.get("TOP_P", 0.95) |
| eval_steps = int(float(config.get("EVAL_STEPS", 10000))) |
|
|
| model = build_model(config, num_actions) |
| apply_eval, _ = make_apply_fns(model) |
|
|
| rng = jax.random.PRNGKey(config["SEED"]) |
| rng, ckpt_rng = jax.random.split(rng) |
| model_params = load_checkpoint(model, ckpt_rng, obs_dim, plan_horizon, config["CHECKPOINT_PATH"]) |
| env_indices = jnp.arange(num_envs) |
|
|
| @jax.jit |
| def mpc_step(carry, _step_idx): |
| obs, state, rng, history, hist_len = carry |
| rng, plan_rng, env_rng = jax.random.split(rng, 3) |
|
|
| |
| seq_full = hist_len >= plan_horizon |
| hist_len = jnp.where(seq_full, 0, hist_len) |
| history = jnp.where(seq_full[:, None], num_actions, history) |
|
|
| plan = sample_plan_inpainting( |
| apply_eval, model_params, plan_rng, obs, |
| history, hist_len, num_actions, plan_horizon, |
| diffusion_steps, temperature, top_p, |
| ) |
|
|
| action = jnp.take_along_axis(plan, hist_len[:, None], axis=-1).squeeze(-1) |
| history = history.at[env_indices, hist_len].set(action) |
| hist_len = hist_len + 1 |
|
|
| obs_next, state_next, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0, None))( |
| jax.random.split(env_rng, num_envs), state, action, env_params, |
| ) |
|
|
| hist_len = jnp.where(done, 0, hist_len) |
| history = jnp.where(done[:, None], num_actions, history) |
| return (obs_next, state_next, rng, history, hist_len), (reward, done, state_next.achievements) |
|
|
| print(f"\nRunning {num_envs} agents in {env_name} for {eval_steps} steps...") |
|
|
| rng, env_rng = jax.random.split(rng) |
| obs, state = jax.vmap(env.reset, in_axes=(0, None))( |
| jax.random.split(env_rng, num_envs), env_params, |
| ) |
| history = jnp.full((num_envs, plan_horizon), num_actions, dtype=jnp.int32) |
| hist_len = jnp.zeros(num_envs, dtype=jnp.int32) |
|
|
| t0 = time.time() |
| _, (rewards, dones, achievements) = jax.lax.scan( |
| mpc_step, (obs, state, rng, history, hist_len), jnp.arange(eval_steps), |
| ) |
| elapsed = time.time() - t0 |
|
|
| |
| rewards_np = np.array(rewards) |
| dones_np = np.array(dones) |
| ach_np = np.array(achievements) |
|
|
| ep_rewards = np.zeros(num_envs) |
| ep_ach = np.zeros((num_envs, ach_np.shape[2])) |
| ep_lengths = np.zeros(num_envs, dtype=int) |
|
|
| for i in range(num_envs): |
| death = np.where(dones_np[:, i])[0] |
| end = death[0] if len(death) > 0 else eval_steps - 1 |
| ep_rewards[i] = rewards_np[:end + 1, i].sum() |
| ep_ach[i] = ach_np[:end + 1, i].max(axis=0) |
| ep_lengths[i] = end + 1 |
|
|
| pct = ep_ach.mean(axis=0) * 100.0 |
|
|
| |
| print(f"\n{'=' * 50}") |
| print(f"EVALUATION COMPLETE ({elapsed:.1f}s)") |
| print(f"{'=' * 50}") |
| print(f"Average Score: {ep_rewards.mean():.1f} | Best: {ep_rewards.max():.1f}") |
|
|
| ach_cls = ClassicAchievements if "Classic" in env_name else FullCraftaxAchievements |
| ach_names = [(a.name.replace("_", " ").title(), a.name.lower()) for a in ach_cls] |
| valid = [i for i, p in enumerate(pct) if p > 0] |
| top_idx = max(valid) if valid else 5 |
|
|
| for i in range(top_idx + 1): |
| name, _ = ach_names[i] |
| count = int(pct[i] / 100.0 * num_envs) |
| icon = "+" if count > 0 else "-" |
| print(f" [{icon}] {name}: {count}/{num_envs}") |
| print(f"{'=' * 50}") |
|
|
| if config.get("USE_WANDB", True): |
| wandb.init( |
| project=config.get("WANDB_PROJECT", "remdm-craftax"), |
| name=f"Eval-T{temperature}-P{top_p}", |
| config=config, job_type="evaluation", |
| ) |
| summary = {"eval/average_score": float(ep_rewards.mean())} |
| for i in range(top_idx + 1): |
| summary[f"eval/achievements/{ach_names[i][1]}"] = pct[i] |
| wandb.log(summary) |
|
|
| table = wandb.Table(columns=["Agent", "Score", "Achievements", "Lifespan"]) |
| unlocked = ep_ach.sum(axis=-1) |
| for i in range(num_envs): |
| table.add_data(f"Agent {i + 1}", float(ep_rewards[i]), int(unlocked[i]), int(ep_lengths[i])) |
| wandb.log({"Individual Results": table}) |
| wandb.finish() |
|
|