File size: 5,504 Bytes
6140064
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""Evaluation: run a trained diffusion planner with MPC + historical inpainting."""

from __future__ import annotations

import time
from typing import Any

import jax
import jax.numpy as jnp
import numpy as np
import wandb
from craftax.craftax_env import make_craftax_env_from_name
from craftax.craftax.constants import Achievement as FullCraftaxAchievements
from craftax.craftax_classic.constants import Achievement as ClassicAchievements

from src.diffusion.sampling import sample_plan_inpainting
from .model import build_model, load_checkpoint, make_apply_fns


# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------

def run_inference(config: dict[str, Any]) -> None:
    env_name = config["ENV_NAME"]
    env = make_craftax_env_from_name(env_name, auto_reset=True)
    env_params = env.default_params
    num_actions = env.action_space(env_params).n
    obs_dim = env.observation_space(env_params).shape[0]
    config["NUM_ACTIONS"] = num_actions

    num_envs = config.get("EVAL_NUM_ENVS", 32)
    plan_horizon = config["PLAN_HORIZON"]
    diffusion_steps = config.get("DIFFUSION_STEPS_EVAL", 10)
    temperature = config.get("TEMPERATURE", 0.5)
    top_p = config.get("TOP_P", 0.95)
    eval_steps = int(float(config.get("EVAL_STEPS", 10000)))

    model = build_model(config, num_actions)
    apply_eval, _ = make_apply_fns(model)

    rng = jax.random.PRNGKey(config["SEED"])
    rng, ckpt_rng = jax.random.split(rng)
    model_params = load_checkpoint(model, ckpt_rng, obs_dim, plan_horizon, config["CHECKPOINT_PATH"])
    env_indices = jnp.arange(num_envs)

    @jax.jit
    def mpc_step(carry, _step_idx):
        obs, state, rng, history, hist_len = carry
        rng, plan_rng, env_rng = jax.random.split(rng, 3)

        # Reset history when plan is exhausted
        seq_full = hist_len >= plan_horizon
        hist_len = jnp.where(seq_full, 0, hist_len)
        history = jnp.where(seq_full[:, None], num_actions, history)

        plan = sample_plan_inpainting(
            apply_eval, model_params, plan_rng, obs,
            history, hist_len, num_actions, plan_horizon,
            diffusion_steps, temperature, top_p,
        )

        action = jnp.take_along_axis(plan, hist_len[:, None], axis=-1).squeeze(-1)
        history = history.at[env_indices, hist_len].set(action)
        hist_len = hist_len + 1

        obs_next, state_next, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0, None))(
            jax.random.split(env_rng, num_envs), state, action, env_params,
        )

        hist_len = jnp.where(done, 0, hist_len)
        history = jnp.where(done[:, None], num_actions, history)
        return (obs_next, state_next, rng, history, hist_len), (reward, done, state_next.achievements)

    print(f"\nRunning {num_envs} agents in {env_name} for {eval_steps} steps...")

    rng, env_rng = jax.random.split(rng)
    obs, state = jax.vmap(env.reset, in_axes=(0, None))(
        jax.random.split(env_rng, num_envs), env_params,
    )
    history = jnp.full((num_envs, plan_horizon), num_actions, dtype=jnp.int32)
    hist_len = jnp.zeros(num_envs, dtype=jnp.int32)

    t0 = time.time()
    _, (rewards, dones, achievements) = jax.lax.scan(
        mpc_step, (obs, state, rng, history, hist_len), jnp.arange(eval_steps),
    )
    elapsed = time.time() - t0

    # First-episode extraction (strict single-life evaluation)
    rewards_np = np.array(rewards)
    dones_np = np.array(dones)
    ach_np = np.array(achievements)

    ep_rewards = np.zeros(num_envs)
    ep_ach = np.zeros((num_envs, ach_np.shape[2]))
    ep_lengths = np.zeros(num_envs, dtype=int)

    for i in range(num_envs):
        death = np.where(dones_np[:, i])[0]
        end = death[0] if len(death) > 0 else eval_steps - 1
        ep_rewards[i] = rewards_np[:end + 1, i].sum()
        ep_ach[i] = ach_np[:end + 1, i].max(axis=0)
        ep_lengths[i] = end + 1

    pct = ep_ach.mean(axis=0) * 100.0

    # Report
    print(f"\n{'=' * 50}")
    print(f"EVALUATION COMPLETE ({elapsed:.1f}s)")
    print(f"{'=' * 50}")
    print(f"Average Score: {ep_rewards.mean():.1f}  |  Best: {ep_rewards.max():.1f}")

    ach_cls = ClassicAchievements if "Classic" in env_name else FullCraftaxAchievements
    ach_names = [(a.name.replace("_", " ").title(), a.name.lower()) for a in ach_cls]
    valid = [i for i, p in enumerate(pct) if p > 0]
    top_idx = max(valid) if valid else 5

    for i in range(top_idx + 1):
        name, _ = ach_names[i]
        count = int(pct[i] / 100.0 * num_envs)
        icon = "+" if count > 0 else "-"
        print(f"  [{icon}] {name}: {count}/{num_envs}")
    print(f"{'=' * 50}")

    if config.get("USE_WANDB", True):
        wandb.init(
            project=config.get("WANDB_PROJECT", "remdm-craftax"),
            name=f"Eval-T{temperature}-P{top_p}",
            config=config, job_type="evaluation",
        )
        summary = {"eval/average_score": float(ep_rewards.mean())}
        for i in range(top_idx + 1):
            summary[f"eval/achievements/{ach_names[i][1]}"] = pct[i]
        wandb.log(summary)

        table = wandb.Table(columns=["Agent", "Score", "Achievements", "Lifespan"])
        unlocked = ep_ach.sum(axis=-1)
        for i in range(num_envs):
            table.add_data(f"Agent {i + 1}", float(ep_rewards[i]), int(unlocked[i]), int(ep_lengths[i]))
        wandb.log({"Individual Results": table})
        wandb.finish()