Spaces:
Sleeping
Sleeping
File size: 4,833 Bytes
a03a89b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 | """Baseline evaluation harness for MiniGridEnv."""
from __future__ import annotations
import inspect
import math
from collections import Counter
from statistics import mean
from typing import Any, Callable
try:
from ..env.config import EnvConfig
from ..env.minigrid_env import MiniGridEnvironment
except ImportError:
from env.config import EnvConfig
from env.minigrid_env import MiniGridEnvironment
def _sem(values: list[float]) -> float:
if len(values) < 2:
return 0.0
avg = mean(values)
variance = sum((value - avg) ** 2 for value in values) / (len(values) - 1)
return math.sqrt(variance / len(values))
def _resolve_episode_baseline(
baseline: Any, env: MiniGridEnvironment
) -> Any:
if hasattr(baseline, "select_action"):
return baseline
if callable(baseline):
try:
return baseline(env._gym_env) # type: ignore[attr-defined]
except TypeError:
return baseline()
raise TypeError("baseline must be an agent object or a callable factory")
def _select_action(agent: Any, obs: Any, raw_obs: dict | None):
select_action = getattr(agent, "select_action")
sig = inspect.signature(select_action)
if len(sig.parameters) >= 2:
return select_action(obs, raw_obs)
return select_action(obs)
def evaluate_baseline(
baseline: Any,
level_name: str,
n_episodes: int = 100,
seed: int = 42,
) -> dict[str, Any]:
"""Evaluate one baseline on one level and return aggregate metrics."""
completed_flags: list[float] = []
completed_steps: list[float] = []
rewards: list[float] = []
efficiencies: list[float] = []
action_counter: Counter[str] = Counter()
total_valid = 0
total_invalid = 0
for episode_offset in range(n_episodes):
env = MiniGridEnvironment(
config=EnvConfig(level_name=level_name, seed=seed + episode_offset)
)
obs = env.reset(seed=seed + episode_offset)
agent = _resolve_episode_baseline(baseline, env)
while not obs.done:
raw_obs = env._last_obs # type: ignore[attr-defined]
action = _select_action(agent, obs, raw_obs)
obs = env.step(action)
state = env.state
completed = 1.0 if state.completed else 0.0
completed_flags.append(completed)
rewards.append(float(state.total_reward))
total_valid += int(state.valid_actions)
total_invalid += int(state.invalid_actions)
action_counter.update(state.action_distribution)
if state.completed:
completed_steps.append(float(state.steps_taken))
if state.efficiency_ratio is not None:
efficiencies.append(float(state.efficiency_ratio))
total_actions = total_valid + total_invalid
return {
"level": level_name,
"episodes": n_episodes,
"completion_rate": mean(completed_flags) if completed_flags else 0.0,
"completion_rate_sem": _sem(completed_flags),
"mean_steps_completed": mean(completed_steps) if completed_steps else 0.0,
"mean_steps_completed_sem": _sem(completed_steps),
"mean_reward": mean(rewards) if rewards else 0.0,
"mean_reward_sem": _sem(rewards),
"efficiency": mean(efficiencies) if efficiencies else 0.0,
"efficiency_sem": _sem(efficiencies),
"action_parse_rate": (float(total_valid) / float(total_actions)) if total_actions else 0.0,
"action_distribution": dict(action_counter),
}
def evaluate_across_levels(
baseline: Any,
level_names: list[str],
n_episodes_per_level: int = 100,
seed: int = 42,
) -> dict[str, dict[str, Any]]:
"""Run evaluate_baseline for all level names."""
return {
level_name: evaluate_baseline(
baseline=baseline,
level_name=level_name,
n_episodes=n_episodes_per_level,
seed=seed,
)
for level_name in level_names
}
def print_comparison_table(results: dict[str, dict[str, dict[str, Any]]]) -> None:
"""Pretty-print baseline comparison by completion rate."""
all_levels: set[str] = set()
for baseline_results in results.values():
all_levels.update(baseline_results.keys())
ordered_levels = sorted(all_levels)
baseline_names = list(results.keys())
header = ["Level", *baseline_names]
row_sep = "| " + " | ".join(["---"] * len(header)) + " |"
print("| " + " | ".join(header) + " |")
print(row_sep)
for level in ordered_levels:
row = [level]
for baseline_name in baseline_names:
metrics = results.get(baseline_name, {}).get(level, {})
value = float(metrics.get("completion_rate", 0.0)) * 100.0
row.append(f"{value:.1f}%")
print("| " + " | ".join(row) + " |")
|