from __future__ import annotations import json import textwrap from collections import deque from typing import TYPE_CHECKING, Any, Callable import textworld from textworld.core import EnvInfos, GameState from .base import INVENTORY_ID, normalize_answer_text, suppress_unsupported_game_warning from .interface import InterfaceAdapter, SimpleInterfaceAdapter from .schema import CompiledWorld, Turn if TYPE_CHECKING: TurnListener = Callable[["EpisodeSession", Turn], None] class EpisodeSession: def __init__( self, compiled: CompiledWorld, interface_adapter: InterfaceAdapter = SimpleInterfaceAdapter(), turn_listener: "TurnListener | None" = None, ) -> None: if interface_adapter is None: raise ValueError("interface_adapter must not be None.") self.compiled = compiled self.interface_adapter = interface_adapter self.turn_listener = turn_listener with suppress_unsupported_game_warning(): self.env = textworld.start(str(compiled.game_file), request_infos=self._requested_infos()) self.state = self.env.reset() self._closed = False self.done = False self.player_won = False self.steps_taken = 0 self.invalid_command_count = 0 self.wrong_submit_count = 0 self.used_items: set[str] = set() self.discovered_clues: set[str] = set() self.consulted_npcs: set[str] = set() self.traded_npcs: set[str] = set() self.prepared_readables: set[str] = set() self.completed_recipe_outputs: set[str] = set() self.completed_use_targets: set[str] = set() self.unlocked_doors: set[str] = set() self.consulted_guardian = False self.hidden_readables = { effect.reveals_readable_id for effect in compiled.use_effects.values() if effect.reveals_readable_id } self.revealed_readables = { node.id for node in compiled.world.nodes if node.type == "readable" and node.id not in self.hidden_readables } self.item_locations = dict(compiled.item_start_locations) self.inventory = {item_id for item_id, location in self.item_locations.items() if location == INVENTORY_ID} self.open_nodes = { node.id for node in compiled.world.nodes if node.type in {"container", "door"} and getattr(node, "open", False) } self.locked_nodes = { node.id for node in compiled.world.nodes if node.type in {"container", "door"} and getattr(node, "locked", False) } self.current_room_id = compiled.world.meta.start_node_id self.visited_nodes: set[str] = {self.current_room_id} self.transcript: list[Turn] = [] self.recent_normalized_commands: deque[str] = deque(maxlen=3) self._node_by_id = {node.id: node for node in compiled.world.nodes} self._label_by_id = {node.id: node.label for node in compiled.world.nodes} self._label_by_id.update({item.id: item.label for item in compiled.world.items}) self._item_name_to_id = {name: item_id for item_id, name in compiled.item_command_names.items()} self.last_state_fingerprint = self.state_fingerprint() @staticmethod def _requested_infos() -> EnvInfos: return EnvInfos( feedback=True, description=True, inventory=True, location=True, facts=False, won=True, lost=True, score=True, moves=True, last_action=True, last_command=True, admissible_commands=True, policy_commands=True, extras=["walkthrough"], ) def available_commands(self) -> list[str]: commands = set(self.state.admissible_commands or []) commands.update(self._custom_commands()) return sorted(commands) def current_feedback(self) -> str: return self.interface_adapter.render_observation(self.state.feedback or "", self.state, self) def state_fingerprint(self) -> str: return json.dumps( { "room": self.current_room_id, "inventory": sorted(self.inventory), "clues": sorted(self.discovered_clues), "opened": sorted(self.open_nodes), "traded": sorted(self.traded_npcs), "use_targets": sorted(self.completed_use_targets), "recipe_outputs": sorted(self.completed_recipe_outputs), }, sort_keys=True, ) def node_id_for_command_name(self, command_name: str, node_types: set[str] | None = None) -> str | None: for node in self.compiled.world.nodes: safe_name = self.compiled.node_command_names.get(node.id) if safe_name != command_name: continue if node_types is None or node.type in node_types: return node.id return None def step(self, raw_command: str) -> Turn: if self.done: raise RuntimeError("Episode is already complete.") lowered = self.interface_adapter.translate_command(raw_command, self).lower().strip() if turn := self._handle_submit(raw_command, lowered): return turn if self._is_wrapper_command(lowered): return self._step_wrapper(raw_command, lowered) return self._step_env(raw_command, lowered) def _handle_submit(self, raw_command: str, lowered: str) -> Turn | None: if not lowered.startswith("submit "): return None answer = normalize_answer_text(lowered[7:]) if self.current_room_id != self.compiled.guardian_room_id or self.compiled.guardian_id not in self.consulted_npcs: return self._wrapper_only_turn( raw_command, lowered, "The guardian has not asked for your answer yet.", {"wrapper": "submit_rejected", "reason": "guardian_not_ready"}, ) required_clues = set(self.compiled.clue_text_by_id) if self.discovered_clues != required_clues: return self._wrapper_only_turn( raw_command, lowered, "The guardian waits. You have not gathered enough evidence yet.", { "wrapper": "submit_rejected", "reason": "missing_clues", "missing_clues": sorted(required_clues - self.discovered_clues), }, ) if answer != self.compiled.correct_answer_normalized: self.wrong_submit_count += 1 return self._wrapper_only_turn( raw_command, lowered, "The guardian shakes their head. That answer is wrong.", {"wrapper": "submit_rejected", "reason": "wrong_answer", "submitted": answer}, ) self.steps_taken += 1 self.done = True self.player_won = True turn = Turn( step=self.steps_taken, player_action=raw_command, textworld_command=self.compiled.correct_submit_command, observation="The guardian weighs your answer, then nods.\n\nThe dungeon yields. You solved it.", game_state_delta={"wrapper": "submit_forwarded", "won": True, "location": self.current_room_id}, ) return self._record_turn(turn) def _step_env(self, raw_command: str, lowered: str) -> Turn: previous = self.state admissible = set(previous.admissible_commands or []) self.state, _, env_done = self.env.step(lowered) self.steps_taken += 1 succeeded = lowered in admissible if not succeeded: self.invalid_command_count += 1 else: self._apply_env_side_effects(lowered) self.done = bool(env_done or self.state.won) observation = self.interface_adapter.render_observation(self.state.feedback or "", self.state, self) turn = Turn( step=self.steps_taken, player_action=raw_command, textworld_command=lowered, observation=observation, game_state_delta=self._compute_delta(previous, self.state, succeeded, self.current_room_id), ) return self._record_turn(turn) def _step_wrapper(self, raw_command: str, lowered: str) -> Turn: observation, delta = self._apply_wrapper_command(lowered) self.steps_taken += 1 if delta.get("succeeded") is False: self.invalid_command_count += 1 delta.setdefault("location", self.current_room_id) rendered = self.interface_adapter.render_observation(observation, self.state, self) turn = Turn( step=self.steps_taken, player_action=raw_command, textworld_command=lowered, observation=rendered, game_state_delta=delta, ) return self._record_turn(turn) def _apply_env_side_effects(self, command: str) -> None: if command.startswith("go "): direction = command[3:].strip() edge = self.compiled.room_edges_by_direction.get((self.current_room_id, direction)) if edge is not None: self.current_room_id = edge.to_node_id self.visited_nodes.add(edge.to_node_id) return if command.startswith("open "): node_id = self.node_id_for_command_name(command[5:].strip(), node_types={"container", "door"}) if node_id: self.open_nodes.add(node_id) self.visited_nodes.add(node_id) return if command.startswith("unlock ") and " with " in command: target_name, key_name = command[7:].split(" with ", 1) target_id = self.node_id_for_command_name(target_name.strip(), node_types={"container", "door"}) if target_id: self.locked_nodes.discard(target_id) if self._node_by_id[target_id].type == "door": self.unlocked_doors.add(target_id) self.visited_nodes.add(target_id) self._mark_item_by_name(key_name.strip()) return if command.startswith("take "): item_name = command[5:].split(" from ", 1)[0].strip() item_id = self._item_name_to_id.get(item_name) if item_id: self.inventory.add(item_id) self.item_locations[item_id] = INVENTORY_ID self.used_items.add(item_id) self.visited_nodes.add(item_id) def _apply_wrapper_command(self, command: str) -> tuple[str, dict[str, Any]]: if command.startswith("read "): return self._apply_read(command) if command.startswith("talk "): return self._apply_talk(command) if command.startswith("use ") and " on " in command: return self._apply_use(command) if command.startswith("combine ") and " with " in command: return self._apply_combine(command) if command.startswith("give ") and " to " in command: return self._apply_give(command) raise RuntimeError(f"Unsupported wrapper command '{command}'.") def _apply_read(self, command: str) -> tuple[str, dict[str, Any]]: readable_id = self.node_id_for_command_name(command[5:].strip(), node_types={"readable"}) if not readable_id or readable_id not in self.revealed_readables: return self._fail("You can't read that right now.", command) node = self._node_by_id[readable_id] if node.parent_id != self.current_room_id: return self._fail("You are too far away to read that.", command) if node.requires_item_id and readable_id not in self.prepared_readables: return self._fail("You still need the right tool before the text becomes legible.", command) clue_id = self.compiled.readable_clue_by_id[readable_id] self.discovered_clues.add(clue_id) self.visited_nodes.add(readable_id) return self._success( textwrap.dedent( f""" {node.description} "{self.compiled.clue_text_by_id[clue_id]}" """ ).strip(), command, ) def _apply_talk(self, command: str) -> tuple[str, dict[str, Any]]: npc_id = self.node_id_for_command_name(command[5:].strip(), node_types={"npc"}) if not npc_id: return self._fail("You can't talk to that right now.", command) node = self._node_by_id[npc_id] if node.parent_id != self.current_room_id: return self._fail("You are too far away to talk to that.", command) self.consulted_npcs.add(npc_id) if npc_id == self.compiled.guardian_id: self.consulted_guardian = True self.visited_nodes.add(npc_id) return self._success(node.description, command) def _apply_use(self, command: str) -> tuple[str, dict[str, Any]]: item_name, target_name = command[4:].split(" on ", 1) item_id = self._item_name_to_id.get(item_name.strip()) target_id = self.node_id_for_command_name(target_name.strip(), node_types={"readable", "fixture"}) if not item_id or item_id not in self.inventory: return self._fail("You don't have the item needed for that.", command) if not target_id: return self._fail("You can't use that here.", command) target = self._node_by_id[target_id] if target.parent_id != self.current_room_id: return self._fail("That target is not within reach.", command) effect = self.compiled.use_effects.get(target_id) if effect is None or effect.required_item_id != item_id: return self._fail("That item doesn't seem to work there.", command) if effect.consumes_item: self.inventory.discard(item_id) self.item_locations[item_id] = None self.used_items.add(item_id) self.visited_nodes.add(target_id) self.completed_use_targets.add(target_id) if effect.clue_id: self.prepared_readables.add(target_id) self.discovered_clues.add(effect.clue_id) return self._success( textwrap.dedent( f""" {target.description} "{self.compiled.clue_text_by_id[effect.clue_id]}" """ ).strip(), command, ) if effect.reveals_readable_id: self.revealed_readables.add(effect.reveals_readable_id) return self._success(f"The {self._label_by_id[effect.reveals_readable_id]} is revealed.", command) if effect.reveals_item_id: self.item_locations[effect.reveals_item_id] = self.current_room_id return self._success(f"The {self._label_by_id[effect.reveals_item_id]} is revealed.", command) return self._fail("Nothing happens.", command) def _apply_combine(self, command: str) -> tuple[str, dict[str, Any]]: item_a_name, item_b_name = command[8:].split(" with ", 1) item_a_id = self._item_name_to_id.get(item_a_name.strip()) item_b_id = self._item_name_to_id.get(item_b_name.strip()) if not item_a_id or not item_b_id or item_a_id not in self.inventory or item_b_id not in self.inventory: return self._fail("You do not have both pieces required to combine those.", command) output_id = self.compiled.recipe_map.get(frozenset({item_a_id, item_b_id})) if not output_id: return self._fail("Those items do not fit together.", command) self.inventory.discard(item_a_id) self.inventory.discard(item_b_id) self.item_locations[item_a_id] = None self.item_locations[item_b_id] = None self.inventory.add(output_id) self.item_locations[output_id] = INVENTORY_ID self.used_items.update({item_a_id, item_b_id, output_id}) self.completed_recipe_outputs.add(output_id) self.visited_nodes.add(output_id) return self._success(f"You assemble the {self._label_by_id[output_id]}.", command) def _apply_give(self, command: str) -> tuple[str, dict[str, Any]]: item_name, npc_name = command[5:].split(" to ", 1) item_id = self._item_name_to_id.get(item_name.strip()) npc_id = self.node_id_for_command_name(npc_name.strip(), node_types={"npc"}) if not item_id or item_id not in self.inventory: return self._fail("You do not have that item to give.", command) if not npc_id: return self._fail("There is no one here by that name.", command) npc = self._node_by_id[npc_id] if npc.parent_id != self.current_room_id: return self._fail("That person is not here.", command) trade = self.compiled.npc_trade_map.get(npc_id) if trade is None or trade.required_item_id != item_id: return self._fail("They are not interested in that item.", command) if npc_id in self.traded_npcs: return self._fail("That trade has already been completed.", command) self.inventory.discard(item_id) self.item_locations[item_id] = None self.used_items.add(item_id) self.traded_npcs.add(npc_id) if trade.gives_item_id: self.inventory.add(trade.gives_item_id) self.item_locations[trade.gives_item_id] = INVENTORY_ID self.used_items.add(trade.gives_item_id) return self._success(f"You receive the {self._label_by_id[trade.gives_item_id]}.", command) if trade.gives_clue_id: self.discovered_clues.add(trade.gives_clue_id) return self._success(f'"{self.compiled.clue_text_by_id[trade.gives_clue_id]}"', command) return self._fail("Nothing comes of the trade.", command) def _custom_commands(self) -> set[str]: commands: set[str] = set() for node in self.compiled.world.nodes: if node.type == "npc" and node.parent_id == self.current_room_id: commands.add(f"talk {self.compiled.node_command_names[node.id]}") trade = self.compiled.npc_trade_map.get(node.id) if trade and node.id not in self.traded_npcs and trade.required_item_id in self.inventory: commands.add( f"give {self.compiled.item_command_names[trade.required_item_id]} to {self.compiled.node_command_names[node.id]}" ) elif node.type == "readable" and node.parent_id == self.current_room_id and node.id in self.revealed_readables: if not node.requires_item_id or node.id in self.prepared_readables: commands.add(f"read {self.compiled.node_command_names[node.id]}") elif node.type == "fixture" and node.parent_id == self.current_room_id: effect = self.compiled.use_effects.get(node.id) if effect and effect.required_item_id in self.inventory: commands.add( f"use {self.compiled.item_command_names[effect.required_item_id]} on {self.compiled.node_command_names[node.id]}" ) for readable_id, effect in self.compiled.use_effects.items(): node = self._node_by_id.get(readable_id) if node and node.type == "readable" and node.parent_id == self.current_room_id and effect.required_item_id in self.inventory: commands.add( f"use {self.compiled.item_command_names[effect.required_item_id]} on {self.compiled.node_command_names[readable_id]}" ) for recipe_inputs, output_id in self.compiled.recipe_map.items(): del output_id item_ids = sorted(recipe_inputs) if all(item_id in self.inventory for item_id in item_ids): commands.add( f"combine {self.compiled.item_command_names[item_ids[0]]} with {self.compiled.item_command_names[item_ids[1]]}" ) commands.add( f"combine {self.compiled.item_command_names[item_ids[1]]} with {self.compiled.item_command_names[item_ids[0]]}" ) return commands def _is_wrapper_command(self, command: str) -> bool: return any( command.startswith(prefix) for prefix in ("read ", "talk ", "use ", "combine ", "give ") ) def _mark_item_by_name(self, name: str) -> None: item_id = self._item_name_to_id.get(name) if item_id: self.used_items.add(item_id) def _success(self, observation: str, command: str) -> tuple[str, dict[str, Any]]: return observation, {"wrapper": "custom", "command": command, "succeeded": True, "location": self.current_room_id} def _fail(self, observation: str, command: str) -> tuple[str, dict[str, Any]]: return observation, {"wrapper": "custom", "command": command, "succeeded": False, "location": self.current_room_id} @staticmethod def _compute_delta(previous: GameState, current: GameState, succeeded: bool, fallback_location: str | None) -> dict[str, Any]: return { "added_facts": [], "removed_facts": [], "location": current.location or fallback_location, "score": current.score, "won": current.won, "lost": current.lost, "succeeded": succeeded, } def _wrapper_only_turn( self, raw_command: str, translated: str, observation: str, delta: dict[str, Any], ) -> Turn: self.steps_taken += 1 delta.setdefault("location", self.current_room_id) turn = Turn( step=self.steps_taken, player_action=raw_command, textworld_command=translated, observation=observation, game_state_delta=delta, ) return self._record_turn(turn) def _record_turn(self, turn: Turn) -> Turn: self.transcript.append(turn) self.last_state_fingerprint = self.state_fingerprint() if self.turn_listener is not None: self.turn_listener(self, turn) return turn def close(self) -> None: if self._closed: return close = getattr(self.env, "close", None) if callable(close): close() self._closed = True