File size: 2,662 Bytes
f6d8768
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from __future__ import annotations

import argparse
from pathlib import Path

from minidreamer.config import load_config
from minidreamer.data.replay_buffer import ReplayBuffer
from minidreamer.evaluation import evaluate_random_policy, evaluate_world_model
from minidreamer.envs.make_env import make_env_from_config
from minidreamer.planning.evaluate_planner import evaluate_planner
from minidreamer.serialization import load_world_model_checkpoint


def build_arg_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(description="Evaluate MiniDreamer components.")
    subparsers = parser.add_subparsers(dest="command", required=True)

    random_parser = subparsers.add_parser("random", help="Evaluate a random policy.")
    random_parser.add_argument("--config", type=Path, required=True)

    planner_parser = subparsers.add_parser("planner", help="Evaluate a trained planner.")
    planner_parser.add_argument("--config", type=Path, required=True)
    planner_parser.add_argument("--checkpoint", type=Path, required=True)
    planner_parser.add_argument(
        "--random-action-fraction",
        type=float,
        default=0.0,
        help="Optional evaluation-time action noise. Defaults to 0.0 for a clean planner evaluation.",
    )

    world_model_parser = subparsers.add_parser("world-model", help="Evaluate held-out world model metrics.")
    world_model_parser.add_argument("--config", type=Path, required=True)
    world_model_parser.add_argument("--checkpoint", type=Path, required=True)
    world_model_parser.add_argument("--replay-dir", type=Path, required=True)
    world_model_parser.add_argument("--split", type=str, default="val", choices=["train", "val", "test"])
    return parser


def main() -> None:
    parser = build_arg_parser()
    args = parser.parse_args()
    config = load_config(args.config)

    if args.command == "random":
        print(evaluate_random_policy(config))
        return

    env = make_env_from_config(config, seed=config.get("project", {}).get("seed", 0))
    action_dim = env.action_space.n
    env.close()
    model, _, metadata = load_world_model_checkpoint(args.checkpoint, action_dim=action_dim, map_location="cpu")

    if args.command == "planner":
        print({
            "metadata": metadata,
            **evaluate_planner(
                config,
                model,
                random_action_fraction=args.random_action_fraction,
            ),
        })
        return

    replay = ReplayBuffer.load(args.replay_dir)
    print({"metadata": metadata, **evaluate_world_model(config, model, replay, split=args.split)})


if __name__ == "__main__":
    main()