File size: 4,833 Bytes
a03a89b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
"""Baseline evaluation harness for MiniGridEnv."""

from __future__ import annotations

import inspect
import math
from collections import Counter
from statistics import mean
from typing import Any, Callable

try:
    from ..env.config import EnvConfig
    from ..env.minigrid_env import MiniGridEnvironment
except ImportError:
    from env.config import EnvConfig
    from env.minigrid_env import MiniGridEnvironment


def _sem(values: list[float]) -> float:
    if len(values) < 2:
        return 0.0
    avg = mean(values)
    variance = sum((value - avg) ** 2 for value in values) / (len(values) - 1)
    return math.sqrt(variance / len(values))


def _resolve_episode_baseline(
    baseline: Any, env: MiniGridEnvironment
) -> Any:
    if hasattr(baseline, "select_action"):
        return baseline
    if callable(baseline):
        try:
            return baseline(env._gym_env)  # type: ignore[attr-defined]
        except TypeError:
            return baseline()
    raise TypeError("baseline must be an agent object or a callable factory")


def _select_action(agent: Any, obs: Any, raw_obs: dict | None):
    select_action = getattr(agent, "select_action")
    sig = inspect.signature(select_action)
    if len(sig.parameters) >= 2:
        return select_action(obs, raw_obs)
    return select_action(obs)


def evaluate_baseline(
    baseline: Any,
    level_name: str,
    n_episodes: int = 100,
    seed: int = 42,
) -> dict[str, Any]:
    """Evaluate one baseline on one level and return aggregate metrics."""
    completed_flags: list[float] = []
    completed_steps: list[float] = []
    rewards: list[float] = []
    efficiencies: list[float] = []
    action_counter: Counter[str] = Counter()
    total_valid = 0
    total_invalid = 0

    for episode_offset in range(n_episodes):
        env = MiniGridEnvironment(
            config=EnvConfig(level_name=level_name, seed=seed + episode_offset)
        )
        obs = env.reset(seed=seed + episode_offset)
        agent = _resolve_episode_baseline(baseline, env)

        while not obs.done:
            raw_obs = env._last_obs  # type: ignore[attr-defined]
            action = _select_action(agent, obs, raw_obs)
            obs = env.step(action)

        state = env.state
        completed = 1.0 if state.completed else 0.0
        completed_flags.append(completed)
        rewards.append(float(state.total_reward))
        total_valid += int(state.valid_actions)
        total_invalid += int(state.invalid_actions)
        action_counter.update(state.action_distribution)
        if state.completed:
            completed_steps.append(float(state.steps_taken))
            if state.efficiency_ratio is not None:
                efficiencies.append(float(state.efficiency_ratio))

    total_actions = total_valid + total_invalid
    return {
        "level": level_name,
        "episodes": n_episodes,
        "completion_rate": mean(completed_flags) if completed_flags else 0.0,
        "completion_rate_sem": _sem(completed_flags),
        "mean_steps_completed": mean(completed_steps) if completed_steps else 0.0,
        "mean_steps_completed_sem": _sem(completed_steps),
        "mean_reward": mean(rewards) if rewards else 0.0,
        "mean_reward_sem": _sem(rewards),
        "efficiency": mean(efficiencies) if efficiencies else 0.0,
        "efficiency_sem": _sem(efficiencies),
        "action_parse_rate": (float(total_valid) / float(total_actions)) if total_actions else 0.0,
        "action_distribution": dict(action_counter),
    }


def evaluate_across_levels(
    baseline: Any,
    level_names: list[str],
    n_episodes_per_level: int = 100,
    seed: int = 42,
) -> dict[str, dict[str, Any]]:
    """Run evaluate_baseline for all level names."""
    return {
        level_name: evaluate_baseline(
            baseline=baseline,
            level_name=level_name,
            n_episodes=n_episodes_per_level,
            seed=seed,
        )
        for level_name in level_names
    }


def print_comparison_table(results: dict[str, dict[str, dict[str, Any]]]) -> None:
    """Pretty-print baseline comparison by completion rate."""
    all_levels: set[str] = set()
    for baseline_results in results.values():
        all_levels.update(baseline_results.keys())
    ordered_levels = sorted(all_levels)
    baseline_names = list(results.keys())

    header = ["Level", *baseline_names]
    row_sep = "| " + " | ".join(["---"] * len(header)) + " |"
    print("| " + " | ".join(header) + " |")
    print(row_sep)
    for level in ordered_levels:
        row = [level]
        for baseline_name in baseline_names:
            metrics = results.get(baseline_name, {}).get(level, {})
            value = float(metrics.get("completion_rate", 0.0)) * 100.0
            row.append(f"{value:.1f}%")
        print("| " + " | ".join(row) + " |")