"""Baseline evaluation harness for MiniGridEnv.""" from __future__ import annotations import inspect import math from collections import Counter from statistics import mean from typing import Any, Callable try: from ..env.config import EnvConfig from ..env.minigrid_env import MiniGridEnvironment except ImportError: from env.config import EnvConfig from env.minigrid_env import MiniGridEnvironment def _sem(values: list[float]) -> float: if len(values) < 2: return 0.0 avg = mean(values) variance = sum((value - avg) ** 2 for value in values) / (len(values) - 1) return math.sqrt(variance / len(values)) def _resolve_episode_baseline( baseline: Any, env: MiniGridEnvironment ) -> Any: if hasattr(baseline, "select_action"): return baseline if callable(baseline): try: return baseline(env._gym_env) # type: ignore[attr-defined] except TypeError: return baseline() raise TypeError("baseline must be an agent object or a callable factory") def _select_action(agent: Any, obs: Any, raw_obs: dict | None): select_action = getattr(agent, "select_action") sig = inspect.signature(select_action) if len(sig.parameters) >= 2: return select_action(obs, raw_obs) return select_action(obs) def evaluate_baseline( baseline: Any, level_name: str, n_episodes: int = 100, seed: int = 42, ) -> dict[str, Any]: """Evaluate one baseline on one level and return aggregate metrics.""" completed_flags: list[float] = [] completed_steps: list[float] = [] rewards: list[float] = [] efficiencies: list[float] = [] action_counter: Counter[str] = Counter() total_valid = 0 total_invalid = 0 for episode_offset in range(n_episodes): env = MiniGridEnvironment( config=EnvConfig(level_name=level_name, seed=seed + episode_offset) ) obs = env.reset(seed=seed + episode_offset) agent = _resolve_episode_baseline(baseline, env) while not obs.done: raw_obs = env._last_obs # type: ignore[attr-defined] action = _select_action(agent, obs, raw_obs) obs = env.step(action) state = env.state completed = 1.0 if state.completed else 0.0 completed_flags.append(completed) rewards.append(float(state.total_reward)) total_valid += int(state.valid_actions) total_invalid += int(state.invalid_actions) action_counter.update(state.action_distribution) if state.completed: completed_steps.append(float(state.steps_taken)) if state.efficiency_ratio is not None: efficiencies.append(float(state.efficiency_ratio)) total_actions = total_valid + total_invalid return { "level": level_name, "episodes": n_episodes, "completion_rate": mean(completed_flags) if completed_flags else 0.0, "completion_rate_sem": _sem(completed_flags), "mean_steps_completed": mean(completed_steps) if completed_steps else 0.0, "mean_steps_completed_sem": _sem(completed_steps), "mean_reward": mean(rewards) if rewards else 0.0, "mean_reward_sem": _sem(rewards), "efficiency": mean(efficiencies) if efficiencies else 0.0, "efficiency_sem": _sem(efficiencies), "action_parse_rate": (float(total_valid) / float(total_actions)) if total_actions else 0.0, "action_distribution": dict(action_counter), } def evaluate_across_levels( baseline: Any, level_names: list[str], n_episodes_per_level: int = 100, seed: int = 42, ) -> dict[str, dict[str, Any]]: """Run evaluate_baseline for all level names.""" return { level_name: evaluate_baseline( baseline=baseline, level_name=level_name, n_episodes=n_episodes_per_level, seed=seed, ) for level_name in level_names } def print_comparison_table(results: dict[str, dict[str, dict[str, Any]]]) -> None: """Pretty-print baseline comparison by completion rate.""" all_levels: set[str] = set() for baseline_results in results.values(): all_levels.update(baseline_results.keys()) ordered_levels = sorted(all_levels) baseline_names = list(results.keys()) header = ["Level", *baseline_names] row_sep = "| " + " | ".join(["---"] * len(header)) + " |" print("| " + " | ".join(header) + " |") print(row_sep) for level in ordered_levels: row = [level] for baseline_name in baseline_names: metrics = results.get(baseline_name, {}).get(level, {}) value = float(metrics.get("completion_rate", 0.0)) * 100.0 row.append(f"{value:.1f}%") print("| " + " | ".join(row) + " |")