File size: 5,504 Bytes
6140064 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 | """Evaluation: run a trained diffusion planner with MPC + historical inpainting."""
from __future__ import annotations
import time
from typing import Any
import jax
import jax.numpy as jnp
import numpy as np
import wandb
from craftax.craftax_env import make_craftax_env_from_name
from craftax.craftax.constants import Achievement as FullCraftaxAchievements
from craftax.craftax_classic.constants import Achievement as ClassicAchievements
from src.diffusion.sampling import sample_plan_inpainting
from .model import build_model, load_checkpoint, make_apply_fns
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
def run_inference(config: dict[str, Any]) -> None:
env_name = config["ENV_NAME"]
env = make_craftax_env_from_name(env_name, auto_reset=True)
env_params = env.default_params
num_actions = env.action_space(env_params).n
obs_dim = env.observation_space(env_params).shape[0]
config["NUM_ACTIONS"] = num_actions
num_envs = config.get("EVAL_NUM_ENVS", 32)
plan_horizon = config["PLAN_HORIZON"]
diffusion_steps = config.get("DIFFUSION_STEPS_EVAL", 10)
temperature = config.get("TEMPERATURE", 0.5)
top_p = config.get("TOP_P", 0.95)
eval_steps = int(float(config.get("EVAL_STEPS", 10000)))
model = build_model(config, num_actions)
apply_eval, _ = make_apply_fns(model)
rng = jax.random.PRNGKey(config["SEED"])
rng, ckpt_rng = jax.random.split(rng)
model_params = load_checkpoint(model, ckpt_rng, obs_dim, plan_horizon, config["CHECKPOINT_PATH"])
env_indices = jnp.arange(num_envs)
@jax.jit
def mpc_step(carry, _step_idx):
obs, state, rng, history, hist_len = carry
rng, plan_rng, env_rng = jax.random.split(rng, 3)
# Reset history when plan is exhausted
seq_full = hist_len >= plan_horizon
hist_len = jnp.where(seq_full, 0, hist_len)
history = jnp.where(seq_full[:, None], num_actions, history)
plan = sample_plan_inpainting(
apply_eval, model_params, plan_rng, obs,
history, hist_len, num_actions, plan_horizon,
diffusion_steps, temperature, top_p,
)
action = jnp.take_along_axis(plan, hist_len[:, None], axis=-1).squeeze(-1)
history = history.at[env_indices, hist_len].set(action)
hist_len = hist_len + 1
obs_next, state_next, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0, None))(
jax.random.split(env_rng, num_envs), state, action, env_params,
)
hist_len = jnp.where(done, 0, hist_len)
history = jnp.where(done[:, None], num_actions, history)
return (obs_next, state_next, rng, history, hist_len), (reward, done, state_next.achievements)
print(f"\nRunning {num_envs} agents in {env_name} for {eval_steps} steps...")
rng, env_rng = jax.random.split(rng)
obs, state = jax.vmap(env.reset, in_axes=(0, None))(
jax.random.split(env_rng, num_envs), env_params,
)
history = jnp.full((num_envs, plan_horizon), num_actions, dtype=jnp.int32)
hist_len = jnp.zeros(num_envs, dtype=jnp.int32)
t0 = time.time()
_, (rewards, dones, achievements) = jax.lax.scan(
mpc_step, (obs, state, rng, history, hist_len), jnp.arange(eval_steps),
)
elapsed = time.time() - t0
# First-episode extraction (strict single-life evaluation)
rewards_np = np.array(rewards)
dones_np = np.array(dones)
ach_np = np.array(achievements)
ep_rewards = np.zeros(num_envs)
ep_ach = np.zeros((num_envs, ach_np.shape[2]))
ep_lengths = np.zeros(num_envs, dtype=int)
for i in range(num_envs):
death = np.where(dones_np[:, i])[0]
end = death[0] if len(death) > 0 else eval_steps - 1
ep_rewards[i] = rewards_np[:end + 1, i].sum()
ep_ach[i] = ach_np[:end + 1, i].max(axis=0)
ep_lengths[i] = end + 1
pct = ep_ach.mean(axis=0) * 100.0
# Report
print(f"\n{'=' * 50}")
print(f"EVALUATION COMPLETE ({elapsed:.1f}s)")
print(f"{'=' * 50}")
print(f"Average Score: {ep_rewards.mean():.1f} | Best: {ep_rewards.max():.1f}")
ach_cls = ClassicAchievements if "Classic" in env_name else FullCraftaxAchievements
ach_names = [(a.name.replace("_", " ").title(), a.name.lower()) for a in ach_cls]
valid = [i for i, p in enumerate(pct) if p > 0]
top_idx = max(valid) if valid else 5
for i in range(top_idx + 1):
name, _ = ach_names[i]
count = int(pct[i] / 100.0 * num_envs)
icon = "+" if count > 0 else "-"
print(f" [{icon}] {name}: {count}/{num_envs}")
print(f"{'=' * 50}")
if config.get("USE_WANDB", True):
wandb.init(
project=config.get("WANDB_PROJECT", "remdm-craftax"),
name=f"Eval-T{temperature}-P{top_p}",
config=config, job_type="evaluation",
)
summary = {"eval/average_score": float(ep_rewards.mean())}
for i in range(top_idx + 1):
summary[f"eval/achievements/{ach_names[i][1]}"] = pct[i]
wandb.log(summary)
table = wandb.Table(columns=["Agent", "Score", "Achievements", "Lifespan"])
unlocked = ep_ach.sum(axis=-1)
for i in range(num_envs):
table.add_data(f"Agent {i + 1}", float(ep_rewards[i]), int(unlocked[i]), int(ep_lengths[i]))
wandb.log({"Individual Results": table})
wandb.finish()
|