from __future__ import annotations from collections.abc import Iterable from typing import Protocol from agents.master.session import EpisodeSession from .env import HeroEnvironment from .policy import HeroPolicyError from .schema import HeroAction, HeroEpisodeStats, HeroObservation, HeroState class ToolCallingPolicy(Protocol): def reset(self) -> None: ... def next_action( self, observation: HeroObservation, state: HeroState, scratchpad: str, ) -> HeroAction | dict[str, object] | None: ... class ScriptedToolCallingPolicy: def __init__(self, actions: Iterable[HeroAction | dict[str, object]]) -> None: self._initial_actions = list(actions) self._remaining_actions = list(self._initial_actions) def reset(self) -> None: self._remaining_actions = list(self._initial_actions) def next_action( self, observation: HeroObservation, state: HeroState, scratchpad: str, ) -> HeroAction | dict[str, object] | None: del observation, state, scratchpad if not self._remaining_actions: return None return self._remaining_actions.pop(0) class HeroRunner: def __init__( self, policy: ToolCallingPolicy, *, max_game_steps: int | None = 40, max_tool_calls: int | None = None, scratchpad_max_chars: int = 8000, debug: bool = False, ) -> None: self.policy = policy self.max_game_steps = max_game_steps self.max_tool_calls = max_tool_calls self.scratchpad_max_chars = scratchpad_max_chars self.debug = debug self.last_error: str | None = None self.last_observation: HeroObservation | None = None self.episode_stats: HeroEpisodeStats | None = None def run(self, session: EpisodeSession, max_steps: int) -> None: self.last_error = None self.last_observation = None self.episode_stats = None self.policy.reset() env = HeroEnvironment.from_session( session, max_game_steps=max_steps if self.max_game_steps is None else min(max_steps, self.max_game_steps), max_tool_calls=self.max_tool_calls, scratchpad_max_chars=self.scratchpad_max_chars, debug=self.debug, ) observation = env.reset() self.last_observation = observation while not observation.done: try: action = self.policy.next_action(observation, env.state, env.scratchpad) except HeroPolicyError as exc: self.last_error = str(exc) self.episode_stats = env.episode_stats return if action is None: self.episode_stats = env.episode_stats return result = env.step(action) observation = result.observation self.last_observation = observation self.episode_stats = env.episode_stats