FATHOM-DM / agents /hero /runner.py
aarushgupta's picture
Deploy FATHOM-DM Space bundle
2803d7e verified
from __future__ import annotations
from collections.abc import Iterable
from typing import Protocol
from agents.master.session import EpisodeSession
from .env import HeroEnvironment
from .policy import HeroPolicyError
from .schema import HeroAction, HeroEpisodeStats, HeroObservation, HeroState
class ToolCallingPolicy(Protocol):
def reset(self) -> None:
...
def next_action(
self,
observation: HeroObservation,
state: HeroState,
scratchpad: str,
) -> HeroAction | dict[str, object] | None:
...
class ScriptedToolCallingPolicy:
def __init__(self, actions: Iterable[HeroAction | dict[str, object]]) -> None:
self._initial_actions = list(actions)
self._remaining_actions = list(self._initial_actions)
def reset(self) -> None:
self._remaining_actions = list(self._initial_actions)
def next_action(
self,
observation: HeroObservation,
state: HeroState,
scratchpad: str,
) -> HeroAction | dict[str, object] | None:
del observation, state, scratchpad
if not self._remaining_actions:
return None
return self._remaining_actions.pop(0)
class HeroRunner:
def __init__(
self,
policy: ToolCallingPolicy,
*,
max_game_steps: int | None = 40,
max_tool_calls: int | None = None,
scratchpad_max_chars: int = 8000,
debug: bool = False,
) -> None:
self.policy = policy
self.max_game_steps = max_game_steps
self.max_tool_calls = max_tool_calls
self.scratchpad_max_chars = scratchpad_max_chars
self.debug = debug
self.last_error: str | None = None
self.last_observation: HeroObservation | None = None
self.episode_stats: HeroEpisodeStats | None = None
def run(self, session: EpisodeSession, max_steps: int) -> None:
self.last_error = None
self.last_observation = None
self.episode_stats = None
self.policy.reset()
env = HeroEnvironment.from_session(
session,
max_game_steps=max_steps if self.max_game_steps is None else min(max_steps, self.max_game_steps),
max_tool_calls=self.max_tool_calls,
scratchpad_max_chars=self.scratchpad_max_chars,
debug=self.debug,
)
observation = env.reset()
self.last_observation = observation
while not observation.done:
try:
action = self.policy.next_action(observation, env.state, env.scratchpad)
except HeroPolicyError as exc:
self.last_error = str(exc)
self.episode_stats = env.episode_stats
return
if action is None:
self.episode_stats = env.episode_stats
return
result = env.step(action)
observation = result.observation
self.last_observation = observation
self.episode_stats = env.episode_stats