| from __future__ import annotations |
|
|
| import argparse |
| from pathlib import Path |
|
|
| from minidreamer.config import load_config |
| from minidreamer.data.replay_buffer import ReplayBuffer |
| from minidreamer.evaluation import evaluate_random_policy, evaluate_world_model |
| from minidreamer.envs.make_env import make_env_from_config |
| from minidreamer.planning.evaluate_planner import evaluate_planner |
| from minidreamer.serialization import load_world_model_checkpoint |
|
|
|
|
| def build_arg_parser() -> argparse.ArgumentParser: |
| parser = argparse.ArgumentParser(description="Evaluate MiniDreamer components.") |
| subparsers = parser.add_subparsers(dest="command", required=True) |
|
|
| random_parser = subparsers.add_parser("random", help="Evaluate a random policy.") |
| random_parser.add_argument("--config", type=Path, required=True) |
|
|
| planner_parser = subparsers.add_parser("planner", help="Evaluate a trained planner.") |
| planner_parser.add_argument("--config", type=Path, required=True) |
| planner_parser.add_argument("--checkpoint", type=Path, required=True) |
| planner_parser.add_argument( |
| "--random-action-fraction", |
| type=float, |
| default=0.0, |
| help="Optional evaluation-time action noise. Defaults to 0.0 for a clean planner evaluation.", |
| ) |
|
|
| world_model_parser = subparsers.add_parser("world-model", help="Evaluate held-out world model metrics.") |
| world_model_parser.add_argument("--config", type=Path, required=True) |
| world_model_parser.add_argument("--checkpoint", type=Path, required=True) |
| world_model_parser.add_argument("--replay-dir", type=Path, required=True) |
| world_model_parser.add_argument("--split", type=str, default="val", choices=["train", "val", "test"]) |
| return parser |
|
|
|
|
| def main() -> None: |
| parser = build_arg_parser() |
| args = parser.parse_args() |
| config = load_config(args.config) |
|
|
| if args.command == "random": |
| print(evaluate_random_policy(config)) |
| return |
|
|
| env = make_env_from_config(config, seed=config.get("project", {}).get("seed", 0)) |
| action_dim = env.action_space.n |
| env.close() |
| model, _, metadata = load_world_model_checkpoint(args.checkpoint, action_dim=action_dim, map_location="cpu") |
|
|
| if args.command == "planner": |
| print({ |
| "metadata": metadata, |
| **evaluate_planner( |
| config, |
| model, |
| random_action_fraction=args.random_action_fraction, |
| ), |
| }) |
| return |
|
|
| replay = ReplayBuffer.load(args.replay_dir) |
| print({"metadata": metadata, **evaluate_world_model(config, model, replay, split=args.split)}) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|