aarushgupta's picture
Deploy FATHOM-DM Space bundle
2803d7e verified
from __future__ import annotations
import random
from typing import Iterable, Protocol, TYPE_CHECKING
from .base import DMInterfaceError
if TYPE_CHECKING:
from .session import EpisodeSession
class EpisodeRunner(Protocol):
def run(self, session: EpisodeSession, max_steps: int) -> None:
...
class WalkthroughRunner:
def __init__(self, commands: Iterable[str] | None = None) -> None:
self._commands = list(commands) if commands is not None else None
def run(self, session: EpisodeSession, max_steps: int) -> None:
commands = list(self._commands or session.compiled.solver_policy)
for command in commands:
if session.done or session.steps_taken >= max_steps:
return
session.step(command)
class CommandSequenceRunner:
def __init__(self, commands: Iterable[str]) -> None:
self._commands = list(commands)
def run(self, session: EpisodeSession, max_steps: int) -> None:
for command in self._commands:
if session.done or session.steps_taken >= max_steps:
return
session.step(command)
class RandomAdmissibleRunner:
def __init__(self, seed: int | None = None) -> None:
self._rng = random.Random(seed)
def run(self, session: EpisodeSession, max_steps: int) -> None:
while not session.done and session.steps_taken < max_steps:
options = session.available_commands()
if not options:
return
session.step(self._rng.choice(options))
class ManualRunner:
def run(self, session: EpisodeSession, max_steps: int) -> None:
print(session.current_feedback())
while not session.done and session.steps_taken < max_steps:
print()
print(f"Step {session.steps_taken + 1}/{max_steps}")
command = input("> ").strip()
if command in {"quit", "exit"}:
return
try:
turn = session.step(command)
except DMInterfaceError:
print("I'm not sure what you mean. Try rephrasing that command.")
if session.available_commands():
print("Admissible:", ", ".join(session.available_commands()))
continue
print(turn.observation)
if session.available_commands():
print("Admissible:", ", ".join(session.available_commands()))