from __future__ import annotations import random from typing import Iterable, Protocol, TYPE_CHECKING from .base import DMInterfaceError if TYPE_CHECKING: from .session import EpisodeSession class EpisodeRunner(Protocol): def run(self, session: EpisodeSession, max_steps: int) -> None: ... class WalkthroughRunner: def __init__(self, commands: Iterable[str] | None = None) -> None: self._commands = list(commands) if commands is not None else None def run(self, session: EpisodeSession, max_steps: int) -> None: commands = list(self._commands or session.compiled.solver_policy) for command in commands: if session.done or session.steps_taken >= max_steps: return session.step(command) class CommandSequenceRunner: def __init__(self, commands: Iterable[str]) -> None: self._commands = list(commands) def run(self, session: EpisodeSession, max_steps: int) -> None: for command in self._commands: if session.done or session.steps_taken >= max_steps: return session.step(command) class RandomAdmissibleRunner: def __init__(self, seed: int | None = None) -> None: self._rng = random.Random(seed) def run(self, session: EpisodeSession, max_steps: int) -> None: while not session.done and session.steps_taken < max_steps: options = session.available_commands() if not options: return session.step(self._rng.choice(options)) class ManualRunner: def run(self, session: EpisodeSession, max_steps: int) -> None: print(session.current_feedback()) while not session.done and session.steps_taken < max_steps: print() print(f"Step {session.steps_taken + 1}/{max_steps}") command = input("> ").strip() if command in {"quit", "exit"}: return try: turn = session.step(command) except DMInterfaceError: print("I'm not sure what you mean. Try rephrasing that command.") if session.available_commands(): print("Admissible:", ", ".join(session.available_commands())) continue print(turn.observation) if session.available_commands(): print("Admissible:", ", ".join(session.available_commands()))