remdm-craftax / src /planners /inference.py
MathisW78's picture
Upload COMP0258 demo bundle (code + diffusion/PPO checkpoints + ablation assets)
6140064 verified
"""Evaluation: run a trained diffusion planner with MPC + historical inpainting."""
from __future__ import annotations
import time
from typing import Any
import jax
import jax.numpy as jnp
import numpy as np
import wandb
from craftax.craftax_env import make_craftax_env_from_name
from craftax.craftax.constants import Achievement as FullCraftaxAchievements
from craftax.craftax_classic.constants import Achievement as ClassicAchievements
from src.diffusion.sampling import sample_plan_inpainting
from .model import build_model, load_checkpoint, make_apply_fns
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
def run_inference(config: dict[str, Any]) -> None:
env_name = config["ENV_NAME"]
env = make_craftax_env_from_name(env_name, auto_reset=True)
env_params = env.default_params
num_actions = env.action_space(env_params).n
obs_dim = env.observation_space(env_params).shape[0]
config["NUM_ACTIONS"] = num_actions
num_envs = config.get("EVAL_NUM_ENVS", 32)
plan_horizon = config["PLAN_HORIZON"]
diffusion_steps = config.get("DIFFUSION_STEPS_EVAL", 10)
temperature = config.get("TEMPERATURE", 0.5)
top_p = config.get("TOP_P", 0.95)
eval_steps = int(float(config.get("EVAL_STEPS", 10000)))
model = build_model(config, num_actions)
apply_eval, _ = make_apply_fns(model)
rng = jax.random.PRNGKey(config["SEED"])
rng, ckpt_rng = jax.random.split(rng)
model_params = load_checkpoint(model, ckpt_rng, obs_dim, plan_horizon, config["CHECKPOINT_PATH"])
env_indices = jnp.arange(num_envs)
@jax.jit
def mpc_step(carry, _step_idx):
obs, state, rng, history, hist_len = carry
rng, plan_rng, env_rng = jax.random.split(rng, 3)
# Reset history when plan is exhausted
seq_full = hist_len >= plan_horizon
hist_len = jnp.where(seq_full, 0, hist_len)
history = jnp.where(seq_full[:, None], num_actions, history)
plan = sample_plan_inpainting(
apply_eval, model_params, plan_rng, obs,
history, hist_len, num_actions, plan_horizon,
diffusion_steps, temperature, top_p,
)
action = jnp.take_along_axis(plan, hist_len[:, None], axis=-1).squeeze(-1)
history = history.at[env_indices, hist_len].set(action)
hist_len = hist_len + 1
obs_next, state_next, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0, None))(
jax.random.split(env_rng, num_envs), state, action, env_params,
)
hist_len = jnp.where(done, 0, hist_len)
history = jnp.where(done[:, None], num_actions, history)
return (obs_next, state_next, rng, history, hist_len), (reward, done, state_next.achievements)
print(f"\nRunning {num_envs} agents in {env_name} for {eval_steps} steps...")
rng, env_rng = jax.random.split(rng)
obs, state = jax.vmap(env.reset, in_axes=(0, None))(
jax.random.split(env_rng, num_envs), env_params,
)
history = jnp.full((num_envs, plan_horizon), num_actions, dtype=jnp.int32)
hist_len = jnp.zeros(num_envs, dtype=jnp.int32)
t0 = time.time()
_, (rewards, dones, achievements) = jax.lax.scan(
mpc_step, (obs, state, rng, history, hist_len), jnp.arange(eval_steps),
)
elapsed = time.time() - t0
# First-episode extraction (strict single-life evaluation)
rewards_np = np.array(rewards)
dones_np = np.array(dones)
ach_np = np.array(achievements)
ep_rewards = np.zeros(num_envs)
ep_ach = np.zeros((num_envs, ach_np.shape[2]))
ep_lengths = np.zeros(num_envs, dtype=int)
for i in range(num_envs):
death = np.where(dones_np[:, i])[0]
end = death[0] if len(death) > 0 else eval_steps - 1
ep_rewards[i] = rewards_np[:end + 1, i].sum()
ep_ach[i] = ach_np[:end + 1, i].max(axis=0)
ep_lengths[i] = end + 1
pct = ep_ach.mean(axis=0) * 100.0
# Report
print(f"\n{'=' * 50}")
print(f"EVALUATION COMPLETE ({elapsed:.1f}s)")
print(f"{'=' * 50}")
print(f"Average Score: {ep_rewards.mean():.1f} | Best: {ep_rewards.max():.1f}")
ach_cls = ClassicAchievements if "Classic" in env_name else FullCraftaxAchievements
ach_names = [(a.name.replace("_", " ").title(), a.name.lower()) for a in ach_cls]
valid = [i for i, p in enumerate(pct) if p > 0]
top_idx = max(valid) if valid else 5
for i in range(top_idx + 1):
name, _ = ach_names[i]
count = int(pct[i] / 100.0 * num_envs)
icon = "+" if count > 0 else "-"
print(f" [{icon}] {name}: {count}/{num_envs}")
print(f"{'=' * 50}")
if config.get("USE_WANDB", True):
wandb.init(
project=config.get("WANDB_PROJECT", "remdm-craftax"),
name=f"Eval-T{temperature}-P{top_p}",
config=config, job_type="evaluation",
)
summary = {"eval/average_score": float(ep_rewards.mean())}
for i in range(top_idx + 1):
summary[f"eval/achievements/{ach_names[i][1]}"] = pct[i]
wandb.log(summary)
table = wandb.Table(columns=["Agent", "Score", "Achievements", "Lifespan"])
unlocked = ep_ach.sum(axis=-1)
for i in range(num_envs):
table.add_data(f"Agent {i + 1}", float(ep_rewards[i]), int(unlocked[i]), int(ep_lengths[i]))
wandb.log({"Individual Results": table})
wandb.finish()