MiniGridEnv / eval /evaluate.py
yashu2000's picture
Upload folder using huggingface_hub
a03a89b verified
"""Baseline evaluation harness for MiniGridEnv."""
from __future__ import annotations
import inspect
import math
from collections import Counter
from statistics import mean
from typing import Any, Callable
try:
from ..env.config import EnvConfig
from ..env.minigrid_env import MiniGridEnvironment
except ImportError:
from env.config import EnvConfig
from env.minigrid_env import MiniGridEnvironment
def _sem(values: list[float]) -> float:
if len(values) < 2:
return 0.0
avg = mean(values)
variance = sum((value - avg) ** 2 for value in values) / (len(values) - 1)
return math.sqrt(variance / len(values))
def _resolve_episode_baseline(
baseline: Any, env: MiniGridEnvironment
) -> Any:
if hasattr(baseline, "select_action"):
return baseline
if callable(baseline):
try:
return baseline(env._gym_env) # type: ignore[attr-defined]
except TypeError:
return baseline()
raise TypeError("baseline must be an agent object or a callable factory")
def _select_action(agent: Any, obs: Any, raw_obs: dict | None):
select_action = getattr(agent, "select_action")
sig = inspect.signature(select_action)
if len(sig.parameters) >= 2:
return select_action(obs, raw_obs)
return select_action(obs)
def evaluate_baseline(
baseline: Any,
level_name: str,
n_episodes: int = 100,
seed: int = 42,
) -> dict[str, Any]:
"""Evaluate one baseline on one level and return aggregate metrics."""
completed_flags: list[float] = []
completed_steps: list[float] = []
rewards: list[float] = []
efficiencies: list[float] = []
action_counter: Counter[str] = Counter()
total_valid = 0
total_invalid = 0
for episode_offset in range(n_episodes):
env = MiniGridEnvironment(
config=EnvConfig(level_name=level_name, seed=seed + episode_offset)
)
obs = env.reset(seed=seed + episode_offset)
agent = _resolve_episode_baseline(baseline, env)
while not obs.done:
raw_obs = env._last_obs # type: ignore[attr-defined]
action = _select_action(agent, obs, raw_obs)
obs = env.step(action)
state = env.state
completed = 1.0 if state.completed else 0.0
completed_flags.append(completed)
rewards.append(float(state.total_reward))
total_valid += int(state.valid_actions)
total_invalid += int(state.invalid_actions)
action_counter.update(state.action_distribution)
if state.completed:
completed_steps.append(float(state.steps_taken))
if state.efficiency_ratio is not None:
efficiencies.append(float(state.efficiency_ratio))
total_actions = total_valid + total_invalid
return {
"level": level_name,
"episodes": n_episodes,
"completion_rate": mean(completed_flags) if completed_flags else 0.0,
"completion_rate_sem": _sem(completed_flags),
"mean_steps_completed": mean(completed_steps) if completed_steps else 0.0,
"mean_steps_completed_sem": _sem(completed_steps),
"mean_reward": mean(rewards) if rewards else 0.0,
"mean_reward_sem": _sem(rewards),
"efficiency": mean(efficiencies) if efficiencies else 0.0,
"efficiency_sem": _sem(efficiencies),
"action_parse_rate": (float(total_valid) / float(total_actions)) if total_actions else 0.0,
"action_distribution": dict(action_counter),
}
def evaluate_across_levels(
baseline: Any,
level_names: list[str],
n_episodes_per_level: int = 100,
seed: int = 42,
) -> dict[str, dict[str, Any]]:
"""Run evaluate_baseline for all level names."""
return {
level_name: evaluate_baseline(
baseline=baseline,
level_name=level_name,
n_episodes=n_episodes_per_level,
seed=seed,
)
for level_name in level_names
}
def print_comparison_table(results: dict[str, dict[str, dict[str, Any]]]) -> None:
"""Pretty-print baseline comparison by completion rate."""
all_levels: set[str] = set()
for baseline_results in results.values():
all_levels.update(baseline_results.keys())
ordered_levels = sorted(all_levels)
baseline_names = list(results.keys())
header = ["Level", *baseline_names]
row_sep = "| " + " | ".join(["---"] * len(header)) + " |"
print("| " + " | ".join(header) + " |")
print(row_sep)
for level in ordered_levels:
row = [level]
for baseline_name in baseline_names:
metrics = results.get(baseline_name, {}).get(level, {})
value = float(metrics.get("completion_rate", 0.0)) * 100.0
row.append(f"{value:.1f}%")
print("| " + " | ".join(row) + " |")