Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import json | |
| import textwrap | |
| from collections import deque | |
| from typing import TYPE_CHECKING, Any, Callable | |
| import textworld | |
| from textworld.core import EnvInfos, GameState | |
| from .base import INVENTORY_ID, normalize_answer_text, suppress_unsupported_game_warning | |
| from .interface import InterfaceAdapter, SimpleInterfaceAdapter | |
| from .schema import CompiledWorld, Turn | |
| if TYPE_CHECKING: | |
| TurnListener = Callable[["EpisodeSession", Turn], None] | |
| class EpisodeSession: | |
| def __init__( | |
| self, | |
| compiled: CompiledWorld, | |
| interface_adapter: InterfaceAdapter = SimpleInterfaceAdapter(), | |
| turn_listener: "TurnListener | None" = None, | |
| ) -> None: | |
| if interface_adapter is None: | |
| raise ValueError("interface_adapter must not be None.") | |
| self.compiled = compiled | |
| self.interface_adapter = interface_adapter | |
| self.turn_listener = turn_listener | |
| with suppress_unsupported_game_warning(): | |
| self.env = textworld.start(str(compiled.game_file), request_infos=self._requested_infos()) | |
| self.state = self.env.reset() | |
| self._closed = False | |
| self.done = False | |
| self.player_won = False | |
| self.steps_taken = 0 | |
| self.invalid_command_count = 0 | |
| self.wrong_submit_count = 0 | |
| self.used_items: set[str] = set() | |
| self.discovered_clues: set[str] = set() | |
| self.consulted_npcs: set[str] = set() | |
| self.traded_npcs: set[str] = set() | |
| self.prepared_readables: set[str] = set() | |
| self.completed_recipe_outputs: set[str] = set() | |
| self.completed_use_targets: set[str] = set() | |
| self.unlocked_doors: set[str] = set() | |
| self.consulted_guardian = False | |
| self.hidden_readables = { | |
| effect.reveals_readable_id for effect in compiled.use_effects.values() if effect.reveals_readable_id | |
| } | |
| self.revealed_readables = { | |
| node.id for node in compiled.world.nodes if node.type == "readable" and node.id not in self.hidden_readables | |
| } | |
| self.item_locations = dict(compiled.item_start_locations) | |
| self.inventory = {item_id for item_id, location in self.item_locations.items() if location == INVENTORY_ID} | |
| self.open_nodes = { | |
| node.id for node in compiled.world.nodes if node.type in {"container", "door"} and getattr(node, "open", False) | |
| } | |
| self.locked_nodes = { | |
| node.id for node in compiled.world.nodes if node.type in {"container", "door"} and getattr(node, "locked", False) | |
| } | |
| self.current_room_id = compiled.world.meta.start_node_id | |
| self.visited_nodes: set[str] = {self.current_room_id} | |
| self.transcript: list[Turn] = [] | |
| self.recent_normalized_commands: deque[str] = deque(maxlen=3) | |
| self._node_by_id = {node.id: node for node in compiled.world.nodes} | |
| self._label_by_id = {node.id: node.label for node in compiled.world.nodes} | |
| self._label_by_id.update({item.id: item.label for item in compiled.world.items}) | |
| self._item_name_to_id = {name: item_id for item_id, name in compiled.item_command_names.items()} | |
| self.last_state_fingerprint = self.state_fingerprint() | |
| def _requested_infos() -> EnvInfos: | |
| return EnvInfos( | |
| feedback=True, | |
| description=True, | |
| inventory=True, | |
| location=True, | |
| facts=False, | |
| won=True, | |
| lost=True, | |
| score=True, | |
| moves=True, | |
| last_action=True, | |
| last_command=True, | |
| admissible_commands=True, | |
| policy_commands=True, | |
| extras=["walkthrough"], | |
| ) | |
| def available_commands(self) -> list[str]: | |
| commands = set(self.state.admissible_commands or []) | |
| commands.update(self._custom_commands()) | |
| return sorted(commands) | |
| def current_feedback(self) -> str: | |
| return self.interface_adapter.render_observation(self.state.feedback or "", self.state, self) | |
| def state_fingerprint(self) -> str: | |
| return json.dumps( | |
| { | |
| "room": self.current_room_id, | |
| "inventory": sorted(self.inventory), | |
| "clues": sorted(self.discovered_clues), | |
| "opened": sorted(self.open_nodes), | |
| "traded": sorted(self.traded_npcs), | |
| "use_targets": sorted(self.completed_use_targets), | |
| "recipe_outputs": sorted(self.completed_recipe_outputs), | |
| }, | |
| sort_keys=True, | |
| ) | |
| def node_id_for_command_name(self, command_name: str, node_types: set[str] | None = None) -> str | None: | |
| for node in self.compiled.world.nodes: | |
| safe_name = self.compiled.node_command_names.get(node.id) | |
| if safe_name != command_name: | |
| continue | |
| if node_types is None or node.type in node_types: | |
| return node.id | |
| return None | |
| def step(self, raw_command: str) -> Turn: | |
| if self.done: | |
| raise RuntimeError("Episode is already complete.") | |
| lowered = self.interface_adapter.translate_command(raw_command, self).lower().strip() | |
| if turn := self._handle_submit(raw_command, lowered): | |
| return turn | |
| if self._is_wrapper_command(lowered): | |
| return self._step_wrapper(raw_command, lowered) | |
| return self._step_env(raw_command, lowered) | |
| def _handle_submit(self, raw_command: str, lowered: str) -> Turn | None: | |
| if not lowered.startswith("submit "): | |
| return None | |
| answer = normalize_answer_text(lowered[7:]) | |
| if self.current_room_id != self.compiled.guardian_room_id or self.compiled.guardian_id not in self.consulted_npcs: | |
| return self._wrapper_only_turn( | |
| raw_command, | |
| lowered, | |
| "The guardian has not asked for your answer yet.", | |
| {"wrapper": "submit_rejected", "reason": "guardian_not_ready"}, | |
| ) | |
| required_clues = set(self.compiled.clue_text_by_id) | |
| if self.discovered_clues != required_clues: | |
| return self._wrapper_only_turn( | |
| raw_command, | |
| lowered, | |
| "The guardian waits. You have not gathered enough evidence yet.", | |
| { | |
| "wrapper": "submit_rejected", | |
| "reason": "missing_clues", | |
| "missing_clues": sorted(required_clues - self.discovered_clues), | |
| }, | |
| ) | |
| if answer != self.compiled.correct_answer_normalized: | |
| self.wrong_submit_count += 1 | |
| return self._wrapper_only_turn( | |
| raw_command, | |
| lowered, | |
| "The guardian shakes their head. That answer is wrong.", | |
| {"wrapper": "submit_rejected", "reason": "wrong_answer", "submitted": answer}, | |
| ) | |
| self.steps_taken += 1 | |
| self.done = True | |
| self.player_won = True | |
| turn = Turn( | |
| step=self.steps_taken, | |
| player_action=raw_command, | |
| textworld_command=self.compiled.correct_submit_command, | |
| observation="The guardian weighs your answer, then nods.\n\nThe dungeon yields. You solved it.", | |
| game_state_delta={"wrapper": "submit_forwarded", "won": True, "location": self.current_room_id}, | |
| ) | |
| return self._record_turn(turn) | |
| def _step_env(self, raw_command: str, lowered: str) -> Turn: | |
| previous = self.state | |
| admissible = set(previous.admissible_commands or []) | |
| self.state, _, env_done = self.env.step(lowered) | |
| self.steps_taken += 1 | |
| succeeded = lowered in admissible | |
| if not succeeded: | |
| self.invalid_command_count += 1 | |
| else: | |
| self._apply_env_side_effects(lowered) | |
| self.done = bool(env_done or self.state.won) | |
| observation = self.interface_adapter.render_observation(self.state.feedback or "", self.state, self) | |
| turn = Turn( | |
| step=self.steps_taken, | |
| player_action=raw_command, | |
| textworld_command=lowered, | |
| observation=observation, | |
| game_state_delta=self._compute_delta(previous, self.state, succeeded, self.current_room_id), | |
| ) | |
| return self._record_turn(turn) | |
| def _step_wrapper(self, raw_command: str, lowered: str) -> Turn: | |
| observation, delta = self._apply_wrapper_command(lowered) | |
| self.steps_taken += 1 | |
| if delta.get("succeeded") is False: | |
| self.invalid_command_count += 1 | |
| delta.setdefault("location", self.current_room_id) | |
| rendered = self.interface_adapter.render_observation(observation, self.state, self) | |
| turn = Turn( | |
| step=self.steps_taken, | |
| player_action=raw_command, | |
| textworld_command=lowered, | |
| observation=rendered, | |
| game_state_delta=delta, | |
| ) | |
| return self._record_turn(turn) | |
| def _apply_env_side_effects(self, command: str) -> None: | |
| if command.startswith("go "): | |
| direction = command[3:].strip() | |
| edge = self.compiled.room_edges_by_direction.get((self.current_room_id, direction)) | |
| if edge is not None: | |
| self.current_room_id = edge.to_node_id | |
| self.visited_nodes.add(edge.to_node_id) | |
| return | |
| if command.startswith("open "): | |
| node_id = self.node_id_for_command_name(command[5:].strip(), node_types={"container", "door"}) | |
| if node_id: | |
| self.open_nodes.add(node_id) | |
| self.visited_nodes.add(node_id) | |
| return | |
| if command.startswith("unlock ") and " with " in command: | |
| target_name, key_name = command[7:].split(" with ", 1) | |
| target_id = self.node_id_for_command_name(target_name.strip(), node_types={"container", "door"}) | |
| if target_id: | |
| self.locked_nodes.discard(target_id) | |
| if self._node_by_id[target_id].type == "door": | |
| self.unlocked_doors.add(target_id) | |
| self.visited_nodes.add(target_id) | |
| self._mark_item_by_name(key_name.strip()) | |
| return | |
| if command.startswith("take "): | |
| item_name = command[5:].split(" from ", 1)[0].strip() | |
| item_id = self._item_name_to_id.get(item_name) | |
| if item_id: | |
| self.inventory.add(item_id) | |
| self.item_locations[item_id] = INVENTORY_ID | |
| self.used_items.add(item_id) | |
| self.visited_nodes.add(item_id) | |
| def _apply_wrapper_command(self, command: str) -> tuple[str, dict[str, Any]]: | |
| if command.startswith("read "): | |
| return self._apply_read(command) | |
| if command.startswith("talk "): | |
| return self._apply_talk(command) | |
| if command.startswith("use ") and " on " in command: | |
| return self._apply_use(command) | |
| if command.startswith("combine ") and " with " in command: | |
| return self._apply_combine(command) | |
| if command.startswith("give ") and " to " in command: | |
| return self._apply_give(command) | |
| raise RuntimeError(f"Unsupported wrapper command '{command}'.") | |
| def _apply_read(self, command: str) -> tuple[str, dict[str, Any]]: | |
| readable_id = self.node_id_for_command_name(command[5:].strip(), node_types={"readable"}) | |
| if not readable_id or readable_id not in self.revealed_readables: | |
| return self._fail("You can't read that right now.", command) | |
| node = self._node_by_id[readable_id] | |
| if node.parent_id != self.current_room_id: | |
| return self._fail("You are too far away to read that.", command) | |
| if node.requires_item_id and readable_id not in self.prepared_readables: | |
| return self._fail("You still need the right tool before the text becomes legible.", command) | |
| clue_id = self.compiled.readable_clue_by_id[readable_id] | |
| self.discovered_clues.add(clue_id) | |
| self.visited_nodes.add(readable_id) | |
| return self._success( | |
| textwrap.dedent( | |
| f""" | |
| {node.description} | |
| "{self.compiled.clue_text_by_id[clue_id]}" | |
| """ | |
| ).strip(), | |
| command, | |
| ) | |
| def _apply_talk(self, command: str) -> tuple[str, dict[str, Any]]: | |
| npc_id = self.node_id_for_command_name(command[5:].strip(), node_types={"npc"}) | |
| if not npc_id: | |
| return self._fail("You can't talk to that right now.", command) | |
| node = self._node_by_id[npc_id] | |
| if node.parent_id != self.current_room_id: | |
| return self._fail("You are too far away to talk to that.", command) | |
| self.consulted_npcs.add(npc_id) | |
| if npc_id == self.compiled.guardian_id: | |
| self.consulted_guardian = True | |
| self.visited_nodes.add(npc_id) | |
| return self._success(node.description, command) | |
| def _apply_use(self, command: str) -> tuple[str, dict[str, Any]]: | |
| item_name, target_name = command[4:].split(" on ", 1) | |
| item_id = self._item_name_to_id.get(item_name.strip()) | |
| target_id = self.node_id_for_command_name(target_name.strip(), node_types={"readable", "fixture"}) | |
| if not item_id or item_id not in self.inventory: | |
| return self._fail("You don't have the item needed for that.", command) | |
| if not target_id: | |
| return self._fail("You can't use that here.", command) | |
| target = self._node_by_id[target_id] | |
| if target.parent_id != self.current_room_id: | |
| return self._fail("That target is not within reach.", command) | |
| effect = self.compiled.use_effects.get(target_id) | |
| if effect is None or effect.required_item_id != item_id: | |
| return self._fail("That item doesn't seem to work there.", command) | |
| if effect.consumes_item: | |
| self.inventory.discard(item_id) | |
| self.item_locations[item_id] = None | |
| self.used_items.add(item_id) | |
| self.visited_nodes.add(target_id) | |
| self.completed_use_targets.add(target_id) | |
| if effect.clue_id: | |
| self.prepared_readables.add(target_id) | |
| self.discovered_clues.add(effect.clue_id) | |
| return self._success( | |
| textwrap.dedent( | |
| f""" | |
| {target.description} | |
| "{self.compiled.clue_text_by_id[effect.clue_id]}" | |
| """ | |
| ).strip(), | |
| command, | |
| ) | |
| if effect.reveals_readable_id: | |
| self.revealed_readables.add(effect.reveals_readable_id) | |
| return self._success(f"The {self._label_by_id[effect.reveals_readable_id]} is revealed.", command) | |
| if effect.reveals_item_id: | |
| self.item_locations[effect.reveals_item_id] = self.current_room_id | |
| return self._success(f"The {self._label_by_id[effect.reveals_item_id]} is revealed.", command) | |
| return self._fail("Nothing happens.", command) | |
| def _apply_combine(self, command: str) -> tuple[str, dict[str, Any]]: | |
| item_a_name, item_b_name = command[8:].split(" with ", 1) | |
| item_a_id = self._item_name_to_id.get(item_a_name.strip()) | |
| item_b_id = self._item_name_to_id.get(item_b_name.strip()) | |
| if not item_a_id or not item_b_id or item_a_id not in self.inventory or item_b_id not in self.inventory: | |
| return self._fail("You do not have both pieces required to combine those.", command) | |
| output_id = self.compiled.recipe_map.get(frozenset({item_a_id, item_b_id})) | |
| if not output_id: | |
| return self._fail("Those items do not fit together.", command) | |
| self.inventory.discard(item_a_id) | |
| self.inventory.discard(item_b_id) | |
| self.item_locations[item_a_id] = None | |
| self.item_locations[item_b_id] = None | |
| self.inventory.add(output_id) | |
| self.item_locations[output_id] = INVENTORY_ID | |
| self.used_items.update({item_a_id, item_b_id, output_id}) | |
| self.completed_recipe_outputs.add(output_id) | |
| self.visited_nodes.add(output_id) | |
| return self._success(f"You assemble the {self._label_by_id[output_id]}.", command) | |
| def _apply_give(self, command: str) -> tuple[str, dict[str, Any]]: | |
| item_name, npc_name = command[5:].split(" to ", 1) | |
| item_id = self._item_name_to_id.get(item_name.strip()) | |
| npc_id = self.node_id_for_command_name(npc_name.strip(), node_types={"npc"}) | |
| if not item_id or item_id not in self.inventory: | |
| return self._fail("You do not have that item to give.", command) | |
| if not npc_id: | |
| return self._fail("There is no one here by that name.", command) | |
| npc = self._node_by_id[npc_id] | |
| if npc.parent_id != self.current_room_id: | |
| return self._fail("That person is not here.", command) | |
| trade = self.compiled.npc_trade_map.get(npc_id) | |
| if trade is None or trade.required_item_id != item_id: | |
| return self._fail("They are not interested in that item.", command) | |
| if npc_id in self.traded_npcs: | |
| return self._fail("That trade has already been completed.", command) | |
| self.inventory.discard(item_id) | |
| self.item_locations[item_id] = None | |
| self.used_items.add(item_id) | |
| self.traded_npcs.add(npc_id) | |
| if trade.gives_item_id: | |
| self.inventory.add(trade.gives_item_id) | |
| self.item_locations[trade.gives_item_id] = INVENTORY_ID | |
| self.used_items.add(trade.gives_item_id) | |
| return self._success(f"You receive the {self._label_by_id[trade.gives_item_id]}.", command) | |
| if trade.gives_clue_id: | |
| self.discovered_clues.add(trade.gives_clue_id) | |
| return self._success(f'"{self.compiled.clue_text_by_id[trade.gives_clue_id]}"', command) | |
| return self._fail("Nothing comes of the trade.", command) | |
| def _custom_commands(self) -> set[str]: | |
| commands: set[str] = set() | |
| for node in self.compiled.world.nodes: | |
| if node.type == "npc" and node.parent_id == self.current_room_id: | |
| commands.add(f"talk {self.compiled.node_command_names[node.id]}") | |
| trade = self.compiled.npc_trade_map.get(node.id) | |
| if trade and node.id not in self.traded_npcs and trade.required_item_id in self.inventory: | |
| commands.add( | |
| f"give {self.compiled.item_command_names[trade.required_item_id]} to {self.compiled.node_command_names[node.id]}" | |
| ) | |
| elif node.type == "readable" and node.parent_id == self.current_room_id and node.id in self.revealed_readables: | |
| if not node.requires_item_id or node.id in self.prepared_readables: | |
| commands.add(f"read {self.compiled.node_command_names[node.id]}") | |
| elif node.type == "fixture" and node.parent_id == self.current_room_id: | |
| effect = self.compiled.use_effects.get(node.id) | |
| if effect and effect.required_item_id in self.inventory: | |
| commands.add( | |
| f"use {self.compiled.item_command_names[effect.required_item_id]} on {self.compiled.node_command_names[node.id]}" | |
| ) | |
| for readable_id, effect in self.compiled.use_effects.items(): | |
| node = self._node_by_id.get(readable_id) | |
| if node and node.type == "readable" and node.parent_id == self.current_room_id and effect.required_item_id in self.inventory: | |
| commands.add( | |
| f"use {self.compiled.item_command_names[effect.required_item_id]} on {self.compiled.node_command_names[readable_id]}" | |
| ) | |
| for recipe_inputs, output_id in self.compiled.recipe_map.items(): | |
| del output_id | |
| item_ids = sorted(recipe_inputs) | |
| if all(item_id in self.inventory for item_id in item_ids): | |
| commands.add( | |
| f"combine {self.compiled.item_command_names[item_ids[0]]} with {self.compiled.item_command_names[item_ids[1]]}" | |
| ) | |
| commands.add( | |
| f"combine {self.compiled.item_command_names[item_ids[1]]} with {self.compiled.item_command_names[item_ids[0]]}" | |
| ) | |
| return commands | |
| def _is_wrapper_command(self, command: str) -> bool: | |
| return any( | |
| command.startswith(prefix) | |
| for prefix in ("read ", "talk ", "use ", "combine ", "give ") | |
| ) | |
| def _mark_item_by_name(self, name: str) -> None: | |
| item_id = self._item_name_to_id.get(name) | |
| if item_id: | |
| self.used_items.add(item_id) | |
| def _success(self, observation: str, command: str) -> tuple[str, dict[str, Any]]: | |
| return observation, {"wrapper": "custom", "command": command, "succeeded": True, "location": self.current_room_id} | |
| def _fail(self, observation: str, command: str) -> tuple[str, dict[str, Any]]: | |
| return observation, {"wrapper": "custom", "command": command, "succeeded": False, "location": self.current_room_id} | |
| def _compute_delta(previous: GameState, current: GameState, succeeded: bool, fallback_location: str | None) -> dict[str, Any]: | |
| return { | |
| "added_facts": [], | |
| "removed_facts": [], | |
| "location": current.location or fallback_location, | |
| "score": current.score, | |
| "won": current.won, | |
| "lost": current.lost, | |
| "succeeded": succeeded, | |
| } | |
| def _wrapper_only_turn( | |
| self, | |
| raw_command: str, | |
| translated: str, | |
| observation: str, | |
| delta: dict[str, Any], | |
| ) -> Turn: | |
| self.steps_taken += 1 | |
| delta.setdefault("location", self.current_room_id) | |
| turn = Turn( | |
| step=self.steps_taken, | |
| player_action=raw_command, | |
| textworld_command=translated, | |
| observation=observation, | |
| game_state_delta=delta, | |
| ) | |
| return self._record_turn(turn) | |
| def _record_turn(self, turn: Turn) -> Turn: | |
| self.transcript.append(turn) | |
| self.last_state_fingerprint = self.state_fingerprint() | |
| if self.turn_listener is not None: | |
| self.turn_listener(self, turn) | |
| return turn | |
| def close(self) -> None: | |
| if self._closed: | |
| return | |
| close = getattr(self.env, "close", None) | |
| if callable(close): | |
| close() | |
| self._closed = True | |