File size: 3,874 Bytes
2803d7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from __future__ import annotations

import argparse
import json
from pathlib import Path

from agents.hero.policy import HeroLLMPolicy
from agents.master.interface import DEFAULT_GEMINI_MODEL
from agents.master.env import DMEnvironment
from agents.master.policy import DungeonMasterLLMPolicy
from agents.shared.runtime import (
    build_interface_adapter,
    create_structured_client,
    resolve_interface_config,
    resolve_structured_client_config,
)

from .runner import ClosedLoopRunner


def main(argv: list[str] | None = None) -> int:
    parser = argparse.ArgumentParser(description="Closed-loop dungeon master and hero harness")
    parser.add_argument("--episodes", type=int, default=1)
    parser.add_argument("--seed", type=int)
    parser.add_argument("--target-ratio", type=float)
    parser.add_argument("--dm-provider", choices=["gemini", "hf_local"])
    parser.add_argument("--dm-model")
    parser.add_argument("--dm-adapter-path")
    parser.add_argument("--hero-provider", choices=["gemini", "hf_local"])
    parser.add_argument("--hero-model")
    parser.add_argument("--hero-adapter-path")
    parser.add_argument("--interface-provider", choices=["strict", "simple", "gemini"])
    parser.add_argument("--interface-model", default=DEFAULT_GEMINI_MODEL)
    parser.add_argument("--interface-narrate", action="store_true")
    parser.add_argument(
        "--translate-corporate-env",
        action="store_true",
        help="Rewrite hero-facing observations into a corporate app metaphor and map translated commands back through Gemini.",
    )
    parser.add_argument("--artifacts-root", type=Path)
    parser.add_argument("--dm-artifacts-root", type=Path)
    parser.add_argument("--dm-repair-attempts", type=int, default=2)
    parser.add_argument("--hero-max-game-steps", type=int, default=40)
    parser.add_argument("--hero-max-tool-calls", type=int, default=80)
    parser.add_argument("--live", action="store_true")
    parser.add_argument("--live-dir", type=Path)
    args = parser.parse_args(argv)

    dm_config = resolve_structured_client_config(
        "dm",
        provider=args.dm_provider,
        model_name=args.dm_model,
        adapter_path=args.dm_adapter_path,
    )
    hero_config = resolve_structured_client_config(
        "hero",
        provider=args.hero_provider,
        model_name=args.hero_model,
        adapter_path=args.hero_adapter_path,
    )
    interface_config = resolve_interface_config(
        provider=args.interface_provider,
        model_name=args.interface_model,
        narrate_observations=args.interface_narrate,
        translation_mode="corporate_app" if args.translate_corporate_env else None,
    )
    runner = ClosedLoopRunner(
        dm_env=DMEnvironment(artifacts_root=args.dm_artifacts_root),
        dm_policy=DungeonMasterLLMPolicy(create_structured_client(dm_config), model_name=dm_config.model_name),
        hero_policy=HeroLLMPolicy(create_structured_client(hero_config), model_name=hero_config.model_name),
        artifacts_root=args.artifacts_root,
        live_dir=args.live_dir,
        max_dm_repair_attempts=args.dm_repair_attempts,
        hero_runner_kwargs={
            "max_game_steps": args.hero_max_game_steps,
            "max_tool_calls": args.hero_max_tool_calls,
        },
        hero_interface_adapter=build_interface_adapter(interface_config),
    )
    records = []
    for index in range(args.episodes):
        seed = None if args.seed is None else args.seed + index
        record = runner.run_episode(seed=seed, target_ratio=args.target_ratio, live=args.live)
        records.append(record)
        print(json.dumps(ClosedLoopRunner.summary(record).model_dump(mode="json")))
    if records:
        print(json.dumps(ClosedLoopRunner.aggregate(records).model_dump(mode="json")))
    return 0


if __name__ == "__main__":
    raise SystemExit(main())