Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| from collections.abc import Iterable | |
| from typing import Protocol | |
| from agents.master.session import EpisodeSession | |
| from .env import HeroEnvironment | |
| from .policy import HeroPolicyError | |
| from .schema import HeroAction, HeroEpisodeStats, HeroObservation, HeroState | |
| class ToolCallingPolicy(Protocol): | |
| def reset(self) -> None: | |
| ... | |
| def next_action( | |
| self, | |
| observation: HeroObservation, | |
| state: HeroState, | |
| scratchpad: str, | |
| ) -> HeroAction | dict[str, object] | None: | |
| ... | |
| class ScriptedToolCallingPolicy: | |
| def __init__(self, actions: Iterable[HeroAction | dict[str, object]]) -> None: | |
| self._initial_actions = list(actions) | |
| self._remaining_actions = list(self._initial_actions) | |
| def reset(self) -> None: | |
| self._remaining_actions = list(self._initial_actions) | |
| def next_action( | |
| self, | |
| observation: HeroObservation, | |
| state: HeroState, | |
| scratchpad: str, | |
| ) -> HeroAction | dict[str, object] | None: | |
| del observation, state, scratchpad | |
| if not self._remaining_actions: | |
| return None | |
| return self._remaining_actions.pop(0) | |
| class HeroRunner: | |
| def __init__( | |
| self, | |
| policy: ToolCallingPolicy, | |
| *, | |
| max_game_steps: int | None = 40, | |
| max_tool_calls: int | None = None, | |
| scratchpad_max_chars: int = 8000, | |
| debug: bool = False, | |
| ) -> None: | |
| self.policy = policy | |
| self.max_game_steps = max_game_steps | |
| self.max_tool_calls = max_tool_calls | |
| self.scratchpad_max_chars = scratchpad_max_chars | |
| self.debug = debug | |
| self.last_error: str | None = None | |
| self.last_observation: HeroObservation | None = None | |
| self.episode_stats: HeroEpisodeStats | None = None | |
| def run(self, session: EpisodeSession, max_steps: int) -> None: | |
| self.last_error = None | |
| self.last_observation = None | |
| self.episode_stats = None | |
| self.policy.reset() | |
| env = HeroEnvironment.from_session( | |
| session, | |
| max_game_steps=max_steps if self.max_game_steps is None else min(max_steps, self.max_game_steps), | |
| max_tool_calls=self.max_tool_calls, | |
| scratchpad_max_chars=self.scratchpad_max_chars, | |
| debug=self.debug, | |
| ) | |
| observation = env.reset() | |
| self.last_observation = observation | |
| while not observation.done: | |
| try: | |
| action = self.policy.next_action(observation, env.state, env.scratchpad) | |
| except HeroPolicyError as exc: | |
| self.last_error = str(exc) | |
| self.episode_stats = env.episode_stats | |
| return | |
| if action is None: | |
| self.episode_stats = env.episode_stats | |
| return | |
| result = env.step(action) | |
| observation = result.observation | |
| self.last_observation = observation | |
| self.episode_stats = env.episode_stats | |