File size: 2,449 Bytes
2803d7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from __future__ import annotations

import random
from typing import Iterable, Protocol, TYPE_CHECKING

from .base import DMInterfaceError

if TYPE_CHECKING:
    from .session import EpisodeSession


class EpisodeRunner(Protocol):
    def run(self, session: EpisodeSession, max_steps: int) -> None:
        ...


class WalkthroughRunner:
    def __init__(self, commands: Iterable[str] | None = None) -> None:
        self._commands = list(commands) if commands is not None else None

    def run(self, session: EpisodeSession, max_steps: int) -> None:
        commands = list(self._commands or session.compiled.solver_policy)
        for command in commands:
            if session.done or session.steps_taken >= max_steps:
                return
            session.step(command)


class CommandSequenceRunner:
    def __init__(self, commands: Iterable[str]) -> None:
        self._commands = list(commands)

    def run(self, session: EpisodeSession, max_steps: int) -> None:
        for command in self._commands:
            if session.done or session.steps_taken >= max_steps:
                return
            session.step(command)


class RandomAdmissibleRunner:
    def __init__(self, seed: int | None = None) -> None:
        self._rng = random.Random(seed)

    def run(self, session: EpisodeSession, max_steps: int) -> None:
        while not session.done and session.steps_taken < max_steps:
            options = session.available_commands()
            if not options:
                return
            session.step(self._rng.choice(options))


class ManualRunner:
    def run(self, session: EpisodeSession, max_steps: int) -> None:
        print(session.current_feedback())
        while not session.done and session.steps_taken < max_steps:
            print()
            print(f"Step {session.steps_taken + 1}/{max_steps}")
            command = input("> ").strip()
            if command in {"quit", "exit"}:
                return
            try:
                turn = session.step(command)
            except DMInterfaceError:
                print("I'm not sure what you mean. Try rephrasing that command.")
                if session.available_commands():
                    print("Admissible:", ", ".join(session.available_commands()))
                continue
            print(turn.observation)
            if session.available_commands():
                print("Admissible:", ", ".join(session.available_commands()))