Spaces:
Sleeping
Sleeping
| """Baseline evaluation harness for MiniGridEnv.""" | |
| from __future__ import annotations | |
| import inspect | |
| import math | |
| from collections import Counter | |
| from statistics import mean | |
| from typing import Any, Callable | |
| try: | |
| from ..env.config import EnvConfig | |
| from ..env.minigrid_env import MiniGridEnvironment | |
| except ImportError: | |
| from env.config import EnvConfig | |
| from env.minigrid_env import MiniGridEnvironment | |
| def _sem(values: list[float]) -> float: | |
| if len(values) < 2: | |
| return 0.0 | |
| avg = mean(values) | |
| variance = sum((value - avg) ** 2 for value in values) / (len(values) - 1) | |
| return math.sqrt(variance / len(values)) | |
| def _resolve_episode_baseline( | |
| baseline: Any, env: MiniGridEnvironment | |
| ) -> Any: | |
| if hasattr(baseline, "select_action"): | |
| return baseline | |
| if callable(baseline): | |
| try: | |
| return baseline(env._gym_env) # type: ignore[attr-defined] | |
| except TypeError: | |
| return baseline() | |
| raise TypeError("baseline must be an agent object or a callable factory") | |
| def _select_action(agent: Any, obs: Any, raw_obs: dict | None): | |
| select_action = getattr(agent, "select_action") | |
| sig = inspect.signature(select_action) | |
| if len(sig.parameters) >= 2: | |
| return select_action(obs, raw_obs) | |
| return select_action(obs) | |
| def evaluate_baseline( | |
| baseline: Any, | |
| level_name: str, | |
| n_episodes: int = 100, | |
| seed: int = 42, | |
| ) -> dict[str, Any]: | |
| """Evaluate one baseline on one level and return aggregate metrics.""" | |
| completed_flags: list[float] = [] | |
| completed_steps: list[float] = [] | |
| rewards: list[float] = [] | |
| efficiencies: list[float] = [] | |
| action_counter: Counter[str] = Counter() | |
| total_valid = 0 | |
| total_invalid = 0 | |
| for episode_offset in range(n_episodes): | |
| env = MiniGridEnvironment( | |
| config=EnvConfig(level_name=level_name, seed=seed + episode_offset) | |
| ) | |
| obs = env.reset(seed=seed + episode_offset) | |
| agent = _resolve_episode_baseline(baseline, env) | |
| while not obs.done: | |
| raw_obs = env._last_obs # type: ignore[attr-defined] | |
| action = _select_action(agent, obs, raw_obs) | |
| obs = env.step(action) | |
| state = env.state | |
| completed = 1.0 if state.completed else 0.0 | |
| completed_flags.append(completed) | |
| rewards.append(float(state.total_reward)) | |
| total_valid += int(state.valid_actions) | |
| total_invalid += int(state.invalid_actions) | |
| action_counter.update(state.action_distribution) | |
| if state.completed: | |
| completed_steps.append(float(state.steps_taken)) | |
| if state.efficiency_ratio is not None: | |
| efficiencies.append(float(state.efficiency_ratio)) | |
| total_actions = total_valid + total_invalid | |
| return { | |
| "level": level_name, | |
| "episodes": n_episodes, | |
| "completion_rate": mean(completed_flags) if completed_flags else 0.0, | |
| "completion_rate_sem": _sem(completed_flags), | |
| "mean_steps_completed": mean(completed_steps) if completed_steps else 0.0, | |
| "mean_steps_completed_sem": _sem(completed_steps), | |
| "mean_reward": mean(rewards) if rewards else 0.0, | |
| "mean_reward_sem": _sem(rewards), | |
| "efficiency": mean(efficiencies) if efficiencies else 0.0, | |
| "efficiency_sem": _sem(efficiencies), | |
| "action_parse_rate": (float(total_valid) / float(total_actions)) if total_actions else 0.0, | |
| "action_distribution": dict(action_counter), | |
| } | |
| def evaluate_across_levels( | |
| baseline: Any, | |
| level_names: list[str], | |
| n_episodes_per_level: int = 100, | |
| seed: int = 42, | |
| ) -> dict[str, dict[str, Any]]: | |
| """Run evaluate_baseline for all level names.""" | |
| return { | |
| level_name: evaluate_baseline( | |
| baseline=baseline, | |
| level_name=level_name, | |
| n_episodes=n_episodes_per_level, | |
| seed=seed, | |
| ) | |
| for level_name in level_names | |
| } | |
| def print_comparison_table(results: dict[str, dict[str, dict[str, Any]]]) -> None: | |
| """Pretty-print baseline comparison by completion rate.""" | |
| all_levels: set[str] = set() | |
| for baseline_results in results.values(): | |
| all_levels.update(baseline_results.keys()) | |
| ordered_levels = sorted(all_levels) | |
| baseline_names = list(results.keys()) | |
| header = ["Level", *baseline_names] | |
| row_sep = "| " + " | ".join(["---"] * len(header)) + " |" | |
| print("| " + " | ".join(header) + " |") | |
| print(row_sep) | |
| for level in ordered_levels: | |
| row = [level] | |
| for baseline_name in baseline_names: | |
| metrics = results.get(baseline_name, {}).get(level, {}) | |
| value = float(metrics.get("completion_rate", 0.0)) * 100.0 | |
| row.append(f"{value:.1f}%") | |
| print("| " + " | ".join(row) + " |") | |