File size: 3,032 Bytes
2803d7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from __future__ import annotations

from collections.abc import Iterable
from typing import Protocol

from agents.master.session import EpisodeSession

from .env import HeroEnvironment
from .policy import HeroPolicyError
from .schema import HeroAction, HeroEpisodeStats, HeroObservation, HeroState


class ToolCallingPolicy(Protocol):
    def reset(self) -> None:
        ...

    def next_action(
        self,
        observation: HeroObservation,
        state: HeroState,
        scratchpad: str,
    ) -> HeroAction | dict[str, object] | None:
        ...


class ScriptedToolCallingPolicy:
    def __init__(self, actions: Iterable[HeroAction | dict[str, object]]) -> None:
        self._initial_actions = list(actions)
        self._remaining_actions = list(self._initial_actions)

    def reset(self) -> None:
        self._remaining_actions = list(self._initial_actions)

    def next_action(
        self,
        observation: HeroObservation,
        state: HeroState,
        scratchpad: str,
    ) -> HeroAction | dict[str, object] | None:
        del observation, state, scratchpad
        if not self._remaining_actions:
            return None
        return self._remaining_actions.pop(0)


class HeroRunner:
    def __init__(
        self,
        policy: ToolCallingPolicy,
        *,
        max_game_steps: int | None = 40,
        max_tool_calls: int | None = None,
        scratchpad_max_chars: int = 8000,
        debug: bool = False,
    ) -> None:
        self.policy = policy
        self.max_game_steps = max_game_steps
        self.max_tool_calls = max_tool_calls
        self.scratchpad_max_chars = scratchpad_max_chars
        self.debug = debug
        self.last_error: str | None = None
        self.last_observation: HeroObservation | None = None
        self.episode_stats: HeroEpisodeStats | None = None

    def run(self, session: EpisodeSession, max_steps: int) -> None:
        self.last_error = None
        self.last_observation = None
        self.episode_stats = None
        self.policy.reset()
        env = HeroEnvironment.from_session(
            session,
            max_game_steps=max_steps if self.max_game_steps is None else min(max_steps, self.max_game_steps),
            max_tool_calls=self.max_tool_calls,
            scratchpad_max_chars=self.scratchpad_max_chars,
            debug=self.debug,
        )
        observation = env.reset()
        self.last_observation = observation
        while not observation.done:
            try:
                action = self.policy.next_action(observation, env.state, env.scratchpad)
            except HeroPolicyError as exc:
                self.last_error = str(exc)
                self.episode_stats = env.episode_stats
                return
            if action is None:
                self.episode_stats = env.episode_stats
                return
            result = env.step(action)
            observation = result.observation
            self.last_observation = observation
        self.episode_stats = env.episode_stats