Spaces:
Runtime error
Runtime error
File size: 2,449 Bytes
2803d7e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 | from __future__ import annotations
import random
from typing import Iterable, Protocol, TYPE_CHECKING
from .base import DMInterfaceError
if TYPE_CHECKING:
from .session import EpisodeSession
class EpisodeRunner(Protocol):
def run(self, session: EpisodeSession, max_steps: int) -> None:
...
class WalkthroughRunner:
def __init__(self, commands: Iterable[str] | None = None) -> None:
self._commands = list(commands) if commands is not None else None
def run(self, session: EpisodeSession, max_steps: int) -> None:
commands = list(self._commands or session.compiled.solver_policy)
for command in commands:
if session.done or session.steps_taken >= max_steps:
return
session.step(command)
class CommandSequenceRunner:
def __init__(self, commands: Iterable[str]) -> None:
self._commands = list(commands)
def run(self, session: EpisodeSession, max_steps: int) -> None:
for command in self._commands:
if session.done or session.steps_taken >= max_steps:
return
session.step(command)
class RandomAdmissibleRunner:
def __init__(self, seed: int | None = None) -> None:
self._rng = random.Random(seed)
def run(self, session: EpisodeSession, max_steps: int) -> None:
while not session.done and session.steps_taken < max_steps:
options = session.available_commands()
if not options:
return
session.step(self._rng.choice(options))
class ManualRunner:
def run(self, session: EpisodeSession, max_steps: int) -> None:
print(session.current_feedback())
while not session.done and session.steps_taken < max_steps:
print()
print(f"Step {session.steps_taken + 1}/{max_steps}")
command = input("> ").strip()
if command in {"quit", "exit"}:
return
try:
turn = session.step(command)
except DMInterfaceError:
print("I'm not sure what you mean. Try rephrasing that command.")
if session.available_commands():
print("Admissible:", ", ".join(session.available_commands()))
continue
print(turn.observation)
if session.available_commands():
print("Admissible:", ", ".join(session.available_commands()))
|