diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..415182fca3387bce886684eeaf9c4cdd3a6ac297 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,19 @@ +FROM python:3.11-slim + +ENV PYTHONDONTWRITEBYTECODE=1 +ENV PYTHONUNBUFFERED=1 + +WORKDIR /app + +RUN apt-get update \ + && apt-get install -y --no-install-recommends build-essential git curl \ + && rm -rf /var/lib/apt/lists/* + +COPY . /app + +RUN pip install --no-cache-dir --upgrade pip \ + && pip install --no-cache-dir . + +EXPOSE 8000 + +CMD ["uvicorn", "agents.spaces.dm_space:create_app", "--factory", "--host", "0.0.0.0", "--port", "8000"] diff --git a/README.md b/README.md index d8250b320d18703e310790cc04443b8193d779de..fe6b38d43150331d3af9144dd125756af50a054a 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,19 @@ --- -title: FATHOM DM -emoji: 🏃 -colorFrom: yellow -colorTo: blue +title: DND-DM sdk: docker -pinned: false +app_port: 8000 +tags: + - openenv + - dnd + - textworld --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +# DND-DM + +This Space hosts the CPU-only `DND-DM` environment. + +- OpenEnv API: `/env` +- Health check: `/healthz` +- Latest normalized world output: `/world-output/latest` + +`DND-DM` evaluates submitted world definitions. It does not generate worlds by itself. diff --git a/agents/__init__.py b/agents/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fbc2dd692e781d0a404a2cebabc4947dc3c5faac --- /dev/null +++ b/agents/__init__.py @@ -0,0 +1,2 @@ +"""Agent environments for the dungeon project.""" + diff --git a/agents/hero/__init__.py b/agents/hero/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2dc140bb1caa13bda440bb032b58bd1a560b40eb --- /dev/null +++ b/agents/hero/__init__.py @@ -0,0 +1,58 @@ +"""Hero agent environment and runner primitives.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +__all__ = [ + "HeroEnvironment", + "HeroLLMPolicy", + "HeroObservation", + "HeroPolicy", + "HeroPolicyError", + "HeroRunner", + "HeroServerAction", + "HeroState", + "HeroTraceEvent", + "ScriptedToolCallingPolicy", + "ToolCallingPolicy", +] + +if TYPE_CHECKING: + from .env import HeroEnvironment + from .policy import HeroLLMPolicy, HeroPolicy, HeroPolicyError, HeroTraceEvent + from .runner import HeroRunner, ScriptedToolCallingPolicy, ToolCallingPolicy + from .schema import HeroObservation, HeroServerAction, HeroState + + +def __getattr__(name: str) -> Any: + if name == "HeroEnvironment": + from .env import HeroEnvironment + + return HeroEnvironment + if name in {"HeroLLMPolicy", "HeroPolicy", "HeroPolicyError", "HeroTraceEvent"}: + from .policy import HeroLLMPolicy, HeroPolicy, HeroPolicyError, HeroTraceEvent + + return { + "HeroLLMPolicy": HeroLLMPolicy, + "HeroPolicy": HeroPolicy, + "HeroPolicyError": HeroPolicyError, + "HeroTraceEvent": HeroTraceEvent, + }[name] + if name in {"HeroRunner", "ScriptedToolCallingPolicy", "ToolCallingPolicy"}: + from .runner import HeroRunner, ScriptedToolCallingPolicy, ToolCallingPolicy + + return { + "HeroRunner": HeroRunner, + "ScriptedToolCallingPolicy": ScriptedToolCallingPolicy, + "ToolCallingPolicy": ToolCallingPolicy, + }[name] + if name in {"HeroObservation", "HeroServerAction", "HeroState"}: + from .schema import HeroObservation, HeroServerAction, HeroState + + return { + "HeroObservation": HeroObservation, + "HeroServerAction": HeroServerAction, + "HeroState": HeroState, + }[name] + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/agents/hero/__main__.py b/agents/hero/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..6e0da3fbc56a682eaa7e458bf3f64af49e277729 --- /dev/null +++ b/agents/hero/__main__.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +import argparse +import json +from pathlib import Path + +from agents.master.sample import load_world +from agents.shared.runtime import build_interface_adapter, resolve_interface_config + +from .env import HeroEnvironment + + +def _manual_action(raw: str) -> dict[str, object]: + if raw == "/read": + return {"tool": "scratchpad_read"} + if raw.startswith("/write append "): + return {"tool": "scratchpad_write", "mode": "append", "content": raw[len("/write append ") :]} + if raw.startswith("/write replace "): + return {"tool": "scratchpad_write", "mode": "replace", "content": raw[len("/write replace ") :]} + return {"tool": "act", "command": raw} + + +def main(argv: list[str] | None = None) -> int: + parser = argparse.ArgumentParser(description="Local hero environment smoke runner") + parser.add_argument("mode", choices=["manual", "scripted"]) + parser.add_argument("world", help="Path to a world-definition JSON file.") + parser.add_argument("--actions", help="JSON file containing a list of hero action objects.") + parser.add_argument("--debug", action="store_true") + parser.add_argument("--interface-model") + parser.add_argument( + "--translate-corporate-env", + action="store_true", + help="Rewrite observations into a corporate app metaphor and translate parser-safe corporate commands back through Gemini.", + ) + args = parser.parse_args(argv) + + world = load_world(args.world) + interface_adapter = build_interface_adapter( + resolve_interface_config( + model_name=args.interface_model, + translation_mode="corporate_app" if args.translate_corporate_env else None, + ) + ) + env = HeroEnvironment(debug=args.debug, interface_adapter=interface_adapter) + observation = env.reset(world) + print(observation.message) + + if args.mode == "scripted": + if not args.actions: + parser.error("--actions is required for scripted mode.") + actions = json.loads(Path(args.actions).read_text(encoding="utf-8")) + for action in actions: + result = env.step(action) + print(result.observation.message) + if result.done: + print(json.dumps(result.observation.model_dump(), indent=2)) + return 0 + print(json.dumps(env.state.model_dump(), indent=2)) + return 0 + + while not observation.done: + try: + raw = input("hero> ").strip() + except EOFError: + print() + return 0 + if raw in {"quit", "exit"}: + return 0 + result = env.step(_manual_action(raw)) + observation = result.observation + print(observation.message) + if result.done: + print(json.dumps(observation.model_dump(), indent=2)) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/agents/hero/cli.py b/agents/hero/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..6feb6e082284ca60fd8cf9cca82d3e84e4faad77 --- /dev/null +++ b/agents/hero/cli.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +import re +from dataclasses import dataclass + +from agents.master.base import SUPPORTED_DIRECTIONS + +_TOKEN_RE = re.compile(r"^[a-z0-9]+(?: [a-z0-9]+)*$") +_BANNED_OBJECT_TOKENS = {"a", "an", "the"} + + +@dataclass(frozen=True) +class CliCommandAst: + kind: str + normalized_command: str + arguments: tuple[str, ...] = () + + +@dataclass(frozen=True) +class CliCommandParseResult: + valid: bool + normalized_command: str | None = None + ast: CliCommandAst | None = None + error: str | None = None + + +def parse_cli_command(raw_command: str) -> CliCommandParseResult: + normalized = normalize_cli_command(raw_command) + if not normalized: + return CliCommandParseResult(valid=False, error="Command must not be empty.") + + if normalized in {"look", "inventory", "wait"}: + return _ok(normalized, normalized) + + if normalized in SUPPORTED_DIRECTIONS: + return _ok("move", f"go {normalized}", normalized) + if normalized.startswith("go "): + direction = normalized[3:].strip() + if direction in SUPPORTED_DIRECTIONS: + return _ok("move", f"go {direction}", direction) + return CliCommandParseResult(valid=False, error="Unknown direction.") + + if match := re.fullmatch(r"look in (?P.+)", normalized): + object_text = match.group("object").strip() + return _object_result("look_in", normalized, object_text) + if match := re.fullmatch(r"take (?P.+) from (?P.+)", normalized): + return _two_object_result("take_from", normalized, match.group("object"), match.group("source")) + + one_target_patterns = { + "open": r"open (?P.+)", + "read": r"read (?P.+)", + "talk": r"talk (?P.+)", + "examine": r"examine (?P.+)", + } + for kind, pattern in one_target_patterns.items(): + if match := re.fullmatch(pattern, normalized): + object_text = match.group("object").strip() + return _object_result(kind, normalized, object_text) + if match := re.fullmatch(r"take (?P.+)", normalized): + object_text = match.group("object").strip() + return _object_result("take", normalized, object_text) + if match := re.fullmatch(r"unlock (?P.+) with (?P.+)", normalized): + return _two_object_result("unlock", normalized, match.group("object"), match.group("tool")) + if match := re.fullmatch(r"use (?P.+) on (?P.+)", normalized): + return _two_object_result("use", normalized, match.group("object"), match.group("target")) + if match := re.fullmatch(r"combine (?P.+) with (?P.+)", normalized): + return _two_object_result("combine", normalized, match.group("object"), match.group("target")) + if match := re.fullmatch(r"give (?P.+) to (?P.+)", normalized): + return _two_object_result("give", normalized, match.group("object"), match.group("target")) + + if match := re.fullmatch(r"submit (?P[a-z0-9]+(?: [a-z0-9]+)*)", normalized): + answer = match.group("answer").strip() + return _ok("submit", normalized, answer) + + return CliCommandParseResult(valid=False, error="Command does not match the strict CLI grammar.") + + +def normalize_cli_command(raw_command: str) -> str: + return re.sub(r"\s+", " ", raw_command.strip().lower()) + + +def _object_result(kind: str, normalized_command: str, object_text: str) -> CliCommandParseResult: + object_error = _validate_object_text(object_text) + if object_error is not None: + return CliCommandParseResult(valid=False, error=object_error) + return _ok(kind, normalized_command, object_text) + + +def _two_object_result(kind: str, normalized_command: str, first: str, second: str) -> CliCommandParseResult: + first_error = _validate_object_text(first) + if first_error is not None: + return CliCommandParseResult(valid=False, error=first_error) + second_error = _validate_object_text(second) + if second_error is not None: + return CliCommandParseResult(valid=False, error=second_error) + return _ok(kind, normalized_command, first.strip(), second.strip()) + + +def _validate_object_text(value: str) -> str | None: + candidate = value.strip() + if not candidate: + return "Command target must not be empty." + if not _TOKEN_RE.fullmatch(candidate): + return "Command targets must use lowercase letters, numbers, and spaces only." + if any(token in _BANNED_OBJECT_TOKENS for token in candidate.split()): + return "Strict CLI commands must use exact parser-safe object names without articles." + return None + + +def _ok(kind: str, normalized_command: str, *arguments: str) -> CliCommandParseResult: + return CliCommandParseResult( + valid=True, + normalized_command=normalized_command, + ast=CliCommandAst(kind=kind, normalized_command=normalized_command, arguments=arguments), + ) diff --git a/agents/hero/env.py b/agents/hero/env.py new file mode 100644 index 0000000000000000000000000000000000000000..79727b4d5395e4db8de8783b6049458f4c811ca5 --- /dev/null +++ b/agents/hero/env.py @@ -0,0 +1,450 @@ +from __future__ import annotations + +from collections import deque +from pathlib import Path +from typing import Any + +from agents.master.base import DMInterfaceError, MAX_STEP_MULTIPLIER +from agents.master.build import WorldCompiler +from agents.master.interface import InterfaceAdapter, StrictCliInterfaceAdapter +from agents.master.schema import CompiledWorld, WorldDefinition +from agents.master.session import EpisodeSession +from agents.shared.openenv_compat import Environment, StepResult, build_step_result + +from .cli import parse_cli_command +from .schema import ( + ActAction, + HeroAction, + HeroAuxSignals, + HeroEpisodeStats, + HeroObservation, + HeroRewardBreakdown, + HeroState, + ScratchpadReadAction, + ScratchpadWriteAction, + validate_hero_action, +) + +_DENSE_PROGRESS_SCALE = 0.30 +_SYNTAX_PENALTY = -0.02 +_INVALID_ACTION_PENALTY = -0.02 +_REPEAT_NOOP_PENALTY = -0.01 +_WRONG_SUBMIT_PENALTY = -0.10 + + +class HeroEnvironment(Environment[HeroAction, HeroObservation, HeroState]): + def __init__( + self, + *, + artifacts_root: Path | None = None, + world_input: CompiledWorld | WorldDefinition | dict[str, Any] | None = None, + session: EpisodeSession | None = None, + interface_adapter: InterfaceAdapter | None = None, + model: str = "", + max_game_steps: int | None = None, + max_tool_calls: int | None = None, + scratchpad_max_chars: int = 8000, + debug: bool = False, + ) -> None: + super().__init__() + self.compiler = WorldCompiler(artifacts_root=artifacts_root) + self._initial_world_input = world_input + self._provided_session = session + self._provided_interface_adapter = interface_adapter + self.model = model + self._default_max_game_steps = max_game_steps + self._default_max_tool_calls = max_tool_calls + self.scratchpad_max_chars = scratchpad_max_chars + self.debug = debug + self._state = HeroState() + self._compiled: CompiledWorld | None = None + self._session: EpisodeSession | None = None + self._scratchpad = "" + self._max_game_steps = 0 + self._max_tool_calls = 0 + self._debug_dir: Path | None = None + self._episode_stats = HeroEpisodeStats() + self._recent_noop_signatures: deque[tuple[str, str, str]] = deque(maxlen=3) + + @classmethod + def from_session( + cls, + session: EpisodeSession, + *, + max_game_steps: int | None = None, + max_tool_calls: int | None = None, + scratchpad_max_chars: int = 8000, + debug: bool = False, + ) -> "HeroEnvironment": + return cls( + session=session, + max_game_steps=max_game_steps, + max_tool_calls=max_tool_calls, + scratchpad_max_chars=scratchpad_max_chars, + debug=debug, + ) + + def reset( + self, + world_input: CompiledWorld | WorldDefinition | dict[str, Any] | None = None, + *, + seed: int | None = None, + episode_id: str | None = None, + max_game_steps: int | None = None, + max_tool_calls: int | None = None, + scratchpad_max_chars: int | None = None, + debug: bool | None = None, + ) -> HeroObservation: + del seed, episode_id + if debug is not None: + self.debug = debug + if scratchpad_max_chars is not None: + self.scratchpad_max_chars = scratchpad_max_chars + self._scratchpad = "" + self._episode_stats = HeroEpisodeStats() + self._recent_noop_signatures.clear() + + if self._provided_session is not None: + self._session = self._provided_session + self._compiled = self._session.compiled + else: + selected_world = world_input if world_input is not None else self._initial_world_input + if selected_world is None: + raise ValueError("HeroEnvironment.reset requires a compiled world, world definition, or live session.") + self._compiled = ( + selected_world + if isinstance(selected_world, CompiledWorld) + else self.compiler.compile(selected_world) + ) + adapter = self._provided_interface_adapter or StrictCliInterfaceAdapter() + self._session = EpisodeSession(self._compiled, interface_adapter=adapter) + + self._max_game_steps = max_game_steps or self._default_max_game_steps or max( + 1, len(self._compiled.solver_policy) * MAX_STEP_MULTIPLIER + ) + self._max_tool_calls = max_tool_calls or self._default_max_tool_calls or (self._max_game_steps * 4) + self._state = HeroState( + episode_id=self._compiled.episode_id, + step_count=0, + game_steps_taken=self._session.steps_taken, + tool_calls_total=0, + max_game_steps=self._max_game_steps, + max_tool_calls=self._max_tool_calls, + game_steps_remaining=max(0, self._max_game_steps - self._session.steps_taken), + tool_calls_remaining=self._max_tool_calls, + status="running", + world_title=self._compiled.world.meta.title, + last_command=None, + scratchpad_chars=0, + ) + self._prepare_debug_dir() + reward_breakdown = self._empty_breakdown(self._progress_potential()) + observation = self._apply_transform( + HeroObservation( + message=self._session.current_feedback(), + reward=0.0, + done=False, + won=None, + reward_breakdown=reward_breakdown, + aux_signals=self._progress_signals(), + ) + ) + return observation + + def step( # type: ignore[override] + self, + action: HeroAction | dict[str, object], + timeout_s: float | None = None, + **kwargs: Any, + ) -> StepResult[HeroObservation]: + del timeout_s, kwargs + if self._session is None or self._compiled is None: + raise RuntimeError("HeroEnvironment.reset must be called before step().") + if self._state.status != "running": + observation = HeroObservation( + message="", + reward=1.0 if self._state.status == "won" else 0.0, + done=True, + won=self._state.status == "won", + terminal_reason="episode_complete", + reward_breakdown=self._empty_breakdown(self._progress_potential()), + aux_signals=self._progress_signals(), + ) + return build_step_result(self._apply_transform(observation)) + + parsed = validate_hero_action(action) + self._state.tool_calls_total += 1 + self._state.step_count = self._state.tool_calls_total + self._update_remaining_counters() + + if isinstance(parsed, ScratchpadReadAction): + observation = self._observation( + message=self._scratchpad, + tool=parsed.tool, + tool_success=True, + reward_breakdown=self._empty_breakdown(self._progress_potential()), + ) + return build_step_result(observation) + + if isinstance(parsed, ScratchpadWriteAction): + observation = self._handle_scratchpad_write(parsed) + return build_step_result(observation) + + observation = self._handle_act(parsed) + return build_step_result(observation) + + @property + def state(self) -> HeroState: + return self._state + + @property + def scratchpad(self) -> str: + return self._scratchpad + + @property + def session(self) -> EpisodeSession | None: + return self._session + + @property + def episode_stats(self) -> HeroEpisodeStats: + return self._episode_stats + + def _handle_scratchpad_write(self, action: ScratchpadWriteAction) -> HeroObservation: + new_value = ( + self._scratchpad + action.content + if action.mode == "append" + else action.content + ) + if len(new_value) > self.scratchpad_max_chars: + return self._observation( + message="Scratchpad write rejected: notebook size limit exceeded.", + tool=action.tool, + tool_success=False, + reward_breakdown=self._empty_breakdown(self._progress_potential()), + ) + + self._scratchpad = new_value + self._state.scratchpad_chars = len(self._scratchpad) + self._persist_debug_scratchpad() + return self._observation( + message="Scratchpad updated.", + tool=action.tool, + tool_success=True, + reward_breakdown=self._empty_breakdown(self._progress_potential()), + ) + + def _handle_act(self, action: ActAction) -> HeroObservation: + assert self._session is not None + parsed_command = parse_cli_command(action.command) + self._state.last_command = parsed_command.normalized_command or action.command + if not parsed_command.valid or parsed_command.normalized_command is None: + breakdown = self._empty_breakdown(self._progress_potential()) + breakdown.syntax_penalty = _SYNTAX_PENALTY + return self._observation( + message=parsed_command.error or "That command does not match the strict CLI grammar.", + tool=action.tool, + tool_success=False, + reward_breakdown=breakdown, + ) + + potential_before = self._progress_potential() + fingerprint_before = self._session.state_fingerprint() + room_before = self._session.current_room_id + try: + turn = self._session.step(parsed_command.normalized_command) + except DMInterfaceError: + breakdown = self._empty_breakdown(potential_before) + breakdown.syntax_penalty = _SYNTAX_PENALTY + return self._observation( + message="The interface could not interpret that action.", + tool=action.tool, + tool_success=False, + reward_breakdown=breakdown, + ) + + tool_success = self._turn_succeeded(turn.game_state_delta) + self._state.game_steps_taken = self._session.steps_taken + self._session.recent_normalized_commands.append(parsed_command.normalized_command) + potential_after = self._progress_potential() + breakdown = self._empty_breakdown(potential_before) + breakdown.progress_potential_after = potential_after + breakdown.dense_progress_reward = _DENSE_PROGRESS_SCALE * max(0.0, potential_after - potential_before) + if not tool_success: + breakdown.invalid_action_penalty = _INVALID_ACTION_PENALTY + if self._is_wrong_submit(turn.game_state_delta): + breakdown.wrong_submit_penalty = _WRONG_SUBMIT_PENALTY + if self._repeat_noop(parsed_command.normalized_command, fingerprint_before, room_before): + breakdown.repeat_noop_penalty = _REPEAT_NOOP_PENALTY + return self._observation( + message=turn.observation, + tool=action.tool, + tool_success=tool_success, + reward_breakdown=breakdown, + ) + + def _update_remaining_counters(self) -> None: + self._state.game_steps_remaining = max(0, self._max_game_steps - self._state.game_steps_taken) + self._state.tool_calls_remaining = max(0, self._max_tool_calls - self._state.tool_calls_total) + + def _turn_succeeded(self, delta: dict[str, Any]) -> bool: + if delta.get("wrapper") == "submit_rejected": + return False + if "succeeded" in delta: + return bool(delta["succeeded"]) + return True + + def _observation( + self, + *, + message: str, + tool: str, + tool_success: bool, + reward_breakdown: HeroRewardBreakdown, + ) -> HeroObservation: + assert self._session is not None + done = False + won: bool | None = None + terminal_reason: str | None = None + + if self._session.player_won: + self._state.status = "won" + done = True + won = True + reward_breakdown.base_terminal_reward = 1.0 + elif self._session.done: + self._state.status = "lost" + done = True + won = False + terminal_reason = "session_ended" + elif self._state.game_steps_taken >= self._max_game_steps: + self._state.status = "timed_out" + done = True + won = False + terminal_reason = "game_step_budget_exhausted" + elif self._state.tool_calls_total >= self._max_tool_calls: + self._state.status = "timed_out" + done = True + won = False + terminal_reason = "tool_budget_exhausted" + + reward_breakdown.total_reward = ( + reward_breakdown.base_terminal_reward + + reward_breakdown.dense_progress_reward + + reward_breakdown.syntax_penalty + + reward_breakdown.invalid_action_penalty + + reward_breakdown.repeat_noop_penalty + + reward_breakdown.wrong_submit_penalty + ) + self._update_remaining_counters() + aux_signals = self._progress_signals() + self._accumulate_episode_stats(reward_breakdown, won is True) + + observation = self._apply_transform( + HeroObservation( + message=message, + reward=reward_breakdown.total_reward, + done=done, + won=won, + tool=tool, + tool_success=tool_success, + terminal_reason=terminal_reason, + reward_breakdown=reward_breakdown, + aux_signals=aux_signals, + ) + ) + return observation + + def _prepare_debug_dir(self) -> None: + if not self.debug or self._compiled is None: + self._debug_dir = None + return + self._debug_dir = self._compiled.artifacts_dir / "hero_debug" + self._debug_dir.mkdir(parents=True, exist_ok=True) + self._persist_debug_scratchpad() + + def _persist_debug_scratchpad(self) -> None: + if self._debug_dir is None: + return + (self._debug_dir / "scratchpad.txt").write_text(self._scratchpad, encoding="utf-8") + + def _progress_signals(self) -> HeroAuxSignals: + assert self._session is not None + assert self._compiled is not None + room_ids = {node.id for node in self._compiled.world.nodes if node.type in {"location", "junction"}} + total_locked_doors = { + edge.door_node_id + for edge in self._compiled.world.edges + if edge.type == "locked_passage" and edge.door_node_id + } + total_clues = {clue.id for clue in self._compiled.world.clues} + answer_ready = float( + bool(total_clues) + and self._session.consulted_guardian + and self._session.discovered_clues == total_clues + ) + return HeroAuxSignals( + visited_room_progress=_fraction(len(self._session.visited_nodes & room_ids), len(room_ids)), + clue_progress=_fraction(len(self._session.discovered_clues), len(total_clues)), + locked_gate_progress=_fraction(len(self._session.unlocked_doors), len(total_locked_doors)), + trade_progress=_fraction(len(self._session.traded_npcs), len(self._compiled.npc_trade_map)), + recipe_progress=_fraction(len(self._session.completed_recipe_outputs), len(self._compiled.world.recipes)), + use_effect_progress=_fraction(len(self._session.completed_use_targets), len(self._compiled.use_effects)), + guardian_consulted_progress=1.0 if self._session.consulted_guardian else 0.0, + answer_ready_progress=answer_ready, + ) + + def _progress_potential(self) -> float: + signals = self._progress_signals() + potential = ( + 0.10 * signals.visited_room_progress + + 0.35 * signals.clue_progress + + 0.10 * signals.locked_gate_progress + + 0.10 * signals.trade_progress + + 0.10 * signals.recipe_progress + + 0.15 * signals.use_effect_progress + + 0.05 * signals.guardian_consulted_progress + + 0.05 * signals.answer_ready_progress + ) + return max(0.0, min(1.0, potential)) + + def _empty_breakdown(self, potential: float) -> HeroRewardBreakdown: + return HeroRewardBreakdown( + progress_potential_before=potential, + progress_potential_after=potential, + ) + + def _repeat_noop(self, command: str, fingerprint_before: str, room_before: str) -> bool: + assert self._session is not None + fingerprint_after = self._session.state_fingerprint() + room_after = self._session.current_room_id + if room_before == room_after and fingerprint_before == fingerprint_after: + self._recent_noop_signatures.append((command, room_after, fingerprint_after)) + else: + self._recent_noop_signatures.clear() + return ( + len(self._recent_noop_signatures) == 3 + and len({signature[0] for signature in self._recent_noop_signatures}) == 1 + and len({signature[1] for signature in self._recent_noop_signatures}) == 1 + and len({signature[2] for signature in self._recent_noop_signatures}) == 1 + ) + + @staticmethod + def _is_wrong_submit(delta: dict[str, Any]) -> bool: + return delta.get("wrapper") == "submit_rejected" and delta.get("reason") == "wrong_answer" + + def _accumulate_episode_stats(self, breakdown: HeroRewardBreakdown, player_won: bool) -> None: + self._episode_stats.player_won = player_won or self._episode_stats.player_won + self._episode_stats.total_reward += breakdown.total_reward + self._episode_stats.dense_return += breakdown.dense_progress_reward + self._episode_stats.syntax_penalty_total += breakdown.syntax_penalty + self._episode_stats.invalid_action_penalty_total += breakdown.invalid_action_penalty + self._episode_stats.repeat_noop_penalty_total += breakdown.repeat_noop_penalty + self._episode_stats.wrong_submit_penalty_total += breakdown.wrong_submit_penalty + self._episode_stats.steps_taken = self._state.game_steps_taken + self._episode_stats.tool_calls_total = self._state.tool_calls_total + + +def _fraction(done: int, total: int) -> float: + if total <= 0: + return 0.0 + return min(1.0, done / total) diff --git a/agents/hero/policy.py b/agents/hero/policy.py new file mode 100644 index 0000000000000000000000000000000000000000..d3f4cc2f9e6881e9d88acf6b007a88f1ca4c1570 --- /dev/null +++ b/agents/hero/policy.py @@ -0,0 +1,157 @@ +from __future__ import annotations + +from typing import Literal, Protocol + +from pydantic import BaseModel + +from agents.shared.llm_client import StructuredModelClient +from agents.shared.model_schema import ModelMessage, StrictModel + +from .cli import parse_cli_command +from .prompt import format_hero_system_prompt, format_hero_turn_prompt +from .schema import ActAction, HeroAction, HeroObservation, HeroState, validate_hero_action + + +class HeroPolicyError(RuntimeError): + pass + + +class HeroPolicy(Protocol): + trace_events: list["HeroTraceEvent"] + last_error: str | None + + def reset(self) -> None: + ... + + def next_action( + self, + observation: HeroObservation, + state: HeroState, + scratchpad: str, + ) -> HeroAction: + ... + + +class HeroActionPayload(BaseModel): + tool: Literal["act", "scratchpad_read", "scratchpad_write"] + command: str | None = None + mode: Literal["append", "replace"] | None = None + content: str | None = None + + +class HeroActionResponse(BaseModel): + action: HeroActionPayload + + +class HeroTraceEvent(StrictModel): + turn_index: int + observation: str + scratchpad: str + state: dict[str, object] + action: dict[str, object] | None = None + repair_count: int = 0 + validation_error: str | None = None + + +class HeroLLMPolicy: + def __init__( + self, + client: StructuredModelClient, + *, + model_name: str, + temperature: float = 0.1, + max_output_tokens: int = 256, + max_repair_attempts: int = 1, + ) -> None: + self.client = client + self.model_name = model_name + self.temperature = temperature + self.max_output_tokens = max_output_tokens + self.max_repair_attempts = max_repair_attempts + self.trace_events: list[HeroTraceEvent] = [] + self.last_error: str | None = None + + def reset(self) -> None: + self.trace_events = [] + self.last_error = None + + def next_action( + self, + observation: HeroObservation, + state: HeroState, + scratchpad: str, + ) -> HeroAction: + repair_error: str | None = None + for attempt in range(self.max_repair_attempts + 1): + try: + response = self.client.generate_structured( + self._messages(observation, state, scratchpad, repair_error), + HeroActionResponse, + model_name=self.model_name, + temperature=self.temperature, + max_output_tokens=self.max_output_tokens, + ) + action = validate_hero_action(response.action.model_dump(mode="json", exclude_none=True)) + if isinstance(action, ActAction): + parsed_command = parse_cli_command(action.command) + if not parsed_command.valid or parsed_command.normalized_command is None: + raise ValueError(parsed_command.error or "Invalid strict CLI command.") + action = ActAction(command=parsed_command.normalized_command) + self.trace_events.append( + HeroTraceEvent( + turn_index=len(self.trace_events), + observation=observation.message, + scratchpad=scratchpad, + state=state.model_dump(mode="json"), + action=action.model_dump(mode="json"), + repair_count=attempt, + ) + ) + self.last_error = None + return action + except Exception as exc: + repair_error = self._normalize_error(exc) + if attempt >= self.max_repair_attempts: + self.last_error = repair_error + self.trace_events.append( + HeroTraceEvent( + turn_index=len(self.trace_events), + observation=observation.message, + scratchpad=scratchpad, + state=state.model_dump(mode="json"), + repair_count=attempt, + validation_error=repair_error, + ) + ) + raise HeroPolicyError(repair_error) from exc + raise HeroPolicyError("Hero policy failed without a usable action.") + + def _messages( + self, + observation: HeroObservation, + state: HeroState, + scratchpad: str, + repair_error: str | None, + ) -> list[ModelMessage]: + user_prompt = format_hero_turn_prompt(observation.message, state, scratchpad) + if repair_error is not None: + user_prompt += ( + "\nThe previous response did not match the action schema.\n" + f"Validation error: {repair_error}\n" + "Return one corrected action only.\n" + ) + return [ + ModelMessage( + role="system", + content=format_hero_system_prompt( + state.world_title, + state.max_game_steps, + state.max_tool_calls, + ), + ), + ModelMessage(role="user", content=user_prompt), + ] + + @staticmethod + def _normalize_error(exc: Exception) -> str: + return " ".join(str(exc).split()) or exc.__class__.__name__ diff --git a/agents/hero/prompt.py b/agents/hero/prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..3f67ff913c3c22805bf717cf28497ef8ad910602 --- /dev/null +++ b/agents/hero/prompt.py @@ -0,0 +1,134 @@ +from __future__ import annotations + +from .schema import HeroState + + +HERO_SYSTEM_PROMPT = """You are the hero exploring a living dungeon. + +You can only act through tools. + +Rules: +- Use `act` for any in-world action with one strict parser-style CLI command. +- Use `scratchpad_read` and `scratchpad_write` to manage your own notebook. +- Track rooms, objects, clues, hypotheses, and failed attempts in the notebook. +- Do not assume the world is fair in obvious ways; verify. +- Do not expect command hints from the environment. Use `look` and `inventory` when needed. +- Prefer systematic play: open visible containers and doors, take portable items, read text, talk to NPCs, and backtrack when blocked. +- When a puzzle reveals a clue, record it immediately. +- Do not submit an answer until you have enough evidence and the guardian is ready. +- Winning requires gathering evidence and then answering the guardian correctly. +- Keep your notebook concise and update it when the world changes. +- Commands must be lowercase only, with no articles, no markdown, and no conversational text. +- Allowed command grammar: + look + inventory + wait + north|south|east|west|up|down|in|out + go north|go south|go east|go west|go up|go down|go in|go out + open + read + talk + examine + look in + take + take from + unlock with + use on + combine with + give to + submit +- Example valid commands: + open entry chest + take brass key from entry chest + unlock iron door with brass key + east + use torch on ash mural + talk stone guardian + submit mira +- Return JSON only. Never add prose, markdown fences, or explanations. +- Valid response shapes: + {"action":{"tool":"act","command":"look"}} + {"action":{"tool":"scratchpad_read"}} + {"action":{"tool":"scratchpad_write","mode":"append","content":"room notes"}} +""" + +HERO_GRPO_SYSTEM_PROMPT = """You are the hero exploring a living dungeon. + +You can only act through tool calls. + +Rules: +- Call exactly one tool for each turn. +- Use `act` for any in-world action with one strict parser-style CLI command. +- Use `scratchpad_read` and `scratchpad_write` to manage your own notebook. +- Track rooms, objects, clues, hypotheses, and failed attempts in the notebook. +- Do not assume the world is fair in obvious ways; verify. +- Do not expect command hints from the environment. Use `look` and `inventory` when needed. +- Prefer systematic play: open visible containers and doors, take portable items, read text, talk to NPCs, and backtrack when blocked. +- When a puzzle reveals a clue, record it immediately. +- Do not submit an answer until you have enough evidence and the guardian is ready. +- Winning requires gathering evidence and then answering the guardian correctly. +- Keep your notebook concise and update it when the world changes. +- Commands must be lowercase only, with no articles, no markdown, and no conversational text. +- Allowed command grammar: + look + inventory + wait + north|south|east|west|up|down|in|out + go north|go south|go east|go west|go up|go down|go in|go out + open + read + talk + examine + look in + take + take from + unlock with + use on + combine with + give to + submit +- Example valid commands: + open entry chest + take brass key from entry chest + unlock iron door with brass key + east + use torch on ash mural + talk stone guardian + submit mira +- Do not write prose, plans, or plain JSON action objects. +- The runtime provides the tool schema; emit a tool call only. +""" + + +def format_hero_system_prompt(world_title: str, max_game_steps: int, max_tool_calls: int) -> str: + return ( + f"{HERO_SYSTEM_PROMPT}\n\n" + f"World: {world_title}\n" + f"Game-step budget: {max_game_steps}\n" + f"Total tool-call budget: {max_tool_calls}\n" + ) + + +def format_hero_grpo_system_prompt(world_title: str, max_game_steps: int, max_tool_calls: int) -> str: + return ( + f"{HERO_GRPO_SYSTEM_PROMPT}\n\n" + f"World: {world_title}\n" + f"Game-step budget: {max_game_steps}\n" + f"Total tool-call budget: {max_tool_calls}\n" + ) + + +def format_hero_turn_prompt(message: str, state: HeroState, scratchpad: str) -> str: + notebook = scratchpad if scratchpad else "" + return ( + "Choose exactly one next tool call.\n" + f"Observation:\n{message.strip() or ''}\n\n" + f"World: {state.world_title}\n" + f"Status: {state.status}\n" + f"Game steps taken: {state.game_steps_taken}/{state.max_game_steps}\n" + f"Tool calls used: {state.tool_calls_total}/{state.max_tool_calls}\n" + f"Game steps remaining: {state.game_steps_remaining}\n" + f"Tool calls remaining: {state.tool_calls_remaining}\n" + f"Last command: {state.last_command or ''}\n\n" + f"Scratchpad:\n{notebook}\n" + ) diff --git a/agents/hero/runner.py b/agents/hero/runner.py new file mode 100644 index 0000000000000000000000000000000000000000..d01cc0076a68ee52abc713d0c8ab6107834a49e5 --- /dev/null +++ b/agents/hero/runner.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +from collections.abc import Iterable +from typing import Protocol + +from agents.master.session import EpisodeSession + +from .env import HeroEnvironment +from .policy import HeroPolicyError +from .schema import HeroAction, HeroEpisodeStats, HeroObservation, HeroState + + +class ToolCallingPolicy(Protocol): + def reset(self) -> None: + ... + + def next_action( + self, + observation: HeroObservation, + state: HeroState, + scratchpad: str, + ) -> HeroAction | dict[str, object] | None: + ... + + +class ScriptedToolCallingPolicy: + def __init__(self, actions: Iterable[HeroAction | dict[str, object]]) -> None: + self._initial_actions = list(actions) + self._remaining_actions = list(self._initial_actions) + + def reset(self) -> None: + self._remaining_actions = list(self._initial_actions) + + def next_action( + self, + observation: HeroObservation, + state: HeroState, + scratchpad: str, + ) -> HeroAction | dict[str, object] | None: + del observation, state, scratchpad + if not self._remaining_actions: + return None + return self._remaining_actions.pop(0) + + +class HeroRunner: + def __init__( + self, + policy: ToolCallingPolicy, + *, + max_game_steps: int | None = 40, + max_tool_calls: int | None = None, + scratchpad_max_chars: int = 8000, + debug: bool = False, + ) -> None: + self.policy = policy + self.max_game_steps = max_game_steps + self.max_tool_calls = max_tool_calls + self.scratchpad_max_chars = scratchpad_max_chars + self.debug = debug + self.last_error: str | None = None + self.last_observation: HeroObservation | None = None + self.episode_stats: HeroEpisodeStats | None = None + + def run(self, session: EpisodeSession, max_steps: int) -> None: + self.last_error = None + self.last_observation = None + self.episode_stats = None + self.policy.reset() + env = HeroEnvironment.from_session( + session, + max_game_steps=max_steps if self.max_game_steps is None else min(max_steps, self.max_game_steps), + max_tool_calls=self.max_tool_calls, + scratchpad_max_chars=self.scratchpad_max_chars, + debug=self.debug, + ) + observation = env.reset() + self.last_observation = observation + while not observation.done: + try: + action = self.policy.next_action(observation, env.state, env.scratchpad) + except HeroPolicyError as exc: + self.last_error = str(exc) + self.episode_stats = env.episode_stats + return + if action is None: + self.episode_stats = env.episode_stats + return + result = env.step(action) + observation = result.observation + self.last_observation = observation + self.episode_stats = env.episode_stats diff --git a/agents/hero/schema.py b/agents/hero/schema.py new file mode 100644 index 0000000000000000000000000000000000000000..dc7b48585bd1a31e6cd0d19afe44fab2bbf3642a --- /dev/null +++ b/agents/hero/schema.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +from typing import Any +from typing import Annotated, Literal, TypeAlias + +from pydantic import Field, TypeAdapter + +from agents.shared.openenv_compat import Action, Observation, State +from agents.shared.model_schema import StrictModel + + +class ActAction(Action): + tool: Literal["act"] = "act" + command: str + + +class ScratchpadReadAction(Action): + tool: Literal["scratchpad_read"] = "scratchpad_read" + + +class ScratchpadWriteAction(Action): + tool: Literal["scratchpad_write"] = "scratchpad_write" + mode: Literal["append", "replace"] + content: str + + +class HeroServerAction(Action): + tool: Literal["act", "scratchpad_read", "scratchpad_write"] + command: str | None = None + mode: Literal["append", "replace"] | None = None + content: str | None = None + + +HeroAction: TypeAlias = Annotated[ + ActAction | ScratchpadReadAction | ScratchpadWriteAction, + Field(discriminator="tool"), +] + +HERO_ACTION_ADAPTER = TypeAdapter(HeroAction) + + +def validate_hero_action(value: HeroAction | HeroServerAction | dict[str, Any]) -> HeroAction: + if isinstance(value, Action): + value = value.model_dump(mode="json", exclude_none=True) + return HERO_ACTION_ADAPTER.validate_python(value) + + +class HeroObservation(Observation): + message: str = "" + won: bool | None = None + tool: str | None = None + tool_success: bool | None = None + terminal_reason: str | None = None + reward_breakdown: "HeroRewardBreakdown | None" = None + aux_signals: "HeroAuxSignals | None" = None + + +class HeroAuxSignals(StrictModel): + visited_room_progress: float = 0.0 + clue_progress: float = 0.0 + locked_gate_progress: float = 0.0 + trade_progress: float = 0.0 + recipe_progress: float = 0.0 + use_effect_progress: float = 0.0 + guardian_consulted_progress: float = 0.0 + answer_ready_progress: float = 0.0 + + +class HeroRewardBreakdown(StrictModel): + base_terminal_reward: float = 0.0 + dense_progress_reward: float = 0.0 + syntax_penalty: float = 0.0 + invalid_action_penalty: float = 0.0 + repeat_noop_penalty: float = 0.0 + wrong_submit_penalty: float = 0.0 + total_reward: float = 0.0 + progress_potential_before: float = 0.0 + progress_potential_after: float = 0.0 + + +class HeroEpisodeStats(StrictModel): + player_won: bool = False + total_reward: float = 0.0 + dense_return: float = 0.0 + syntax_penalty_total: float = 0.0 + invalid_action_penalty_total: float = 0.0 + repeat_noop_penalty_total: float = 0.0 + wrong_submit_penalty_total: float = 0.0 + steps_taken: int = 0 + tool_calls_total: int = 0 + + +class HeroState(State): + game_steps_taken: int = 0 + tool_calls_total: int = 0 + max_game_steps: int = 0 + max_tool_calls: int = 0 + game_steps_remaining: int = 0 + tool_calls_remaining: int = 0 + status: Literal["ready", "running", "won", "lost", "timed_out", "error"] = "ready" + world_title: str = "" + last_command: str | None = None + scratchpad_chars: int = 0 diff --git a/agents/loop/__init__.py b/agents/loop/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d8590edb0e43d527fc4992adfbb00c923edec962 --- /dev/null +++ b/agents/loop/__init__.py @@ -0,0 +1,11 @@ +"""Closed-loop orchestration for hero and dungeon master policies.""" + +from .runner import ClosedLoopRunner +from .schema import ClosedLoopEpisodeArtifacts, ClosedLoopEpisodeRecord, ClosedLoopEpisodeSummary + +__all__ = [ + "ClosedLoopEpisodeArtifacts", + "ClosedLoopEpisodeRecord", + "ClosedLoopEpisodeSummary", + "ClosedLoopRunner", +] diff --git a/agents/loop/__main__.py b/agents/loop/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..1b0691f6f2d6ff840893c9cc4f2b83a9bd107634 --- /dev/null +++ b/agents/loop/__main__.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +import argparse +import json +from pathlib import Path + +from agents.hero.policy import HeroLLMPolicy +from agents.master.interface import DEFAULT_GEMINI_MODEL +from agents.master.env import DMEnvironment +from agents.master.policy import DungeonMasterLLMPolicy +from agents.shared.runtime import ( + build_interface_adapter, + create_structured_client, + resolve_interface_config, + resolve_structured_client_config, +) + +from .runner import ClosedLoopRunner + + +def main(argv: list[str] | None = None) -> int: + parser = argparse.ArgumentParser(description="Closed-loop dungeon master and hero harness") + parser.add_argument("--episodes", type=int, default=1) + parser.add_argument("--seed", type=int) + parser.add_argument("--target-ratio", type=float) + parser.add_argument("--dm-provider", choices=["gemini", "hf_local"]) + parser.add_argument("--dm-model") + parser.add_argument("--dm-adapter-path") + parser.add_argument("--hero-provider", choices=["gemini", "hf_local"]) + parser.add_argument("--hero-model") + parser.add_argument("--hero-adapter-path") + parser.add_argument("--interface-provider", choices=["strict", "simple", "gemini"]) + parser.add_argument("--interface-model", default=DEFAULT_GEMINI_MODEL) + parser.add_argument("--interface-narrate", action="store_true") + parser.add_argument( + "--translate-corporate-env", + action="store_true", + help="Rewrite hero-facing observations into a corporate app metaphor and map translated commands back through Gemini.", + ) + parser.add_argument("--artifacts-root", type=Path) + parser.add_argument("--dm-artifacts-root", type=Path) + parser.add_argument("--dm-repair-attempts", type=int, default=2) + parser.add_argument("--hero-max-game-steps", type=int, default=40) + parser.add_argument("--hero-max-tool-calls", type=int, default=80) + parser.add_argument("--live", action="store_true") + parser.add_argument("--live-dir", type=Path) + args = parser.parse_args(argv) + + dm_config = resolve_structured_client_config( + "dm", + provider=args.dm_provider, + model_name=args.dm_model, + adapter_path=args.dm_adapter_path, + ) + hero_config = resolve_structured_client_config( + "hero", + provider=args.hero_provider, + model_name=args.hero_model, + adapter_path=args.hero_adapter_path, + ) + interface_config = resolve_interface_config( + provider=args.interface_provider, + model_name=args.interface_model, + narrate_observations=args.interface_narrate, + translation_mode="corporate_app" if args.translate_corporate_env else None, + ) + runner = ClosedLoopRunner( + dm_env=DMEnvironment(artifacts_root=args.dm_artifacts_root), + dm_policy=DungeonMasterLLMPolicy(create_structured_client(dm_config), model_name=dm_config.model_name), + hero_policy=HeroLLMPolicy(create_structured_client(hero_config), model_name=hero_config.model_name), + artifacts_root=args.artifacts_root, + live_dir=args.live_dir, + max_dm_repair_attempts=args.dm_repair_attempts, + hero_runner_kwargs={ + "max_game_steps": args.hero_max_game_steps, + "max_tool_calls": args.hero_max_tool_calls, + }, + hero_interface_adapter=build_interface_adapter(interface_config), + ) + records = [] + for index in range(args.episodes): + seed = None if args.seed is None else args.seed + index + record = runner.run_episode(seed=seed, target_ratio=args.target_ratio, live=args.live) + records.append(record) + print(json.dumps(ClosedLoopRunner.summary(record).model_dump(mode="json"))) + if records: + print(json.dumps(ClosedLoopRunner.aggregate(records).model_dump(mode="json"))) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/agents/loop/runner.py b/agents/loop/runner.py new file mode 100644 index 0000000000000000000000000000000000000000..5b24c876d898a44aa82bbab385931e6cc9b95e64 --- /dev/null +++ b/agents/loop/runner.py @@ -0,0 +1,253 @@ +from __future__ import annotations + +import json +from pathlib import Path + +from agents.hero.policy import HeroPolicy +from agents.hero.runner import HeroRunner +from agents.master.env import DMEnvironment +from agents.master.interface import InterfaceAdapter, StrictCliInterfaceAdapter +from agents.master.policy import DMRepairContext, DungeonMasterPolicy, DungeonMasterPolicyError +from agents.master.schema import DMObservation, DMRewardBreakdown, WorldDefinition +from agents.master.snapshots import LiveObserver, LiveSnapshotWriter + +from .schema import ( + ClosedLoopAggregateReport, + ClosedLoopEpisodeArtifacts, + ClosedLoopEpisodeRecord, + ClosedLoopEpisodeSummary, +) + +DEFAULT_CLOSED_LOOP_ROOT = Path(__file__).resolve().parents[2] / ".play_runs" / "closed_loop" + + +class ClosedLoopRunner: + def __init__( + self, + *, + dm_env: DMEnvironment, + dm_policy: DungeonMasterPolicy, + hero_policy: HeroPolicy, + artifacts_root: Path | None = None, + live_dir: Path | None = None, + max_dm_repair_attempts: int = 2, + hero_runner_kwargs: dict[str, object] | None = None, + hero_interface_adapter: InterfaceAdapter | None = None, + ) -> None: + self.dm_env = dm_env + self.dm_policy = dm_policy + self.hero_policy = hero_policy + self.artifacts_root = artifacts_root or DEFAULT_CLOSED_LOOP_ROOT + self.live_dir = live_dir + self.max_dm_repair_attempts = max_dm_repair_attempts + self.hero_runner_kwargs = hero_runner_kwargs or {"max_game_steps": 40, "max_tool_calls": 80} + self.hero_interface_adapter = hero_interface_adapter or StrictCliInterfaceAdapter() + + def run_episode( + self, + *, + seed: int | None = None, + target_ratio: float | None = None, + live: bool = False, + ) -> ClosedLoopEpisodeRecord: + self.dm_env.reset(seed=seed, difficulty_hint=target_ratio) + episode_id = self.dm_env.state.episode_id + if episode_id is None: + raise RuntimeError("DM environment did not assign an episode id.") + episode_dir = self.artifacts_root / episode_id + episode_dir.mkdir(parents=True, exist_ok=True) + artifacts = ClosedLoopEpisodeArtifacts.from_episode_dir(episode_dir) + observer = self._observer(live) + + world: WorldDefinition | None = None + errors: list[str] = [] + compile_attempts = 0 + repair_context: DMRepairContext | None = None + previous_candidate_json: str | None = None + attempt_rows: list[dict[str, object]] = [] + + for attempt in range(1, self.max_dm_repair_attempts + 2): + compile_attempts = attempt + try: + candidate = self.dm_policy.generate_world( + target_ratio=self.dm_env.state.target_ratio, + repair_context=repair_context, + ) + previous_candidate_json = candidate.model_dump_json(indent=2) + self._write_json(Path(artifacts.world_definition_path), previous_candidate_json) + self.dm_env.compile_world(candidate, episode_id=episode_id) + world = candidate + attempt_rows.append( + { + "attempt_number": attempt, + "status": "compiled", + "world_title": candidate.meta.title, + "difficulty_target": candidate.meta.difficulty_target, + } + ) + break + except Exception as exc: + normalized_error = self._normalize_error(exc) + errors.append(normalized_error) + attempt_rows.append( + { + "attempt_number": attempt, + "status": "failed", + "error": normalized_error, + } + ) + repair_context = DMRepairContext( + attempt_number=attempt, + error_message=normalized_error, + previous_candidate_json=previous_candidate_json, + ) + + self._write_jsonl(Path(artifacts.world_generation_attempts_path), attempt_rows) + + if world is None: + observation = self._compile_failure_observation(errors[-1] if errors else "world compilation failed") + record = ClosedLoopEpisodeRecord( + episode_id=episode_id, + status="compile_failed", + target_ratio=self.dm_env.state.target_ratio, + compile_attempts=compile_attempts, + dm_repair_errors=errors, + world_definition=None, + declared_difficulty_target=None, + difficulty_target_matches_target_ratio=None, + observation=observation, + artifacts=artifacts, + ) + self._persist_record(record) + self._write_jsonl(Path(artifacts.hero_trace_path), []) + self._write_jsonl(Path(artifacts.transcript_path), []) + return record + + hero_runner = HeroRunner(policy=self.hero_policy, **self.hero_runner_kwargs) + previous_adapter = self.dm_env.interface_adapter + self.dm_env.interface_adapter = self.hero_interface_adapter + try: + result = self.dm_env.step(world, runner=hero_runner, observer=observer) + finally: + self.dm_env.interface_adapter = previous_adapter + observation = result.observation + status = "policy_error" if hero_runner.last_error else ("complete" if observation.player_won else "failed") + record = ClosedLoopEpisodeRecord( + episode_id=episode_id, + status=status, + target_ratio=self.dm_env.state.target_ratio, + compile_attempts=compile_attempts, + dm_repair_errors=errors, + hero_policy_error=hero_runner.last_error, + hero_episode_stats=hero_runner.episode_stats, + world_definition=world, + declared_difficulty_target=world.meta.difficulty_target, + difficulty_target_matches_target_ratio=(world.meta.difficulty_target == self.dm_env.state.target_ratio), + observation=observation, + artifacts=artifacts, + ) + self._persist_record(record) + self._write_jsonl( + Path(artifacts.hero_trace_path), + [event.model_dump(mode="json") for event in self.hero_policy.trace_events], + ) + self._write_jsonl( + Path(artifacts.transcript_path), + [turn.model_dump(mode="json") for turn in observation.episode_transcript], + ) + return record + + @staticmethod + def summary(record: ClosedLoopEpisodeRecord) -> ClosedLoopEpisodeSummary: + return ClosedLoopEpisodeSummary( + episode_id=record.episode_id, + status=record.status, + reward=record.observation.reward, + player_won=record.observation.player_won, + ratio=record.observation.ratio, + compile_error=record.observation.compile_error, + hero_policy_error=record.hero_policy_error, + ) + + @staticmethod + def aggregate(records: list[ClosedLoopEpisodeRecord]) -> ClosedLoopAggregateReport: + episodes = len(records) + dense_returns = [ + record.hero_episode_stats.dense_return + for record in records + if record.hero_episode_stats is not None + ] + invalid_penalties = [ + record.hero_episode_stats.invalid_action_penalty_total + for record in records + if record.hero_episode_stats is not None + ] + repeat_penalties = [ + record.hero_episode_stats.repeat_noop_penalty_total + for record in records + if record.hero_episode_stats is not None + ] + return ClosedLoopAggregateReport( + episodes=episodes, + compile_valid_rate=_rate(sum(record.status != "compile_failed" for record in records), episodes), + policy_error_rate=_rate(sum(record.status == "policy_error" for record in records), episodes), + playable_rate=_rate(sum(record.world_definition is not None for record in records), episodes), + solve_rate=_rate(sum(record.status == "complete" for record in records), episodes), + mean_dense_return=_mean(dense_returns), + mean_invalid_action_penalty=_mean(invalid_penalties), + mean_repeat_noop_penalty=_mean(repeat_penalties), + ) + + def _compile_failure_observation(self, error: str) -> DMObservation: + breakdown = DMRewardBreakdown( + reward_mode="compile_failure_penalty", + player_won=False, + target_ratio=self.dm_env.state.target_ratio, + quality_score=0.0, + reward=0.0, + ) + return DMObservation( + player_won=False, + compile_error=error, + reward=0.0, + done=True, + reward_breakdown=breakdown, + target_ratio_used=self.dm_env.state.target_ratio, + ) + + def _observer(self, live: bool) -> LiveObserver | None: + if not live: + return None + return LiveSnapshotWriter(live_dir=self.live_dir, runner_name="hero_llm") + + def _persist_record(self, record: ClosedLoopEpisodeRecord) -> None: + self._write_json(Path(record.artifacts.run_record_path), record.model_dump_json(indent=2)) + + @staticmethod + def _write_json(path: Path, payload: str) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(payload + "\n", encoding="utf-8") + + @staticmethod + def _write_jsonl(path: Path, rows: list[dict[str, object]]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + payload = "".join(json.dumps(row) + "\n" for row in rows) + path.write_text(payload, encoding="utf-8") + + @staticmethod + def _normalize_error(exc: Exception) -> str: + if isinstance(exc, DungeonMasterPolicyError): + return str(exc) + return " ".join(str(exc).split()) or exc.__class__.__name__ + + +def _mean(values: list[float]) -> float: + if not values: + return 0.0 + return sum(values) / len(values) + + +def _rate(count: int, total: int) -> float: + if total <= 0: + return 0.0 + return count / total diff --git a/agents/loop/schema.py b/agents/loop/schema.py new file mode 100644 index 0000000000000000000000000000000000000000..bd2034ee3845166229d5b3e36e735da19148381e --- /dev/null +++ b/agents/loop/schema.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Literal + +from agents.hero.schema import HeroEpisodeStats +from agents.master.schema import DMObservation, WorldDefinition +from agents.shared.model_schema import StrictModel + + +class ClosedLoopEpisodeArtifacts(StrictModel): + episode_dir: str + world_generation_attempts_path: str + world_definition_path: str + run_record_path: str + hero_trace_path: str + transcript_path: str + + @classmethod + def from_episode_dir(cls, episode_dir: Path) -> "ClosedLoopEpisodeArtifacts": + return cls( + episode_dir=str(episode_dir), + world_generation_attempts_path=str(episode_dir / "world_generation_attempts.jsonl"), + world_definition_path=str(episode_dir / "world_definition.json"), + run_record_path=str(episode_dir / "run_record.json"), + hero_trace_path=str(episode_dir / "hero_trace.jsonl"), + transcript_path=str(episode_dir / "transcript.jsonl"), + ) + + +class ClosedLoopEpisodeRecord(StrictModel): + episode_id: str + status: Literal["complete", "failed", "compile_failed", "policy_error"] + target_ratio: float + compile_attempts: int + dm_repair_errors: list[str] + hero_policy_error: str | None = None + hero_episode_stats: HeroEpisodeStats | None = None + declared_difficulty_target: float | None = None + difficulty_target_matches_target_ratio: bool | None = None + world_definition: WorldDefinition | None = None + observation: DMObservation + artifacts: ClosedLoopEpisodeArtifacts + + +class ClosedLoopEpisodeSummary(StrictModel): + episode_id: str + status: str + reward: float | None = None + player_won: bool | None = None + ratio: float | None = None + compile_error: str | None = None + hero_policy_error: str | None = None + + +class ClosedLoopAggregateReport(StrictModel): + episodes: int + compile_valid_rate: float + policy_error_rate: float + playable_rate: float + solve_rate: float + mean_dense_return: float + mean_invalid_action_penalty: float + mean_repeat_noop_penalty: float diff --git a/agents/master/__init__.py b/agents/master/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..68d915f1860a5810fe19b630f2dd7c09b9214715 --- /dev/null +++ b/agents/master/__init__.py @@ -0,0 +1,15 @@ +"""DM environment source package.""" + +from .policy import ( + DMRepairContext, + DungeonMasterLLMPolicy, + DungeonMasterPolicy, + DungeonMasterPolicyError, +) + +__all__ = [ + "DMRepairContext", + "DungeonMasterLLMPolicy", + "DungeonMasterPolicy", + "DungeonMasterPolicyError", +] diff --git a/agents/master/__main__.py b/agents/master/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..2fb3861e6c1b330e1930ed00632a66d1ec669e4d --- /dev/null +++ b/agents/master/__main__.py @@ -0,0 +1,5 @@ +from .main import main + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/agents/master/base.py b/agents/master/base.py new file mode 100644 index 0000000000000000000000000000000000000000..11399be85b6d4e77eae7e63fa9d78f465a50d000 --- /dev/null +++ b/agents/master/base.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from contextlib import contextmanager +import re +import warnings +from pathlib import Path + + +MAX_NODES = 40 +MAX_ITEMS = 32 +MAX_QUEST_STEPS = 64 +MIN_NODES = 5 +MIN_QUEST_STEPS = 2 +MIN_CLUES = 3 +MAX_CLUES = 5 +TARGET_RATIO = 1.5 +TARGET_RATIO_SIGMA = 0.4 +MAX_STEP_MULTIPLIER = 5 +INVENTORY_ID = "__inventory__" +STORED_ID = "__stored__" +ROOT_DIR = Path(__file__).resolve().parents[2] +ARTIFACTS_ROOT = ROOT_DIR / ".artifacts" / "dm_env" +CUSTOM_LOGIC_DIR = ROOT_DIR / "textworld_data" / "dnd" / "logic" +CUSTOM_GRAMMAR_DIR = ROOT_DIR / "textworld_data" / "dnd" / "text_grammars" +SUPPORTED_DIRECTIONS = ("north", "south", "east", "west", "up", "down", "in", "out") +OPPOSITE_DIRECTION = { + "north": "south", + "south": "north", + "east": "west", + "west": "east", + "up": "down", + "down": "up", + "in": "out", + "out": "in", +} + +GO_RE = re.compile(r"^go\((?P[a-z0-9_]+)\)$") +OPEN_RE = re.compile(r"^open\((?P[a-z0-9_]+)\)$") +UNLOCK_RE = re.compile(r"^unlock\((?P[a-z0-9_]+),(?P[a-z0-9_]+)\)$") +TAKE_RE = re.compile(r"^take\((?P[a-z0-9_]+),(?P[a-z0-9_]+)\)$") +READ_RE = re.compile(r"^read\((?P[a-z0-9_]+)\)$") +USE_RE = re.compile(r"^use\((?P[a-z0-9_]+),(?P[a-z0-9_]+)\)$") +COMBINE_RE = re.compile(r"^combine\((?P[a-z0-9_]+),(?P[a-z0-9_]+)\)$") +GIVE_RE = re.compile(r"^give\((?P[a-z0-9_]+),(?P[a-z0-9_]+)\)$") +TALK_RE = re.compile(r"^talk\((?P[a-z0-9_]+)\)$") +SUBMIT_RE = re.compile(r"^submit\((?P[\"'])(?P.+)(?P=quote)\)$") + + +class DMCompileError(RuntimeError): + pass + + +class DMInterfaceError(RuntimeError): + pass + + +@contextmanager +def suppress_unsupported_game_warning(): + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message=r"Game '.*' is not fully supported\..*", + category=Warning, + ) + yield + + +def normalize_snake_id(value: str, kind: str) -> str: + if not re.fullmatch(r"[a-z][a-z0-9_]*", value): + raise DMCompileError(f"{kind} '{value}' must be snake_case.") + return value + + +def parser_safe_text(value: str) -> str: + collapsed = re.sub(r"[^A-Za-z0-9 ]+", " ", value).strip().lower() + collapsed = re.sub(r"\s+", " ", collapsed) + if not collapsed: + raise DMCompileError(f"Unable to derive a parser-safe name from '{value}'.") + return collapsed + + +def normalize_answer_text(value: str) -> str: + collapsed = re.sub(r"[^A-Za-z0-9 ]+", " ", value).strip().lower() + return re.sub(r"\s+", " ", collapsed) diff --git a/agents/master/build.py b/agents/master/build.py new file mode 100644 index 0000000000000000000000000000000000000000..b645ea9ba889543aa0f6f87ea2c626200bb87621 --- /dev/null +++ b/agents/master/build.py @@ -0,0 +1,287 @@ +from __future__ import annotations + +import uuid +from collections import defaultdict +from pathlib import Path +from typing import Any + +from textworld.generator import GameMaker, GameOptions, compile_game +from textworld.generator.data import KnowledgeBase + +from .base import ARTIFACTS_ROOT, DMCompileError, parser_safe_text +from .check import validate_and_normalize +from .graph import ( + door_room_mapping, + hidden_readable_ids, + npc_trade_mapping, + produced_item_ids, + readable_clue_mapping, + recipe_mapping, + use_effect_mapping, +) +from .logic import build_grammar_dir, build_logic_dir, solver_policy, submit_command_text, write_artifacts +from .quest import parse_quest_action, simulate_walkthrough, topological_linearize +from .schema import CompiledWorld, WorldDefinition + + +class WorldCompiler: + def __init__(self, artifacts_root: Path | None = None) -> None: + self.artifacts_root = artifacts_root or ARTIFACTS_ROOT + + def compile(self, world_input: WorldDefinition | dict[str, Any], episode_id: str | None = None) -> CompiledWorld: + world = validate_and_normalize(world_input) + episode_id = episode_id or uuid.uuid4().hex[:12] + artifacts_dir = self.artifacts_root / episode_id + artifacts_dir.mkdir(parents=True, exist_ok=True) + parsed_steps = [parse_quest_action(step.action) for step in topological_linearize(world.quest_chain)] + entity_names = self._assign_command_names(world) + + options = GameOptions() + options.kb = KnowledgeBase.load( + logic_path=str(build_logic_dir(artifacts_dir, world)), + grammar_path=str(build_grammar_dir(artifacts_dir)), + ) + options.path = str(artifacts_dir / "game.z8") + options.force_recompile = True + maker = GameMaker(options=options) + + rooms, entities = self._build_entities(maker, world, entity_names) + maker.set_player(rooms[world.meta.start_node_id]) + self._compile_edges(maker, world, rooms, entities) + self._compile_clue_sources(maker, world, entities) + self._compile_fixtures(maker, world, entities) + self._compile_npcs(maker, world, entities) + self._compile_recipes(maker, world, entities) + + guardian = entities[world.meta.win_condition.target_npc_id] + answer = maker.new(type="answer", name="final answer token") + maker.nowhere.append(answer) + entities["__answer__"] = answer + maker.add_fact("guardian", guardian) + maker.add_fact("correct", answer, guardian) + + walkthrough_commands = simulate_walkthrough(world, parsed_steps, entity_names) + game = maker.build() + game.objective = ( + f"Explore {world.meta.title}, manipulate the dungeon's tools, gather every clue, " + f"speak to {entities[world.meta.win_condition.target_npc_id].name}, and submit the answer." + ) + game.metadata.update( + {"episode_id": episode_id, "dm_title": world.meta.title, "start_node_id": world.meta.start_node_id} + ) + compile_game(game, options) + write_artifacts(artifacts_dir, world, walkthrough_commands) + policy = solver_policy(str(options.path)) + if not policy: + policy = list(walkthrough_commands) + return self._compiled_world( + episode_id, + artifacts_dir, + Path(options.path), + world, + entity_names, + walkthrough_commands, + policy, + ) + + def _build_entities( + self, + maker: GameMaker, + world: WorldDefinition, + entity_names: dict[str, str], + ) -> tuple[dict[str, Any], dict[str, Any]]: + rooms = { + node.id: maker.new(type="r", name=entity_names[node.id], desc=node.description) + for node in world.nodes + if node.type in {"location", "junction"} + } + entities: dict[str, Any] = {} + hidden_readables = hidden_readable_ids(world) + recipe_outputs = {recipe.output_item_id for recipe in world.recipes} + produced_items = produced_item_ids(world) + + for node in world.nodes: + if node.type in {"location", "junction"}: + continue + entity = self._make_node_entity(maker, node, entity_names[node.id]) + entities[node.id] = entity + if node.type == "door": + maker.nowhere.append(entity) + elif node.type == "readable" and node.id in hidden_readables: + maker.nowhere.append(entity) + maker.add_fact("hidden_readable", entity) + else: + rooms[node.parent_id].add(entity) + + for item in world.items: + item_type = "k" if item.subtype == "key" else "o" + entity = maker.new(type=item_type, name=entity_names[item.id], desc=item.description) + entities[item.id] = entity + if item.id in produced_items: + maker.nowhere.append(entity) + if item.id in recipe_outputs: + maker.add_fact("fresh", entity) + else: + maker.add_fact("stored_item", entity) + continue + holder = item.start_node_id + if holder is None: + raise DMCompileError(f"Placed item '{item.id}' is missing start_node_id.") + if holder in rooms: + rooms[holder].add(entity) + else: + entities[holder].add(entity) + + return rooms, entities + + @staticmethod + def _make_node_entity(maker: GameMaker, node: object, name: str) -> Any: + if node.type == "container": + entity = maker.new(type="c", name=name, desc=node.description) + entity.add_property("open" if node.open else "locked" if node.locked else "closed") + return entity + if node.type == "door": + entity = maker.new(type="d", name=name, desc=node.description) + entity.add_property("open" if node.open else "locked" if node.locked else "closed") + return entity + if node.type == "readable": + return maker.new(type="readable", name=name, desc=node.description) + if node.type == "fixture": + return maker.new(type="fixture", name=name, desc=node.description) + if node.type == "npc": + return maker.new(type="npc", name=name, desc=node.description) + raise DMCompileError(f"Unsupported node type '{node.type}'.") + + def _compile_clue_sources( + self, + maker: GameMaker, + world: WorldDefinition, + entities: dict[str, Any], + ) -> None: + hidden_readables = hidden_readable_ids(world) + for node in world.nodes: + if node.type != "readable": + continue + readable = entities[node.id] + if node.requires_item_id: + maker.add_fact("read_requires", readable, entities[node.requires_item_id]) + maker.add_fact("read_consumes_use" if node.consumes_item else "read_keeps_use", readable) + else: + maker.add_fact("free_read", readable) + if node.id in hidden_readables: + continue + + def _compile_fixtures(self, maker: GameMaker, world: WorldDefinition, entities: dict[str, Any]) -> None: + for node in world.nodes: + if node.type != "fixture": + continue + fixture = entities[node.id] + maker.add_fact("fixture_requires", fixture, entities[node.requires_item_id]) + maker.add_fact("sealed", fixture) + maker.add_fact("fixture_consumes_use" if node.consumes_item else "fixture_keeps_use", fixture) + if node.reveals_item_id: + maker.add_fact("reveals_item", fixture, entities[node.reveals_item_id]) + if node.reveals_readable_id: + maker.add_fact("reveals_readable", fixture, entities[node.reveals_readable_id]) + + def _compile_npcs( + self, + maker: GameMaker, + world: WorldDefinition, + entities: dict[str, Any], + ) -> None: + guardian_id = world.meta.win_condition.target_npc_id + for node in world.nodes: + if node.type != "npc": + continue + npc = entities[node.id] + if node.id == guardian_id: + continue + maker.add_fact("trade_pending", npc) + maker.add_fact("trade_requires", npc, entities[node.requires_item_id]) + if node.gives_item_id: + maker.add_fact("trade_gives_item", npc, entities[node.gives_item_id]) + if node.gives_clue_id: + maker.add_fact("trade_gives_clue", npc) + + def _compile_recipes(self, maker: GameMaker, world: WorldDefinition, entities: dict[str, Any]) -> None: + for recipe in world.recipes: + a_id, b_id = recipe.input_item_ids + output = entities[recipe.output_item_id] + maker.add_fact("combines_with", entities[a_id], entities[b_id], output) + maker.add_fact("combines_with", entities[b_id], entities[a_id], output) + + @staticmethod + def _compile_edges( + maker: GameMaker, + world: WorldDefinition, + rooms: dict[str, Any], + entities: dict[str, Any], + ) -> None: + pair_groups: dict[frozenset[str], list[Any]] = defaultdict(list) + for edge in world.edges: + pair_groups.setdefault(frozenset({edge.from_node_id, edge.to_node_id}), []).append(edge) + for edges in pair_groups.values(): + forward, backward = sorted(edges, key=lambda edge: edge.id) + for edge in (forward, backward): + maker.add_fact(f"{edge.direction}_of", rooms[edge.to_node_id], rooms[edge.from_node_id]) + if forward.door_node_id: + door = entities[forward.door_node_id] + room_a = rooms[forward.from_node_id] + room_b = rooms[forward.to_node_id] + maker.add_fact("link", room_a, door, room_b) + maker.add_fact("link", room_b, door, room_a) + if forward.required_item_id: + maker.add_fact("match", entities[forward.required_item_id], door) + door_is_open = door.has_property("open") + if door_is_open: + maker.add_fact("free", room_a, room_b) + maker.add_fact("free", room_b, room_a) + else: + maker.add_fact("free", rooms[forward.from_node_id], rooms[forward.to_node_id]) + maker.add_fact("free", rooms[forward.to_node_id], rooms[forward.from_node_id]) + + def _compiled_world( + self, + episode_id: str, + artifacts_dir: Path, + game_file: Path, + world: WorldDefinition, + entity_names: dict[str, str], + walkthrough_commands: list[str], + policy: list[str], + ) -> CompiledWorld: + node_by_id = {node.id: node for node in world.nodes} + return CompiledWorld( + episode_id=episode_id, + world=world, + artifacts_dir=artifacts_dir, + game_file=game_file, + walkthrough_commands=walkthrough_commands, + solver_policy=policy, + correct_answer_normalized=submit_command_text(world).replace("submit ", "", 1), + correct_submit_command=submit_command_text(world), + guardian_id=world.meta.win_condition.target_npc_id, + guardian_room_id=node_by_id[world.meta.win_condition.target_npc_id].parent_id, + room_name_to_id={ + entity_names[node.id]: node.id for node in world.nodes if node.type in {"location", "junction"} + }, + node_command_names={node.id: entity_names[node.id] for node in world.nodes}, + item_command_names={item.id: entity_names[item.id] for item in world.items}, + item_start_locations={item.id: item.start_node_id for item in world.items}, + clue_text_by_id={clue.id: clue.text for clue in world.clues}, + readable_clue_by_id=readable_clue_mapping(world), + npc_trade_map=npc_trade_mapping(world), + recipe_map=recipe_mapping(world), + use_effects=use_effect_mapping(world), + produced_item_ids=produced_item_ids(world), + room_edges_by_target={(edge.from_node_id, edge.to_node_id): edge for edge in world.edges}, + room_edges_by_direction={(edge.from_node_id, edge.direction): edge for edge in world.edges}, + door_rooms=door_room_mapping(world), + ) + + @staticmethod + def _assign_command_names(world: WorldDefinition) -> dict[str, str]: + names = {node.id: parser_safe_text(node.label) for node in world.nodes} + names.update({item.id: parser_safe_text(item.label) for item in world.items}) + return names diff --git a/agents/master/check.py b/agents/master/check.py new file mode 100644 index 0000000000000000000000000000000000000000..7392f40ac9b1d6b7bf53bb9083cae10e41f91f60 --- /dev/null +++ b/agents/master/check.py @@ -0,0 +1,435 @@ +from __future__ import annotations + +from collections import defaultdict, deque +from typing import Any + +from pydantic import ValidationError + +from .base import ( + DMCompileError, + MAX_CLUES, + MAX_ITEMS, + MAX_NODES, + MAX_QUEST_STEPS, + MIN_CLUES, + MIN_NODES, + MIN_QUEST_STEPS, + OPPOSITE_DIRECTION, + normalize_answer_text, + normalize_snake_id, + parser_safe_text, +) +from .graph import hidden_readable_ids, produced_item_ids +from .quest import parse_quest_action, simulate_walkthrough, topological_linearize +from .schema import ( + CombineAction, + ContainerNode, + DoorNode, + GiveAction, + NpcNode, + ReadableNode, + SubmitAction, + TakeAction, + TalkAction, + UnlockAction, + UseAction, + WorldDefinition, +) + + +def validate_and_normalize(world_input: WorldDefinition | dict[str, Any]) -> WorldDefinition: + if isinstance(world_input, dict): + _reject_legacy_shapes(world_input) + try: + world = WorldDefinition.model_validate(world_input) + except ValidationError as exc: # pragma: no cover - exercised indirectly in compile paths + raise DMCompileError(str(exc)) from exc + _validate_ids(world) + _validate_shape(world) + _validate_nodes(world) + _validate_edges(world) + _validate_items(world) + _validate_clues(world) + _validate_visibility(world) + _validate_answer_leaks(world) + _validate_guardian_path(world) + _validate_clue_gates(world) + _validate_item_usage(world) + _validate_quest_shape(world) + return world + + +def infer_start_room(world: WorldDefinition) -> str: + return world.meta.start_node_id + + +def _reject_legacy_shapes(world_input: dict[str, Any]) -> None: + for node in world_input.get("nodes", []): + if node.get("type") == "clue": + raise DMCompileError("Legacy clue nodes are not supported in v2. Use top-level clues[].") + if node.get("state", {}).get("npc_dialogue") is not None: + raise DMCompileError("Legacy npc_dialogue is not supported in v2.") + for edge in world_input.get("edges", []): + if edge.get("type") == "conditional_passage": + raise DMCompileError("conditional_passage is not supported in v2.") + + +def _validate_ids(world: WorldDefinition) -> None: + global_ids: set[str] = set() + collections = { + "node": [node.id for node in world.nodes], + "item": [item.id for item in world.items], + "clue": [clue.id for clue in world.clues], + "recipe": [recipe.id for recipe in world.recipes], + "quest step": [step.step_id for step in world.quest_chain], + } + for kind, values in collections.items(): + seen: set[str] = set() + for value in values: + normalize_snake_id(value, kind) + if value in seen: + raise DMCompileError(f"Duplicate {kind} id '{value}'.") + if value in global_ids: + raise DMCompileError(f"Duplicate world id '{value}' across collections.") + seen.add(value) + global_ids.add(value) + + +def _validate_shape(world: WorldDefinition) -> None: + room_nodes = [node for node in world.nodes if node.type in {"location", "junction"}] + if len(world.nodes) < MIN_NODES: + raise DMCompileError(f"Worlds need at least {MIN_NODES} nodes.") + if len(world.nodes) > MAX_NODES: + raise DMCompileError(f"Worlds support at most {MAX_NODES} nodes.") + if len(world.items) > MAX_ITEMS: + raise DMCompileError(f"Worlds support at most {MAX_ITEMS} items.") + if len(world.clues) < MIN_CLUES or len(world.clues) > MAX_CLUES: + raise DMCompileError(f"Worlds must define between {MIN_CLUES} and {MAX_CLUES} clues.") + if len(world.quest_chain) < MIN_QUEST_STEPS or len(world.quest_chain) > MAX_QUEST_STEPS: + raise DMCompileError(f"quest_chain must contain between {MIN_QUEST_STEPS} and {MAX_QUEST_STEPS} steps.") + if world.meta.start_node_id not in {node.id for node in room_nodes}: + raise DMCompileError("meta.start_node_id must reference a location or junction.") + if world.meta.win_condition.type != "deduce": + raise DMCompileError("Only deduce win conditions are supported in v2.") + if not normalize_answer_text(world.meta.win_condition.answer_string): + raise DMCompileError("answer_string cannot normalize to an empty command.") + + +def _validate_nodes(world: WorldDefinition) -> None: + node_by_id = {node.id: node for node in world.nodes} + item_ids = {item.id for item in world.items} + clue_ids = {clue.id for clue in world.clues} + hidden_readables = hidden_readable_ids(world) + guardian_id = world.meta.win_condition.target_npc_id + + guardian_seen = False + for node in world.nodes: + if node.type in {"location", "junction"}: + continue + if node.type == "door": + _validate_lockable(node, item_ids) + continue + parent = node_by_id.get(node.parent_id) + if parent is None or parent.type not in {"location", "junction"}: + raise DMCompileError(f"Node '{node.id}' must live in a location or junction.") + if node.type == "container": + _validate_lockable(node, item_ids) + elif node.type == "readable": + if node.clue_id not in clue_ids: + raise DMCompileError(f"Readable '{node.id}' references unknown clue '{node.clue_id}'.") + if node.requires_item_id and node.requires_item_id not in item_ids: + raise DMCompileError(f"Readable '{node.id}' references unknown item '{node.requires_item_id}'.") + elif node.type == "fixture": + if node.requires_item_id not in item_ids: + raise DMCompileError(f"Fixture '{node.id}' references unknown item '{node.requires_item_id}'.") + if bool(node.reveals_item_id) == bool(node.reveals_readable_id): + raise DMCompileError(f"Fixture '{node.id}' must reveal exactly one item or readable.") + if node.reveals_item_id and node.reveals_item_id not in item_ids: + raise DMCompileError(f"Fixture '{node.id}' reveals unknown item '{node.reveals_item_id}'.") + if node.reveals_readable_id and node.reveals_readable_id not in node_by_id: + raise DMCompileError(f"Fixture '{node.id}' reveals unknown readable '{node.reveals_readable_id}'.") + if node.reveals_readable_id: + readable = node_by_id[node.reveals_readable_id] + if not isinstance(readable, ReadableNode): + raise DMCompileError(f"Fixture '{node.id}' can only reveal readable nodes.") + if readable.parent_id != node.parent_id: + raise DMCompileError( + f"Fixture '{node.id}' must reveal readable '{readable.id}' in the same room." + ) + elif node.type == "npc": + if node.id == guardian_id: + guardian_seen = True + if node.requires_item_id or node.gives_item_id or node.gives_clue_id: + raise DMCompileError("Guardian NPC cannot have trade fields.") + else: + if not node.requires_item_id: + raise DMCompileError(f"NPC '{node.id}' requires requires_item_id in v2.") + if node.requires_item_id not in item_ids: + raise DMCompileError(f"NPC '{node.id}' references unknown item '{node.requires_item_id}'.") + if bool(node.gives_item_id) == bool(node.gives_clue_id): + raise DMCompileError( + f"NPC '{node.id}' must define exactly one of gives_item_id or gives_clue_id." + ) + if node.gives_item_id and node.gives_item_id not in item_ids: + raise DMCompileError(f"NPC '{node.id}' gives unknown item '{node.gives_item_id}'.") + if node.gives_clue_id and node.gives_clue_id not in clue_ids: + raise DMCompileError(f"NPC '{node.id}' gives unknown clue '{node.gives_clue_id}'.") + else: # pragma: no cover + raise AssertionError(f"Unhandled node type {node.type}") + + if not guardian_seen: + raise DMCompileError(f"Guardian NPC '{guardian_id}' does not exist.") + for readable_id in hidden_readables: + readable = node_by_id[readable_id] + if not isinstance(readable, ReadableNode): + raise DMCompileError(f"Only readable nodes can be hidden, not '{readable_id}'.") + + +def _validate_lockable(node: ContainerNode | DoorNode, item_ids: set[str]) -> None: + if node.open and node.locked: + raise DMCompileError(f"Lockable node '{node.id}' cannot be both open and locked.") + if node.locked and not node.lock_key_id: + raise DMCompileError(f"Lockable node '{node.id}' is locked but has no lock_key_id.") + if node.lock_key_id and node.lock_key_id not in item_ids: + raise DMCompileError(f"Lockable node '{node.id}' references unknown key '{node.lock_key_id}'.") + + +def _validate_edges(world: WorldDefinition) -> None: + room_ids = {node.id for node in world.nodes if node.type in {"location", "junction"}} + node_by_id = {node.id: node for node in world.nodes} + item_ids = {item.id for item in world.items} + pair_groups: dict[frozenset[str], list[Any]] = defaultdict(list) + graph: dict[str, set[str]] = defaultdict(set) + direction_map: dict[tuple[str, str], str] = {} + + for edge in world.edges: + if edge.from_node_id not in room_ids or edge.to_node_id not in room_ids: + raise DMCompileError(f"Edge '{edge.id}' must connect location or junction nodes only.") + if edge.from_node_id == edge.to_node_id: + raise DMCompileError(f"Edge '{edge.id}' cannot be self-referential.") + if edge.required_item_id and edge.required_item_id not in item_ids: + raise DMCompileError(f"Edge '{edge.id}' references unknown item '{edge.required_item_id}'.") + if edge.required_item_id and edge.required_item_id not in { + item.id for item in world.items if item.subtype == "key" + }: + raise DMCompileError(f"Edge '{edge.id}' must use a key item, not '{edge.required_item_id}'.") + if edge.type == "locked_passage": + if not edge.door_node_id: + raise DMCompileError(f"Locked edge '{edge.id}' requires door_node_id.") + if not edge.required_item_id: + raise DMCompileError(f"Locked edge '{edge.id}' requires required_item_id.") + elif edge.required_item_id is not None: + raise DMCompileError(f"Only locked_passage edges can reference required_item_id (edge '{edge.id}').") + if edge.door_node_id: + door = node_by_id.get(edge.door_node_id) + if not isinstance(door, DoorNode): + raise DMCompileError(f"Edge '{edge.id}' references unknown door '{edge.door_node_id}'.") + if edge.required_item_id and door.lock_key_id != edge.required_item_id: + raise DMCompileError(f"Edge '{edge.id}' and door '{door.id}' disagree on the key.") + key = (edge.from_node_id, edge.direction) + if key in direction_map: + raise DMCompileError( + f"Edges '{direction_map[key]}' and '{edge.id}' both leave '{edge.from_node_id}' via '{edge.direction}'." + ) + direction_map[key] = edge.id + graph[edge.from_node_id].add(edge.to_node_id) + pair_groups[frozenset({edge.from_node_id, edge.to_node_id})].append(edge) + + for pair, edges in pair_groups.items(): + if len(edges) != 2: + raise DMCompileError(f"Edges between {', '.join(sorted(pair))} must be explicitly bidirectional.") + a, b = edges + if OPPOSITE_DIRECTION[a.direction] != b.direction: + raise DMCompileError(f"Edges '{a.id}' and '{b.id}' must use opposite directions.") + if a.type != b.type or a.required_item_id != b.required_item_id or a.door_node_id != b.door_node_id: + raise DMCompileError(f"Edge pair '{a.id}'/'{b.id}' must agree on type, key, and door.") + + reachable = _reachable_rooms(graph, world.meta.start_node_id) + if reachable != room_ids: + raise DMCompileError(f"Some rooms are unreachable from the start node: {sorted(room_ids - reachable)}") + + +def _validate_items(world: WorldDefinition) -> None: + node_by_id = {node.id: node for node in world.nodes} + produced = produced_item_ids(world) + recipe_outputs: set[str] = set() + recipe_inputs: set[frozenset[str]] = set() + for recipe in world.recipes: + inputs = frozenset(recipe.input_item_ids) + if len(inputs) != 2: + raise DMCompileError(f"Recipe '{recipe.id}' must have exactly two distinct input items.") + if inputs in recipe_inputs: + raise DMCompileError(f"Duplicate recipe inputs in '{recipe.id}'.") + recipe_inputs.add(inputs) + if recipe.output_item_id in recipe_outputs: + raise DMCompileError(f"Item '{recipe.output_item_id}' is produced by multiple recipes.") + recipe_outputs.add(recipe.output_item_id) + + for item in world.items: + if item.id in produced and item.start_node_id is not None: + raise DMCompileError(f"Produced item '{item.id}' must not be initially placed.") + if item.id not in produced and item.start_node_id is None: + raise DMCompileError(f"Placed item '{item.id}' requires start_node_id.") + if item.start_node_id is None: + continue + holder = node_by_id.get(item.start_node_id) + if holder is None: + raise DMCompileError(f"Item '{item.id}' starts in unknown node '{item.start_node_id}'.") + if holder.type not in {"location", "junction", "container"}: + raise DMCompileError(f"Item '{item.id}' must start in a room or container.") + if item.subtype not in {"key", "puzzle"}: + raise DMCompileError(f"Item '{item.id}' uses unsupported subtype '{item.subtype}'.") + + +def _validate_clues(world: WorldDefinition) -> None: + clue_sources: dict[str, list[str]] = defaultdict(list) + for node in world.nodes: + if isinstance(node, ReadableNode): + clue_sources[node.clue_id].append(node.id) + elif isinstance(node, NpcNode) and node.gives_clue_id: + clue_sources[node.gives_clue_id].append(node.id) + + clue_ids = {clue.id for clue in world.clues} + if set(clue_sources) != clue_ids: + missing = sorted(clue_ids - set(clue_sources)) + raise DMCompileError(f"Every clue needs exactly one source. Missing: {missing}") + for clue_id, source_ids in sorted(clue_sources.items()): + if len(source_ids) > 1: + raise DMCompileError( + f"Clue '{clue_id}' has multiple sources: {', '.join(sorted(source_ids))}." + ) + + +def _validate_visibility(world: WorldDefinition) -> None: + names: dict[str, str] = {} + for label in [node.label for node in world.nodes] + [item.label for item in world.items]: + safe = parser_safe_text(label) + if safe in names: + raise DMCompileError( + f"Visible labels '{label}' and '{names[safe]}' collapse to the same parser name '{safe}'." + ) + names[safe] = label + + +def _validate_answer_leaks(world: WorldDefinition) -> None: + answer = normalize_answer_text(world.meta.win_condition.answer_string) + forbidden = {f"the answer is {answer}", f"answer is {answer}", f"submit {answer}"} + text_fragments = [world.meta.title] + text_fragments.extend(clue.text for clue in world.clues) + for node in world.nodes: + text_fragments.extend([node.label, node.description]) + if isinstance(node, ReadableNode): + text_fragments.append(node.text_content) + for text in text_fragments: + normalized = normalize_answer_text(text) + if any(phrase in normalized for phrase in forbidden): + raise DMCompileError("World leaks the final answer too directly. Clues must stay partial.") + + +def _validate_guardian_path(world: WorldDefinition) -> None: + node_by_id = {node.id: node for node in world.nodes} + guardian = node_by_id[world.meta.win_condition.target_npc_id] + graph: dict[str, set[str]] = defaultdict(set) + for edge in world.edges: + if edge.type == "passage": + graph[edge.from_node_id].add(edge.to_node_id) + reachable = _reachable_rooms(graph, world.meta.start_node_id) + if guardian.parent_id not in reachable: + raise DMCompileError("Guardian room must be reachable from the start without item gates.") + + +def _validate_clue_gates(world: WorldDefinition) -> None: + reachable = _reachable_zero_item_rooms(world) + hidden_readables = hidden_readable_ids(world) + for node in world.nodes: + if isinstance(node, ReadableNode): + if node.id in hidden_readables: + continue + if node.parent_id not in reachable: + continue + if node.requires_item_id: + continue + raise DMCompileError( + f"Readable '{node.id}' exposes clue '{node.clue_id}' without any item interaction." + ) + if isinstance(node, NpcNode) and node.gives_clue_id and not node.requires_item_id: + raise DMCompileError(f"NPC '{node.id}' gives clue '{node.gives_clue_id}' without an item gate.") + + +def _validate_item_usage(world: WorldDefinition) -> None: + quest_items: set[str] = set() + ordered = topological_linearize(world.quest_chain) + for action in (parse_quest_action(step.action) for step in ordered): + if isinstance(action, UnlockAction): + quest_items.add(action.key_id) + elif isinstance(action, (UseAction, GiveAction)): + quest_items.add(action.item_id) + elif isinstance(action, CombineAction): + quest_items.update({action.item_a_id, action.item_b_id}) + elif isinstance(action, TakeAction): + quest_items.add(action.item_id) + + mechanical_items = { + edge.required_item_id + for edge in world.edges + if edge.required_item_id + } + for node in world.nodes: + if node.type == "container" and node.lock_key_id: + mechanical_items.add(node.lock_key_id) + elif node.type == "door" and node.lock_key_id: + mechanical_items.add(node.lock_key_id) + elif node.type == "readable" and node.requires_item_id: + mechanical_items.add(node.requires_item_id) + elif node.type == "fixture": + mechanical_items.add(node.requires_item_id) + if node.reveals_item_id: + mechanical_items.add(node.reveals_item_id) + elif node.type == "npc": + if node.requires_item_id: + mechanical_items.add(node.requires_item_id) + if node.gives_item_id: + mechanical_items.add(node.gives_item_id) + for recipe in world.recipes: + mechanical_items.update(recipe.input_item_ids) + mechanical_items.add(recipe.output_item_id) + + for item in world.items: + if item.id not in quest_items and item.id not in mechanical_items: + raise DMCompileError(f"Unused decorative items are not supported in v2: '{item.id}'.") + + +def _validate_quest_shape(world: WorldDefinition) -> None: + ordered = topological_linearize(world.quest_chain) + parsed = [parse_quest_action(step.action) for step in ordered] + if not isinstance(parsed[-1], SubmitAction): + raise DMCompileError('The final quest step must be submit("answer").') + if len(parsed) < 2 or not isinstance(parsed[-2], TalkAction): + raise DMCompileError("The penultimate quest step must be talk(guardian).") + if parsed[-2].target_node_id != world.meta.win_condition.target_npc_id: + raise DMCompileError("The final talk step must target the guardian NPC.") + if normalize_answer_text(parsed[-1].answer_text) != normalize_answer_text(world.meta.win_condition.answer_string): + raise DMCompileError("The final submit step must match win_condition.answer_string.") + entity_names = {node.id: parser_safe_text(node.label) for node in world.nodes} + entity_names.update({item.id: parser_safe_text(item.label) for item in world.items}) + simulate_walkthrough(world, parsed, entity_names) + + +def _reachable_rooms(graph: dict[str, set[str]], start: str) -> set[str]: + seen = {start} + queue = deque([start]) + while queue: + current = queue.popleft() + for nxt in graph.get(current, set()): + if nxt not in seen: + seen.add(nxt) + queue.append(nxt) + return seen + + +def _reachable_zero_item_rooms(world: WorldDefinition) -> set[str]: + graph: dict[str, set[str]] = defaultdict(set) + for edge in world.edges: + if edge.type == "passage": + graph[edge.from_node_id].add(edge.to_node_id) + return _reachable_rooms(graph, world.meta.start_node_id) diff --git a/agents/master/env.py b/agents/master/env.py new file mode 100644 index 0000000000000000000000000000000000000000..240b6010c86bacaff06c2dfdd7fcce63af1e221f --- /dev/null +++ b/agents/master/env.py @@ -0,0 +1,236 @@ +from __future__ import annotations + +import math +import uuid +from pathlib import Path +from typing import Any + +from .base import DMCompileError, DMInterfaceError, MAX_STEP_MULTIPLIER, TARGET_RATIO, TARGET_RATIO_SIGMA +from .build import WorldCompiler +from .interface import InterfaceAdapter, SimpleInterfaceAdapter +from .play import EpisodeRunner, WalkthroughRunner +from .schema import ( + CompiledWorld, + DMAction, + DMFeedback, + DMObservation, + DMRewardBreakdown, + DMState, + Turn, + WorldDefinition, +) +from .session import EpisodeSession +from .snapshots import LiveObserver +from agents.shared.openenv_compat import Environment, StepResult, build_step_result + + +class DMEnvironment(Environment[DMAction, DMObservation, DMState]): + def __init__( + self, + artifacts_root: Path | None = None, + target_ratio: float = TARGET_RATIO, + reward_sigma: float = TARGET_RATIO_SIGMA, + max_step_multiplier: int = MAX_STEP_MULTIPLIER, + interface_adapter: InterfaceAdapter = SimpleInterfaceAdapter(), + default_runner: EpisodeRunner | None = None, + ) -> None: + super().__init__() + if interface_adapter is None: + raise ValueError("interface_adapter must not be None.") + self.compiler = WorldCompiler(artifacts_root=artifacts_root) + self.target_ratio = target_ratio + self.reward_sigma = reward_sigma + self.max_step_multiplier = max_step_multiplier + self.interface_adapter = interface_adapter + self.default_runner = default_runner or WalkthroughRunner() + self.episode_count = 0 + self.success_count = 0 + self._state = DMState( + episode_id=uuid.uuid4().hex[:12], + target_ratio=target_ratio, + ) + self.last_compiled_world: CompiledWorld | None = None + + def reset(self, difficulty_hint: float | None = None, seed: int | None = None) -> DMObservation: + del seed + episode_target_ratio = self.target_ratio if difficulty_hint is None else difficulty_hint + self._state = DMState( + episode_id=uuid.uuid4().hex[:12], + compile_status="pending", + episode_status="running", + cumulative_success_rate=self._running_success_rate(), + target_ratio=episode_target_ratio, + difficulty_hint=difficulty_hint, + ) + self.last_compiled_world = None + return self._apply_transform( + DMObservation( + done=False, + reward=None, + target_ratio_used=episode_target_ratio, + ) + ) + + def step( # type: ignore[override] + self, + action: DMAction | WorldDefinition | dict[str, Any], + runner: EpisodeRunner | None = None, + observer: LiveObserver | None = None, + timeout_s: float | None = None, + ) -> StepResult[DMObservation]: + del timeout_s + world_input = action.world_definition if isinstance(action, DMAction) else action + compiled: CompiledWorld | None = None + session: EpisodeSession | None = None + if observer is not None: + observer.on_run_start(self._state.episode_id, world_input) + self.last_compiled_world = None + self._state.current_world = None + try: + compiled = self.compiler.compile(world_input, episode_id=self._state.episode_id) + self.last_compiled_world = compiled + self._state.current_world = compiled.world + self._state.compile_status = "valid" + max_steps = max(1, len(compiled.solver_policy) * self.max_step_multiplier) + + def on_turn(current_session: EpisodeSession, turn: Turn) -> None: + self._state.step_count = current_session.steps_taken + if observer is not None: + observer.on_turn(current_session, turn) + + session = EpisodeSession( + compiled, + interface_adapter=self.interface_adapter, + turn_listener=on_turn, + ) + if observer is not None: + observer.on_compile_success(compiled, session) + (runner or self.default_runner).run(session, max_steps=max_steps) + player_won = bool(session.player_won) + min_steps = len(compiled.solver_policy) + reward_breakdown = self._reward_breakdown(player_won, session.steps_taken, min_steps) + reward = reward_breakdown.reward + self.episode_count += 1 + self.success_count += int(player_won) + self._state.step_count = session.steps_taken + self._state.episode_status = "complete" if player_won else "failed" + self._state.cumulative_success_rate = self._running_success_rate() + observation = self._apply_transform( + DMObservation( + episode_transcript=session.transcript, + player_won=player_won, + steps_taken=session.steps_taken, + min_steps=min_steps, + ratio=(session.steps_taken / min_steps) if min_steps else None, + reward=reward, + done=True, + feedback=self._build_feedback(compiled, session), + reward_breakdown=reward_breakdown, + target_ratio_used=self._state.target_ratio, + ) + ) + if observer is not None: + observer.on_complete(compiled, session, observation) + return build_step_result(observation) + except (DMCompileError, DMInterfaceError, ValueError) as exc: + self.last_compiled_world = None + self._state.current_world = None + self._state.compile_status = "invalid" + self._state.episode_status = "failed" + if observer is not None: + observer.on_error( + episode_id=self._state.episode_id, + error=str(exc), + world_input=world_input, + compiled=compiled, + session=session, + ) + observation = self._apply_transform( + DMObservation( + player_won=False, + compile_error=str(exc), + reward=0.0, + done=True, + reward_breakdown=DMRewardBreakdown( + reward_mode="compile_failure_penalty", + player_won=False, + target_ratio=self._state.target_ratio, + quality_score=0.0, + reward=0.0, + ), + target_ratio_used=self._state.target_ratio, + ) + ) + return build_step_result(observation) + finally: + if session is not None: + session.close() + + def compile_world( + self, + world_input: WorldDefinition | dict[str, Any], + *, + episode_id: str | None = None, + ) -> CompiledWorld: + return self.compiler.compile(world_input, episode_id=episode_id) + + def play( + self, + world_input: WorldDefinition | dict[str, Any], + runner: EpisodeRunner | None = None, + observer: LiveObserver | None = None, + ) -> StepResult[DMObservation]: + self.reset() + return self.step(world_input, runner=runner, observer=observer) + + @property + def state(self) -> DMState: + return self._state + + def _reward_breakdown( + self, + player_won: bool, + steps_taken: int | None, + min_steps: int | None, + ) -> DMRewardBreakdown: + raw_ratio: float | None = None + clamped_ratio: float | None = None + target_ratio_delta: float | None = None + efficiency_score: float | None = None + quality_score = 0.0 + if steps_taken is not None and min_steps is not None and min_steps > 0: + raw_ratio = steps_taken / min_steps + clamped_ratio = max(raw_ratio, 1.0) + target_ratio_delta = abs(clamped_ratio - self._state.target_ratio) + if player_won and steps_taken > 0: + efficiency_score = min(1.0, min_steps / steps_taken) + sigma_sq = max(self.reward_sigma, 1e-6) ** 2 + quality_score = math.exp(-((clamped_ratio - self._state.target_ratio) ** 2) / (2.0 * sigma_sq)) + reward = quality_score if player_won else 0.0 + return DMRewardBreakdown( + reward_mode="gaussian_target_ratio", + player_won=player_won, + raw_ratio=raw_ratio, + clamped_ratio=clamped_ratio, + target_ratio=self._state.target_ratio, + target_ratio_delta=target_ratio_delta, + efficiency_score=efficiency_score, + quality_score=quality_score, + reward=reward, + ) + + def _build_feedback(self, compiled: CompiledWorld, session: EpisodeSession) -> DMFeedback: + room_ids = [node.id for node in compiled.world.nodes if node.type in {"location", "junction"}] + clue_ids = [clue.id for clue in compiled.world.clues] + unique_rooms = [node_id for node_id in session.visited_nodes if node_id in room_ids] + return DMFeedback( + unreachable_nodes=sorted(set(room_ids) - set(unique_rooms)), + unused_items=sorted({item.id for item in compiled.world.items} - session.used_items), + clues_missed=sorted(set(clue_ids) - session.discovered_clues), + mean_steps_per_room=session.steps_taken / max(1, len(set(unique_rooms))), + invalid_command_count=session.invalid_command_count, + wrong_submit_count=session.wrong_submit_count, + ) + + def _running_success_rate(self) -> float: + return 0.0 if self.episode_count == 0 else self.success_count / self.episode_count diff --git a/agents/master/graph.py b/agents/master/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..1609cdf2a4a91e37e326a01c0459b495a8123838 --- /dev/null +++ b/agents/master/graph.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +from collections import defaultdict + +from .schema import DoorNode, Edge, NpcTrade, ReadableNode, UseEffect, WorldDefinition + + +def readable_clue_mapping(world: WorldDefinition) -> dict[str, str]: + return {node.id: node.clue_id for node in world.nodes if isinstance(node, ReadableNode)} + + +def clue_source_mapping(world: WorldDefinition) -> dict[str, str]: + mapping = {node.clue_id: node.id for node in world.nodes if isinstance(node, ReadableNode)} + for node in world.nodes: + if node.type == "npc" and node.gives_clue_id: + mapping[node.gives_clue_id] = node.id + return mapping + + +def npc_trade_mapping(world: WorldDefinition) -> dict[str, NpcTrade]: + trades: dict[str, NpcTrade] = {} + for node in world.nodes: + if node.type != "npc" or node.id == world.meta.win_condition.target_npc_id: + continue + trades[node.id] = NpcTrade( + required_item_id=node.requires_item_id or "", + gives_item_id=node.gives_item_id, + gives_clue_id=node.gives_clue_id, + ) + return trades + + +def use_effect_mapping(world: WorldDefinition) -> dict[str, UseEffect]: + effects: dict[str, UseEffect] = {} + for node in world.nodes: + if node.type == "readable" and node.requires_item_id: + effects[node.id] = UseEffect( + required_item_id=node.requires_item_id, + clue_id=node.clue_id, + consumes_item=node.consumes_item, + ) + elif node.type == "fixture": + effects[node.id] = UseEffect( + required_item_id=node.requires_item_id, + reveals_item_id=node.reveals_item_id, + reveals_readable_id=node.reveals_readable_id, + consumes_item=node.consumes_item, + ) + return effects + + +def recipe_mapping(world: WorldDefinition) -> dict[frozenset[str], str]: + return {frozenset(recipe.input_item_ids): recipe.output_item_id for recipe in world.recipes} + + +def produced_item_ids(world: WorldDefinition) -> set[str]: + produced = {recipe.output_item_id for recipe in world.recipes} + for node in world.nodes: + if node.type == "npc" and node.gives_item_id: + produced.add(node.gives_item_id) + if node.type == "fixture" and node.reveals_item_id: + produced.add(node.reveals_item_id) + return produced + + +def hidden_readable_ids(world: WorldDefinition) -> set[str]: + return {node.reveals_readable_id for node in world.nodes if node.type == "fixture" and node.reveals_readable_id} + + +def door_room_mapping(world: WorldDefinition) -> dict[str, frozenset[str]]: + mapping: dict[str, set[str]] = defaultdict(set) + for edge in world.edges: + if edge.door_node_id: + mapping[edge.door_node_id].add(edge.from_node_id) + mapping[edge.door_node_id].add(edge.to_node_id) + return {door_id: frozenset(rooms) for door_id, rooms in mapping.items()} + + +def edge_for_door(world: WorldDefinition, door_id: str) -> Edge | None: + for edge in world.edges: + if edge.door_node_id == door_id: + return edge + return None + + +def door_nodes(world: WorldDefinition) -> dict[str, DoorNode]: + return {node.id: node for node in world.nodes if isinstance(node, DoorNode)} diff --git a/agents/master/interface.py b/agents/master/interface.py new file mode 100644 index 0000000000000000000000000000000000000000..f309ddf437b461a9887a064ef84a26f5e05ffb30 --- /dev/null +++ b/agents/master/interface.py @@ -0,0 +1,831 @@ +from __future__ import annotations + +import json +import os +import re +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Literal, Protocol + +from dotenv import load_dotenv +from google import genai +from google.genai import types +from textworld.core import GameState + +from agents.hero.cli import parse_cli_command + +from .base import DMInterfaceError, SUPPORTED_DIRECTIONS + +if TYPE_CHECKING: + from .session import EpisodeSession + + +DEFAULT_GEMINI_MODEL = "gemini-2.5-flash-lite" +_TEXTWORLD_PROMPT_LINE_RE = re.compile(r"^\s*>\s.*-\=\s.*=\-(?:\d+/\d+)?\s*$") +_TEXTWORLD_BANNER_CHAR_RE = re.compile(r"[\\|$_/]") +_TEXTWORLD_ROOM_HEADER_RE = re.compile(r"^\s*-\=\s*(?P