aarushgupta's picture
Deploy FATHOM-Hero Space bundle
c782fbf verified
from __future__ import annotations
import json
import textwrap
from collections import deque
from typing import TYPE_CHECKING, Any, Callable
import textworld
from textworld.core import EnvInfos, GameState
from .base import INVENTORY_ID, normalize_answer_text, suppress_unsupported_game_warning
from .interface import InterfaceAdapter, SimpleInterfaceAdapter
from .schema import CompiledWorld, Turn
if TYPE_CHECKING:
TurnListener = Callable[["EpisodeSession", Turn], None]
class EpisodeSession:
def __init__(
self,
compiled: CompiledWorld,
interface_adapter: InterfaceAdapter = SimpleInterfaceAdapter(),
turn_listener: "TurnListener | None" = None,
) -> None:
if interface_adapter is None:
raise ValueError("interface_adapter must not be None.")
self.compiled = compiled
self.interface_adapter = interface_adapter
self.turn_listener = turn_listener
with suppress_unsupported_game_warning():
self.env = textworld.start(str(compiled.game_file), request_infos=self._requested_infos())
self.state = self.env.reset()
self._closed = False
self.done = False
self.player_won = False
self.steps_taken = 0
self.invalid_command_count = 0
self.wrong_submit_count = 0
self.used_items: set[str] = set()
self.discovered_clues: set[str] = set()
self.consulted_npcs: set[str] = set()
self.traded_npcs: set[str] = set()
self.prepared_readables: set[str] = set()
self.completed_recipe_outputs: set[str] = set()
self.completed_use_targets: set[str] = set()
self.unlocked_doors: set[str] = set()
self.consulted_guardian = False
self.hidden_readables = {
effect.reveals_readable_id for effect in compiled.use_effects.values() if effect.reveals_readable_id
}
self.revealed_readables = {
node.id for node in compiled.world.nodes if node.type == "readable" and node.id not in self.hidden_readables
}
self.item_locations = dict(compiled.item_start_locations)
self.inventory = {item_id for item_id, location in self.item_locations.items() if location == INVENTORY_ID}
self.open_nodes = {
node.id for node in compiled.world.nodes if node.type in {"container", "door"} and getattr(node, "open", False)
}
self.locked_nodes = {
node.id for node in compiled.world.nodes if node.type in {"container", "door"} and getattr(node, "locked", False)
}
self.current_room_id = compiled.world.meta.start_node_id
self.visited_nodes: set[str] = {self.current_room_id}
self.transcript: list[Turn] = []
self.recent_normalized_commands: deque[str] = deque(maxlen=3)
self._node_by_id = {node.id: node for node in compiled.world.nodes}
self._label_by_id = {node.id: node.label for node in compiled.world.nodes}
self._label_by_id.update({item.id: item.label for item in compiled.world.items})
self._item_name_to_id = {name: item_id for item_id, name in compiled.item_command_names.items()}
self.last_state_fingerprint = self.state_fingerprint()
@staticmethod
def _requested_infos() -> EnvInfos:
return EnvInfos(
feedback=True,
description=True,
inventory=True,
location=True,
facts=False,
won=True,
lost=True,
score=True,
moves=True,
last_action=True,
last_command=True,
admissible_commands=True,
policy_commands=True,
extras=["walkthrough"],
)
def available_commands(self) -> list[str]:
commands = set(self.state.admissible_commands or [])
commands.update(self._custom_commands())
return sorted(commands)
def current_feedback(self) -> str:
return self.interface_adapter.render_observation(self.state.feedback or "", self.state, self)
def state_fingerprint(self) -> str:
return json.dumps(
{
"room": self.current_room_id,
"inventory": sorted(self.inventory),
"clues": sorted(self.discovered_clues),
"opened": sorted(self.open_nodes),
"traded": sorted(self.traded_npcs),
"use_targets": sorted(self.completed_use_targets),
"recipe_outputs": sorted(self.completed_recipe_outputs),
},
sort_keys=True,
)
def node_id_for_command_name(self, command_name: str, node_types: set[str] | None = None) -> str | None:
for node in self.compiled.world.nodes:
safe_name = self.compiled.node_command_names.get(node.id)
if safe_name != command_name:
continue
if node_types is None or node.type in node_types:
return node.id
return None
def step(self, raw_command: str) -> Turn:
if self.done:
raise RuntimeError("Episode is already complete.")
lowered = self.interface_adapter.translate_command(raw_command, self).lower().strip()
if turn := self._handle_submit(raw_command, lowered):
return turn
if self._is_wrapper_command(lowered):
return self._step_wrapper(raw_command, lowered)
return self._step_env(raw_command, lowered)
def _handle_submit(self, raw_command: str, lowered: str) -> Turn | None:
if not lowered.startswith("submit "):
return None
answer = normalize_answer_text(lowered[7:])
if self.current_room_id != self.compiled.guardian_room_id or self.compiled.guardian_id not in self.consulted_npcs:
return self._wrapper_only_turn(
raw_command,
lowered,
"The guardian has not asked for your answer yet.",
{"wrapper": "submit_rejected", "reason": "guardian_not_ready"},
)
required_clues = set(self.compiled.clue_text_by_id)
if self.discovered_clues != required_clues:
return self._wrapper_only_turn(
raw_command,
lowered,
"The guardian waits. You have not gathered enough evidence yet.",
{
"wrapper": "submit_rejected",
"reason": "missing_clues",
"missing_clues": sorted(required_clues - self.discovered_clues),
},
)
if answer != self.compiled.correct_answer_normalized:
self.wrong_submit_count += 1
return self._wrapper_only_turn(
raw_command,
lowered,
"The guardian shakes their head. That answer is wrong.",
{"wrapper": "submit_rejected", "reason": "wrong_answer", "submitted": answer},
)
self.steps_taken += 1
self.done = True
self.player_won = True
turn = Turn(
step=self.steps_taken,
player_action=raw_command,
textworld_command=self.compiled.correct_submit_command,
observation="The guardian weighs your answer, then nods.\n\nThe dungeon yields. You solved it.",
game_state_delta={"wrapper": "submit_forwarded", "won": True, "location": self.current_room_id},
)
return self._record_turn(turn)
def _step_env(self, raw_command: str, lowered: str) -> Turn:
previous = self.state
admissible = set(previous.admissible_commands or [])
self.state, _, env_done = self.env.step(lowered)
self.steps_taken += 1
succeeded = lowered in admissible
if not succeeded:
self.invalid_command_count += 1
else:
self._apply_env_side_effects(lowered)
self.done = bool(env_done or self.state.won)
observation = self.interface_adapter.render_observation(self.state.feedback or "", self.state, self)
turn = Turn(
step=self.steps_taken,
player_action=raw_command,
textworld_command=lowered,
observation=observation,
game_state_delta=self._compute_delta(previous, self.state, succeeded, self.current_room_id),
)
return self._record_turn(turn)
def _step_wrapper(self, raw_command: str, lowered: str) -> Turn:
observation, delta = self._apply_wrapper_command(lowered)
self.steps_taken += 1
if delta.get("succeeded") is False:
self.invalid_command_count += 1
delta.setdefault("location", self.current_room_id)
rendered = self.interface_adapter.render_observation(observation, self.state, self)
turn = Turn(
step=self.steps_taken,
player_action=raw_command,
textworld_command=lowered,
observation=rendered,
game_state_delta=delta,
)
return self._record_turn(turn)
def _apply_env_side_effects(self, command: str) -> None:
if command.startswith("go "):
direction = command[3:].strip()
edge = self.compiled.room_edges_by_direction.get((self.current_room_id, direction))
if edge is not None:
self.current_room_id = edge.to_node_id
self.visited_nodes.add(edge.to_node_id)
return
if command.startswith("open "):
node_id = self.node_id_for_command_name(command[5:].strip(), node_types={"container", "door"})
if node_id:
self.open_nodes.add(node_id)
self.visited_nodes.add(node_id)
return
if command.startswith("unlock ") and " with " in command:
target_name, key_name = command[7:].split(" with ", 1)
target_id = self.node_id_for_command_name(target_name.strip(), node_types={"container", "door"})
if target_id:
self.locked_nodes.discard(target_id)
if self._node_by_id[target_id].type == "door":
self.unlocked_doors.add(target_id)
self.visited_nodes.add(target_id)
self._mark_item_by_name(key_name.strip())
return
if command.startswith("take "):
item_name = command[5:].split(" from ", 1)[0].strip()
item_id = self._item_name_to_id.get(item_name)
if item_id:
self.inventory.add(item_id)
self.item_locations[item_id] = INVENTORY_ID
self.used_items.add(item_id)
self.visited_nodes.add(item_id)
def _apply_wrapper_command(self, command: str) -> tuple[str, dict[str, Any]]:
if command.startswith("read "):
return self._apply_read(command)
if command.startswith("talk "):
return self._apply_talk(command)
if command.startswith("use ") and " on " in command:
return self._apply_use(command)
if command.startswith("combine ") and " with " in command:
return self._apply_combine(command)
if command.startswith("give ") and " to " in command:
return self._apply_give(command)
raise RuntimeError(f"Unsupported wrapper command '{command}'.")
def _apply_read(self, command: str) -> tuple[str, dict[str, Any]]:
readable_id = self.node_id_for_command_name(command[5:].strip(), node_types={"readable"})
if not readable_id or readable_id not in self.revealed_readables:
return self._fail("You can't read that right now.", command)
node = self._node_by_id[readable_id]
if node.parent_id != self.current_room_id:
return self._fail("You are too far away to read that.", command)
if node.requires_item_id and readable_id not in self.prepared_readables:
return self._fail("You still need the right tool before the text becomes legible.", command)
clue_id = self.compiled.readable_clue_by_id[readable_id]
self.discovered_clues.add(clue_id)
self.visited_nodes.add(readable_id)
return self._success(
textwrap.dedent(
f"""
{node.description}
"{self.compiled.clue_text_by_id[clue_id]}"
"""
).strip(),
command,
)
def _apply_talk(self, command: str) -> tuple[str, dict[str, Any]]:
npc_id = self.node_id_for_command_name(command[5:].strip(), node_types={"npc"})
if not npc_id:
return self._fail("You can't talk to that right now.", command)
node = self._node_by_id[npc_id]
if node.parent_id != self.current_room_id:
return self._fail("You are too far away to talk to that.", command)
self.consulted_npcs.add(npc_id)
if npc_id == self.compiled.guardian_id:
self.consulted_guardian = True
self.visited_nodes.add(npc_id)
return self._success(node.description, command)
def _apply_use(self, command: str) -> tuple[str, dict[str, Any]]:
item_name, target_name = command[4:].split(" on ", 1)
item_id = self._item_name_to_id.get(item_name.strip())
target_id = self.node_id_for_command_name(target_name.strip(), node_types={"readable", "fixture"})
if not item_id or item_id not in self.inventory:
return self._fail("You don't have the item needed for that.", command)
if not target_id:
return self._fail("You can't use that here.", command)
target = self._node_by_id[target_id]
if target.parent_id != self.current_room_id:
return self._fail("That target is not within reach.", command)
effect = self.compiled.use_effects.get(target_id)
if effect is None or effect.required_item_id != item_id:
return self._fail("That item doesn't seem to work there.", command)
if effect.consumes_item:
self.inventory.discard(item_id)
self.item_locations[item_id] = None
self.used_items.add(item_id)
self.visited_nodes.add(target_id)
self.completed_use_targets.add(target_id)
if effect.clue_id:
self.prepared_readables.add(target_id)
self.discovered_clues.add(effect.clue_id)
return self._success(
textwrap.dedent(
f"""
{target.description}
"{self.compiled.clue_text_by_id[effect.clue_id]}"
"""
).strip(),
command,
)
if effect.reveals_readable_id:
self.revealed_readables.add(effect.reveals_readable_id)
return self._success(f"The {self._label_by_id[effect.reveals_readable_id]} is revealed.", command)
if effect.reveals_item_id:
self.item_locations[effect.reveals_item_id] = self.current_room_id
return self._success(f"The {self._label_by_id[effect.reveals_item_id]} is revealed.", command)
return self._fail("Nothing happens.", command)
def _apply_combine(self, command: str) -> tuple[str, dict[str, Any]]:
item_a_name, item_b_name = command[8:].split(" with ", 1)
item_a_id = self._item_name_to_id.get(item_a_name.strip())
item_b_id = self._item_name_to_id.get(item_b_name.strip())
if not item_a_id or not item_b_id or item_a_id not in self.inventory or item_b_id not in self.inventory:
return self._fail("You do not have both pieces required to combine those.", command)
output_id = self.compiled.recipe_map.get(frozenset({item_a_id, item_b_id}))
if not output_id:
return self._fail("Those items do not fit together.", command)
self.inventory.discard(item_a_id)
self.inventory.discard(item_b_id)
self.item_locations[item_a_id] = None
self.item_locations[item_b_id] = None
self.inventory.add(output_id)
self.item_locations[output_id] = INVENTORY_ID
self.used_items.update({item_a_id, item_b_id, output_id})
self.completed_recipe_outputs.add(output_id)
self.visited_nodes.add(output_id)
return self._success(f"You assemble the {self._label_by_id[output_id]}.", command)
def _apply_give(self, command: str) -> tuple[str, dict[str, Any]]:
item_name, npc_name = command[5:].split(" to ", 1)
item_id = self._item_name_to_id.get(item_name.strip())
npc_id = self.node_id_for_command_name(npc_name.strip(), node_types={"npc"})
if not item_id or item_id not in self.inventory:
return self._fail("You do not have that item to give.", command)
if not npc_id:
return self._fail("There is no one here by that name.", command)
npc = self._node_by_id[npc_id]
if npc.parent_id != self.current_room_id:
return self._fail("That person is not here.", command)
trade = self.compiled.npc_trade_map.get(npc_id)
if trade is None or trade.required_item_id != item_id:
return self._fail("They are not interested in that item.", command)
if npc_id in self.traded_npcs:
return self._fail("That trade has already been completed.", command)
self.inventory.discard(item_id)
self.item_locations[item_id] = None
self.used_items.add(item_id)
self.traded_npcs.add(npc_id)
if trade.gives_item_id:
self.inventory.add(trade.gives_item_id)
self.item_locations[trade.gives_item_id] = INVENTORY_ID
self.used_items.add(trade.gives_item_id)
return self._success(f"You receive the {self._label_by_id[trade.gives_item_id]}.", command)
if trade.gives_clue_id:
self.discovered_clues.add(trade.gives_clue_id)
return self._success(f'"{self.compiled.clue_text_by_id[trade.gives_clue_id]}"', command)
return self._fail("Nothing comes of the trade.", command)
def _custom_commands(self) -> set[str]:
commands: set[str] = set()
for node in self.compiled.world.nodes:
if node.type == "npc" and node.parent_id == self.current_room_id:
commands.add(f"talk {self.compiled.node_command_names[node.id]}")
trade = self.compiled.npc_trade_map.get(node.id)
if trade and node.id not in self.traded_npcs and trade.required_item_id in self.inventory:
commands.add(
f"give {self.compiled.item_command_names[trade.required_item_id]} to {self.compiled.node_command_names[node.id]}"
)
elif node.type == "readable" and node.parent_id == self.current_room_id and node.id in self.revealed_readables:
if not node.requires_item_id or node.id in self.prepared_readables:
commands.add(f"read {self.compiled.node_command_names[node.id]}")
elif node.type == "fixture" and node.parent_id == self.current_room_id:
effect = self.compiled.use_effects.get(node.id)
if effect and effect.required_item_id in self.inventory:
commands.add(
f"use {self.compiled.item_command_names[effect.required_item_id]} on {self.compiled.node_command_names[node.id]}"
)
for readable_id, effect in self.compiled.use_effects.items():
node = self._node_by_id.get(readable_id)
if node and node.type == "readable" and node.parent_id == self.current_room_id and effect.required_item_id in self.inventory:
commands.add(
f"use {self.compiled.item_command_names[effect.required_item_id]} on {self.compiled.node_command_names[readable_id]}"
)
for recipe_inputs, output_id in self.compiled.recipe_map.items():
del output_id
item_ids = sorted(recipe_inputs)
if all(item_id in self.inventory for item_id in item_ids):
commands.add(
f"combine {self.compiled.item_command_names[item_ids[0]]} with {self.compiled.item_command_names[item_ids[1]]}"
)
commands.add(
f"combine {self.compiled.item_command_names[item_ids[1]]} with {self.compiled.item_command_names[item_ids[0]]}"
)
return commands
def _is_wrapper_command(self, command: str) -> bool:
return any(
command.startswith(prefix)
for prefix in ("read ", "talk ", "use ", "combine ", "give ")
)
def _mark_item_by_name(self, name: str) -> None:
item_id = self._item_name_to_id.get(name)
if item_id:
self.used_items.add(item_id)
def _success(self, observation: str, command: str) -> tuple[str, dict[str, Any]]:
return observation, {"wrapper": "custom", "command": command, "succeeded": True, "location": self.current_room_id}
def _fail(self, observation: str, command: str) -> tuple[str, dict[str, Any]]:
return observation, {"wrapper": "custom", "command": command, "succeeded": False, "location": self.current_room_id}
@staticmethod
def _compute_delta(previous: GameState, current: GameState, succeeded: bool, fallback_location: str | None) -> dict[str, Any]:
return {
"added_facts": [],
"removed_facts": [],
"location": current.location or fallback_location,
"score": current.score,
"won": current.won,
"lost": current.lost,
"succeeded": succeeded,
}
def _wrapper_only_turn(
self,
raw_command: str,
translated: str,
observation: str,
delta: dict[str, Any],
) -> Turn:
self.steps_taken += 1
delta.setdefault("location", self.current_room_id)
turn = Turn(
step=self.steps_taken,
player_action=raw_command,
textworld_command=translated,
observation=observation,
game_state_delta=delta,
)
return self._record_turn(turn)
def _record_turn(self, turn: Turn) -> Turn:
self.transcript.append(turn)
self.last_state_fingerprint = self.state_fingerprint()
if self.turn_listener is not None:
self.turn_listener(self, turn)
return turn
def close(self) -> None:
if self._closed:
return
close = getattr(self.env, "close", None)
if callable(close):
close()
self._closed = True