Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import random | |
| from typing import Iterable, Protocol, TYPE_CHECKING | |
| from .base import DMInterfaceError | |
| if TYPE_CHECKING: | |
| from .session import EpisodeSession | |
| class EpisodeRunner(Protocol): | |
| def run(self, session: EpisodeSession, max_steps: int) -> None: | |
| ... | |
| class WalkthroughRunner: | |
| def __init__(self, commands: Iterable[str] | None = None) -> None: | |
| self._commands = list(commands) if commands is not None else None | |
| def run(self, session: EpisodeSession, max_steps: int) -> None: | |
| commands = list(self._commands or session.compiled.solver_policy) | |
| for command in commands: | |
| if session.done or session.steps_taken >= max_steps: | |
| return | |
| session.step(command) | |
| class CommandSequenceRunner: | |
| def __init__(self, commands: Iterable[str]) -> None: | |
| self._commands = list(commands) | |
| def run(self, session: EpisodeSession, max_steps: int) -> None: | |
| for command in self._commands: | |
| if session.done or session.steps_taken >= max_steps: | |
| return | |
| session.step(command) | |
| class RandomAdmissibleRunner: | |
| def __init__(self, seed: int | None = None) -> None: | |
| self._rng = random.Random(seed) | |
| def run(self, session: EpisodeSession, max_steps: int) -> None: | |
| while not session.done and session.steps_taken < max_steps: | |
| options = session.available_commands() | |
| if not options: | |
| return | |
| session.step(self._rng.choice(options)) | |
| class ManualRunner: | |
| def run(self, session: EpisodeSession, max_steps: int) -> None: | |
| print(session.current_feedback()) | |
| while not session.done and session.steps_taken < max_steps: | |
| print() | |
| print(f"Step {session.steps_taken + 1}/{max_steps}") | |
| command = input("> ").strip() | |
| if command in {"quit", "exit"}: | |
| return | |
| try: | |
| turn = session.step(command) | |
| except DMInterfaceError: | |
| print("I'm not sure what you mean. Try rephrasing that command.") | |
| if session.available_commands(): | |
| print("Admissible:", ", ".join(session.available_commands())) | |
| continue | |
| print(turn.observation) | |
| if session.available_commands(): | |
| print("Admissible:", ", ".join(session.available_commands())) | |