Spaces:
Runtime error
Runtime error
Deploy FATHOM-DM Space bundle
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- Dockerfile +19 -0
- README.md +15 -6
- agents/__init__.py +2 -0
- agents/hero/__init__.py +58 -0
- agents/hero/__main__.py +78 -0
- agents/hero/cli.py +115 -0
- agents/hero/env.py +450 -0
- agents/hero/policy.py +157 -0
- agents/hero/prompt.py +134 -0
- agents/hero/runner.py +92 -0
- agents/hero/schema.py +103 -0
- agents/loop/__init__.py +11 -0
- agents/loop/__main__.py +92 -0
- agents/loop/runner.py +253 -0
- agents/loop/schema.py +64 -0
- agents/master/__init__.py +15 -0
- agents/master/__main__.py +5 -0
- agents/master/base.py +84 -0
- agents/master/build.py +287 -0
- agents/master/check.py +435 -0
- agents/master/env.py +236 -0
- agents/master/graph.py +87 -0
- agents/master/interface.py +831 -0
- agents/master/logic.py +92 -0
- agents/master/main.py +72 -0
- agents/master/play.py +70 -0
- agents/master/policy.py +147 -0
- agents/master/prompt.py +371 -0
- agents/master/quest.py +418 -0
- agents/master/sample.py +499 -0
- agents/master/schema.py +316 -0
- agents/master/server.py +370 -0
- agents/master/session.py +484 -0
- agents/master/snapshots.py +308 -0
- agents/master/templates.py +44 -0
- agents/openenv_server/__init__.py +2 -0
- agents/openenv_server/__main__.py +72 -0
- agents/shared/__init__.py +43 -0
- agents/shared/llm_client.py +415 -0
- agents/shared/model_schema.py +14 -0
- agents/shared/openenv_compat.py +125 -0
- agents/shared/runtime.py +165 -0
- agents/spaces/__init__.py +13 -0
- agents/spaces/dm_space.py +194 -0
- agents/spaces/hero_space.py +271 -0
- agents/train/__init__.py +2 -0
- agents/train/__main__.py +361 -0
- agents/train/grpo.py +0 -0
- agents/train/joint.py +278 -0
- pyproject.toml +63 -0
Dockerfile
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
ENV PYTHONDONTWRITEBYTECODE=1
|
| 4 |
+
ENV PYTHONUNBUFFERED=1
|
| 5 |
+
|
| 6 |
+
WORKDIR /app
|
| 7 |
+
|
| 8 |
+
RUN apt-get update \
|
| 9 |
+
&& apt-get install -y --no-install-recommends build-essential git curl \
|
| 10 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 11 |
+
|
| 12 |
+
COPY . /app
|
| 13 |
+
|
| 14 |
+
RUN pip install --no-cache-dir --upgrade pip \
|
| 15 |
+
&& pip install --no-cache-dir .
|
| 16 |
+
|
| 17 |
+
EXPOSE 8000
|
| 18 |
+
|
| 19 |
+
CMD ["uvicorn", "agents.spaces.dm_space:create_app", "--factory", "--host", "0.0.0.0", "--port", "8000"]
|
README.md
CHANGED
|
@@ -1,10 +1,19 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji: 🏃
|
| 4 |
-
colorFrom: yellow
|
| 5 |
-
colorTo: blue
|
| 6 |
sdk: docker
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: DND-DM
|
|
|
|
|
|
|
|
|
|
| 3 |
sdk: docker
|
| 4 |
+
app_port: 8000
|
| 5 |
+
tags:
|
| 6 |
+
- openenv
|
| 7 |
+
- dnd
|
| 8 |
+
- textworld
|
| 9 |
---
|
| 10 |
|
| 11 |
+
# DND-DM
|
| 12 |
+
|
| 13 |
+
This Space hosts the CPU-only `DND-DM` environment.
|
| 14 |
+
|
| 15 |
+
- OpenEnv API: `/env`
|
| 16 |
+
- Health check: `/healthz`
|
| 17 |
+
- Latest normalized world output: `/world-output/latest`
|
| 18 |
+
|
| 19 |
+
`DND-DM` evaluates submitted world definitions. It does not generate worlds by itself.
|
agents/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Agent environments for the dungeon project."""
|
| 2 |
+
|
agents/hero/__init__.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hero agent environment and runner primitives."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import TYPE_CHECKING, Any
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
"HeroEnvironment",
|
| 9 |
+
"HeroLLMPolicy",
|
| 10 |
+
"HeroObservation",
|
| 11 |
+
"HeroPolicy",
|
| 12 |
+
"HeroPolicyError",
|
| 13 |
+
"HeroRunner",
|
| 14 |
+
"HeroServerAction",
|
| 15 |
+
"HeroState",
|
| 16 |
+
"HeroTraceEvent",
|
| 17 |
+
"ScriptedToolCallingPolicy",
|
| 18 |
+
"ToolCallingPolicy",
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
if TYPE_CHECKING:
|
| 22 |
+
from .env import HeroEnvironment
|
| 23 |
+
from .policy import HeroLLMPolicy, HeroPolicy, HeroPolicyError, HeroTraceEvent
|
| 24 |
+
from .runner import HeroRunner, ScriptedToolCallingPolicy, ToolCallingPolicy
|
| 25 |
+
from .schema import HeroObservation, HeroServerAction, HeroState
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def __getattr__(name: str) -> Any:
|
| 29 |
+
if name == "HeroEnvironment":
|
| 30 |
+
from .env import HeroEnvironment
|
| 31 |
+
|
| 32 |
+
return HeroEnvironment
|
| 33 |
+
if name in {"HeroLLMPolicy", "HeroPolicy", "HeroPolicyError", "HeroTraceEvent"}:
|
| 34 |
+
from .policy import HeroLLMPolicy, HeroPolicy, HeroPolicyError, HeroTraceEvent
|
| 35 |
+
|
| 36 |
+
return {
|
| 37 |
+
"HeroLLMPolicy": HeroLLMPolicy,
|
| 38 |
+
"HeroPolicy": HeroPolicy,
|
| 39 |
+
"HeroPolicyError": HeroPolicyError,
|
| 40 |
+
"HeroTraceEvent": HeroTraceEvent,
|
| 41 |
+
}[name]
|
| 42 |
+
if name in {"HeroRunner", "ScriptedToolCallingPolicy", "ToolCallingPolicy"}:
|
| 43 |
+
from .runner import HeroRunner, ScriptedToolCallingPolicy, ToolCallingPolicy
|
| 44 |
+
|
| 45 |
+
return {
|
| 46 |
+
"HeroRunner": HeroRunner,
|
| 47 |
+
"ScriptedToolCallingPolicy": ScriptedToolCallingPolicy,
|
| 48 |
+
"ToolCallingPolicy": ToolCallingPolicy,
|
| 49 |
+
}[name]
|
| 50 |
+
if name in {"HeroObservation", "HeroServerAction", "HeroState"}:
|
| 51 |
+
from .schema import HeroObservation, HeroServerAction, HeroState
|
| 52 |
+
|
| 53 |
+
return {
|
| 54 |
+
"HeroObservation": HeroObservation,
|
| 55 |
+
"HeroServerAction": HeroServerAction,
|
| 56 |
+
"HeroState": HeroState,
|
| 57 |
+
}[name]
|
| 58 |
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
agents/hero/__main__.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
from agents.master.sample import load_world
|
| 8 |
+
from agents.shared.runtime import build_interface_adapter, resolve_interface_config
|
| 9 |
+
|
| 10 |
+
from .env import HeroEnvironment
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _manual_action(raw: str) -> dict[str, object]:
|
| 14 |
+
if raw == "/read":
|
| 15 |
+
return {"tool": "scratchpad_read"}
|
| 16 |
+
if raw.startswith("/write append "):
|
| 17 |
+
return {"tool": "scratchpad_write", "mode": "append", "content": raw[len("/write append ") :]}
|
| 18 |
+
if raw.startswith("/write replace "):
|
| 19 |
+
return {"tool": "scratchpad_write", "mode": "replace", "content": raw[len("/write replace ") :]}
|
| 20 |
+
return {"tool": "act", "command": raw}
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def main(argv: list[str] | None = None) -> int:
|
| 24 |
+
parser = argparse.ArgumentParser(description="Local hero environment smoke runner")
|
| 25 |
+
parser.add_argument("mode", choices=["manual", "scripted"])
|
| 26 |
+
parser.add_argument("world", help="Path to a world-definition JSON file.")
|
| 27 |
+
parser.add_argument("--actions", help="JSON file containing a list of hero action objects.")
|
| 28 |
+
parser.add_argument("--debug", action="store_true")
|
| 29 |
+
parser.add_argument("--interface-model")
|
| 30 |
+
parser.add_argument(
|
| 31 |
+
"--translate-corporate-env",
|
| 32 |
+
action="store_true",
|
| 33 |
+
help="Rewrite observations into a corporate app metaphor and translate parser-safe corporate commands back through Gemini.",
|
| 34 |
+
)
|
| 35 |
+
args = parser.parse_args(argv)
|
| 36 |
+
|
| 37 |
+
world = load_world(args.world)
|
| 38 |
+
interface_adapter = build_interface_adapter(
|
| 39 |
+
resolve_interface_config(
|
| 40 |
+
model_name=args.interface_model,
|
| 41 |
+
translation_mode="corporate_app" if args.translate_corporate_env else None,
|
| 42 |
+
)
|
| 43 |
+
)
|
| 44 |
+
env = HeroEnvironment(debug=args.debug, interface_adapter=interface_adapter)
|
| 45 |
+
observation = env.reset(world)
|
| 46 |
+
print(observation.message)
|
| 47 |
+
|
| 48 |
+
if args.mode == "scripted":
|
| 49 |
+
if not args.actions:
|
| 50 |
+
parser.error("--actions is required for scripted mode.")
|
| 51 |
+
actions = json.loads(Path(args.actions).read_text(encoding="utf-8"))
|
| 52 |
+
for action in actions:
|
| 53 |
+
result = env.step(action)
|
| 54 |
+
print(result.observation.message)
|
| 55 |
+
if result.done:
|
| 56 |
+
print(json.dumps(result.observation.model_dump(), indent=2))
|
| 57 |
+
return 0
|
| 58 |
+
print(json.dumps(env.state.model_dump(), indent=2))
|
| 59 |
+
return 0
|
| 60 |
+
|
| 61 |
+
while not observation.done:
|
| 62 |
+
try:
|
| 63 |
+
raw = input("hero> ").strip()
|
| 64 |
+
except EOFError:
|
| 65 |
+
print()
|
| 66 |
+
return 0
|
| 67 |
+
if raw in {"quit", "exit"}:
|
| 68 |
+
return 0
|
| 69 |
+
result = env.step(_manual_action(raw))
|
| 70 |
+
observation = result.observation
|
| 71 |
+
print(observation.message)
|
| 72 |
+
if result.done:
|
| 73 |
+
print(json.dumps(observation.model_dump(), indent=2))
|
| 74 |
+
return 0
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
if __name__ == "__main__":
|
| 78 |
+
raise SystemExit(main())
|
agents/hero/cli.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import re
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
|
| 6 |
+
from agents.master.base import SUPPORTED_DIRECTIONS
|
| 7 |
+
|
| 8 |
+
_TOKEN_RE = re.compile(r"^[a-z0-9]+(?: [a-z0-9]+)*$")
|
| 9 |
+
_BANNED_OBJECT_TOKENS = {"a", "an", "the"}
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass(frozen=True)
|
| 13 |
+
class CliCommandAst:
|
| 14 |
+
kind: str
|
| 15 |
+
normalized_command: str
|
| 16 |
+
arguments: tuple[str, ...] = ()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass(frozen=True)
|
| 20 |
+
class CliCommandParseResult:
|
| 21 |
+
valid: bool
|
| 22 |
+
normalized_command: str | None = None
|
| 23 |
+
ast: CliCommandAst | None = None
|
| 24 |
+
error: str | None = None
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def parse_cli_command(raw_command: str) -> CliCommandParseResult:
|
| 28 |
+
normalized = normalize_cli_command(raw_command)
|
| 29 |
+
if not normalized:
|
| 30 |
+
return CliCommandParseResult(valid=False, error="Command must not be empty.")
|
| 31 |
+
|
| 32 |
+
if normalized in {"look", "inventory", "wait"}:
|
| 33 |
+
return _ok(normalized, normalized)
|
| 34 |
+
|
| 35 |
+
if normalized in SUPPORTED_DIRECTIONS:
|
| 36 |
+
return _ok("move", f"go {normalized}", normalized)
|
| 37 |
+
if normalized.startswith("go "):
|
| 38 |
+
direction = normalized[3:].strip()
|
| 39 |
+
if direction in SUPPORTED_DIRECTIONS:
|
| 40 |
+
return _ok("move", f"go {direction}", direction)
|
| 41 |
+
return CliCommandParseResult(valid=False, error="Unknown direction.")
|
| 42 |
+
|
| 43 |
+
if match := re.fullmatch(r"look in (?P<object>.+)", normalized):
|
| 44 |
+
object_text = match.group("object").strip()
|
| 45 |
+
return _object_result("look_in", normalized, object_text)
|
| 46 |
+
if match := re.fullmatch(r"take (?P<object>.+) from (?P<source>.+)", normalized):
|
| 47 |
+
return _two_object_result("take_from", normalized, match.group("object"), match.group("source"))
|
| 48 |
+
|
| 49 |
+
one_target_patterns = {
|
| 50 |
+
"open": r"open (?P<object>.+)",
|
| 51 |
+
"read": r"read (?P<object>.+)",
|
| 52 |
+
"talk": r"talk (?P<object>.+)",
|
| 53 |
+
"examine": r"examine (?P<object>.+)",
|
| 54 |
+
}
|
| 55 |
+
for kind, pattern in one_target_patterns.items():
|
| 56 |
+
if match := re.fullmatch(pattern, normalized):
|
| 57 |
+
object_text = match.group("object").strip()
|
| 58 |
+
return _object_result(kind, normalized, object_text)
|
| 59 |
+
if match := re.fullmatch(r"take (?P<object>.+)", normalized):
|
| 60 |
+
object_text = match.group("object").strip()
|
| 61 |
+
return _object_result("take", normalized, object_text)
|
| 62 |
+
if match := re.fullmatch(r"unlock (?P<object>.+) with (?P<tool>.+)", normalized):
|
| 63 |
+
return _two_object_result("unlock", normalized, match.group("object"), match.group("tool"))
|
| 64 |
+
if match := re.fullmatch(r"use (?P<object>.+) on (?P<target>.+)", normalized):
|
| 65 |
+
return _two_object_result("use", normalized, match.group("object"), match.group("target"))
|
| 66 |
+
if match := re.fullmatch(r"combine (?P<object>.+) with (?P<target>.+)", normalized):
|
| 67 |
+
return _two_object_result("combine", normalized, match.group("object"), match.group("target"))
|
| 68 |
+
if match := re.fullmatch(r"give (?P<object>.+) to (?P<target>.+)", normalized):
|
| 69 |
+
return _two_object_result("give", normalized, match.group("object"), match.group("target"))
|
| 70 |
+
|
| 71 |
+
if match := re.fullmatch(r"submit (?P<answer>[a-z0-9]+(?: [a-z0-9]+)*)", normalized):
|
| 72 |
+
answer = match.group("answer").strip()
|
| 73 |
+
return _ok("submit", normalized, answer)
|
| 74 |
+
|
| 75 |
+
return CliCommandParseResult(valid=False, error="Command does not match the strict CLI grammar.")
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def normalize_cli_command(raw_command: str) -> str:
|
| 79 |
+
return re.sub(r"\s+", " ", raw_command.strip().lower())
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _object_result(kind: str, normalized_command: str, object_text: str) -> CliCommandParseResult:
|
| 83 |
+
object_error = _validate_object_text(object_text)
|
| 84 |
+
if object_error is not None:
|
| 85 |
+
return CliCommandParseResult(valid=False, error=object_error)
|
| 86 |
+
return _ok(kind, normalized_command, object_text)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def _two_object_result(kind: str, normalized_command: str, first: str, second: str) -> CliCommandParseResult:
|
| 90 |
+
first_error = _validate_object_text(first)
|
| 91 |
+
if first_error is not None:
|
| 92 |
+
return CliCommandParseResult(valid=False, error=first_error)
|
| 93 |
+
second_error = _validate_object_text(second)
|
| 94 |
+
if second_error is not None:
|
| 95 |
+
return CliCommandParseResult(valid=False, error=second_error)
|
| 96 |
+
return _ok(kind, normalized_command, first.strip(), second.strip())
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def _validate_object_text(value: str) -> str | None:
|
| 100 |
+
candidate = value.strip()
|
| 101 |
+
if not candidate:
|
| 102 |
+
return "Command target must not be empty."
|
| 103 |
+
if not _TOKEN_RE.fullmatch(candidate):
|
| 104 |
+
return "Command targets must use lowercase letters, numbers, and spaces only."
|
| 105 |
+
if any(token in _BANNED_OBJECT_TOKENS for token in candidate.split()):
|
| 106 |
+
return "Strict CLI commands must use exact parser-safe object names without articles."
|
| 107 |
+
return None
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def _ok(kind: str, normalized_command: str, *arguments: str) -> CliCommandParseResult:
|
| 111 |
+
return CliCommandParseResult(
|
| 112 |
+
valid=True,
|
| 113 |
+
normalized_command=normalized_command,
|
| 114 |
+
ast=CliCommandAst(kind=kind, normalized_command=normalized_command, arguments=arguments),
|
| 115 |
+
)
|
agents/hero/env.py
ADDED
|
@@ -0,0 +1,450 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from collections import deque
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
from agents.master.base import DMInterfaceError, MAX_STEP_MULTIPLIER
|
| 8 |
+
from agents.master.build import WorldCompiler
|
| 9 |
+
from agents.master.interface import InterfaceAdapter, StrictCliInterfaceAdapter
|
| 10 |
+
from agents.master.schema import CompiledWorld, WorldDefinition
|
| 11 |
+
from agents.master.session import EpisodeSession
|
| 12 |
+
from agents.shared.openenv_compat import Environment, StepResult, build_step_result
|
| 13 |
+
|
| 14 |
+
from .cli import parse_cli_command
|
| 15 |
+
from .schema import (
|
| 16 |
+
ActAction,
|
| 17 |
+
HeroAction,
|
| 18 |
+
HeroAuxSignals,
|
| 19 |
+
HeroEpisodeStats,
|
| 20 |
+
HeroObservation,
|
| 21 |
+
HeroRewardBreakdown,
|
| 22 |
+
HeroState,
|
| 23 |
+
ScratchpadReadAction,
|
| 24 |
+
ScratchpadWriteAction,
|
| 25 |
+
validate_hero_action,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
_DENSE_PROGRESS_SCALE = 0.30
|
| 29 |
+
_SYNTAX_PENALTY = -0.02
|
| 30 |
+
_INVALID_ACTION_PENALTY = -0.02
|
| 31 |
+
_REPEAT_NOOP_PENALTY = -0.01
|
| 32 |
+
_WRONG_SUBMIT_PENALTY = -0.10
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class HeroEnvironment(Environment[HeroAction, HeroObservation, HeroState]):
|
| 36 |
+
def __init__(
|
| 37 |
+
self,
|
| 38 |
+
*,
|
| 39 |
+
artifacts_root: Path | None = None,
|
| 40 |
+
world_input: CompiledWorld | WorldDefinition | dict[str, Any] | None = None,
|
| 41 |
+
session: EpisodeSession | None = None,
|
| 42 |
+
interface_adapter: InterfaceAdapter | None = None,
|
| 43 |
+
model: str = "",
|
| 44 |
+
max_game_steps: int | None = None,
|
| 45 |
+
max_tool_calls: int | None = None,
|
| 46 |
+
scratchpad_max_chars: int = 8000,
|
| 47 |
+
debug: bool = False,
|
| 48 |
+
) -> None:
|
| 49 |
+
super().__init__()
|
| 50 |
+
self.compiler = WorldCompiler(artifacts_root=artifacts_root)
|
| 51 |
+
self._initial_world_input = world_input
|
| 52 |
+
self._provided_session = session
|
| 53 |
+
self._provided_interface_adapter = interface_adapter
|
| 54 |
+
self.model = model
|
| 55 |
+
self._default_max_game_steps = max_game_steps
|
| 56 |
+
self._default_max_tool_calls = max_tool_calls
|
| 57 |
+
self.scratchpad_max_chars = scratchpad_max_chars
|
| 58 |
+
self.debug = debug
|
| 59 |
+
self._state = HeroState()
|
| 60 |
+
self._compiled: CompiledWorld | None = None
|
| 61 |
+
self._session: EpisodeSession | None = None
|
| 62 |
+
self._scratchpad = ""
|
| 63 |
+
self._max_game_steps = 0
|
| 64 |
+
self._max_tool_calls = 0
|
| 65 |
+
self._debug_dir: Path | None = None
|
| 66 |
+
self._episode_stats = HeroEpisodeStats()
|
| 67 |
+
self._recent_noop_signatures: deque[tuple[str, str, str]] = deque(maxlen=3)
|
| 68 |
+
|
| 69 |
+
@classmethod
|
| 70 |
+
def from_session(
|
| 71 |
+
cls,
|
| 72 |
+
session: EpisodeSession,
|
| 73 |
+
*,
|
| 74 |
+
max_game_steps: int | None = None,
|
| 75 |
+
max_tool_calls: int | None = None,
|
| 76 |
+
scratchpad_max_chars: int = 8000,
|
| 77 |
+
debug: bool = False,
|
| 78 |
+
) -> "HeroEnvironment":
|
| 79 |
+
return cls(
|
| 80 |
+
session=session,
|
| 81 |
+
max_game_steps=max_game_steps,
|
| 82 |
+
max_tool_calls=max_tool_calls,
|
| 83 |
+
scratchpad_max_chars=scratchpad_max_chars,
|
| 84 |
+
debug=debug,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
def reset(
|
| 88 |
+
self,
|
| 89 |
+
world_input: CompiledWorld | WorldDefinition | dict[str, Any] | None = None,
|
| 90 |
+
*,
|
| 91 |
+
seed: int | None = None,
|
| 92 |
+
episode_id: str | None = None,
|
| 93 |
+
max_game_steps: int | None = None,
|
| 94 |
+
max_tool_calls: int | None = None,
|
| 95 |
+
scratchpad_max_chars: int | None = None,
|
| 96 |
+
debug: bool | None = None,
|
| 97 |
+
) -> HeroObservation:
|
| 98 |
+
del seed, episode_id
|
| 99 |
+
if debug is not None:
|
| 100 |
+
self.debug = debug
|
| 101 |
+
if scratchpad_max_chars is not None:
|
| 102 |
+
self.scratchpad_max_chars = scratchpad_max_chars
|
| 103 |
+
self._scratchpad = ""
|
| 104 |
+
self._episode_stats = HeroEpisodeStats()
|
| 105 |
+
self._recent_noop_signatures.clear()
|
| 106 |
+
|
| 107 |
+
if self._provided_session is not None:
|
| 108 |
+
self._session = self._provided_session
|
| 109 |
+
self._compiled = self._session.compiled
|
| 110 |
+
else:
|
| 111 |
+
selected_world = world_input if world_input is not None else self._initial_world_input
|
| 112 |
+
if selected_world is None:
|
| 113 |
+
raise ValueError("HeroEnvironment.reset requires a compiled world, world definition, or live session.")
|
| 114 |
+
self._compiled = (
|
| 115 |
+
selected_world
|
| 116 |
+
if isinstance(selected_world, CompiledWorld)
|
| 117 |
+
else self.compiler.compile(selected_world)
|
| 118 |
+
)
|
| 119 |
+
adapter = self._provided_interface_adapter or StrictCliInterfaceAdapter()
|
| 120 |
+
self._session = EpisodeSession(self._compiled, interface_adapter=adapter)
|
| 121 |
+
|
| 122 |
+
self._max_game_steps = max_game_steps or self._default_max_game_steps or max(
|
| 123 |
+
1, len(self._compiled.solver_policy) * MAX_STEP_MULTIPLIER
|
| 124 |
+
)
|
| 125 |
+
self._max_tool_calls = max_tool_calls or self._default_max_tool_calls or (self._max_game_steps * 4)
|
| 126 |
+
self._state = HeroState(
|
| 127 |
+
episode_id=self._compiled.episode_id,
|
| 128 |
+
step_count=0,
|
| 129 |
+
game_steps_taken=self._session.steps_taken,
|
| 130 |
+
tool_calls_total=0,
|
| 131 |
+
max_game_steps=self._max_game_steps,
|
| 132 |
+
max_tool_calls=self._max_tool_calls,
|
| 133 |
+
game_steps_remaining=max(0, self._max_game_steps - self._session.steps_taken),
|
| 134 |
+
tool_calls_remaining=self._max_tool_calls,
|
| 135 |
+
status="running",
|
| 136 |
+
world_title=self._compiled.world.meta.title,
|
| 137 |
+
last_command=None,
|
| 138 |
+
scratchpad_chars=0,
|
| 139 |
+
)
|
| 140 |
+
self._prepare_debug_dir()
|
| 141 |
+
reward_breakdown = self._empty_breakdown(self._progress_potential())
|
| 142 |
+
observation = self._apply_transform(
|
| 143 |
+
HeroObservation(
|
| 144 |
+
message=self._session.current_feedback(),
|
| 145 |
+
reward=0.0,
|
| 146 |
+
done=False,
|
| 147 |
+
won=None,
|
| 148 |
+
reward_breakdown=reward_breakdown,
|
| 149 |
+
aux_signals=self._progress_signals(),
|
| 150 |
+
)
|
| 151 |
+
)
|
| 152 |
+
return observation
|
| 153 |
+
|
| 154 |
+
def step( # type: ignore[override]
|
| 155 |
+
self,
|
| 156 |
+
action: HeroAction | dict[str, object],
|
| 157 |
+
timeout_s: float | None = None,
|
| 158 |
+
**kwargs: Any,
|
| 159 |
+
) -> StepResult[HeroObservation]:
|
| 160 |
+
del timeout_s, kwargs
|
| 161 |
+
if self._session is None or self._compiled is None:
|
| 162 |
+
raise RuntimeError("HeroEnvironment.reset must be called before step().")
|
| 163 |
+
if self._state.status != "running":
|
| 164 |
+
observation = HeroObservation(
|
| 165 |
+
message="",
|
| 166 |
+
reward=1.0 if self._state.status == "won" else 0.0,
|
| 167 |
+
done=True,
|
| 168 |
+
won=self._state.status == "won",
|
| 169 |
+
terminal_reason="episode_complete",
|
| 170 |
+
reward_breakdown=self._empty_breakdown(self._progress_potential()),
|
| 171 |
+
aux_signals=self._progress_signals(),
|
| 172 |
+
)
|
| 173 |
+
return build_step_result(self._apply_transform(observation))
|
| 174 |
+
|
| 175 |
+
parsed = validate_hero_action(action)
|
| 176 |
+
self._state.tool_calls_total += 1
|
| 177 |
+
self._state.step_count = self._state.tool_calls_total
|
| 178 |
+
self._update_remaining_counters()
|
| 179 |
+
|
| 180 |
+
if isinstance(parsed, ScratchpadReadAction):
|
| 181 |
+
observation = self._observation(
|
| 182 |
+
message=self._scratchpad,
|
| 183 |
+
tool=parsed.tool,
|
| 184 |
+
tool_success=True,
|
| 185 |
+
reward_breakdown=self._empty_breakdown(self._progress_potential()),
|
| 186 |
+
)
|
| 187 |
+
return build_step_result(observation)
|
| 188 |
+
|
| 189 |
+
if isinstance(parsed, ScratchpadWriteAction):
|
| 190 |
+
observation = self._handle_scratchpad_write(parsed)
|
| 191 |
+
return build_step_result(observation)
|
| 192 |
+
|
| 193 |
+
observation = self._handle_act(parsed)
|
| 194 |
+
return build_step_result(observation)
|
| 195 |
+
|
| 196 |
+
@property
|
| 197 |
+
def state(self) -> HeroState:
|
| 198 |
+
return self._state
|
| 199 |
+
|
| 200 |
+
@property
|
| 201 |
+
def scratchpad(self) -> str:
|
| 202 |
+
return self._scratchpad
|
| 203 |
+
|
| 204 |
+
@property
|
| 205 |
+
def session(self) -> EpisodeSession | None:
|
| 206 |
+
return self._session
|
| 207 |
+
|
| 208 |
+
@property
|
| 209 |
+
def episode_stats(self) -> HeroEpisodeStats:
|
| 210 |
+
return self._episode_stats
|
| 211 |
+
|
| 212 |
+
def _handle_scratchpad_write(self, action: ScratchpadWriteAction) -> HeroObservation:
|
| 213 |
+
new_value = (
|
| 214 |
+
self._scratchpad + action.content
|
| 215 |
+
if action.mode == "append"
|
| 216 |
+
else action.content
|
| 217 |
+
)
|
| 218 |
+
if len(new_value) > self.scratchpad_max_chars:
|
| 219 |
+
return self._observation(
|
| 220 |
+
message="Scratchpad write rejected: notebook size limit exceeded.",
|
| 221 |
+
tool=action.tool,
|
| 222 |
+
tool_success=False,
|
| 223 |
+
reward_breakdown=self._empty_breakdown(self._progress_potential()),
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
self._scratchpad = new_value
|
| 227 |
+
self._state.scratchpad_chars = len(self._scratchpad)
|
| 228 |
+
self._persist_debug_scratchpad()
|
| 229 |
+
return self._observation(
|
| 230 |
+
message="Scratchpad updated.",
|
| 231 |
+
tool=action.tool,
|
| 232 |
+
tool_success=True,
|
| 233 |
+
reward_breakdown=self._empty_breakdown(self._progress_potential()),
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
def _handle_act(self, action: ActAction) -> HeroObservation:
|
| 237 |
+
assert self._session is not None
|
| 238 |
+
parsed_command = parse_cli_command(action.command)
|
| 239 |
+
self._state.last_command = parsed_command.normalized_command or action.command
|
| 240 |
+
if not parsed_command.valid or parsed_command.normalized_command is None:
|
| 241 |
+
breakdown = self._empty_breakdown(self._progress_potential())
|
| 242 |
+
breakdown.syntax_penalty = _SYNTAX_PENALTY
|
| 243 |
+
return self._observation(
|
| 244 |
+
message=parsed_command.error or "That command does not match the strict CLI grammar.",
|
| 245 |
+
tool=action.tool,
|
| 246 |
+
tool_success=False,
|
| 247 |
+
reward_breakdown=breakdown,
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
potential_before = self._progress_potential()
|
| 251 |
+
fingerprint_before = self._session.state_fingerprint()
|
| 252 |
+
room_before = self._session.current_room_id
|
| 253 |
+
try:
|
| 254 |
+
turn = self._session.step(parsed_command.normalized_command)
|
| 255 |
+
except DMInterfaceError:
|
| 256 |
+
breakdown = self._empty_breakdown(potential_before)
|
| 257 |
+
breakdown.syntax_penalty = _SYNTAX_PENALTY
|
| 258 |
+
return self._observation(
|
| 259 |
+
message="The interface could not interpret that action.",
|
| 260 |
+
tool=action.tool,
|
| 261 |
+
tool_success=False,
|
| 262 |
+
reward_breakdown=breakdown,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
tool_success = self._turn_succeeded(turn.game_state_delta)
|
| 266 |
+
self._state.game_steps_taken = self._session.steps_taken
|
| 267 |
+
self._session.recent_normalized_commands.append(parsed_command.normalized_command)
|
| 268 |
+
potential_after = self._progress_potential()
|
| 269 |
+
breakdown = self._empty_breakdown(potential_before)
|
| 270 |
+
breakdown.progress_potential_after = potential_after
|
| 271 |
+
breakdown.dense_progress_reward = _DENSE_PROGRESS_SCALE * max(0.0, potential_after - potential_before)
|
| 272 |
+
if not tool_success:
|
| 273 |
+
breakdown.invalid_action_penalty = _INVALID_ACTION_PENALTY
|
| 274 |
+
if self._is_wrong_submit(turn.game_state_delta):
|
| 275 |
+
breakdown.wrong_submit_penalty = _WRONG_SUBMIT_PENALTY
|
| 276 |
+
if self._repeat_noop(parsed_command.normalized_command, fingerprint_before, room_before):
|
| 277 |
+
breakdown.repeat_noop_penalty = _REPEAT_NOOP_PENALTY
|
| 278 |
+
return self._observation(
|
| 279 |
+
message=turn.observation,
|
| 280 |
+
tool=action.tool,
|
| 281 |
+
tool_success=tool_success,
|
| 282 |
+
reward_breakdown=breakdown,
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
def _update_remaining_counters(self) -> None:
|
| 286 |
+
self._state.game_steps_remaining = max(0, self._max_game_steps - self._state.game_steps_taken)
|
| 287 |
+
self._state.tool_calls_remaining = max(0, self._max_tool_calls - self._state.tool_calls_total)
|
| 288 |
+
|
| 289 |
+
def _turn_succeeded(self, delta: dict[str, Any]) -> bool:
|
| 290 |
+
if delta.get("wrapper") == "submit_rejected":
|
| 291 |
+
return False
|
| 292 |
+
if "succeeded" in delta:
|
| 293 |
+
return bool(delta["succeeded"])
|
| 294 |
+
return True
|
| 295 |
+
|
| 296 |
+
def _observation(
|
| 297 |
+
self,
|
| 298 |
+
*,
|
| 299 |
+
message: str,
|
| 300 |
+
tool: str,
|
| 301 |
+
tool_success: bool,
|
| 302 |
+
reward_breakdown: HeroRewardBreakdown,
|
| 303 |
+
) -> HeroObservation:
|
| 304 |
+
assert self._session is not None
|
| 305 |
+
done = False
|
| 306 |
+
won: bool | None = None
|
| 307 |
+
terminal_reason: str | None = None
|
| 308 |
+
|
| 309 |
+
if self._session.player_won:
|
| 310 |
+
self._state.status = "won"
|
| 311 |
+
done = True
|
| 312 |
+
won = True
|
| 313 |
+
reward_breakdown.base_terminal_reward = 1.0
|
| 314 |
+
elif self._session.done:
|
| 315 |
+
self._state.status = "lost"
|
| 316 |
+
done = True
|
| 317 |
+
won = False
|
| 318 |
+
terminal_reason = "session_ended"
|
| 319 |
+
elif self._state.game_steps_taken >= self._max_game_steps:
|
| 320 |
+
self._state.status = "timed_out"
|
| 321 |
+
done = True
|
| 322 |
+
won = False
|
| 323 |
+
terminal_reason = "game_step_budget_exhausted"
|
| 324 |
+
elif self._state.tool_calls_total >= self._max_tool_calls:
|
| 325 |
+
self._state.status = "timed_out"
|
| 326 |
+
done = True
|
| 327 |
+
won = False
|
| 328 |
+
terminal_reason = "tool_budget_exhausted"
|
| 329 |
+
|
| 330 |
+
reward_breakdown.total_reward = (
|
| 331 |
+
reward_breakdown.base_terminal_reward
|
| 332 |
+
+ reward_breakdown.dense_progress_reward
|
| 333 |
+
+ reward_breakdown.syntax_penalty
|
| 334 |
+
+ reward_breakdown.invalid_action_penalty
|
| 335 |
+
+ reward_breakdown.repeat_noop_penalty
|
| 336 |
+
+ reward_breakdown.wrong_submit_penalty
|
| 337 |
+
)
|
| 338 |
+
self._update_remaining_counters()
|
| 339 |
+
aux_signals = self._progress_signals()
|
| 340 |
+
self._accumulate_episode_stats(reward_breakdown, won is True)
|
| 341 |
+
|
| 342 |
+
observation = self._apply_transform(
|
| 343 |
+
HeroObservation(
|
| 344 |
+
message=message,
|
| 345 |
+
reward=reward_breakdown.total_reward,
|
| 346 |
+
done=done,
|
| 347 |
+
won=won,
|
| 348 |
+
tool=tool,
|
| 349 |
+
tool_success=tool_success,
|
| 350 |
+
terminal_reason=terminal_reason,
|
| 351 |
+
reward_breakdown=reward_breakdown,
|
| 352 |
+
aux_signals=aux_signals,
|
| 353 |
+
)
|
| 354 |
+
)
|
| 355 |
+
return observation
|
| 356 |
+
|
| 357 |
+
def _prepare_debug_dir(self) -> None:
|
| 358 |
+
if not self.debug or self._compiled is None:
|
| 359 |
+
self._debug_dir = None
|
| 360 |
+
return
|
| 361 |
+
self._debug_dir = self._compiled.artifacts_dir / "hero_debug"
|
| 362 |
+
self._debug_dir.mkdir(parents=True, exist_ok=True)
|
| 363 |
+
self._persist_debug_scratchpad()
|
| 364 |
+
|
| 365 |
+
def _persist_debug_scratchpad(self) -> None:
|
| 366 |
+
if self._debug_dir is None:
|
| 367 |
+
return
|
| 368 |
+
(self._debug_dir / "scratchpad.txt").write_text(self._scratchpad, encoding="utf-8")
|
| 369 |
+
|
| 370 |
+
def _progress_signals(self) -> HeroAuxSignals:
|
| 371 |
+
assert self._session is not None
|
| 372 |
+
assert self._compiled is not None
|
| 373 |
+
room_ids = {node.id for node in self._compiled.world.nodes if node.type in {"location", "junction"}}
|
| 374 |
+
total_locked_doors = {
|
| 375 |
+
edge.door_node_id
|
| 376 |
+
for edge in self._compiled.world.edges
|
| 377 |
+
if edge.type == "locked_passage" and edge.door_node_id
|
| 378 |
+
}
|
| 379 |
+
total_clues = {clue.id for clue in self._compiled.world.clues}
|
| 380 |
+
answer_ready = float(
|
| 381 |
+
bool(total_clues)
|
| 382 |
+
and self._session.consulted_guardian
|
| 383 |
+
and self._session.discovered_clues == total_clues
|
| 384 |
+
)
|
| 385 |
+
return HeroAuxSignals(
|
| 386 |
+
visited_room_progress=_fraction(len(self._session.visited_nodes & room_ids), len(room_ids)),
|
| 387 |
+
clue_progress=_fraction(len(self._session.discovered_clues), len(total_clues)),
|
| 388 |
+
locked_gate_progress=_fraction(len(self._session.unlocked_doors), len(total_locked_doors)),
|
| 389 |
+
trade_progress=_fraction(len(self._session.traded_npcs), len(self._compiled.npc_trade_map)),
|
| 390 |
+
recipe_progress=_fraction(len(self._session.completed_recipe_outputs), len(self._compiled.world.recipes)),
|
| 391 |
+
use_effect_progress=_fraction(len(self._session.completed_use_targets), len(self._compiled.use_effects)),
|
| 392 |
+
guardian_consulted_progress=1.0 if self._session.consulted_guardian else 0.0,
|
| 393 |
+
answer_ready_progress=answer_ready,
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
def _progress_potential(self) -> float:
|
| 397 |
+
signals = self._progress_signals()
|
| 398 |
+
potential = (
|
| 399 |
+
0.10 * signals.visited_room_progress
|
| 400 |
+
+ 0.35 * signals.clue_progress
|
| 401 |
+
+ 0.10 * signals.locked_gate_progress
|
| 402 |
+
+ 0.10 * signals.trade_progress
|
| 403 |
+
+ 0.10 * signals.recipe_progress
|
| 404 |
+
+ 0.15 * signals.use_effect_progress
|
| 405 |
+
+ 0.05 * signals.guardian_consulted_progress
|
| 406 |
+
+ 0.05 * signals.answer_ready_progress
|
| 407 |
+
)
|
| 408 |
+
return max(0.0, min(1.0, potential))
|
| 409 |
+
|
| 410 |
+
def _empty_breakdown(self, potential: float) -> HeroRewardBreakdown:
|
| 411 |
+
return HeroRewardBreakdown(
|
| 412 |
+
progress_potential_before=potential,
|
| 413 |
+
progress_potential_after=potential,
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
def _repeat_noop(self, command: str, fingerprint_before: str, room_before: str) -> bool:
|
| 417 |
+
assert self._session is not None
|
| 418 |
+
fingerprint_after = self._session.state_fingerprint()
|
| 419 |
+
room_after = self._session.current_room_id
|
| 420 |
+
if room_before == room_after and fingerprint_before == fingerprint_after:
|
| 421 |
+
self._recent_noop_signatures.append((command, room_after, fingerprint_after))
|
| 422 |
+
else:
|
| 423 |
+
self._recent_noop_signatures.clear()
|
| 424 |
+
return (
|
| 425 |
+
len(self._recent_noop_signatures) == 3
|
| 426 |
+
and len({signature[0] for signature in self._recent_noop_signatures}) == 1
|
| 427 |
+
and len({signature[1] for signature in self._recent_noop_signatures}) == 1
|
| 428 |
+
and len({signature[2] for signature in self._recent_noop_signatures}) == 1
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
@staticmethod
|
| 432 |
+
def _is_wrong_submit(delta: dict[str, Any]) -> bool:
|
| 433 |
+
return delta.get("wrapper") == "submit_rejected" and delta.get("reason") == "wrong_answer"
|
| 434 |
+
|
| 435 |
+
def _accumulate_episode_stats(self, breakdown: HeroRewardBreakdown, player_won: bool) -> None:
|
| 436 |
+
self._episode_stats.player_won = player_won or self._episode_stats.player_won
|
| 437 |
+
self._episode_stats.total_reward += breakdown.total_reward
|
| 438 |
+
self._episode_stats.dense_return += breakdown.dense_progress_reward
|
| 439 |
+
self._episode_stats.syntax_penalty_total += breakdown.syntax_penalty
|
| 440 |
+
self._episode_stats.invalid_action_penalty_total += breakdown.invalid_action_penalty
|
| 441 |
+
self._episode_stats.repeat_noop_penalty_total += breakdown.repeat_noop_penalty
|
| 442 |
+
self._episode_stats.wrong_submit_penalty_total += breakdown.wrong_submit_penalty
|
| 443 |
+
self._episode_stats.steps_taken = self._state.game_steps_taken
|
| 444 |
+
self._episode_stats.tool_calls_total = self._state.tool_calls_total
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
def _fraction(done: int, total: int) -> float:
|
| 448 |
+
if total <= 0:
|
| 449 |
+
return 0.0
|
| 450 |
+
return min(1.0, done / total)
|
agents/hero/policy.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Literal, Protocol
|
| 4 |
+
|
| 5 |
+
from pydantic import BaseModel
|
| 6 |
+
|
| 7 |
+
from agents.shared.llm_client import StructuredModelClient
|
| 8 |
+
from agents.shared.model_schema import ModelMessage, StrictModel
|
| 9 |
+
|
| 10 |
+
from .cli import parse_cli_command
|
| 11 |
+
from .prompt import format_hero_system_prompt, format_hero_turn_prompt
|
| 12 |
+
from .schema import ActAction, HeroAction, HeroObservation, HeroState, validate_hero_action
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class HeroPolicyError(RuntimeError):
|
| 16 |
+
pass
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class HeroPolicy(Protocol):
|
| 20 |
+
trace_events: list["HeroTraceEvent"]
|
| 21 |
+
last_error: str | None
|
| 22 |
+
|
| 23 |
+
def reset(self) -> None:
|
| 24 |
+
...
|
| 25 |
+
|
| 26 |
+
def next_action(
|
| 27 |
+
self,
|
| 28 |
+
observation: HeroObservation,
|
| 29 |
+
state: HeroState,
|
| 30 |
+
scratchpad: str,
|
| 31 |
+
) -> HeroAction:
|
| 32 |
+
...
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class HeroActionPayload(BaseModel):
|
| 36 |
+
tool: Literal["act", "scratchpad_read", "scratchpad_write"]
|
| 37 |
+
command: str | None = None
|
| 38 |
+
mode: Literal["append", "replace"] | None = None
|
| 39 |
+
content: str | None = None
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class HeroActionResponse(BaseModel):
|
| 43 |
+
action: HeroActionPayload
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class HeroTraceEvent(StrictModel):
|
| 47 |
+
turn_index: int
|
| 48 |
+
observation: str
|
| 49 |
+
scratchpad: str
|
| 50 |
+
state: dict[str, object]
|
| 51 |
+
action: dict[str, object] | None = None
|
| 52 |
+
repair_count: int = 0
|
| 53 |
+
validation_error: str | None = None
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class HeroLLMPolicy:
|
| 57 |
+
def __init__(
|
| 58 |
+
self,
|
| 59 |
+
client: StructuredModelClient,
|
| 60 |
+
*,
|
| 61 |
+
model_name: str,
|
| 62 |
+
temperature: float = 0.1,
|
| 63 |
+
max_output_tokens: int = 256,
|
| 64 |
+
max_repair_attempts: int = 1,
|
| 65 |
+
) -> None:
|
| 66 |
+
self.client = client
|
| 67 |
+
self.model_name = model_name
|
| 68 |
+
self.temperature = temperature
|
| 69 |
+
self.max_output_tokens = max_output_tokens
|
| 70 |
+
self.max_repair_attempts = max_repair_attempts
|
| 71 |
+
self.trace_events: list[HeroTraceEvent] = []
|
| 72 |
+
self.last_error: str | None = None
|
| 73 |
+
|
| 74 |
+
def reset(self) -> None:
|
| 75 |
+
self.trace_events = []
|
| 76 |
+
self.last_error = None
|
| 77 |
+
|
| 78 |
+
def next_action(
|
| 79 |
+
self,
|
| 80 |
+
observation: HeroObservation,
|
| 81 |
+
state: HeroState,
|
| 82 |
+
scratchpad: str,
|
| 83 |
+
) -> HeroAction:
|
| 84 |
+
repair_error: str | None = None
|
| 85 |
+
for attempt in range(self.max_repair_attempts + 1):
|
| 86 |
+
try:
|
| 87 |
+
response = self.client.generate_structured(
|
| 88 |
+
self._messages(observation, state, scratchpad, repair_error),
|
| 89 |
+
HeroActionResponse,
|
| 90 |
+
model_name=self.model_name,
|
| 91 |
+
temperature=self.temperature,
|
| 92 |
+
max_output_tokens=self.max_output_tokens,
|
| 93 |
+
)
|
| 94 |
+
action = validate_hero_action(response.action.model_dump(mode="json", exclude_none=True))
|
| 95 |
+
if isinstance(action, ActAction):
|
| 96 |
+
parsed_command = parse_cli_command(action.command)
|
| 97 |
+
if not parsed_command.valid or parsed_command.normalized_command is None:
|
| 98 |
+
raise ValueError(parsed_command.error or "Invalid strict CLI command.")
|
| 99 |
+
action = ActAction(command=parsed_command.normalized_command)
|
| 100 |
+
self.trace_events.append(
|
| 101 |
+
HeroTraceEvent(
|
| 102 |
+
turn_index=len(self.trace_events),
|
| 103 |
+
observation=observation.message,
|
| 104 |
+
scratchpad=scratchpad,
|
| 105 |
+
state=state.model_dump(mode="json"),
|
| 106 |
+
action=action.model_dump(mode="json"),
|
| 107 |
+
repair_count=attempt,
|
| 108 |
+
)
|
| 109 |
+
)
|
| 110 |
+
self.last_error = None
|
| 111 |
+
return action
|
| 112 |
+
except Exception as exc:
|
| 113 |
+
repair_error = self._normalize_error(exc)
|
| 114 |
+
if attempt >= self.max_repair_attempts:
|
| 115 |
+
self.last_error = repair_error
|
| 116 |
+
self.trace_events.append(
|
| 117 |
+
HeroTraceEvent(
|
| 118 |
+
turn_index=len(self.trace_events),
|
| 119 |
+
observation=observation.message,
|
| 120 |
+
scratchpad=scratchpad,
|
| 121 |
+
state=state.model_dump(mode="json"),
|
| 122 |
+
repair_count=attempt,
|
| 123 |
+
validation_error=repair_error,
|
| 124 |
+
)
|
| 125 |
+
)
|
| 126 |
+
raise HeroPolicyError(repair_error) from exc
|
| 127 |
+
raise HeroPolicyError("Hero policy failed without a usable action.")
|
| 128 |
+
|
| 129 |
+
def _messages(
|
| 130 |
+
self,
|
| 131 |
+
observation: HeroObservation,
|
| 132 |
+
state: HeroState,
|
| 133 |
+
scratchpad: str,
|
| 134 |
+
repair_error: str | None,
|
| 135 |
+
) -> list[ModelMessage]:
|
| 136 |
+
user_prompt = format_hero_turn_prompt(observation.message, state, scratchpad)
|
| 137 |
+
if repair_error is not None:
|
| 138 |
+
user_prompt += (
|
| 139 |
+
"\nThe previous response did not match the action schema.\n"
|
| 140 |
+
f"Validation error: {repair_error}\n"
|
| 141 |
+
"Return one corrected action only.\n"
|
| 142 |
+
)
|
| 143 |
+
return [
|
| 144 |
+
ModelMessage(
|
| 145 |
+
role="system",
|
| 146 |
+
content=format_hero_system_prompt(
|
| 147 |
+
state.world_title,
|
| 148 |
+
state.max_game_steps,
|
| 149 |
+
state.max_tool_calls,
|
| 150 |
+
),
|
| 151 |
+
),
|
| 152 |
+
ModelMessage(role="user", content=user_prompt),
|
| 153 |
+
]
|
| 154 |
+
|
| 155 |
+
@staticmethod
|
| 156 |
+
def _normalize_error(exc: Exception) -> str:
|
| 157 |
+
return " ".join(str(exc).split()) or exc.__class__.__name__
|
agents/hero/prompt.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from .schema import HeroState
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
HERO_SYSTEM_PROMPT = """You are the hero exploring a living dungeon.
|
| 7 |
+
|
| 8 |
+
You can only act through tools.
|
| 9 |
+
|
| 10 |
+
Rules:
|
| 11 |
+
- Use `act` for any in-world action with one strict parser-style CLI command.
|
| 12 |
+
- Use `scratchpad_read` and `scratchpad_write` to manage your own notebook.
|
| 13 |
+
- Track rooms, objects, clues, hypotheses, and failed attempts in the notebook.
|
| 14 |
+
- Do not assume the world is fair in obvious ways; verify.
|
| 15 |
+
- Do not expect command hints from the environment. Use `look` and `inventory` when needed.
|
| 16 |
+
- Prefer systematic play: open visible containers and doors, take portable items, read text, talk to NPCs, and backtrack when blocked.
|
| 17 |
+
- When a puzzle reveals a clue, record it immediately.
|
| 18 |
+
- Do not submit an answer until you have enough evidence and the guardian is ready.
|
| 19 |
+
- Winning requires gathering evidence and then answering the guardian correctly.
|
| 20 |
+
- Keep your notebook concise and update it when the world changes.
|
| 21 |
+
- Commands must be lowercase only, with no articles, no markdown, and no conversational text.
|
| 22 |
+
- Allowed command grammar:
|
| 23 |
+
look
|
| 24 |
+
inventory
|
| 25 |
+
wait
|
| 26 |
+
north|south|east|west|up|down|in|out
|
| 27 |
+
go north|go south|go east|go west|go up|go down|go in|go out
|
| 28 |
+
open <object>
|
| 29 |
+
read <object>
|
| 30 |
+
talk <npc>
|
| 31 |
+
examine <object>
|
| 32 |
+
look in <object>
|
| 33 |
+
take <item>
|
| 34 |
+
take <item> from <container>
|
| 35 |
+
unlock <door> with <key>
|
| 36 |
+
use <item> on <target>
|
| 37 |
+
combine <item_a> with <item_b>
|
| 38 |
+
give <item> to <npc>
|
| 39 |
+
submit <answer>
|
| 40 |
+
- Example valid commands:
|
| 41 |
+
open entry chest
|
| 42 |
+
take brass key from entry chest
|
| 43 |
+
unlock iron door with brass key
|
| 44 |
+
east
|
| 45 |
+
use torch on ash mural
|
| 46 |
+
talk stone guardian
|
| 47 |
+
submit mira
|
| 48 |
+
- Return JSON only. Never add prose, markdown fences, or explanations.
|
| 49 |
+
- Valid response shapes:
|
| 50 |
+
{"action":{"tool":"act","command":"look"}}
|
| 51 |
+
{"action":{"tool":"scratchpad_read"}}
|
| 52 |
+
{"action":{"tool":"scratchpad_write","mode":"append","content":"room notes"}}
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
HERO_GRPO_SYSTEM_PROMPT = """You are the hero exploring a living dungeon.
|
| 56 |
+
|
| 57 |
+
You can only act through tool calls.
|
| 58 |
+
|
| 59 |
+
Rules:
|
| 60 |
+
- Call exactly one tool for each turn.
|
| 61 |
+
- Use `act` for any in-world action with one strict parser-style CLI command.
|
| 62 |
+
- Use `scratchpad_read` and `scratchpad_write` to manage your own notebook.
|
| 63 |
+
- Track rooms, objects, clues, hypotheses, and failed attempts in the notebook.
|
| 64 |
+
- Do not assume the world is fair in obvious ways; verify.
|
| 65 |
+
- Do not expect command hints from the environment. Use `look` and `inventory` when needed.
|
| 66 |
+
- Prefer systematic play: open visible containers and doors, take portable items, read text, talk to NPCs, and backtrack when blocked.
|
| 67 |
+
- When a puzzle reveals a clue, record it immediately.
|
| 68 |
+
- Do not submit an answer until you have enough evidence and the guardian is ready.
|
| 69 |
+
- Winning requires gathering evidence and then answering the guardian correctly.
|
| 70 |
+
- Keep your notebook concise and update it when the world changes.
|
| 71 |
+
- Commands must be lowercase only, with no articles, no markdown, and no conversational text.
|
| 72 |
+
- Allowed command grammar:
|
| 73 |
+
look
|
| 74 |
+
inventory
|
| 75 |
+
wait
|
| 76 |
+
north|south|east|west|up|down|in|out
|
| 77 |
+
go north|go south|go east|go west|go up|go down|go in|go out
|
| 78 |
+
open <object>
|
| 79 |
+
read <object>
|
| 80 |
+
talk <npc>
|
| 81 |
+
examine <object>
|
| 82 |
+
look in <object>
|
| 83 |
+
take <item>
|
| 84 |
+
take <item> from <container>
|
| 85 |
+
unlock <door> with <key>
|
| 86 |
+
use <item> on <target>
|
| 87 |
+
combine <item_a> with <item_b>
|
| 88 |
+
give <item> to <npc>
|
| 89 |
+
submit <answer>
|
| 90 |
+
- Example valid commands:
|
| 91 |
+
open entry chest
|
| 92 |
+
take brass key from entry chest
|
| 93 |
+
unlock iron door with brass key
|
| 94 |
+
east
|
| 95 |
+
use torch on ash mural
|
| 96 |
+
talk stone guardian
|
| 97 |
+
submit mira
|
| 98 |
+
- Do not write prose, plans, or plain JSON action objects.
|
| 99 |
+
- The runtime provides the tool schema; emit a tool call only.
|
| 100 |
+
"""
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def format_hero_system_prompt(world_title: str, max_game_steps: int, max_tool_calls: int) -> str:
|
| 104 |
+
return (
|
| 105 |
+
f"{HERO_SYSTEM_PROMPT}\n\n"
|
| 106 |
+
f"World: {world_title}\n"
|
| 107 |
+
f"Game-step budget: {max_game_steps}\n"
|
| 108 |
+
f"Total tool-call budget: {max_tool_calls}\n"
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def format_hero_grpo_system_prompt(world_title: str, max_game_steps: int, max_tool_calls: int) -> str:
|
| 113 |
+
return (
|
| 114 |
+
f"{HERO_GRPO_SYSTEM_PROMPT}\n\n"
|
| 115 |
+
f"World: {world_title}\n"
|
| 116 |
+
f"Game-step budget: {max_game_steps}\n"
|
| 117 |
+
f"Total tool-call budget: {max_tool_calls}\n"
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def format_hero_turn_prompt(message: str, state: HeroState, scratchpad: str) -> str:
|
| 122 |
+
notebook = scratchpad if scratchpad else "<empty>"
|
| 123 |
+
return (
|
| 124 |
+
"Choose exactly one next tool call.\n"
|
| 125 |
+
f"Observation:\n{message.strip() or '<empty>'}\n\n"
|
| 126 |
+
f"World: {state.world_title}\n"
|
| 127 |
+
f"Status: {state.status}\n"
|
| 128 |
+
f"Game steps taken: {state.game_steps_taken}/{state.max_game_steps}\n"
|
| 129 |
+
f"Tool calls used: {state.tool_calls_total}/{state.max_tool_calls}\n"
|
| 130 |
+
f"Game steps remaining: {state.game_steps_remaining}\n"
|
| 131 |
+
f"Tool calls remaining: {state.tool_calls_remaining}\n"
|
| 132 |
+
f"Last command: {state.last_command or '<none>'}\n\n"
|
| 133 |
+
f"Scratchpad:\n{notebook}\n"
|
| 134 |
+
)
|
agents/hero/runner.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from collections.abc import Iterable
|
| 4 |
+
from typing import Protocol
|
| 5 |
+
|
| 6 |
+
from agents.master.session import EpisodeSession
|
| 7 |
+
|
| 8 |
+
from .env import HeroEnvironment
|
| 9 |
+
from .policy import HeroPolicyError
|
| 10 |
+
from .schema import HeroAction, HeroEpisodeStats, HeroObservation, HeroState
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ToolCallingPolicy(Protocol):
|
| 14 |
+
def reset(self) -> None:
|
| 15 |
+
...
|
| 16 |
+
|
| 17 |
+
def next_action(
|
| 18 |
+
self,
|
| 19 |
+
observation: HeroObservation,
|
| 20 |
+
state: HeroState,
|
| 21 |
+
scratchpad: str,
|
| 22 |
+
) -> HeroAction | dict[str, object] | None:
|
| 23 |
+
...
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class ScriptedToolCallingPolicy:
|
| 27 |
+
def __init__(self, actions: Iterable[HeroAction | dict[str, object]]) -> None:
|
| 28 |
+
self._initial_actions = list(actions)
|
| 29 |
+
self._remaining_actions = list(self._initial_actions)
|
| 30 |
+
|
| 31 |
+
def reset(self) -> None:
|
| 32 |
+
self._remaining_actions = list(self._initial_actions)
|
| 33 |
+
|
| 34 |
+
def next_action(
|
| 35 |
+
self,
|
| 36 |
+
observation: HeroObservation,
|
| 37 |
+
state: HeroState,
|
| 38 |
+
scratchpad: str,
|
| 39 |
+
) -> HeroAction | dict[str, object] | None:
|
| 40 |
+
del observation, state, scratchpad
|
| 41 |
+
if not self._remaining_actions:
|
| 42 |
+
return None
|
| 43 |
+
return self._remaining_actions.pop(0)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class HeroRunner:
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
policy: ToolCallingPolicy,
|
| 50 |
+
*,
|
| 51 |
+
max_game_steps: int | None = 40,
|
| 52 |
+
max_tool_calls: int | None = None,
|
| 53 |
+
scratchpad_max_chars: int = 8000,
|
| 54 |
+
debug: bool = False,
|
| 55 |
+
) -> None:
|
| 56 |
+
self.policy = policy
|
| 57 |
+
self.max_game_steps = max_game_steps
|
| 58 |
+
self.max_tool_calls = max_tool_calls
|
| 59 |
+
self.scratchpad_max_chars = scratchpad_max_chars
|
| 60 |
+
self.debug = debug
|
| 61 |
+
self.last_error: str | None = None
|
| 62 |
+
self.last_observation: HeroObservation | None = None
|
| 63 |
+
self.episode_stats: HeroEpisodeStats | None = None
|
| 64 |
+
|
| 65 |
+
def run(self, session: EpisodeSession, max_steps: int) -> None:
|
| 66 |
+
self.last_error = None
|
| 67 |
+
self.last_observation = None
|
| 68 |
+
self.episode_stats = None
|
| 69 |
+
self.policy.reset()
|
| 70 |
+
env = HeroEnvironment.from_session(
|
| 71 |
+
session,
|
| 72 |
+
max_game_steps=max_steps if self.max_game_steps is None else min(max_steps, self.max_game_steps),
|
| 73 |
+
max_tool_calls=self.max_tool_calls,
|
| 74 |
+
scratchpad_max_chars=self.scratchpad_max_chars,
|
| 75 |
+
debug=self.debug,
|
| 76 |
+
)
|
| 77 |
+
observation = env.reset()
|
| 78 |
+
self.last_observation = observation
|
| 79 |
+
while not observation.done:
|
| 80 |
+
try:
|
| 81 |
+
action = self.policy.next_action(observation, env.state, env.scratchpad)
|
| 82 |
+
except HeroPolicyError as exc:
|
| 83 |
+
self.last_error = str(exc)
|
| 84 |
+
self.episode_stats = env.episode_stats
|
| 85 |
+
return
|
| 86 |
+
if action is None:
|
| 87 |
+
self.episode_stats = env.episode_stats
|
| 88 |
+
return
|
| 89 |
+
result = env.step(action)
|
| 90 |
+
observation = result.observation
|
| 91 |
+
self.last_observation = observation
|
| 92 |
+
self.episode_stats = env.episode_stats
|
agents/hero/schema.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Any
|
| 4 |
+
from typing import Annotated, Literal, TypeAlias
|
| 5 |
+
|
| 6 |
+
from pydantic import Field, TypeAdapter
|
| 7 |
+
|
| 8 |
+
from agents.shared.openenv_compat import Action, Observation, State
|
| 9 |
+
from agents.shared.model_schema import StrictModel
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ActAction(Action):
|
| 13 |
+
tool: Literal["act"] = "act"
|
| 14 |
+
command: str
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ScratchpadReadAction(Action):
|
| 18 |
+
tool: Literal["scratchpad_read"] = "scratchpad_read"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class ScratchpadWriteAction(Action):
|
| 22 |
+
tool: Literal["scratchpad_write"] = "scratchpad_write"
|
| 23 |
+
mode: Literal["append", "replace"]
|
| 24 |
+
content: str
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class HeroServerAction(Action):
|
| 28 |
+
tool: Literal["act", "scratchpad_read", "scratchpad_write"]
|
| 29 |
+
command: str | None = None
|
| 30 |
+
mode: Literal["append", "replace"] | None = None
|
| 31 |
+
content: str | None = None
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
HeroAction: TypeAlias = Annotated[
|
| 35 |
+
ActAction | ScratchpadReadAction | ScratchpadWriteAction,
|
| 36 |
+
Field(discriminator="tool"),
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
HERO_ACTION_ADAPTER = TypeAdapter(HeroAction)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def validate_hero_action(value: HeroAction | HeroServerAction | dict[str, Any]) -> HeroAction:
|
| 43 |
+
if isinstance(value, Action):
|
| 44 |
+
value = value.model_dump(mode="json", exclude_none=True)
|
| 45 |
+
return HERO_ACTION_ADAPTER.validate_python(value)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class HeroObservation(Observation):
|
| 49 |
+
message: str = ""
|
| 50 |
+
won: bool | None = None
|
| 51 |
+
tool: str | None = None
|
| 52 |
+
tool_success: bool | None = None
|
| 53 |
+
terminal_reason: str | None = None
|
| 54 |
+
reward_breakdown: "HeroRewardBreakdown | None" = None
|
| 55 |
+
aux_signals: "HeroAuxSignals | None" = None
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class HeroAuxSignals(StrictModel):
|
| 59 |
+
visited_room_progress: float = 0.0
|
| 60 |
+
clue_progress: float = 0.0
|
| 61 |
+
locked_gate_progress: float = 0.0
|
| 62 |
+
trade_progress: float = 0.0
|
| 63 |
+
recipe_progress: float = 0.0
|
| 64 |
+
use_effect_progress: float = 0.0
|
| 65 |
+
guardian_consulted_progress: float = 0.0
|
| 66 |
+
answer_ready_progress: float = 0.0
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class HeroRewardBreakdown(StrictModel):
|
| 70 |
+
base_terminal_reward: float = 0.0
|
| 71 |
+
dense_progress_reward: float = 0.0
|
| 72 |
+
syntax_penalty: float = 0.0
|
| 73 |
+
invalid_action_penalty: float = 0.0
|
| 74 |
+
repeat_noop_penalty: float = 0.0
|
| 75 |
+
wrong_submit_penalty: float = 0.0
|
| 76 |
+
total_reward: float = 0.0
|
| 77 |
+
progress_potential_before: float = 0.0
|
| 78 |
+
progress_potential_after: float = 0.0
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class HeroEpisodeStats(StrictModel):
|
| 82 |
+
player_won: bool = False
|
| 83 |
+
total_reward: float = 0.0
|
| 84 |
+
dense_return: float = 0.0
|
| 85 |
+
syntax_penalty_total: float = 0.0
|
| 86 |
+
invalid_action_penalty_total: float = 0.0
|
| 87 |
+
repeat_noop_penalty_total: float = 0.0
|
| 88 |
+
wrong_submit_penalty_total: float = 0.0
|
| 89 |
+
steps_taken: int = 0
|
| 90 |
+
tool_calls_total: int = 0
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class HeroState(State):
|
| 94 |
+
game_steps_taken: int = 0
|
| 95 |
+
tool_calls_total: int = 0
|
| 96 |
+
max_game_steps: int = 0
|
| 97 |
+
max_tool_calls: int = 0
|
| 98 |
+
game_steps_remaining: int = 0
|
| 99 |
+
tool_calls_remaining: int = 0
|
| 100 |
+
status: Literal["ready", "running", "won", "lost", "timed_out", "error"] = "ready"
|
| 101 |
+
world_title: str = ""
|
| 102 |
+
last_command: str | None = None
|
| 103 |
+
scratchpad_chars: int = 0
|
agents/loop/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Closed-loop orchestration for hero and dungeon master policies."""
|
| 2 |
+
|
| 3 |
+
from .runner import ClosedLoopRunner
|
| 4 |
+
from .schema import ClosedLoopEpisodeArtifacts, ClosedLoopEpisodeRecord, ClosedLoopEpisodeSummary
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"ClosedLoopEpisodeArtifacts",
|
| 8 |
+
"ClosedLoopEpisodeRecord",
|
| 9 |
+
"ClosedLoopEpisodeSummary",
|
| 10 |
+
"ClosedLoopRunner",
|
| 11 |
+
]
|
agents/loop/__main__.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
from agents.hero.policy import HeroLLMPolicy
|
| 8 |
+
from agents.master.interface import DEFAULT_GEMINI_MODEL
|
| 9 |
+
from agents.master.env import DMEnvironment
|
| 10 |
+
from agents.master.policy import DungeonMasterLLMPolicy
|
| 11 |
+
from agents.shared.runtime import (
|
| 12 |
+
build_interface_adapter,
|
| 13 |
+
create_structured_client,
|
| 14 |
+
resolve_interface_config,
|
| 15 |
+
resolve_structured_client_config,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
from .runner import ClosedLoopRunner
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def main(argv: list[str] | None = None) -> int:
|
| 22 |
+
parser = argparse.ArgumentParser(description="Closed-loop dungeon master and hero harness")
|
| 23 |
+
parser.add_argument("--episodes", type=int, default=1)
|
| 24 |
+
parser.add_argument("--seed", type=int)
|
| 25 |
+
parser.add_argument("--target-ratio", type=float)
|
| 26 |
+
parser.add_argument("--dm-provider", choices=["gemini", "hf_local"])
|
| 27 |
+
parser.add_argument("--dm-model")
|
| 28 |
+
parser.add_argument("--dm-adapter-path")
|
| 29 |
+
parser.add_argument("--hero-provider", choices=["gemini", "hf_local"])
|
| 30 |
+
parser.add_argument("--hero-model")
|
| 31 |
+
parser.add_argument("--hero-adapter-path")
|
| 32 |
+
parser.add_argument("--interface-provider", choices=["strict", "simple", "gemini"])
|
| 33 |
+
parser.add_argument("--interface-model", default=DEFAULT_GEMINI_MODEL)
|
| 34 |
+
parser.add_argument("--interface-narrate", action="store_true")
|
| 35 |
+
parser.add_argument(
|
| 36 |
+
"--translate-corporate-env",
|
| 37 |
+
action="store_true",
|
| 38 |
+
help="Rewrite hero-facing observations into a corporate app metaphor and map translated commands back through Gemini.",
|
| 39 |
+
)
|
| 40 |
+
parser.add_argument("--artifacts-root", type=Path)
|
| 41 |
+
parser.add_argument("--dm-artifacts-root", type=Path)
|
| 42 |
+
parser.add_argument("--dm-repair-attempts", type=int, default=2)
|
| 43 |
+
parser.add_argument("--hero-max-game-steps", type=int, default=40)
|
| 44 |
+
parser.add_argument("--hero-max-tool-calls", type=int, default=80)
|
| 45 |
+
parser.add_argument("--live", action="store_true")
|
| 46 |
+
parser.add_argument("--live-dir", type=Path)
|
| 47 |
+
args = parser.parse_args(argv)
|
| 48 |
+
|
| 49 |
+
dm_config = resolve_structured_client_config(
|
| 50 |
+
"dm",
|
| 51 |
+
provider=args.dm_provider,
|
| 52 |
+
model_name=args.dm_model,
|
| 53 |
+
adapter_path=args.dm_adapter_path,
|
| 54 |
+
)
|
| 55 |
+
hero_config = resolve_structured_client_config(
|
| 56 |
+
"hero",
|
| 57 |
+
provider=args.hero_provider,
|
| 58 |
+
model_name=args.hero_model,
|
| 59 |
+
adapter_path=args.hero_adapter_path,
|
| 60 |
+
)
|
| 61 |
+
interface_config = resolve_interface_config(
|
| 62 |
+
provider=args.interface_provider,
|
| 63 |
+
model_name=args.interface_model,
|
| 64 |
+
narrate_observations=args.interface_narrate,
|
| 65 |
+
translation_mode="corporate_app" if args.translate_corporate_env else None,
|
| 66 |
+
)
|
| 67 |
+
runner = ClosedLoopRunner(
|
| 68 |
+
dm_env=DMEnvironment(artifacts_root=args.dm_artifacts_root),
|
| 69 |
+
dm_policy=DungeonMasterLLMPolicy(create_structured_client(dm_config), model_name=dm_config.model_name),
|
| 70 |
+
hero_policy=HeroLLMPolicy(create_structured_client(hero_config), model_name=hero_config.model_name),
|
| 71 |
+
artifacts_root=args.artifacts_root,
|
| 72 |
+
live_dir=args.live_dir,
|
| 73 |
+
max_dm_repair_attempts=args.dm_repair_attempts,
|
| 74 |
+
hero_runner_kwargs={
|
| 75 |
+
"max_game_steps": args.hero_max_game_steps,
|
| 76 |
+
"max_tool_calls": args.hero_max_tool_calls,
|
| 77 |
+
},
|
| 78 |
+
hero_interface_adapter=build_interface_adapter(interface_config),
|
| 79 |
+
)
|
| 80 |
+
records = []
|
| 81 |
+
for index in range(args.episodes):
|
| 82 |
+
seed = None if args.seed is None else args.seed + index
|
| 83 |
+
record = runner.run_episode(seed=seed, target_ratio=args.target_ratio, live=args.live)
|
| 84 |
+
records.append(record)
|
| 85 |
+
print(json.dumps(ClosedLoopRunner.summary(record).model_dump(mode="json")))
|
| 86 |
+
if records:
|
| 87 |
+
print(json.dumps(ClosedLoopRunner.aggregate(records).model_dump(mode="json")))
|
| 88 |
+
return 0
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
if __name__ == "__main__":
|
| 92 |
+
raise SystemExit(main())
|
agents/loop/runner.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
from agents.hero.policy import HeroPolicy
|
| 7 |
+
from agents.hero.runner import HeroRunner
|
| 8 |
+
from agents.master.env import DMEnvironment
|
| 9 |
+
from agents.master.interface import InterfaceAdapter, StrictCliInterfaceAdapter
|
| 10 |
+
from agents.master.policy import DMRepairContext, DungeonMasterPolicy, DungeonMasterPolicyError
|
| 11 |
+
from agents.master.schema import DMObservation, DMRewardBreakdown, WorldDefinition
|
| 12 |
+
from agents.master.snapshots import LiveObserver, LiveSnapshotWriter
|
| 13 |
+
|
| 14 |
+
from .schema import (
|
| 15 |
+
ClosedLoopAggregateReport,
|
| 16 |
+
ClosedLoopEpisodeArtifacts,
|
| 17 |
+
ClosedLoopEpisodeRecord,
|
| 18 |
+
ClosedLoopEpisodeSummary,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
DEFAULT_CLOSED_LOOP_ROOT = Path(__file__).resolve().parents[2] / ".play_runs" / "closed_loop"
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class ClosedLoopRunner:
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
*,
|
| 28 |
+
dm_env: DMEnvironment,
|
| 29 |
+
dm_policy: DungeonMasterPolicy,
|
| 30 |
+
hero_policy: HeroPolicy,
|
| 31 |
+
artifacts_root: Path | None = None,
|
| 32 |
+
live_dir: Path | None = None,
|
| 33 |
+
max_dm_repair_attempts: int = 2,
|
| 34 |
+
hero_runner_kwargs: dict[str, object] | None = None,
|
| 35 |
+
hero_interface_adapter: InterfaceAdapter | None = None,
|
| 36 |
+
) -> None:
|
| 37 |
+
self.dm_env = dm_env
|
| 38 |
+
self.dm_policy = dm_policy
|
| 39 |
+
self.hero_policy = hero_policy
|
| 40 |
+
self.artifacts_root = artifacts_root or DEFAULT_CLOSED_LOOP_ROOT
|
| 41 |
+
self.live_dir = live_dir
|
| 42 |
+
self.max_dm_repair_attempts = max_dm_repair_attempts
|
| 43 |
+
self.hero_runner_kwargs = hero_runner_kwargs or {"max_game_steps": 40, "max_tool_calls": 80}
|
| 44 |
+
self.hero_interface_adapter = hero_interface_adapter or StrictCliInterfaceAdapter()
|
| 45 |
+
|
| 46 |
+
def run_episode(
|
| 47 |
+
self,
|
| 48 |
+
*,
|
| 49 |
+
seed: int | None = None,
|
| 50 |
+
target_ratio: float | None = None,
|
| 51 |
+
live: bool = False,
|
| 52 |
+
) -> ClosedLoopEpisodeRecord:
|
| 53 |
+
self.dm_env.reset(seed=seed, difficulty_hint=target_ratio)
|
| 54 |
+
episode_id = self.dm_env.state.episode_id
|
| 55 |
+
if episode_id is None:
|
| 56 |
+
raise RuntimeError("DM environment did not assign an episode id.")
|
| 57 |
+
episode_dir = self.artifacts_root / episode_id
|
| 58 |
+
episode_dir.mkdir(parents=True, exist_ok=True)
|
| 59 |
+
artifacts = ClosedLoopEpisodeArtifacts.from_episode_dir(episode_dir)
|
| 60 |
+
observer = self._observer(live)
|
| 61 |
+
|
| 62 |
+
world: WorldDefinition | None = None
|
| 63 |
+
errors: list[str] = []
|
| 64 |
+
compile_attempts = 0
|
| 65 |
+
repair_context: DMRepairContext | None = None
|
| 66 |
+
previous_candidate_json: str | None = None
|
| 67 |
+
attempt_rows: list[dict[str, object]] = []
|
| 68 |
+
|
| 69 |
+
for attempt in range(1, self.max_dm_repair_attempts + 2):
|
| 70 |
+
compile_attempts = attempt
|
| 71 |
+
try:
|
| 72 |
+
candidate = self.dm_policy.generate_world(
|
| 73 |
+
target_ratio=self.dm_env.state.target_ratio,
|
| 74 |
+
repair_context=repair_context,
|
| 75 |
+
)
|
| 76 |
+
previous_candidate_json = candidate.model_dump_json(indent=2)
|
| 77 |
+
self._write_json(Path(artifacts.world_definition_path), previous_candidate_json)
|
| 78 |
+
self.dm_env.compile_world(candidate, episode_id=episode_id)
|
| 79 |
+
world = candidate
|
| 80 |
+
attempt_rows.append(
|
| 81 |
+
{
|
| 82 |
+
"attempt_number": attempt,
|
| 83 |
+
"status": "compiled",
|
| 84 |
+
"world_title": candidate.meta.title,
|
| 85 |
+
"difficulty_target": candidate.meta.difficulty_target,
|
| 86 |
+
}
|
| 87 |
+
)
|
| 88 |
+
break
|
| 89 |
+
except Exception as exc:
|
| 90 |
+
normalized_error = self._normalize_error(exc)
|
| 91 |
+
errors.append(normalized_error)
|
| 92 |
+
attempt_rows.append(
|
| 93 |
+
{
|
| 94 |
+
"attempt_number": attempt,
|
| 95 |
+
"status": "failed",
|
| 96 |
+
"error": normalized_error,
|
| 97 |
+
}
|
| 98 |
+
)
|
| 99 |
+
repair_context = DMRepairContext(
|
| 100 |
+
attempt_number=attempt,
|
| 101 |
+
error_message=normalized_error,
|
| 102 |
+
previous_candidate_json=previous_candidate_json,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
self._write_jsonl(Path(artifacts.world_generation_attempts_path), attempt_rows)
|
| 106 |
+
|
| 107 |
+
if world is None:
|
| 108 |
+
observation = self._compile_failure_observation(errors[-1] if errors else "world compilation failed")
|
| 109 |
+
record = ClosedLoopEpisodeRecord(
|
| 110 |
+
episode_id=episode_id,
|
| 111 |
+
status="compile_failed",
|
| 112 |
+
target_ratio=self.dm_env.state.target_ratio,
|
| 113 |
+
compile_attempts=compile_attempts,
|
| 114 |
+
dm_repair_errors=errors,
|
| 115 |
+
world_definition=None,
|
| 116 |
+
declared_difficulty_target=None,
|
| 117 |
+
difficulty_target_matches_target_ratio=None,
|
| 118 |
+
observation=observation,
|
| 119 |
+
artifacts=artifacts,
|
| 120 |
+
)
|
| 121 |
+
self._persist_record(record)
|
| 122 |
+
self._write_jsonl(Path(artifacts.hero_trace_path), [])
|
| 123 |
+
self._write_jsonl(Path(artifacts.transcript_path), [])
|
| 124 |
+
return record
|
| 125 |
+
|
| 126 |
+
hero_runner = HeroRunner(policy=self.hero_policy, **self.hero_runner_kwargs)
|
| 127 |
+
previous_adapter = self.dm_env.interface_adapter
|
| 128 |
+
self.dm_env.interface_adapter = self.hero_interface_adapter
|
| 129 |
+
try:
|
| 130 |
+
result = self.dm_env.step(world, runner=hero_runner, observer=observer)
|
| 131 |
+
finally:
|
| 132 |
+
self.dm_env.interface_adapter = previous_adapter
|
| 133 |
+
observation = result.observation
|
| 134 |
+
status = "policy_error" if hero_runner.last_error else ("complete" if observation.player_won else "failed")
|
| 135 |
+
record = ClosedLoopEpisodeRecord(
|
| 136 |
+
episode_id=episode_id,
|
| 137 |
+
status=status,
|
| 138 |
+
target_ratio=self.dm_env.state.target_ratio,
|
| 139 |
+
compile_attempts=compile_attempts,
|
| 140 |
+
dm_repair_errors=errors,
|
| 141 |
+
hero_policy_error=hero_runner.last_error,
|
| 142 |
+
hero_episode_stats=hero_runner.episode_stats,
|
| 143 |
+
world_definition=world,
|
| 144 |
+
declared_difficulty_target=world.meta.difficulty_target,
|
| 145 |
+
difficulty_target_matches_target_ratio=(world.meta.difficulty_target == self.dm_env.state.target_ratio),
|
| 146 |
+
observation=observation,
|
| 147 |
+
artifacts=artifacts,
|
| 148 |
+
)
|
| 149 |
+
self._persist_record(record)
|
| 150 |
+
self._write_jsonl(
|
| 151 |
+
Path(artifacts.hero_trace_path),
|
| 152 |
+
[event.model_dump(mode="json") for event in self.hero_policy.trace_events],
|
| 153 |
+
)
|
| 154 |
+
self._write_jsonl(
|
| 155 |
+
Path(artifacts.transcript_path),
|
| 156 |
+
[turn.model_dump(mode="json") for turn in observation.episode_transcript],
|
| 157 |
+
)
|
| 158 |
+
return record
|
| 159 |
+
|
| 160 |
+
@staticmethod
|
| 161 |
+
def summary(record: ClosedLoopEpisodeRecord) -> ClosedLoopEpisodeSummary:
|
| 162 |
+
return ClosedLoopEpisodeSummary(
|
| 163 |
+
episode_id=record.episode_id,
|
| 164 |
+
status=record.status,
|
| 165 |
+
reward=record.observation.reward,
|
| 166 |
+
player_won=record.observation.player_won,
|
| 167 |
+
ratio=record.observation.ratio,
|
| 168 |
+
compile_error=record.observation.compile_error,
|
| 169 |
+
hero_policy_error=record.hero_policy_error,
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
@staticmethod
|
| 173 |
+
def aggregate(records: list[ClosedLoopEpisodeRecord]) -> ClosedLoopAggregateReport:
|
| 174 |
+
episodes = len(records)
|
| 175 |
+
dense_returns = [
|
| 176 |
+
record.hero_episode_stats.dense_return
|
| 177 |
+
for record in records
|
| 178 |
+
if record.hero_episode_stats is not None
|
| 179 |
+
]
|
| 180 |
+
invalid_penalties = [
|
| 181 |
+
record.hero_episode_stats.invalid_action_penalty_total
|
| 182 |
+
for record in records
|
| 183 |
+
if record.hero_episode_stats is not None
|
| 184 |
+
]
|
| 185 |
+
repeat_penalties = [
|
| 186 |
+
record.hero_episode_stats.repeat_noop_penalty_total
|
| 187 |
+
for record in records
|
| 188 |
+
if record.hero_episode_stats is not None
|
| 189 |
+
]
|
| 190 |
+
return ClosedLoopAggregateReport(
|
| 191 |
+
episodes=episodes,
|
| 192 |
+
compile_valid_rate=_rate(sum(record.status != "compile_failed" for record in records), episodes),
|
| 193 |
+
policy_error_rate=_rate(sum(record.status == "policy_error" for record in records), episodes),
|
| 194 |
+
playable_rate=_rate(sum(record.world_definition is not None for record in records), episodes),
|
| 195 |
+
solve_rate=_rate(sum(record.status == "complete" for record in records), episodes),
|
| 196 |
+
mean_dense_return=_mean(dense_returns),
|
| 197 |
+
mean_invalid_action_penalty=_mean(invalid_penalties),
|
| 198 |
+
mean_repeat_noop_penalty=_mean(repeat_penalties),
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
def _compile_failure_observation(self, error: str) -> DMObservation:
|
| 202 |
+
breakdown = DMRewardBreakdown(
|
| 203 |
+
reward_mode="compile_failure_penalty",
|
| 204 |
+
player_won=False,
|
| 205 |
+
target_ratio=self.dm_env.state.target_ratio,
|
| 206 |
+
quality_score=0.0,
|
| 207 |
+
reward=0.0,
|
| 208 |
+
)
|
| 209 |
+
return DMObservation(
|
| 210 |
+
player_won=False,
|
| 211 |
+
compile_error=error,
|
| 212 |
+
reward=0.0,
|
| 213 |
+
done=True,
|
| 214 |
+
reward_breakdown=breakdown,
|
| 215 |
+
target_ratio_used=self.dm_env.state.target_ratio,
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
def _observer(self, live: bool) -> LiveObserver | None:
|
| 219 |
+
if not live:
|
| 220 |
+
return None
|
| 221 |
+
return LiveSnapshotWriter(live_dir=self.live_dir, runner_name="hero_llm")
|
| 222 |
+
|
| 223 |
+
def _persist_record(self, record: ClosedLoopEpisodeRecord) -> None:
|
| 224 |
+
self._write_json(Path(record.artifacts.run_record_path), record.model_dump_json(indent=2))
|
| 225 |
+
|
| 226 |
+
@staticmethod
|
| 227 |
+
def _write_json(path: Path, payload: str) -> None:
|
| 228 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 229 |
+
path.write_text(payload + "\n", encoding="utf-8")
|
| 230 |
+
|
| 231 |
+
@staticmethod
|
| 232 |
+
def _write_jsonl(path: Path, rows: list[dict[str, object]]) -> None:
|
| 233 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 234 |
+
payload = "".join(json.dumps(row) + "\n" for row in rows)
|
| 235 |
+
path.write_text(payload, encoding="utf-8")
|
| 236 |
+
|
| 237 |
+
@staticmethod
|
| 238 |
+
def _normalize_error(exc: Exception) -> str:
|
| 239 |
+
if isinstance(exc, DungeonMasterPolicyError):
|
| 240 |
+
return str(exc)
|
| 241 |
+
return " ".join(str(exc).split()) or exc.__class__.__name__
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def _mean(values: list[float]) -> float:
|
| 245 |
+
if not values:
|
| 246 |
+
return 0.0
|
| 247 |
+
return sum(values) / len(values)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def _rate(count: int, total: int) -> float:
|
| 251 |
+
if total <= 0:
|
| 252 |
+
return 0.0
|
| 253 |
+
return count / total
|
agents/loop/schema.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Literal
|
| 5 |
+
|
| 6 |
+
from agents.hero.schema import HeroEpisodeStats
|
| 7 |
+
from agents.master.schema import DMObservation, WorldDefinition
|
| 8 |
+
from agents.shared.model_schema import StrictModel
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ClosedLoopEpisodeArtifacts(StrictModel):
|
| 12 |
+
episode_dir: str
|
| 13 |
+
world_generation_attempts_path: str
|
| 14 |
+
world_definition_path: str
|
| 15 |
+
run_record_path: str
|
| 16 |
+
hero_trace_path: str
|
| 17 |
+
transcript_path: str
|
| 18 |
+
|
| 19 |
+
@classmethod
|
| 20 |
+
def from_episode_dir(cls, episode_dir: Path) -> "ClosedLoopEpisodeArtifacts":
|
| 21 |
+
return cls(
|
| 22 |
+
episode_dir=str(episode_dir),
|
| 23 |
+
world_generation_attempts_path=str(episode_dir / "world_generation_attempts.jsonl"),
|
| 24 |
+
world_definition_path=str(episode_dir / "world_definition.json"),
|
| 25 |
+
run_record_path=str(episode_dir / "run_record.json"),
|
| 26 |
+
hero_trace_path=str(episode_dir / "hero_trace.jsonl"),
|
| 27 |
+
transcript_path=str(episode_dir / "transcript.jsonl"),
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class ClosedLoopEpisodeRecord(StrictModel):
|
| 32 |
+
episode_id: str
|
| 33 |
+
status: Literal["complete", "failed", "compile_failed", "policy_error"]
|
| 34 |
+
target_ratio: float
|
| 35 |
+
compile_attempts: int
|
| 36 |
+
dm_repair_errors: list[str]
|
| 37 |
+
hero_policy_error: str | None = None
|
| 38 |
+
hero_episode_stats: HeroEpisodeStats | None = None
|
| 39 |
+
declared_difficulty_target: float | None = None
|
| 40 |
+
difficulty_target_matches_target_ratio: bool | None = None
|
| 41 |
+
world_definition: WorldDefinition | None = None
|
| 42 |
+
observation: DMObservation
|
| 43 |
+
artifacts: ClosedLoopEpisodeArtifacts
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class ClosedLoopEpisodeSummary(StrictModel):
|
| 47 |
+
episode_id: str
|
| 48 |
+
status: str
|
| 49 |
+
reward: float | None = None
|
| 50 |
+
player_won: bool | None = None
|
| 51 |
+
ratio: float | None = None
|
| 52 |
+
compile_error: str | None = None
|
| 53 |
+
hero_policy_error: str | None = None
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class ClosedLoopAggregateReport(StrictModel):
|
| 57 |
+
episodes: int
|
| 58 |
+
compile_valid_rate: float
|
| 59 |
+
policy_error_rate: float
|
| 60 |
+
playable_rate: float
|
| 61 |
+
solve_rate: float
|
| 62 |
+
mean_dense_return: float
|
| 63 |
+
mean_invalid_action_penalty: float
|
| 64 |
+
mean_repeat_noop_penalty: float
|
agents/master/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""DM environment source package."""
|
| 2 |
+
|
| 3 |
+
from .policy import (
|
| 4 |
+
DMRepairContext,
|
| 5 |
+
DungeonMasterLLMPolicy,
|
| 6 |
+
DungeonMasterPolicy,
|
| 7 |
+
DungeonMasterPolicyError,
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"DMRepairContext",
|
| 12 |
+
"DungeonMasterLLMPolicy",
|
| 13 |
+
"DungeonMasterPolicy",
|
| 14 |
+
"DungeonMasterPolicyError",
|
| 15 |
+
]
|
agents/master/__main__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .main import main
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
if __name__ == "__main__":
|
| 5 |
+
raise SystemExit(main())
|
agents/master/base.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from contextlib import contextmanager
|
| 4 |
+
import re
|
| 5 |
+
import warnings
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
MAX_NODES = 40
|
| 10 |
+
MAX_ITEMS = 32
|
| 11 |
+
MAX_QUEST_STEPS = 64
|
| 12 |
+
MIN_NODES = 5
|
| 13 |
+
MIN_QUEST_STEPS = 2
|
| 14 |
+
MIN_CLUES = 3
|
| 15 |
+
MAX_CLUES = 5
|
| 16 |
+
TARGET_RATIO = 1.5
|
| 17 |
+
TARGET_RATIO_SIGMA = 0.4
|
| 18 |
+
MAX_STEP_MULTIPLIER = 5
|
| 19 |
+
INVENTORY_ID = "__inventory__"
|
| 20 |
+
STORED_ID = "__stored__"
|
| 21 |
+
ROOT_DIR = Path(__file__).resolve().parents[2]
|
| 22 |
+
ARTIFACTS_ROOT = ROOT_DIR / ".artifacts" / "dm_env"
|
| 23 |
+
CUSTOM_LOGIC_DIR = ROOT_DIR / "textworld_data" / "dnd" / "logic"
|
| 24 |
+
CUSTOM_GRAMMAR_DIR = ROOT_DIR / "textworld_data" / "dnd" / "text_grammars"
|
| 25 |
+
SUPPORTED_DIRECTIONS = ("north", "south", "east", "west", "up", "down", "in", "out")
|
| 26 |
+
OPPOSITE_DIRECTION = {
|
| 27 |
+
"north": "south",
|
| 28 |
+
"south": "north",
|
| 29 |
+
"east": "west",
|
| 30 |
+
"west": "east",
|
| 31 |
+
"up": "down",
|
| 32 |
+
"down": "up",
|
| 33 |
+
"in": "out",
|
| 34 |
+
"out": "in",
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
GO_RE = re.compile(r"^go\((?P<target>[a-z0-9_]+)\)$")
|
| 38 |
+
OPEN_RE = re.compile(r"^open\((?P<target>[a-z0-9_]+)\)$")
|
| 39 |
+
UNLOCK_RE = re.compile(r"^unlock\((?P<door>[a-z0-9_]+),(?P<key>[a-z0-9_]+)\)$")
|
| 40 |
+
TAKE_RE = re.compile(r"^take\((?P<item>[a-z0-9_]+),(?P<source>[a-z0-9_]+)\)$")
|
| 41 |
+
READ_RE = re.compile(r"^read\((?P<target>[a-z0-9_]+)\)$")
|
| 42 |
+
USE_RE = re.compile(r"^use\((?P<item>[a-z0-9_]+),(?P<target>[a-z0-9_]+)\)$")
|
| 43 |
+
COMBINE_RE = re.compile(r"^combine\((?P<item_a>[a-z0-9_]+),(?P<item_b>[a-z0-9_]+)\)$")
|
| 44 |
+
GIVE_RE = re.compile(r"^give\((?P<item>[a-z0-9_]+),(?P<npc>[a-z0-9_]+)\)$")
|
| 45 |
+
TALK_RE = re.compile(r"^talk\((?P<target>[a-z0-9_]+)\)$")
|
| 46 |
+
SUBMIT_RE = re.compile(r"^submit\((?P<quote>[\"'])(?P<answer>.+)(?P=quote)\)$")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class DMCompileError(RuntimeError):
|
| 50 |
+
pass
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class DMInterfaceError(RuntimeError):
|
| 54 |
+
pass
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@contextmanager
|
| 58 |
+
def suppress_unsupported_game_warning():
|
| 59 |
+
with warnings.catch_warnings():
|
| 60 |
+
warnings.filterwarnings(
|
| 61 |
+
"ignore",
|
| 62 |
+
message=r"Game '.*' is not fully supported\..*",
|
| 63 |
+
category=Warning,
|
| 64 |
+
)
|
| 65 |
+
yield
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def normalize_snake_id(value: str, kind: str) -> str:
|
| 69 |
+
if not re.fullmatch(r"[a-z][a-z0-9_]*", value):
|
| 70 |
+
raise DMCompileError(f"{kind} '{value}' must be snake_case.")
|
| 71 |
+
return value
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def parser_safe_text(value: str) -> str:
|
| 75 |
+
collapsed = re.sub(r"[^A-Za-z0-9 ]+", " ", value).strip().lower()
|
| 76 |
+
collapsed = re.sub(r"\s+", " ", collapsed)
|
| 77 |
+
if not collapsed:
|
| 78 |
+
raise DMCompileError(f"Unable to derive a parser-safe name from '{value}'.")
|
| 79 |
+
return collapsed
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def normalize_answer_text(value: str) -> str:
|
| 83 |
+
collapsed = re.sub(r"[^A-Za-z0-9 ]+", " ", value).strip().lower()
|
| 84 |
+
return re.sub(r"\s+", " ", collapsed)
|
agents/master/build.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import uuid
|
| 4 |
+
from collections import defaultdict
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
| 8 |
+
from textworld.generator import GameMaker, GameOptions, compile_game
|
| 9 |
+
from textworld.generator.data import KnowledgeBase
|
| 10 |
+
|
| 11 |
+
from .base import ARTIFACTS_ROOT, DMCompileError, parser_safe_text
|
| 12 |
+
from .check import validate_and_normalize
|
| 13 |
+
from .graph import (
|
| 14 |
+
door_room_mapping,
|
| 15 |
+
hidden_readable_ids,
|
| 16 |
+
npc_trade_mapping,
|
| 17 |
+
produced_item_ids,
|
| 18 |
+
readable_clue_mapping,
|
| 19 |
+
recipe_mapping,
|
| 20 |
+
use_effect_mapping,
|
| 21 |
+
)
|
| 22 |
+
from .logic import build_grammar_dir, build_logic_dir, solver_policy, submit_command_text, write_artifacts
|
| 23 |
+
from .quest import parse_quest_action, simulate_walkthrough, topological_linearize
|
| 24 |
+
from .schema import CompiledWorld, WorldDefinition
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class WorldCompiler:
|
| 28 |
+
def __init__(self, artifacts_root: Path | None = None) -> None:
|
| 29 |
+
self.artifacts_root = artifacts_root or ARTIFACTS_ROOT
|
| 30 |
+
|
| 31 |
+
def compile(self, world_input: WorldDefinition | dict[str, Any], episode_id: str | None = None) -> CompiledWorld:
|
| 32 |
+
world = validate_and_normalize(world_input)
|
| 33 |
+
episode_id = episode_id or uuid.uuid4().hex[:12]
|
| 34 |
+
artifacts_dir = self.artifacts_root / episode_id
|
| 35 |
+
artifacts_dir.mkdir(parents=True, exist_ok=True)
|
| 36 |
+
parsed_steps = [parse_quest_action(step.action) for step in topological_linearize(world.quest_chain)]
|
| 37 |
+
entity_names = self._assign_command_names(world)
|
| 38 |
+
|
| 39 |
+
options = GameOptions()
|
| 40 |
+
options.kb = KnowledgeBase.load(
|
| 41 |
+
logic_path=str(build_logic_dir(artifacts_dir, world)),
|
| 42 |
+
grammar_path=str(build_grammar_dir(artifacts_dir)),
|
| 43 |
+
)
|
| 44 |
+
options.path = str(artifacts_dir / "game.z8")
|
| 45 |
+
options.force_recompile = True
|
| 46 |
+
maker = GameMaker(options=options)
|
| 47 |
+
|
| 48 |
+
rooms, entities = self._build_entities(maker, world, entity_names)
|
| 49 |
+
maker.set_player(rooms[world.meta.start_node_id])
|
| 50 |
+
self._compile_edges(maker, world, rooms, entities)
|
| 51 |
+
self._compile_clue_sources(maker, world, entities)
|
| 52 |
+
self._compile_fixtures(maker, world, entities)
|
| 53 |
+
self._compile_npcs(maker, world, entities)
|
| 54 |
+
self._compile_recipes(maker, world, entities)
|
| 55 |
+
|
| 56 |
+
guardian = entities[world.meta.win_condition.target_npc_id]
|
| 57 |
+
answer = maker.new(type="answer", name="final answer token")
|
| 58 |
+
maker.nowhere.append(answer)
|
| 59 |
+
entities["__answer__"] = answer
|
| 60 |
+
maker.add_fact("guardian", guardian)
|
| 61 |
+
maker.add_fact("correct", answer, guardian)
|
| 62 |
+
|
| 63 |
+
walkthrough_commands = simulate_walkthrough(world, parsed_steps, entity_names)
|
| 64 |
+
game = maker.build()
|
| 65 |
+
game.objective = (
|
| 66 |
+
f"Explore {world.meta.title}, manipulate the dungeon's tools, gather every clue, "
|
| 67 |
+
f"speak to {entities[world.meta.win_condition.target_npc_id].name}, and submit the answer."
|
| 68 |
+
)
|
| 69 |
+
game.metadata.update(
|
| 70 |
+
{"episode_id": episode_id, "dm_title": world.meta.title, "start_node_id": world.meta.start_node_id}
|
| 71 |
+
)
|
| 72 |
+
compile_game(game, options)
|
| 73 |
+
write_artifacts(artifacts_dir, world, walkthrough_commands)
|
| 74 |
+
policy = solver_policy(str(options.path))
|
| 75 |
+
if not policy:
|
| 76 |
+
policy = list(walkthrough_commands)
|
| 77 |
+
return self._compiled_world(
|
| 78 |
+
episode_id,
|
| 79 |
+
artifacts_dir,
|
| 80 |
+
Path(options.path),
|
| 81 |
+
world,
|
| 82 |
+
entity_names,
|
| 83 |
+
walkthrough_commands,
|
| 84 |
+
policy,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
def _build_entities(
|
| 88 |
+
self,
|
| 89 |
+
maker: GameMaker,
|
| 90 |
+
world: WorldDefinition,
|
| 91 |
+
entity_names: dict[str, str],
|
| 92 |
+
) -> tuple[dict[str, Any], dict[str, Any]]:
|
| 93 |
+
rooms = {
|
| 94 |
+
node.id: maker.new(type="r", name=entity_names[node.id], desc=node.description)
|
| 95 |
+
for node in world.nodes
|
| 96 |
+
if node.type in {"location", "junction"}
|
| 97 |
+
}
|
| 98 |
+
entities: dict[str, Any] = {}
|
| 99 |
+
hidden_readables = hidden_readable_ids(world)
|
| 100 |
+
recipe_outputs = {recipe.output_item_id for recipe in world.recipes}
|
| 101 |
+
produced_items = produced_item_ids(world)
|
| 102 |
+
|
| 103 |
+
for node in world.nodes:
|
| 104 |
+
if node.type in {"location", "junction"}:
|
| 105 |
+
continue
|
| 106 |
+
entity = self._make_node_entity(maker, node, entity_names[node.id])
|
| 107 |
+
entities[node.id] = entity
|
| 108 |
+
if node.type == "door":
|
| 109 |
+
maker.nowhere.append(entity)
|
| 110 |
+
elif node.type == "readable" and node.id in hidden_readables:
|
| 111 |
+
maker.nowhere.append(entity)
|
| 112 |
+
maker.add_fact("hidden_readable", entity)
|
| 113 |
+
else:
|
| 114 |
+
rooms[node.parent_id].add(entity)
|
| 115 |
+
|
| 116 |
+
for item in world.items:
|
| 117 |
+
item_type = "k" if item.subtype == "key" else "o"
|
| 118 |
+
entity = maker.new(type=item_type, name=entity_names[item.id], desc=item.description)
|
| 119 |
+
entities[item.id] = entity
|
| 120 |
+
if item.id in produced_items:
|
| 121 |
+
maker.nowhere.append(entity)
|
| 122 |
+
if item.id in recipe_outputs:
|
| 123 |
+
maker.add_fact("fresh", entity)
|
| 124 |
+
else:
|
| 125 |
+
maker.add_fact("stored_item", entity)
|
| 126 |
+
continue
|
| 127 |
+
holder = item.start_node_id
|
| 128 |
+
if holder is None:
|
| 129 |
+
raise DMCompileError(f"Placed item '{item.id}' is missing start_node_id.")
|
| 130 |
+
if holder in rooms:
|
| 131 |
+
rooms[holder].add(entity)
|
| 132 |
+
else:
|
| 133 |
+
entities[holder].add(entity)
|
| 134 |
+
|
| 135 |
+
return rooms, entities
|
| 136 |
+
|
| 137 |
+
@staticmethod
|
| 138 |
+
def _make_node_entity(maker: GameMaker, node: object, name: str) -> Any:
|
| 139 |
+
if node.type == "container":
|
| 140 |
+
entity = maker.new(type="c", name=name, desc=node.description)
|
| 141 |
+
entity.add_property("open" if node.open else "locked" if node.locked else "closed")
|
| 142 |
+
return entity
|
| 143 |
+
if node.type == "door":
|
| 144 |
+
entity = maker.new(type="d", name=name, desc=node.description)
|
| 145 |
+
entity.add_property("open" if node.open else "locked" if node.locked else "closed")
|
| 146 |
+
return entity
|
| 147 |
+
if node.type == "readable":
|
| 148 |
+
return maker.new(type="readable", name=name, desc=node.description)
|
| 149 |
+
if node.type == "fixture":
|
| 150 |
+
return maker.new(type="fixture", name=name, desc=node.description)
|
| 151 |
+
if node.type == "npc":
|
| 152 |
+
return maker.new(type="npc", name=name, desc=node.description)
|
| 153 |
+
raise DMCompileError(f"Unsupported node type '{node.type}'.")
|
| 154 |
+
|
| 155 |
+
def _compile_clue_sources(
|
| 156 |
+
self,
|
| 157 |
+
maker: GameMaker,
|
| 158 |
+
world: WorldDefinition,
|
| 159 |
+
entities: dict[str, Any],
|
| 160 |
+
) -> None:
|
| 161 |
+
hidden_readables = hidden_readable_ids(world)
|
| 162 |
+
for node in world.nodes:
|
| 163 |
+
if node.type != "readable":
|
| 164 |
+
continue
|
| 165 |
+
readable = entities[node.id]
|
| 166 |
+
if node.requires_item_id:
|
| 167 |
+
maker.add_fact("read_requires", readable, entities[node.requires_item_id])
|
| 168 |
+
maker.add_fact("read_consumes_use" if node.consumes_item else "read_keeps_use", readable)
|
| 169 |
+
else:
|
| 170 |
+
maker.add_fact("free_read", readable)
|
| 171 |
+
if node.id in hidden_readables:
|
| 172 |
+
continue
|
| 173 |
+
|
| 174 |
+
def _compile_fixtures(self, maker: GameMaker, world: WorldDefinition, entities: dict[str, Any]) -> None:
|
| 175 |
+
for node in world.nodes:
|
| 176 |
+
if node.type != "fixture":
|
| 177 |
+
continue
|
| 178 |
+
fixture = entities[node.id]
|
| 179 |
+
maker.add_fact("fixture_requires", fixture, entities[node.requires_item_id])
|
| 180 |
+
maker.add_fact("sealed", fixture)
|
| 181 |
+
maker.add_fact("fixture_consumes_use" if node.consumes_item else "fixture_keeps_use", fixture)
|
| 182 |
+
if node.reveals_item_id:
|
| 183 |
+
maker.add_fact("reveals_item", fixture, entities[node.reveals_item_id])
|
| 184 |
+
if node.reveals_readable_id:
|
| 185 |
+
maker.add_fact("reveals_readable", fixture, entities[node.reveals_readable_id])
|
| 186 |
+
|
| 187 |
+
def _compile_npcs(
|
| 188 |
+
self,
|
| 189 |
+
maker: GameMaker,
|
| 190 |
+
world: WorldDefinition,
|
| 191 |
+
entities: dict[str, Any],
|
| 192 |
+
) -> None:
|
| 193 |
+
guardian_id = world.meta.win_condition.target_npc_id
|
| 194 |
+
for node in world.nodes:
|
| 195 |
+
if node.type != "npc":
|
| 196 |
+
continue
|
| 197 |
+
npc = entities[node.id]
|
| 198 |
+
if node.id == guardian_id:
|
| 199 |
+
continue
|
| 200 |
+
maker.add_fact("trade_pending", npc)
|
| 201 |
+
maker.add_fact("trade_requires", npc, entities[node.requires_item_id])
|
| 202 |
+
if node.gives_item_id:
|
| 203 |
+
maker.add_fact("trade_gives_item", npc, entities[node.gives_item_id])
|
| 204 |
+
if node.gives_clue_id:
|
| 205 |
+
maker.add_fact("trade_gives_clue", npc)
|
| 206 |
+
|
| 207 |
+
def _compile_recipes(self, maker: GameMaker, world: WorldDefinition, entities: dict[str, Any]) -> None:
|
| 208 |
+
for recipe in world.recipes:
|
| 209 |
+
a_id, b_id = recipe.input_item_ids
|
| 210 |
+
output = entities[recipe.output_item_id]
|
| 211 |
+
maker.add_fact("combines_with", entities[a_id], entities[b_id], output)
|
| 212 |
+
maker.add_fact("combines_with", entities[b_id], entities[a_id], output)
|
| 213 |
+
|
| 214 |
+
@staticmethod
|
| 215 |
+
def _compile_edges(
|
| 216 |
+
maker: GameMaker,
|
| 217 |
+
world: WorldDefinition,
|
| 218 |
+
rooms: dict[str, Any],
|
| 219 |
+
entities: dict[str, Any],
|
| 220 |
+
) -> None:
|
| 221 |
+
pair_groups: dict[frozenset[str], list[Any]] = defaultdict(list)
|
| 222 |
+
for edge in world.edges:
|
| 223 |
+
pair_groups.setdefault(frozenset({edge.from_node_id, edge.to_node_id}), []).append(edge)
|
| 224 |
+
for edges in pair_groups.values():
|
| 225 |
+
forward, backward = sorted(edges, key=lambda edge: edge.id)
|
| 226 |
+
for edge in (forward, backward):
|
| 227 |
+
maker.add_fact(f"{edge.direction}_of", rooms[edge.to_node_id], rooms[edge.from_node_id])
|
| 228 |
+
if forward.door_node_id:
|
| 229 |
+
door = entities[forward.door_node_id]
|
| 230 |
+
room_a = rooms[forward.from_node_id]
|
| 231 |
+
room_b = rooms[forward.to_node_id]
|
| 232 |
+
maker.add_fact("link", room_a, door, room_b)
|
| 233 |
+
maker.add_fact("link", room_b, door, room_a)
|
| 234 |
+
if forward.required_item_id:
|
| 235 |
+
maker.add_fact("match", entities[forward.required_item_id], door)
|
| 236 |
+
door_is_open = door.has_property("open")
|
| 237 |
+
if door_is_open:
|
| 238 |
+
maker.add_fact("free", room_a, room_b)
|
| 239 |
+
maker.add_fact("free", room_b, room_a)
|
| 240 |
+
else:
|
| 241 |
+
maker.add_fact("free", rooms[forward.from_node_id], rooms[forward.to_node_id])
|
| 242 |
+
maker.add_fact("free", rooms[forward.to_node_id], rooms[forward.from_node_id])
|
| 243 |
+
|
| 244 |
+
def _compiled_world(
|
| 245 |
+
self,
|
| 246 |
+
episode_id: str,
|
| 247 |
+
artifacts_dir: Path,
|
| 248 |
+
game_file: Path,
|
| 249 |
+
world: WorldDefinition,
|
| 250 |
+
entity_names: dict[str, str],
|
| 251 |
+
walkthrough_commands: list[str],
|
| 252 |
+
policy: list[str],
|
| 253 |
+
) -> CompiledWorld:
|
| 254 |
+
node_by_id = {node.id: node for node in world.nodes}
|
| 255 |
+
return CompiledWorld(
|
| 256 |
+
episode_id=episode_id,
|
| 257 |
+
world=world,
|
| 258 |
+
artifacts_dir=artifacts_dir,
|
| 259 |
+
game_file=game_file,
|
| 260 |
+
walkthrough_commands=walkthrough_commands,
|
| 261 |
+
solver_policy=policy,
|
| 262 |
+
correct_answer_normalized=submit_command_text(world).replace("submit ", "", 1),
|
| 263 |
+
correct_submit_command=submit_command_text(world),
|
| 264 |
+
guardian_id=world.meta.win_condition.target_npc_id,
|
| 265 |
+
guardian_room_id=node_by_id[world.meta.win_condition.target_npc_id].parent_id,
|
| 266 |
+
room_name_to_id={
|
| 267 |
+
entity_names[node.id]: node.id for node in world.nodes if node.type in {"location", "junction"}
|
| 268 |
+
},
|
| 269 |
+
node_command_names={node.id: entity_names[node.id] for node in world.nodes},
|
| 270 |
+
item_command_names={item.id: entity_names[item.id] for item in world.items},
|
| 271 |
+
item_start_locations={item.id: item.start_node_id for item in world.items},
|
| 272 |
+
clue_text_by_id={clue.id: clue.text for clue in world.clues},
|
| 273 |
+
readable_clue_by_id=readable_clue_mapping(world),
|
| 274 |
+
npc_trade_map=npc_trade_mapping(world),
|
| 275 |
+
recipe_map=recipe_mapping(world),
|
| 276 |
+
use_effects=use_effect_mapping(world),
|
| 277 |
+
produced_item_ids=produced_item_ids(world),
|
| 278 |
+
room_edges_by_target={(edge.from_node_id, edge.to_node_id): edge for edge in world.edges},
|
| 279 |
+
room_edges_by_direction={(edge.from_node_id, edge.direction): edge for edge in world.edges},
|
| 280 |
+
door_rooms=door_room_mapping(world),
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
@staticmethod
|
| 284 |
+
def _assign_command_names(world: WorldDefinition) -> dict[str, str]:
|
| 285 |
+
names = {node.id: parser_safe_text(node.label) for node in world.nodes}
|
| 286 |
+
names.update({item.id: parser_safe_text(item.label) for item in world.items})
|
| 287 |
+
return names
|
agents/master/check.py
ADDED
|
@@ -0,0 +1,435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from collections import defaultdict, deque
|
| 4 |
+
from typing import Any
|
| 5 |
+
|
| 6 |
+
from pydantic import ValidationError
|
| 7 |
+
|
| 8 |
+
from .base import (
|
| 9 |
+
DMCompileError,
|
| 10 |
+
MAX_CLUES,
|
| 11 |
+
MAX_ITEMS,
|
| 12 |
+
MAX_NODES,
|
| 13 |
+
MAX_QUEST_STEPS,
|
| 14 |
+
MIN_CLUES,
|
| 15 |
+
MIN_NODES,
|
| 16 |
+
MIN_QUEST_STEPS,
|
| 17 |
+
OPPOSITE_DIRECTION,
|
| 18 |
+
normalize_answer_text,
|
| 19 |
+
normalize_snake_id,
|
| 20 |
+
parser_safe_text,
|
| 21 |
+
)
|
| 22 |
+
from .graph import hidden_readable_ids, produced_item_ids
|
| 23 |
+
from .quest import parse_quest_action, simulate_walkthrough, topological_linearize
|
| 24 |
+
from .schema import (
|
| 25 |
+
CombineAction,
|
| 26 |
+
ContainerNode,
|
| 27 |
+
DoorNode,
|
| 28 |
+
GiveAction,
|
| 29 |
+
NpcNode,
|
| 30 |
+
ReadableNode,
|
| 31 |
+
SubmitAction,
|
| 32 |
+
TakeAction,
|
| 33 |
+
TalkAction,
|
| 34 |
+
UnlockAction,
|
| 35 |
+
UseAction,
|
| 36 |
+
WorldDefinition,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def validate_and_normalize(world_input: WorldDefinition | dict[str, Any]) -> WorldDefinition:
|
| 41 |
+
if isinstance(world_input, dict):
|
| 42 |
+
_reject_legacy_shapes(world_input)
|
| 43 |
+
try:
|
| 44 |
+
world = WorldDefinition.model_validate(world_input)
|
| 45 |
+
except ValidationError as exc: # pragma: no cover - exercised indirectly in compile paths
|
| 46 |
+
raise DMCompileError(str(exc)) from exc
|
| 47 |
+
_validate_ids(world)
|
| 48 |
+
_validate_shape(world)
|
| 49 |
+
_validate_nodes(world)
|
| 50 |
+
_validate_edges(world)
|
| 51 |
+
_validate_items(world)
|
| 52 |
+
_validate_clues(world)
|
| 53 |
+
_validate_visibility(world)
|
| 54 |
+
_validate_answer_leaks(world)
|
| 55 |
+
_validate_guardian_path(world)
|
| 56 |
+
_validate_clue_gates(world)
|
| 57 |
+
_validate_item_usage(world)
|
| 58 |
+
_validate_quest_shape(world)
|
| 59 |
+
return world
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def infer_start_room(world: WorldDefinition) -> str:
|
| 63 |
+
return world.meta.start_node_id
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _reject_legacy_shapes(world_input: dict[str, Any]) -> None:
|
| 67 |
+
for node in world_input.get("nodes", []):
|
| 68 |
+
if node.get("type") == "clue":
|
| 69 |
+
raise DMCompileError("Legacy clue nodes are not supported in v2. Use top-level clues[].")
|
| 70 |
+
if node.get("state", {}).get("npc_dialogue") is not None:
|
| 71 |
+
raise DMCompileError("Legacy npc_dialogue is not supported in v2.")
|
| 72 |
+
for edge in world_input.get("edges", []):
|
| 73 |
+
if edge.get("type") == "conditional_passage":
|
| 74 |
+
raise DMCompileError("conditional_passage is not supported in v2.")
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def _validate_ids(world: WorldDefinition) -> None:
|
| 78 |
+
global_ids: set[str] = set()
|
| 79 |
+
collections = {
|
| 80 |
+
"node": [node.id for node in world.nodes],
|
| 81 |
+
"item": [item.id for item in world.items],
|
| 82 |
+
"clue": [clue.id for clue in world.clues],
|
| 83 |
+
"recipe": [recipe.id for recipe in world.recipes],
|
| 84 |
+
"quest step": [step.step_id for step in world.quest_chain],
|
| 85 |
+
}
|
| 86 |
+
for kind, values in collections.items():
|
| 87 |
+
seen: set[str] = set()
|
| 88 |
+
for value in values:
|
| 89 |
+
normalize_snake_id(value, kind)
|
| 90 |
+
if value in seen:
|
| 91 |
+
raise DMCompileError(f"Duplicate {kind} id '{value}'.")
|
| 92 |
+
if value in global_ids:
|
| 93 |
+
raise DMCompileError(f"Duplicate world id '{value}' across collections.")
|
| 94 |
+
seen.add(value)
|
| 95 |
+
global_ids.add(value)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def _validate_shape(world: WorldDefinition) -> None:
|
| 99 |
+
room_nodes = [node for node in world.nodes if node.type in {"location", "junction"}]
|
| 100 |
+
if len(world.nodes) < MIN_NODES:
|
| 101 |
+
raise DMCompileError(f"Worlds need at least {MIN_NODES} nodes.")
|
| 102 |
+
if len(world.nodes) > MAX_NODES:
|
| 103 |
+
raise DMCompileError(f"Worlds support at most {MAX_NODES} nodes.")
|
| 104 |
+
if len(world.items) > MAX_ITEMS:
|
| 105 |
+
raise DMCompileError(f"Worlds support at most {MAX_ITEMS} items.")
|
| 106 |
+
if len(world.clues) < MIN_CLUES or len(world.clues) > MAX_CLUES:
|
| 107 |
+
raise DMCompileError(f"Worlds must define between {MIN_CLUES} and {MAX_CLUES} clues.")
|
| 108 |
+
if len(world.quest_chain) < MIN_QUEST_STEPS or len(world.quest_chain) > MAX_QUEST_STEPS:
|
| 109 |
+
raise DMCompileError(f"quest_chain must contain between {MIN_QUEST_STEPS} and {MAX_QUEST_STEPS} steps.")
|
| 110 |
+
if world.meta.start_node_id not in {node.id for node in room_nodes}:
|
| 111 |
+
raise DMCompileError("meta.start_node_id must reference a location or junction.")
|
| 112 |
+
if world.meta.win_condition.type != "deduce":
|
| 113 |
+
raise DMCompileError("Only deduce win conditions are supported in v2.")
|
| 114 |
+
if not normalize_answer_text(world.meta.win_condition.answer_string):
|
| 115 |
+
raise DMCompileError("answer_string cannot normalize to an empty command.")
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def _validate_nodes(world: WorldDefinition) -> None:
|
| 119 |
+
node_by_id = {node.id: node for node in world.nodes}
|
| 120 |
+
item_ids = {item.id for item in world.items}
|
| 121 |
+
clue_ids = {clue.id for clue in world.clues}
|
| 122 |
+
hidden_readables = hidden_readable_ids(world)
|
| 123 |
+
guardian_id = world.meta.win_condition.target_npc_id
|
| 124 |
+
|
| 125 |
+
guardian_seen = False
|
| 126 |
+
for node in world.nodes:
|
| 127 |
+
if node.type in {"location", "junction"}:
|
| 128 |
+
continue
|
| 129 |
+
if node.type == "door":
|
| 130 |
+
_validate_lockable(node, item_ids)
|
| 131 |
+
continue
|
| 132 |
+
parent = node_by_id.get(node.parent_id)
|
| 133 |
+
if parent is None or parent.type not in {"location", "junction"}:
|
| 134 |
+
raise DMCompileError(f"Node '{node.id}' must live in a location or junction.")
|
| 135 |
+
if node.type == "container":
|
| 136 |
+
_validate_lockable(node, item_ids)
|
| 137 |
+
elif node.type == "readable":
|
| 138 |
+
if node.clue_id not in clue_ids:
|
| 139 |
+
raise DMCompileError(f"Readable '{node.id}' references unknown clue '{node.clue_id}'.")
|
| 140 |
+
if node.requires_item_id and node.requires_item_id not in item_ids:
|
| 141 |
+
raise DMCompileError(f"Readable '{node.id}' references unknown item '{node.requires_item_id}'.")
|
| 142 |
+
elif node.type == "fixture":
|
| 143 |
+
if node.requires_item_id not in item_ids:
|
| 144 |
+
raise DMCompileError(f"Fixture '{node.id}' references unknown item '{node.requires_item_id}'.")
|
| 145 |
+
if bool(node.reveals_item_id) == bool(node.reveals_readable_id):
|
| 146 |
+
raise DMCompileError(f"Fixture '{node.id}' must reveal exactly one item or readable.")
|
| 147 |
+
if node.reveals_item_id and node.reveals_item_id not in item_ids:
|
| 148 |
+
raise DMCompileError(f"Fixture '{node.id}' reveals unknown item '{node.reveals_item_id}'.")
|
| 149 |
+
if node.reveals_readable_id and node.reveals_readable_id not in node_by_id:
|
| 150 |
+
raise DMCompileError(f"Fixture '{node.id}' reveals unknown readable '{node.reveals_readable_id}'.")
|
| 151 |
+
if node.reveals_readable_id:
|
| 152 |
+
readable = node_by_id[node.reveals_readable_id]
|
| 153 |
+
if not isinstance(readable, ReadableNode):
|
| 154 |
+
raise DMCompileError(f"Fixture '{node.id}' can only reveal readable nodes.")
|
| 155 |
+
if readable.parent_id != node.parent_id:
|
| 156 |
+
raise DMCompileError(
|
| 157 |
+
f"Fixture '{node.id}' must reveal readable '{readable.id}' in the same room."
|
| 158 |
+
)
|
| 159 |
+
elif node.type == "npc":
|
| 160 |
+
if node.id == guardian_id:
|
| 161 |
+
guardian_seen = True
|
| 162 |
+
if node.requires_item_id or node.gives_item_id or node.gives_clue_id:
|
| 163 |
+
raise DMCompileError("Guardian NPC cannot have trade fields.")
|
| 164 |
+
else:
|
| 165 |
+
if not node.requires_item_id:
|
| 166 |
+
raise DMCompileError(f"NPC '{node.id}' requires requires_item_id in v2.")
|
| 167 |
+
if node.requires_item_id not in item_ids:
|
| 168 |
+
raise DMCompileError(f"NPC '{node.id}' references unknown item '{node.requires_item_id}'.")
|
| 169 |
+
if bool(node.gives_item_id) == bool(node.gives_clue_id):
|
| 170 |
+
raise DMCompileError(
|
| 171 |
+
f"NPC '{node.id}' must define exactly one of gives_item_id or gives_clue_id."
|
| 172 |
+
)
|
| 173 |
+
if node.gives_item_id and node.gives_item_id not in item_ids:
|
| 174 |
+
raise DMCompileError(f"NPC '{node.id}' gives unknown item '{node.gives_item_id}'.")
|
| 175 |
+
if node.gives_clue_id and node.gives_clue_id not in clue_ids:
|
| 176 |
+
raise DMCompileError(f"NPC '{node.id}' gives unknown clue '{node.gives_clue_id}'.")
|
| 177 |
+
else: # pragma: no cover
|
| 178 |
+
raise AssertionError(f"Unhandled node type {node.type}")
|
| 179 |
+
|
| 180 |
+
if not guardian_seen:
|
| 181 |
+
raise DMCompileError(f"Guardian NPC '{guardian_id}' does not exist.")
|
| 182 |
+
for readable_id in hidden_readables:
|
| 183 |
+
readable = node_by_id[readable_id]
|
| 184 |
+
if not isinstance(readable, ReadableNode):
|
| 185 |
+
raise DMCompileError(f"Only readable nodes can be hidden, not '{readable_id}'.")
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def _validate_lockable(node: ContainerNode | DoorNode, item_ids: set[str]) -> None:
|
| 189 |
+
if node.open and node.locked:
|
| 190 |
+
raise DMCompileError(f"Lockable node '{node.id}' cannot be both open and locked.")
|
| 191 |
+
if node.locked and not node.lock_key_id:
|
| 192 |
+
raise DMCompileError(f"Lockable node '{node.id}' is locked but has no lock_key_id.")
|
| 193 |
+
if node.lock_key_id and node.lock_key_id not in item_ids:
|
| 194 |
+
raise DMCompileError(f"Lockable node '{node.id}' references unknown key '{node.lock_key_id}'.")
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def _validate_edges(world: WorldDefinition) -> None:
|
| 198 |
+
room_ids = {node.id for node in world.nodes if node.type in {"location", "junction"}}
|
| 199 |
+
node_by_id = {node.id: node for node in world.nodes}
|
| 200 |
+
item_ids = {item.id for item in world.items}
|
| 201 |
+
pair_groups: dict[frozenset[str], list[Any]] = defaultdict(list)
|
| 202 |
+
graph: dict[str, set[str]] = defaultdict(set)
|
| 203 |
+
direction_map: dict[tuple[str, str], str] = {}
|
| 204 |
+
|
| 205 |
+
for edge in world.edges:
|
| 206 |
+
if edge.from_node_id not in room_ids or edge.to_node_id not in room_ids:
|
| 207 |
+
raise DMCompileError(f"Edge '{edge.id}' must connect location or junction nodes only.")
|
| 208 |
+
if edge.from_node_id == edge.to_node_id:
|
| 209 |
+
raise DMCompileError(f"Edge '{edge.id}' cannot be self-referential.")
|
| 210 |
+
if edge.required_item_id and edge.required_item_id not in item_ids:
|
| 211 |
+
raise DMCompileError(f"Edge '{edge.id}' references unknown item '{edge.required_item_id}'.")
|
| 212 |
+
if edge.required_item_id and edge.required_item_id not in {
|
| 213 |
+
item.id for item in world.items if item.subtype == "key"
|
| 214 |
+
}:
|
| 215 |
+
raise DMCompileError(f"Edge '{edge.id}' must use a key item, not '{edge.required_item_id}'.")
|
| 216 |
+
if edge.type == "locked_passage":
|
| 217 |
+
if not edge.door_node_id:
|
| 218 |
+
raise DMCompileError(f"Locked edge '{edge.id}' requires door_node_id.")
|
| 219 |
+
if not edge.required_item_id:
|
| 220 |
+
raise DMCompileError(f"Locked edge '{edge.id}' requires required_item_id.")
|
| 221 |
+
elif edge.required_item_id is not None:
|
| 222 |
+
raise DMCompileError(f"Only locked_passage edges can reference required_item_id (edge '{edge.id}').")
|
| 223 |
+
if edge.door_node_id:
|
| 224 |
+
door = node_by_id.get(edge.door_node_id)
|
| 225 |
+
if not isinstance(door, DoorNode):
|
| 226 |
+
raise DMCompileError(f"Edge '{edge.id}' references unknown door '{edge.door_node_id}'.")
|
| 227 |
+
if edge.required_item_id and door.lock_key_id != edge.required_item_id:
|
| 228 |
+
raise DMCompileError(f"Edge '{edge.id}' and door '{door.id}' disagree on the key.")
|
| 229 |
+
key = (edge.from_node_id, edge.direction)
|
| 230 |
+
if key in direction_map:
|
| 231 |
+
raise DMCompileError(
|
| 232 |
+
f"Edges '{direction_map[key]}' and '{edge.id}' both leave '{edge.from_node_id}' via '{edge.direction}'."
|
| 233 |
+
)
|
| 234 |
+
direction_map[key] = edge.id
|
| 235 |
+
graph[edge.from_node_id].add(edge.to_node_id)
|
| 236 |
+
pair_groups[frozenset({edge.from_node_id, edge.to_node_id})].append(edge)
|
| 237 |
+
|
| 238 |
+
for pair, edges in pair_groups.items():
|
| 239 |
+
if len(edges) != 2:
|
| 240 |
+
raise DMCompileError(f"Edges between {', '.join(sorted(pair))} must be explicitly bidirectional.")
|
| 241 |
+
a, b = edges
|
| 242 |
+
if OPPOSITE_DIRECTION[a.direction] != b.direction:
|
| 243 |
+
raise DMCompileError(f"Edges '{a.id}' and '{b.id}' must use opposite directions.")
|
| 244 |
+
if a.type != b.type or a.required_item_id != b.required_item_id or a.door_node_id != b.door_node_id:
|
| 245 |
+
raise DMCompileError(f"Edge pair '{a.id}'/'{b.id}' must agree on type, key, and door.")
|
| 246 |
+
|
| 247 |
+
reachable = _reachable_rooms(graph, world.meta.start_node_id)
|
| 248 |
+
if reachable != room_ids:
|
| 249 |
+
raise DMCompileError(f"Some rooms are unreachable from the start node: {sorted(room_ids - reachable)}")
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def _validate_items(world: WorldDefinition) -> None:
|
| 253 |
+
node_by_id = {node.id: node for node in world.nodes}
|
| 254 |
+
produced = produced_item_ids(world)
|
| 255 |
+
recipe_outputs: set[str] = set()
|
| 256 |
+
recipe_inputs: set[frozenset[str]] = set()
|
| 257 |
+
for recipe in world.recipes:
|
| 258 |
+
inputs = frozenset(recipe.input_item_ids)
|
| 259 |
+
if len(inputs) != 2:
|
| 260 |
+
raise DMCompileError(f"Recipe '{recipe.id}' must have exactly two distinct input items.")
|
| 261 |
+
if inputs in recipe_inputs:
|
| 262 |
+
raise DMCompileError(f"Duplicate recipe inputs in '{recipe.id}'.")
|
| 263 |
+
recipe_inputs.add(inputs)
|
| 264 |
+
if recipe.output_item_id in recipe_outputs:
|
| 265 |
+
raise DMCompileError(f"Item '{recipe.output_item_id}' is produced by multiple recipes.")
|
| 266 |
+
recipe_outputs.add(recipe.output_item_id)
|
| 267 |
+
|
| 268 |
+
for item in world.items:
|
| 269 |
+
if item.id in produced and item.start_node_id is not None:
|
| 270 |
+
raise DMCompileError(f"Produced item '{item.id}' must not be initially placed.")
|
| 271 |
+
if item.id not in produced and item.start_node_id is None:
|
| 272 |
+
raise DMCompileError(f"Placed item '{item.id}' requires start_node_id.")
|
| 273 |
+
if item.start_node_id is None:
|
| 274 |
+
continue
|
| 275 |
+
holder = node_by_id.get(item.start_node_id)
|
| 276 |
+
if holder is None:
|
| 277 |
+
raise DMCompileError(f"Item '{item.id}' starts in unknown node '{item.start_node_id}'.")
|
| 278 |
+
if holder.type not in {"location", "junction", "container"}:
|
| 279 |
+
raise DMCompileError(f"Item '{item.id}' must start in a room or container.")
|
| 280 |
+
if item.subtype not in {"key", "puzzle"}:
|
| 281 |
+
raise DMCompileError(f"Item '{item.id}' uses unsupported subtype '{item.subtype}'.")
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def _validate_clues(world: WorldDefinition) -> None:
|
| 285 |
+
clue_sources: dict[str, list[str]] = defaultdict(list)
|
| 286 |
+
for node in world.nodes:
|
| 287 |
+
if isinstance(node, ReadableNode):
|
| 288 |
+
clue_sources[node.clue_id].append(node.id)
|
| 289 |
+
elif isinstance(node, NpcNode) and node.gives_clue_id:
|
| 290 |
+
clue_sources[node.gives_clue_id].append(node.id)
|
| 291 |
+
|
| 292 |
+
clue_ids = {clue.id for clue in world.clues}
|
| 293 |
+
if set(clue_sources) != clue_ids:
|
| 294 |
+
missing = sorted(clue_ids - set(clue_sources))
|
| 295 |
+
raise DMCompileError(f"Every clue needs exactly one source. Missing: {missing}")
|
| 296 |
+
for clue_id, source_ids in sorted(clue_sources.items()):
|
| 297 |
+
if len(source_ids) > 1:
|
| 298 |
+
raise DMCompileError(
|
| 299 |
+
f"Clue '{clue_id}' has multiple sources: {', '.join(sorted(source_ids))}."
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def _validate_visibility(world: WorldDefinition) -> None:
|
| 304 |
+
names: dict[str, str] = {}
|
| 305 |
+
for label in [node.label for node in world.nodes] + [item.label for item in world.items]:
|
| 306 |
+
safe = parser_safe_text(label)
|
| 307 |
+
if safe in names:
|
| 308 |
+
raise DMCompileError(
|
| 309 |
+
f"Visible labels '{label}' and '{names[safe]}' collapse to the same parser name '{safe}'."
|
| 310 |
+
)
|
| 311 |
+
names[safe] = label
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def _validate_answer_leaks(world: WorldDefinition) -> None:
|
| 315 |
+
answer = normalize_answer_text(world.meta.win_condition.answer_string)
|
| 316 |
+
forbidden = {f"the answer is {answer}", f"answer is {answer}", f"submit {answer}"}
|
| 317 |
+
text_fragments = [world.meta.title]
|
| 318 |
+
text_fragments.extend(clue.text for clue in world.clues)
|
| 319 |
+
for node in world.nodes:
|
| 320 |
+
text_fragments.extend([node.label, node.description])
|
| 321 |
+
if isinstance(node, ReadableNode):
|
| 322 |
+
text_fragments.append(node.text_content)
|
| 323 |
+
for text in text_fragments:
|
| 324 |
+
normalized = normalize_answer_text(text)
|
| 325 |
+
if any(phrase in normalized for phrase in forbidden):
|
| 326 |
+
raise DMCompileError("World leaks the final answer too directly. Clues must stay partial.")
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
def _validate_guardian_path(world: WorldDefinition) -> None:
|
| 330 |
+
node_by_id = {node.id: node for node in world.nodes}
|
| 331 |
+
guardian = node_by_id[world.meta.win_condition.target_npc_id]
|
| 332 |
+
graph: dict[str, set[str]] = defaultdict(set)
|
| 333 |
+
for edge in world.edges:
|
| 334 |
+
if edge.type == "passage":
|
| 335 |
+
graph[edge.from_node_id].add(edge.to_node_id)
|
| 336 |
+
reachable = _reachable_rooms(graph, world.meta.start_node_id)
|
| 337 |
+
if guardian.parent_id not in reachable:
|
| 338 |
+
raise DMCompileError("Guardian room must be reachable from the start without item gates.")
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def _validate_clue_gates(world: WorldDefinition) -> None:
|
| 342 |
+
reachable = _reachable_zero_item_rooms(world)
|
| 343 |
+
hidden_readables = hidden_readable_ids(world)
|
| 344 |
+
for node in world.nodes:
|
| 345 |
+
if isinstance(node, ReadableNode):
|
| 346 |
+
if node.id in hidden_readables:
|
| 347 |
+
continue
|
| 348 |
+
if node.parent_id not in reachable:
|
| 349 |
+
continue
|
| 350 |
+
if node.requires_item_id:
|
| 351 |
+
continue
|
| 352 |
+
raise DMCompileError(
|
| 353 |
+
f"Readable '{node.id}' exposes clue '{node.clue_id}' without any item interaction."
|
| 354 |
+
)
|
| 355 |
+
if isinstance(node, NpcNode) and node.gives_clue_id and not node.requires_item_id:
|
| 356 |
+
raise DMCompileError(f"NPC '{node.id}' gives clue '{node.gives_clue_id}' without an item gate.")
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
def _validate_item_usage(world: WorldDefinition) -> None:
|
| 360 |
+
quest_items: set[str] = set()
|
| 361 |
+
ordered = topological_linearize(world.quest_chain)
|
| 362 |
+
for action in (parse_quest_action(step.action) for step in ordered):
|
| 363 |
+
if isinstance(action, UnlockAction):
|
| 364 |
+
quest_items.add(action.key_id)
|
| 365 |
+
elif isinstance(action, (UseAction, GiveAction)):
|
| 366 |
+
quest_items.add(action.item_id)
|
| 367 |
+
elif isinstance(action, CombineAction):
|
| 368 |
+
quest_items.update({action.item_a_id, action.item_b_id})
|
| 369 |
+
elif isinstance(action, TakeAction):
|
| 370 |
+
quest_items.add(action.item_id)
|
| 371 |
+
|
| 372 |
+
mechanical_items = {
|
| 373 |
+
edge.required_item_id
|
| 374 |
+
for edge in world.edges
|
| 375 |
+
if edge.required_item_id
|
| 376 |
+
}
|
| 377 |
+
for node in world.nodes:
|
| 378 |
+
if node.type == "container" and node.lock_key_id:
|
| 379 |
+
mechanical_items.add(node.lock_key_id)
|
| 380 |
+
elif node.type == "door" and node.lock_key_id:
|
| 381 |
+
mechanical_items.add(node.lock_key_id)
|
| 382 |
+
elif node.type == "readable" and node.requires_item_id:
|
| 383 |
+
mechanical_items.add(node.requires_item_id)
|
| 384 |
+
elif node.type == "fixture":
|
| 385 |
+
mechanical_items.add(node.requires_item_id)
|
| 386 |
+
if node.reveals_item_id:
|
| 387 |
+
mechanical_items.add(node.reveals_item_id)
|
| 388 |
+
elif node.type == "npc":
|
| 389 |
+
if node.requires_item_id:
|
| 390 |
+
mechanical_items.add(node.requires_item_id)
|
| 391 |
+
if node.gives_item_id:
|
| 392 |
+
mechanical_items.add(node.gives_item_id)
|
| 393 |
+
for recipe in world.recipes:
|
| 394 |
+
mechanical_items.update(recipe.input_item_ids)
|
| 395 |
+
mechanical_items.add(recipe.output_item_id)
|
| 396 |
+
|
| 397 |
+
for item in world.items:
|
| 398 |
+
if item.id not in quest_items and item.id not in mechanical_items:
|
| 399 |
+
raise DMCompileError(f"Unused decorative items are not supported in v2: '{item.id}'.")
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
def _validate_quest_shape(world: WorldDefinition) -> None:
|
| 403 |
+
ordered = topological_linearize(world.quest_chain)
|
| 404 |
+
parsed = [parse_quest_action(step.action) for step in ordered]
|
| 405 |
+
if not isinstance(parsed[-1], SubmitAction):
|
| 406 |
+
raise DMCompileError('The final quest step must be submit("answer").')
|
| 407 |
+
if len(parsed) < 2 or not isinstance(parsed[-2], TalkAction):
|
| 408 |
+
raise DMCompileError("The penultimate quest step must be talk(guardian).")
|
| 409 |
+
if parsed[-2].target_node_id != world.meta.win_condition.target_npc_id:
|
| 410 |
+
raise DMCompileError("The final talk step must target the guardian NPC.")
|
| 411 |
+
if normalize_answer_text(parsed[-1].answer_text) != normalize_answer_text(world.meta.win_condition.answer_string):
|
| 412 |
+
raise DMCompileError("The final submit step must match win_condition.answer_string.")
|
| 413 |
+
entity_names = {node.id: parser_safe_text(node.label) for node in world.nodes}
|
| 414 |
+
entity_names.update({item.id: parser_safe_text(item.label) for item in world.items})
|
| 415 |
+
simulate_walkthrough(world, parsed, entity_names)
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
def _reachable_rooms(graph: dict[str, set[str]], start: str) -> set[str]:
|
| 419 |
+
seen = {start}
|
| 420 |
+
queue = deque([start])
|
| 421 |
+
while queue:
|
| 422 |
+
current = queue.popleft()
|
| 423 |
+
for nxt in graph.get(current, set()):
|
| 424 |
+
if nxt not in seen:
|
| 425 |
+
seen.add(nxt)
|
| 426 |
+
queue.append(nxt)
|
| 427 |
+
return seen
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
def _reachable_zero_item_rooms(world: WorldDefinition) -> set[str]:
|
| 431 |
+
graph: dict[str, set[str]] = defaultdict(set)
|
| 432 |
+
for edge in world.edges:
|
| 433 |
+
if edge.type == "passage":
|
| 434 |
+
graph[edge.from_node_id].add(edge.to_node_id)
|
| 435 |
+
return _reachable_rooms(graph, world.meta.start_node_id)
|
agents/master/env.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
import uuid
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
| 8 |
+
from .base import DMCompileError, DMInterfaceError, MAX_STEP_MULTIPLIER, TARGET_RATIO, TARGET_RATIO_SIGMA
|
| 9 |
+
from .build import WorldCompiler
|
| 10 |
+
from .interface import InterfaceAdapter, SimpleInterfaceAdapter
|
| 11 |
+
from .play import EpisodeRunner, WalkthroughRunner
|
| 12 |
+
from .schema import (
|
| 13 |
+
CompiledWorld,
|
| 14 |
+
DMAction,
|
| 15 |
+
DMFeedback,
|
| 16 |
+
DMObservation,
|
| 17 |
+
DMRewardBreakdown,
|
| 18 |
+
DMState,
|
| 19 |
+
Turn,
|
| 20 |
+
WorldDefinition,
|
| 21 |
+
)
|
| 22 |
+
from .session import EpisodeSession
|
| 23 |
+
from .snapshots import LiveObserver
|
| 24 |
+
from agents.shared.openenv_compat import Environment, StepResult, build_step_result
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class DMEnvironment(Environment[DMAction, DMObservation, DMState]):
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
artifacts_root: Path | None = None,
|
| 31 |
+
target_ratio: float = TARGET_RATIO,
|
| 32 |
+
reward_sigma: float = TARGET_RATIO_SIGMA,
|
| 33 |
+
max_step_multiplier: int = MAX_STEP_MULTIPLIER,
|
| 34 |
+
interface_adapter: InterfaceAdapter = SimpleInterfaceAdapter(),
|
| 35 |
+
default_runner: EpisodeRunner | None = None,
|
| 36 |
+
) -> None:
|
| 37 |
+
super().__init__()
|
| 38 |
+
if interface_adapter is None:
|
| 39 |
+
raise ValueError("interface_adapter must not be None.")
|
| 40 |
+
self.compiler = WorldCompiler(artifacts_root=artifacts_root)
|
| 41 |
+
self.target_ratio = target_ratio
|
| 42 |
+
self.reward_sigma = reward_sigma
|
| 43 |
+
self.max_step_multiplier = max_step_multiplier
|
| 44 |
+
self.interface_adapter = interface_adapter
|
| 45 |
+
self.default_runner = default_runner or WalkthroughRunner()
|
| 46 |
+
self.episode_count = 0
|
| 47 |
+
self.success_count = 0
|
| 48 |
+
self._state = DMState(
|
| 49 |
+
episode_id=uuid.uuid4().hex[:12],
|
| 50 |
+
target_ratio=target_ratio,
|
| 51 |
+
)
|
| 52 |
+
self.last_compiled_world: CompiledWorld | None = None
|
| 53 |
+
|
| 54 |
+
def reset(self, difficulty_hint: float | None = None, seed: int | None = None) -> DMObservation:
|
| 55 |
+
del seed
|
| 56 |
+
episode_target_ratio = self.target_ratio if difficulty_hint is None else difficulty_hint
|
| 57 |
+
self._state = DMState(
|
| 58 |
+
episode_id=uuid.uuid4().hex[:12],
|
| 59 |
+
compile_status="pending",
|
| 60 |
+
episode_status="running",
|
| 61 |
+
cumulative_success_rate=self._running_success_rate(),
|
| 62 |
+
target_ratio=episode_target_ratio,
|
| 63 |
+
difficulty_hint=difficulty_hint,
|
| 64 |
+
)
|
| 65 |
+
self.last_compiled_world = None
|
| 66 |
+
return self._apply_transform(
|
| 67 |
+
DMObservation(
|
| 68 |
+
done=False,
|
| 69 |
+
reward=None,
|
| 70 |
+
target_ratio_used=episode_target_ratio,
|
| 71 |
+
)
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
def step( # type: ignore[override]
|
| 75 |
+
self,
|
| 76 |
+
action: DMAction | WorldDefinition | dict[str, Any],
|
| 77 |
+
runner: EpisodeRunner | None = None,
|
| 78 |
+
observer: LiveObserver | None = None,
|
| 79 |
+
timeout_s: float | None = None,
|
| 80 |
+
) -> StepResult[DMObservation]:
|
| 81 |
+
del timeout_s
|
| 82 |
+
world_input = action.world_definition if isinstance(action, DMAction) else action
|
| 83 |
+
compiled: CompiledWorld | None = None
|
| 84 |
+
session: EpisodeSession | None = None
|
| 85 |
+
if observer is not None:
|
| 86 |
+
observer.on_run_start(self._state.episode_id, world_input)
|
| 87 |
+
self.last_compiled_world = None
|
| 88 |
+
self._state.current_world = None
|
| 89 |
+
try:
|
| 90 |
+
compiled = self.compiler.compile(world_input, episode_id=self._state.episode_id)
|
| 91 |
+
self.last_compiled_world = compiled
|
| 92 |
+
self._state.current_world = compiled.world
|
| 93 |
+
self._state.compile_status = "valid"
|
| 94 |
+
max_steps = max(1, len(compiled.solver_policy) * self.max_step_multiplier)
|
| 95 |
+
|
| 96 |
+
def on_turn(current_session: EpisodeSession, turn: Turn) -> None:
|
| 97 |
+
self._state.step_count = current_session.steps_taken
|
| 98 |
+
if observer is not None:
|
| 99 |
+
observer.on_turn(current_session, turn)
|
| 100 |
+
|
| 101 |
+
session = EpisodeSession(
|
| 102 |
+
compiled,
|
| 103 |
+
interface_adapter=self.interface_adapter,
|
| 104 |
+
turn_listener=on_turn,
|
| 105 |
+
)
|
| 106 |
+
if observer is not None:
|
| 107 |
+
observer.on_compile_success(compiled, session)
|
| 108 |
+
(runner or self.default_runner).run(session, max_steps=max_steps)
|
| 109 |
+
player_won = bool(session.player_won)
|
| 110 |
+
min_steps = len(compiled.solver_policy)
|
| 111 |
+
reward_breakdown = self._reward_breakdown(player_won, session.steps_taken, min_steps)
|
| 112 |
+
reward = reward_breakdown.reward
|
| 113 |
+
self.episode_count += 1
|
| 114 |
+
self.success_count += int(player_won)
|
| 115 |
+
self._state.step_count = session.steps_taken
|
| 116 |
+
self._state.episode_status = "complete" if player_won else "failed"
|
| 117 |
+
self._state.cumulative_success_rate = self._running_success_rate()
|
| 118 |
+
observation = self._apply_transform(
|
| 119 |
+
DMObservation(
|
| 120 |
+
episode_transcript=session.transcript,
|
| 121 |
+
player_won=player_won,
|
| 122 |
+
steps_taken=session.steps_taken,
|
| 123 |
+
min_steps=min_steps,
|
| 124 |
+
ratio=(session.steps_taken / min_steps) if min_steps else None,
|
| 125 |
+
reward=reward,
|
| 126 |
+
done=True,
|
| 127 |
+
feedback=self._build_feedback(compiled, session),
|
| 128 |
+
reward_breakdown=reward_breakdown,
|
| 129 |
+
target_ratio_used=self._state.target_ratio,
|
| 130 |
+
)
|
| 131 |
+
)
|
| 132 |
+
if observer is not None:
|
| 133 |
+
observer.on_complete(compiled, session, observation)
|
| 134 |
+
return build_step_result(observation)
|
| 135 |
+
except (DMCompileError, DMInterfaceError, ValueError) as exc:
|
| 136 |
+
self.last_compiled_world = None
|
| 137 |
+
self._state.current_world = None
|
| 138 |
+
self._state.compile_status = "invalid"
|
| 139 |
+
self._state.episode_status = "failed"
|
| 140 |
+
if observer is not None:
|
| 141 |
+
observer.on_error(
|
| 142 |
+
episode_id=self._state.episode_id,
|
| 143 |
+
error=str(exc),
|
| 144 |
+
world_input=world_input,
|
| 145 |
+
compiled=compiled,
|
| 146 |
+
session=session,
|
| 147 |
+
)
|
| 148 |
+
observation = self._apply_transform(
|
| 149 |
+
DMObservation(
|
| 150 |
+
player_won=False,
|
| 151 |
+
compile_error=str(exc),
|
| 152 |
+
reward=0.0,
|
| 153 |
+
done=True,
|
| 154 |
+
reward_breakdown=DMRewardBreakdown(
|
| 155 |
+
reward_mode="compile_failure_penalty",
|
| 156 |
+
player_won=False,
|
| 157 |
+
target_ratio=self._state.target_ratio,
|
| 158 |
+
quality_score=0.0,
|
| 159 |
+
reward=0.0,
|
| 160 |
+
),
|
| 161 |
+
target_ratio_used=self._state.target_ratio,
|
| 162 |
+
)
|
| 163 |
+
)
|
| 164 |
+
return build_step_result(observation)
|
| 165 |
+
finally:
|
| 166 |
+
if session is not None:
|
| 167 |
+
session.close()
|
| 168 |
+
|
| 169 |
+
def compile_world(
|
| 170 |
+
self,
|
| 171 |
+
world_input: WorldDefinition | dict[str, Any],
|
| 172 |
+
*,
|
| 173 |
+
episode_id: str | None = None,
|
| 174 |
+
) -> CompiledWorld:
|
| 175 |
+
return self.compiler.compile(world_input, episode_id=episode_id)
|
| 176 |
+
|
| 177 |
+
def play(
|
| 178 |
+
self,
|
| 179 |
+
world_input: WorldDefinition | dict[str, Any],
|
| 180 |
+
runner: EpisodeRunner | None = None,
|
| 181 |
+
observer: LiveObserver | None = None,
|
| 182 |
+
) -> StepResult[DMObservation]:
|
| 183 |
+
self.reset()
|
| 184 |
+
return self.step(world_input, runner=runner, observer=observer)
|
| 185 |
+
|
| 186 |
+
@property
|
| 187 |
+
def state(self) -> DMState:
|
| 188 |
+
return self._state
|
| 189 |
+
|
| 190 |
+
def _reward_breakdown(
|
| 191 |
+
self,
|
| 192 |
+
player_won: bool,
|
| 193 |
+
steps_taken: int | None,
|
| 194 |
+
min_steps: int | None,
|
| 195 |
+
) -> DMRewardBreakdown:
|
| 196 |
+
raw_ratio: float | None = None
|
| 197 |
+
clamped_ratio: float | None = None
|
| 198 |
+
target_ratio_delta: float | None = None
|
| 199 |
+
efficiency_score: float | None = None
|
| 200 |
+
quality_score = 0.0
|
| 201 |
+
if steps_taken is not None and min_steps is not None and min_steps > 0:
|
| 202 |
+
raw_ratio = steps_taken / min_steps
|
| 203 |
+
clamped_ratio = max(raw_ratio, 1.0)
|
| 204 |
+
target_ratio_delta = abs(clamped_ratio - self._state.target_ratio)
|
| 205 |
+
if player_won and steps_taken > 0:
|
| 206 |
+
efficiency_score = min(1.0, min_steps / steps_taken)
|
| 207 |
+
sigma_sq = max(self.reward_sigma, 1e-6) ** 2
|
| 208 |
+
quality_score = math.exp(-((clamped_ratio - self._state.target_ratio) ** 2) / (2.0 * sigma_sq))
|
| 209 |
+
reward = quality_score if player_won else 0.0
|
| 210 |
+
return DMRewardBreakdown(
|
| 211 |
+
reward_mode="gaussian_target_ratio",
|
| 212 |
+
player_won=player_won,
|
| 213 |
+
raw_ratio=raw_ratio,
|
| 214 |
+
clamped_ratio=clamped_ratio,
|
| 215 |
+
target_ratio=self._state.target_ratio,
|
| 216 |
+
target_ratio_delta=target_ratio_delta,
|
| 217 |
+
efficiency_score=efficiency_score,
|
| 218 |
+
quality_score=quality_score,
|
| 219 |
+
reward=reward,
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
def _build_feedback(self, compiled: CompiledWorld, session: EpisodeSession) -> DMFeedback:
|
| 223 |
+
room_ids = [node.id for node in compiled.world.nodes if node.type in {"location", "junction"}]
|
| 224 |
+
clue_ids = [clue.id for clue in compiled.world.clues]
|
| 225 |
+
unique_rooms = [node_id for node_id in session.visited_nodes if node_id in room_ids]
|
| 226 |
+
return DMFeedback(
|
| 227 |
+
unreachable_nodes=sorted(set(room_ids) - set(unique_rooms)),
|
| 228 |
+
unused_items=sorted({item.id for item in compiled.world.items} - session.used_items),
|
| 229 |
+
clues_missed=sorted(set(clue_ids) - session.discovered_clues),
|
| 230 |
+
mean_steps_per_room=session.steps_taken / max(1, len(set(unique_rooms))),
|
| 231 |
+
invalid_command_count=session.invalid_command_count,
|
| 232 |
+
wrong_submit_count=session.wrong_submit_count,
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
def _running_success_rate(self) -> float:
|
| 236 |
+
return 0.0 if self.episode_count == 0 else self.success_count / self.episode_count
|
agents/master/graph.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
|
| 5 |
+
from .schema import DoorNode, Edge, NpcTrade, ReadableNode, UseEffect, WorldDefinition
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def readable_clue_mapping(world: WorldDefinition) -> dict[str, str]:
|
| 9 |
+
return {node.id: node.clue_id for node in world.nodes if isinstance(node, ReadableNode)}
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def clue_source_mapping(world: WorldDefinition) -> dict[str, str]:
|
| 13 |
+
mapping = {node.clue_id: node.id for node in world.nodes if isinstance(node, ReadableNode)}
|
| 14 |
+
for node in world.nodes:
|
| 15 |
+
if node.type == "npc" and node.gives_clue_id:
|
| 16 |
+
mapping[node.gives_clue_id] = node.id
|
| 17 |
+
return mapping
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def npc_trade_mapping(world: WorldDefinition) -> dict[str, NpcTrade]:
|
| 21 |
+
trades: dict[str, NpcTrade] = {}
|
| 22 |
+
for node in world.nodes:
|
| 23 |
+
if node.type != "npc" or node.id == world.meta.win_condition.target_npc_id:
|
| 24 |
+
continue
|
| 25 |
+
trades[node.id] = NpcTrade(
|
| 26 |
+
required_item_id=node.requires_item_id or "",
|
| 27 |
+
gives_item_id=node.gives_item_id,
|
| 28 |
+
gives_clue_id=node.gives_clue_id,
|
| 29 |
+
)
|
| 30 |
+
return trades
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def use_effect_mapping(world: WorldDefinition) -> dict[str, UseEffect]:
|
| 34 |
+
effects: dict[str, UseEffect] = {}
|
| 35 |
+
for node in world.nodes:
|
| 36 |
+
if node.type == "readable" and node.requires_item_id:
|
| 37 |
+
effects[node.id] = UseEffect(
|
| 38 |
+
required_item_id=node.requires_item_id,
|
| 39 |
+
clue_id=node.clue_id,
|
| 40 |
+
consumes_item=node.consumes_item,
|
| 41 |
+
)
|
| 42 |
+
elif node.type == "fixture":
|
| 43 |
+
effects[node.id] = UseEffect(
|
| 44 |
+
required_item_id=node.requires_item_id,
|
| 45 |
+
reveals_item_id=node.reveals_item_id,
|
| 46 |
+
reveals_readable_id=node.reveals_readable_id,
|
| 47 |
+
consumes_item=node.consumes_item,
|
| 48 |
+
)
|
| 49 |
+
return effects
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def recipe_mapping(world: WorldDefinition) -> dict[frozenset[str], str]:
|
| 53 |
+
return {frozenset(recipe.input_item_ids): recipe.output_item_id for recipe in world.recipes}
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def produced_item_ids(world: WorldDefinition) -> set[str]:
|
| 57 |
+
produced = {recipe.output_item_id for recipe in world.recipes}
|
| 58 |
+
for node in world.nodes:
|
| 59 |
+
if node.type == "npc" and node.gives_item_id:
|
| 60 |
+
produced.add(node.gives_item_id)
|
| 61 |
+
if node.type == "fixture" and node.reveals_item_id:
|
| 62 |
+
produced.add(node.reveals_item_id)
|
| 63 |
+
return produced
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def hidden_readable_ids(world: WorldDefinition) -> set[str]:
|
| 67 |
+
return {node.reveals_readable_id for node in world.nodes if node.type == "fixture" and node.reveals_readable_id}
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def door_room_mapping(world: WorldDefinition) -> dict[str, frozenset[str]]:
|
| 71 |
+
mapping: dict[str, set[str]] = defaultdict(set)
|
| 72 |
+
for edge in world.edges:
|
| 73 |
+
if edge.door_node_id:
|
| 74 |
+
mapping[edge.door_node_id].add(edge.from_node_id)
|
| 75 |
+
mapping[edge.door_node_id].add(edge.to_node_id)
|
| 76 |
+
return {door_id: frozenset(rooms) for door_id, rooms in mapping.items()}
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def edge_for_door(world: WorldDefinition, door_id: str) -> Edge | None:
|
| 80 |
+
for edge in world.edges:
|
| 81 |
+
if edge.door_node_id == door_id:
|
| 82 |
+
return edge
|
| 83 |
+
return None
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def door_nodes(world: WorldDefinition) -> dict[str, DoorNode]:
|
| 87 |
+
return {node.id: node for node in world.nodes if isinstance(node, DoorNode)}
|
agents/master/interface.py
ADDED
|
@@ -0,0 +1,831 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import re
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import TYPE_CHECKING, Literal, Protocol
|
| 9 |
+
|
| 10 |
+
from dotenv import load_dotenv
|
| 11 |
+
from google import genai
|
| 12 |
+
from google.genai import types
|
| 13 |
+
from textworld.core import GameState
|
| 14 |
+
|
| 15 |
+
from agents.hero.cli import parse_cli_command
|
| 16 |
+
|
| 17 |
+
from .base import DMInterfaceError, SUPPORTED_DIRECTIONS
|
| 18 |
+
|
| 19 |
+
if TYPE_CHECKING:
|
| 20 |
+
from .session import EpisodeSession
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
DEFAULT_GEMINI_MODEL = "gemini-2.5-flash-lite"
|
| 24 |
+
_TEXTWORLD_PROMPT_LINE_RE = re.compile(r"^\s*>\s.*-\=\s.*=\-(?:\d+/\d+)?\s*$")
|
| 25 |
+
_TEXTWORLD_BANNER_CHAR_RE = re.compile(r"[\\|$_/]")
|
| 26 |
+
_TEXTWORLD_ROOM_HEADER_RE = re.compile(r"^\s*-\=\s*(?P<label>.+?)\s*\=-\s*$")
|
| 27 |
+
_TEXTWORLD_META_LINE_RE = re.compile(r"^\s*(?:score:|moves:|available commands:|type 'help')", re.IGNORECASE)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class InterfaceAdapter(Protocol):
|
| 31 |
+
def translate_command(self, raw_command: str, session: EpisodeSession) -> str:
|
| 32 |
+
...
|
| 33 |
+
|
| 34 |
+
def render_observation(self, feedback: str, state: GameState | None, session: EpisodeSession) -> str:
|
| 35 |
+
...
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class SimpleInterfaceAdapter:
|
| 39 |
+
"""A deterministic parser for explicit non-LLM play."""
|
| 40 |
+
|
| 41 |
+
_ARTICLE_RE = re.compile(r"\b(the|a|an)\b", re.IGNORECASE)
|
| 42 |
+
|
| 43 |
+
def translate_command(self, raw_command: str, session: EpisodeSession) -> str:
|
| 44 |
+
command = raw_command.strip()
|
| 45 |
+
lowered = command.lower()
|
| 46 |
+
if lowered in SUPPORTED_DIRECTIONS:
|
| 47 |
+
return "go " + lowered
|
| 48 |
+
if lowered in {"look", "look around"}:
|
| 49 |
+
return "look"
|
| 50 |
+
if lowered in {"inventory", "check inventory", "show inventory"}:
|
| 51 |
+
return "inventory"
|
| 52 |
+
if lowered in {"wait", "pass"}:
|
| 53 |
+
return "wait"
|
| 54 |
+
if lowered.startswith("answer "):
|
| 55 |
+
return "submit " + command[7:].strip()
|
| 56 |
+
if lowered.startswith("say "):
|
| 57 |
+
return "submit " + command[4:].strip().strip("\"'")
|
| 58 |
+
if lowered.startswith("talk to "):
|
| 59 |
+
return "talk " + command[8:].strip()
|
| 60 |
+
if lowered.startswith("speak to "):
|
| 61 |
+
return "talk " + command[9:].strip()
|
| 62 |
+
if lowered.startswith("use ") and " on " in lowered:
|
| 63 |
+
item_text, target_text = re.split(r"\s+on\s+", command[4:].strip(), maxsplit=1, flags=re.IGNORECASE)
|
| 64 |
+
return "use " + self._normalize_object_text(item_text) + " on " + self._normalize_object_text(target_text)
|
| 65 |
+
if lowered.startswith("give ") and " to " in lowered:
|
| 66 |
+
item_text, target_text = re.split(r"\s+to\s+", command[5:].strip(), maxsplit=1, flags=re.IGNORECASE)
|
| 67 |
+
return "give " + self._normalize_object_text(item_text) + " to " + self._normalize_object_text(target_text)
|
| 68 |
+
if lowered.startswith("combine ") and " with " in lowered:
|
| 69 |
+
item_a, item_b = re.split(r"\s+with\s+", command[8:].strip(), maxsplit=1, flags=re.IGNORECASE)
|
| 70 |
+
return "combine " + self._normalize_object_text(item_a) + " with " + self._normalize_object_text(item_b)
|
| 71 |
+
if lowered.startswith("combine ") and " and " in lowered:
|
| 72 |
+
item_a, item_b = re.split(r"\s+and\s+", command[8:].strip(), maxsplit=1, flags=re.IGNORECASE)
|
| 73 |
+
return "combine " + self._normalize_object_text(item_a) + " with " + self._normalize_object_text(item_b)
|
| 74 |
+
|
| 75 |
+
parts = command.split(maxsplit=1)
|
| 76 |
+
if len(parts) != 2:
|
| 77 |
+
return lowered
|
| 78 |
+
|
| 79 |
+
verb = parts[0].lower()
|
| 80 |
+
if verb not in {"read", "talk", "open", "take", "unlock", "examine"}:
|
| 81 |
+
return lowered
|
| 82 |
+
|
| 83 |
+
normalized = self._normalize_object_text(parts[1])
|
| 84 |
+
if verb == "examine":
|
| 85 |
+
if session.node_id_for_command_name(normalized, node_types={"readable"}):
|
| 86 |
+
return "read " + normalized
|
| 87 |
+
if session.node_id_for_command_name(normalized, node_types={"npc"}):
|
| 88 |
+
return "talk " + normalized
|
| 89 |
+
|
| 90 |
+
return verb + " " + normalized
|
| 91 |
+
|
| 92 |
+
def _normalize_object_text(self, text: str) -> str:
|
| 93 |
+
object_text = self._ARTICLE_RE.sub(" ", text)
|
| 94 |
+
return re.sub(r"\s+", " ", object_text).strip().lower()
|
| 95 |
+
|
| 96 |
+
def render_observation(self, feedback: str, state: GameState | None, session: EpisodeSession) -> str:
|
| 97 |
+
del state
|
| 98 |
+
return enrich_feedback_text(sanitize_feedback_text(feedback), session)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class StrictCliInterfaceAdapter:
|
| 102 |
+
"""A deterministic adapter for parser-style CLI commands."""
|
| 103 |
+
|
| 104 |
+
def translate_command(self, raw_command: str, session: EpisodeSession) -> str:
|
| 105 |
+
del session
|
| 106 |
+
parsed = parse_cli_command(raw_command)
|
| 107 |
+
if not parsed.valid or parsed.normalized_command is None:
|
| 108 |
+
raise DMInterfaceError(parsed.error or "Command does not match the strict CLI grammar.")
|
| 109 |
+
return parsed.normalized_command
|
| 110 |
+
|
| 111 |
+
def render_observation(self, feedback: str, state: GameState | None, session: EpisodeSession) -> str:
|
| 112 |
+
del state
|
| 113 |
+
return enrich_feedback_text(sanitize_feedback_text(feedback), session)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
@dataclass(frozen=True)
|
| 117 |
+
class _TranslationGlossary:
|
| 118 |
+
canonical_to_alias: dict[str, str]
|
| 119 |
+
alias_to_canonical: dict[str, str]
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class GeminiInterfaceAdapter:
|
| 123 |
+
_ARTICLE_RE = re.compile(r"\b(the|a|an)\b", re.IGNORECASE)
|
| 124 |
+
_PARSER_SAFE_NAME_RE = re.compile(r"^[a-z0-9]+(?: [a-z0-9]+)*$")
|
| 125 |
+
_TRAILING_POLITENESS_RE = re.compile(r"(?:\s+(?:please|for me|thanks|thank you))+[.!?]*$", re.IGNORECASE)
|
| 126 |
+
_COMMAND_SYSTEM = (
|
| 127 |
+
"Translate the player's text into exactly one canonical dungeon command. "
|
| 128 |
+
"Return only the command and nothing else."
|
| 129 |
+
)
|
| 130 |
+
_OBSERVATION_SYSTEM = (
|
| 131 |
+
"Rewrite dungeon feedback in at most two short sentences. "
|
| 132 |
+
"Preserve facts exactly. Do not infer, solve, explain, or add implications."
|
| 133 |
+
)
|
| 134 |
+
_TRANSLATED_COMMAND_SYSTEM = (
|
| 135 |
+
"The player is using a corporate app metaphor layered over a fantasy dungeon. "
|
| 136 |
+
"Translate the player's text back into exactly one canonical dungeon command from the underlying fantasy world. "
|
| 137 |
+
"Return only the canonical command and nothing else."
|
| 138 |
+
)
|
| 139 |
+
_TRANSLATED_OBSERVATION_SYSTEM = (
|
| 140 |
+
"Rewrite the dungeon observation as a corporate app interface while preserving facts one-to-one. "
|
| 141 |
+
"Use the provided aliases exactly, keep directions unchanged, and do not add hints, solutions, or new mechanics."
|
| 142 |
+
)
|
| 143 |
+
_TRANSLATION_GLOSSARY_SYSTEM = (
|
| 144 |
+
"Create a one-to-one alias glossary that maps fantasy dungeon terms into a corporate app metaphor. "
|
| 145 |
+
"Return JSON only."
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
def __init__(
|
| 149 |
+
self,
|
| 150 |
+
api_key: str | None = None,
|
| 151 |
+
model: str = DEFAULT_GEMINI_MODEL,
|
| 152 |
+
narrate_observations: bool = False,
|
| 153 |
+
translation_mode: Literal["none", "corporate_app"] = "none",
|
| 154 |
+
max_admissible_commands: int = 18,
|
| 155 |
+
) -> None:
|
| 156 |
+
if translation_mode not in {"none", "corporate_app"}:
|
| 157 |
+
raise ValueError(f"Unsupported Gemini translation mode: {translation_mode}")
|
| 158 |
+
self.model = model
|
| 159 |
+
self.narrate_observations = narrate_observations
|
| 160 |
+
self.translation_mode = translation_mode
|
| 161 |
+
self.max_admissible_commands = max_admissible_commands
|
| 162 |
+
self._client = self._create_client(api_key)
|
| 163 |
+
self._translation_glossary_cache: dict[str, _TranslationGlossary] = {}
|
| 164 |
+
self._translation_observation_cache: dict[tuple[str, str], str] = {}
|
| 165 |
+
|
| 166 |
+
def translate_command(self, raw_command: str, session: EpisodeSession) -> str:
|
| 167 |
+
lowered = raw_command.strip().lower()
|
| 168 |
+
if not lowered:
|
| 169 |
+
raise DMInterfaceError("Command must not be empty.")
|
| 170 |
+
admissible = set(session.available_commands())
|
| 171 |
+
direct = self._normalize_generated_command(self._preprocess_player_text(lowered))
|
| 172 |
+
if resolved := self._resolve_candidate_command(direct, session, admissible):
|
| 173 |
+
return resolved
|
| 174 |
+
movement = self._extract_direction_command(lowered, admissible)
|
| 175 |
+
if movement is not None:
|
| 176 |
+
return movement
|
| 177 |
+
|
| 178 |
+
prompt = self._command_prompt(raw_command, session, admissible)
|
| 179 |
+
generated = self._generate_command(
|
| 180 |
+
system_instruction=self._TRANSLATED_COMMAND_SYSTEM if self._translation_enabled() else self._COMMAND_SYSTEM,
|
| 181 |
+
prompt=prompt,
|
| 182 |
+
max_output_tokens=48,
|
| 183 |
+
temperature=0.1,
|
| 184 |
+
)
|
| 185 |
+
if resolved := self._resolve_candidate_command(generated, session, admissible):
|
| 186 |
+
return resolved
|
| 187 |
+
raise DMInterfaceError(f"Gemini returned an invalid command: {generated or '<empty>'}")
|
| 188 |
+
|
| 189 |
+
def render_observation(self, feedback: str, state: GameState | None, session: EpisodeSession) -> str:
|
| 190 |
+
sanitized = sanitize_feedback_text(feedback)
|
| 191 |
+
enriched = enrich_feedback_text(sanitized, session)
|
| 192 |
+
if not sanitized:
|
| 193 |
+
return enriched
|
| 194 |
+
if self._translation_enabled():
|
| 195 |
+
cache_key = (self._translation_cache_key(session), enriched)
|
| 196 |
+
cached = self._translation_observation_cache.get(cache_key)
|
| 197 |
+
if cached is not None:
|
| 198 |
+
return cached
|
| 199 |
+
prompt = self._observation_prompt(enriched, session)
|
| 200 |
+
generated = self._generate_observation(
|
| 201 |
+
system_instruction=self._TRANSLATED_OBSERVATION_SYSTEM,
|
| 202 |
+
prompt=prompt,
|
| 203 |
+
max_output_tokens=220 if not self.narrate_observations else 120,
|
| 204 |
+
temperature=0.2,
|
| 205 |
+
)
|
| 206 |
+
if not generated:
|
| 207 |
+
raise DMInterfaceError("Gemini returned an empty translated observation.")
|
| 208 |
+
self._translation_observation_cache[cache_key] = generated
|
| 209 |
+
return generated
|
| 210 |
+
if not self.narrate_observations:
|
| 211 |
+
return enriched
|
| 212 |
+
if self._should_preserve_feedback(sanitized, state):
|
| 213 |
+
return enriched
|
| 214 |
+
|
| 215 |
+
prompt = self._observation_prompt(sanitized, session)
|
| 216 |
+
generated = self._generate_observation(
|
| 217 |
+
system_instruction=self._OBSERVATION_SYSTEM,
|
| 218 |
+
prompt=prompt,
|
| 219 |
+
max_output_tokens=80,
|
| 220 |
+
temperature=0.2,
|
| 221 |
+
)
|
| 222 |
+
if not generated:
|
| 223 |
+
raise DMInterfaceError("Gemini returned an empty observation.")
|
| 224 |
+
return enrich_feedback_text(generated, session)
|
| 225 |
+
|
| 226 |
+
def _create_client(self, api_key: str | None) -> genai.Client:
|
| 227 |
+
load_dotenv(self._repo_root() / ".env", override=False)
|
| 228 |
+
key = api_key or os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY")
|
| 229 |
+
if not key:
|
| 230 |
+
raise DMInterfaceError("Missing GEMINI_API_KEY or GOOGLE_API_KEY.")
|
| 231 |
+
return genai.Client(api_key=key)
|
| 232 |
+
|
| 233 |
+
@staticmethod
|
| 234 |
+
def _repo_root() -> Path:
|
| 235 |
+
return Path(__file__).resolve().parents[2]
|
| 236 |
+
|
| 237 |
+
def _command_prompt(self, raw_command: str, session: EpisodeSession, admissible: set[str]) -> str:
|
| 238 |
+
commands = sorted(admissible)[: self.max_admissible_commands]
|
| 239 |
+
interactables = self._interactables(session)
|
| 240 |
+
current_room = session.state.location or session.current_room_id
|
| 241 |
+
lines: list[str] = []
|
| 242 |
+
if self._translation_enabled():
|
| 243 |
+
glossary = self._translation_glossary(session)
|
| 244 |
+
lines.extend(
|
| 245 |
+
[
|
| 246 |
+
"The player only sees the translated corporate-app interface.",
|
| 247 |
+
"Map their request back to the underlying dungeon command.",
|
| 248 |
+
"Treat rooms as apps/workspaces, NPCs as coworkers or reviewers, and items as files, tools, credentials, or tickets.",
|
| 249 |
+
"Translated aliases (alias => canonical):",
|
| 250 |
+
*[f"- {alias} => {canonical}" for alias, canonical in sorted(glossary.alias_to_canonical.items())],
|
| 251 |
+
]
|
| 252 |
+
)
|
| 253 |
+
lines.extend(
|
| 254 |
+
[
|
| 255 |
+
"Use an exact visible command whenever possible.",
|
| 256 |
+
"Allowed verbs: go, open, unlock, take, read, use, combine, give, talk, submit, look, inventory, wait",
|
| 257 |
+
f"Room: {current_room}",
|
| 258 |
+
"Visible commands:",
|
| 259 |
+
*[f"- {command}" for command in commands],
|
| 260 |
+
]
|
| 261 |
+
)
|
| 262 |
+
if interactables:
|
| 263 |
+
lines.append(f"Objects here: {', '.join(interactables)}")
|
| 264 |
+
lines.append("If the player is answering the guardian, use: submit <answer>")
|
| 265 |
+
lines.append("If no valid mapping exists, return INVALID")
|
| 266 |
+
lines.append(f"Player text: {raw_command.strip()}")
|
| 267 |
+
return "\n".join(lines)
|
| 268 |
+
|
| 269 |
+
def _observation_prompt(self, feedback: str, session: EpisodeSession) -> str:
|
| 270 |
+
current_room = session.state.location or session.current_room_id
|
| 271 |
+
if self._translation_enabled():
|
| 272 |
+
glossary = self._translation_glossary(session)
|
| 273 |
+
lines = [
|
| 274 |
+
f"Canonical room: {current_room}",
|
| 275 |
+
"Use this exact alias glossary (canonical => alias):",
|
| 276 |
+
*[f"- {canonical} => {alias}" for canonical, alias in sorted(glossary.canonical_to_alias.items())],
|
| 277 |
+
"Preserve the same facts, object counts, and navigation affordances.",
|
| 278 |
+
"Keep any 'Visible here:' and 'Exits:' sections, but rewrite the entity names with the aliases above.",
|
| 279 |
+
]
|
| 280 |
+
if self.narrate_observations:
|
| 281 |
+
lines.append("Keep the response compact.")
|
| 282 |
+
lines.append("Canonical observation:")
|
| 283 |
+
lines.append(feedback)
|
| 284 |
+
return "\n".join(lines)
|
| 285 |
+
return (
|
| 286 |
+
f"Room: {current_room}\n"
|
| 287 |
+
"Describe only what the game text explicitly says.\n"
|
| 288 |
+
"Never reveal what a clue means or what answer it implies.\n"
|
| 289 |
+
f"Feedback: {feedback}"
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
def _translation_glossary_prompt(self, session: EpisodeSession) -> str:
|
| 293 |
+
lines = [
|
| 294 |
+
"Return JSON with shape: {\"aliases\": [{\"source\": \"...\", \"alias\": \"...\"}]}",
|
| 295 |
+
"Rules:",
|
| 296 |
+
"- Every alias must be unique.",
|
| 297 |
+
"- Use lowercase letters, numbers, and spaces only.",
|
| 298 |
+
"- Do not use articles like a, an, or the.",
|
| 299 |
+
"- Keep aliases short and parser-safe.",
|
| 300 |
+
"- Rooms should feel like apps, dashboards, workspaces, portals, or queues.",
|
| 301 |
+
"- NPCs should feel like coworkers, reviewers, owners, admins, or operators.",
|
| 302 |
+
"- Items should feel like files, tickets, tokens, credentials, tools, or documents.",
|
| 303 |
+
"- Preserve identity one-to-one. Do not merge multiple source terms into one alias.",
|
| 304 |
+
"Terms:",
|
| 305 |
+
]
|
| 306 |
+
for kind, source in self._translation_terms(session):
|
| 307 |
+
lines.append(f"- {kind}: {source}")
|
| 308 |
+
return "\n".join(lines)
|
| 309 |
+
|
| 310 |
+
def _interactables(self, session: EpisodeSession) -> list[str]:
|
| 311 |
+
names: list[str] = []
|
| 312 |
+
for node in session.compiled.world.nodes:
|
| 313 |
+
if getattr(node, "parent_id", None) != session.current_room_id:
|
| 314 |
+
continue
|
| 315 |
+
safe_name = session.compiled.node_command_names.get(node.id)
|
| 316 |
+
if safe_name is not None and node.type in {"container", "readable", "npc", "door", "fixture"}:
|
| 317 |
+
names.append(safe_name)
|
| 318 |
+
return sorted(names)[:8]
|
| 319 |
+
|
| 320 |
+
def _generate_response(
|
| 321 |
+
self,
|
| 322 |
+
*,
|
| 323 |
+
system_instruction: str,
|
| 324 |
+
prompt: str,
|
| 325 |
+
max_output_tokens: int,
|
| 326 |
+
temperature: float,
|
| 327 |
+
) -> str:
|
| 328 |
+
response = self._client.models.generate_content(
|
| 329 |
+
model=self.model,
|
| 330 |
+
contents=f"{system_instruction}\n\n{prompt}",
|
| 331 |
+
config=types.GenerateContentConfig(
|
| 332 |
+
temperature=temperature,
|
| 333 |
+
max_output_tokens=max_output_tokens,
|
| 334 |
+
candidate_count=1,
|
| 335 |
+
),
|
| 336 |
+
)
|
| 337 |
+
return getattr(response, "text", "") or ""
|
| 338 |
+
|
| 339 |
+
def _generate_command(
|
| 340 |
+
self,
|
| 341 |
+
*,
|
| 342 |
+
system_instruction: str,
|
| 343 |
+
prompt: str,
|
| 344 |
+
max_output_tokens: int,
|
| 345 |
+
temperature: float,
|
| 346 |
+
) -> str:
|
| 347 |
+
return self._sanitize_command_response(
|
| 348 |
+
self._generate_response(
|
| 349 |
+
system_instruction=system_instruction,
|
| 350 |
+
prompt=prompt,
|
| 351 |
+
max_output_tokens=max_output_tokens,
|
| 352 |
+
temperature=temperature,
|
| 353 |
+
)
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
def _generate_observation(
|
| 357 |
+
self,
|
| 358 |
+
*,
|
| 359 |
+
system_instruction: str,
|
| 360 |
+
prompt: str,
|
| 361 |
+
max_output_tokens: int,
|
| 362 |
+
temperature: float,
|
| 363 |
+
) -> str:
|
| 364 |
+
return self._sanitize_multiline_response(
|
| 365 |
+
self._generate_response(
|
| 366 |
+
system_instruction=system_instruction,
|
| 367 |
+
prompt=prompt,
|
| 368 |
+
max_output_tokens=max_output_tokens,
|
| 369 |
+
temperature=temperature,
|
| 370 |
+
)
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
def _generate_json(
|
| 374 |
+
self,
|
| 375 |
+
*,
|
| 376 |
+
system_instruction: str,
|
| 377 |
+
prompt: str,
|
| 378 |
+
max_output_tokens: int,
|
| 379 |
+
temperature: float,
|
| 380 |
+
) -> str:
|
| 381 |
+
return self._sanitize_json_response(
|
| 382 |
+
self._generate_response(
|
| 383 |
+
system_instruction=system_instruction,
|
| 384 |
+
prompt=prompt,
|
| 385 |
+
max_output_tokens=max_output_tokens,
|
| 386 |
+
temperature=temperature,
|
| 387 |
+
)
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
def _resolve_candidate_command(
|
| 391 |
+
self,
|
| 392 |
+
candidate: str,
|
| 393 |
+
session: EpisodeSession,
|
| 394 |
+
admissible: set[str],
|
| 395 |
+
) -> str | None:
|
| 396 |
+
for option in self._candidate_variants(candidate, session):
|
| 397 |
+
if not option:
|
| 398 |
+
continue
|
| 399 |
+
if option == "invalid":
|
| 400 |
+
continue
|
| 401 |
+
if resolved := self._resolve_admissible_command(option, admissible):
|
| 402 |
+
return resolved
|
| 403 |
+
if self._allow_unlisted_canonical(option):
|
| 404 |
+
return option
|
| 405 |
+
return None
|
| 406 |
+
|
| 407 |
+
def _candidate_variants(self, candidate: str, session: EpisodeSession) -> list[str]:
|
| 408 |
+
variants = [self._normalize_generated_command(candidate)]
|
| 409 |
+
if self._translation_enabled():
|
| 410 |
+
canonicalized = self._canonicalize_translated_command(variants[0], session)
|
| 411 |
+
if canonicalized not in variants:
|
| 412 |
+
variants.insert(0, canonicalized)
|
| 413 |
+
return variants
|
| 414 |
+
|
| 415 |
+
def _canonicalize_translated_command(self, command: str, session: EpisodeSession) -> str:
|
| 416 |
+
glossary = self._translation_glossary(session)
|
| 417 |
+
rewritten = command
|
| 418 |
+
for alias, canonical in sorted(glossary.alias_to_canonical.items(), key=lambda item: (-len(item[0]), item[0])):
|
| 419 |
+
rewritten = re.sub(
|
| 420 |
+
rf"(?<![a-z0-9]){re.escape(alias)}(?![a-z0-9])",
|
| 421 |
+
canonical,
|
| 422 |
+
rewritten,
|
| 423 |
+
)
|
| 424 |
+
return self._normalize_generated_command(rewritten)
|
| 425 |
+
|
| 426 |
+
def _translation_glossary(self, session: EpisodeSession) -> _TranslationGlossary:
|
| 427 |
+
cache_key = self._translation_cache_key(session)
|
| 428 |
+
cached = self._translation_glossary_cache.get(cache_key)
|
| 429 |
+
if cached is not None:
|
| 430 |
+
return cached
|
| 431 |
+
terms = self._translation_terms(session)
|
| 432 |
+
generated = self._generate_json(
|
| 433 |
+
system_instruction=self._TRANSLATION_GLOSSARY_SYSTEM,
|
| 434 |
+
prompt=self._translation_glossary_prompt(session),
|
| 435 |
+
max_output_tokens=700,
|
| 436 |
+
temperature=0.2,
|
| 437 |
+
)
|
| 438 |
+
glossary = self._parse_translation_glossary(generated, terms)
|
| 439 |
+
self._translation_glossary_cache[cache_key] = glossary
|
| 440 |
+
return glossary
|
| 441 |
+
|
| 442 |
+
def _parse_translation_glossary(
|
| 443 |
+
self,
|
| 444 |
+
payload: str,
|
| 445 |
+
terms: list[tuple[str, str]],
|
| 446 |
+
) -> _TranslationGlossary:
|
| 447 |
+
try:
|
| 448 |
+
data = json.loads(payload)
|
| 449 |
+
except json.JSONDecodeError as exc:
|
| 450 |
+
raise DMInterfaceError("Gemini returned invalid translation glossary JSON.") from exc
|
| 451 |
+
|
| 452 |
+
raw_aliases: dict[str, str] = {}
|
| 453 |
+
if isinstance(data, dict):
|
| 454 |
+
aliases = data.get("aliases", data)
|
| 455 |
+
if isinstance(aliases, dict):
|
| 456 |
+
raw_aliases = {
|
| 457 |
+
self._normalize_object_text(str(source)): str(alias)
|
| 458 |
+
for source, alias in aliases.items()
|
| 459 |
+
if isinstance(source, str)
|
| 460 |
+
}
|
| 461 |
+
elif isinstance(aliases, list):
|
| 462 |
+
for entry in aliases:
|
| 463 |
+
if not isinstance(entry, dict):
|
| 464 |
+
continue
|
| 465 |
+
source = entry.get("source")
|
| 466 |
+
alias = entry.get("alias")
|
| 467 |
+
if isinstance(source, str) and isinstance(alias, str):
|
| 468 |
+
raw_aliases[self._normalize_object_text(source)] = alias
|
| 469 |
+
if not raw_aliases:
|
| 470 |
+
raise DMInterfaceError("Gemini returned an empty translation glossary.")
|
| 471 |
+
|
| 472 |
+
canonical_to_alias: dict[str, str] = {}
|
| 473 |
+
alias_to_canonical: dict[str, str] = {}
|
| 474 |
+
used_aliases: set[str] = set()
|
| 475 |
+
for _kind, source in terms:
|
| 476 |
+
requested_alias = self._normalize_parser_safe_alias(raw_aliases.get(source, ""))
|
| 477 |
+
alias = self._dedupe_alias(source, requested_alias, used_aliases)
|
| 478 |
+
canonical_to_alias[source] = alias
|
| 479 |
+
alias_to_canonical[alias] = source
|
| 480 |
+
used_aliases.add(alias)
|
| 481 |
+
return _TranslationGlossary(
|
| 482 |
+
canonical_to_alias=canonical_to_alias,
|
| 483 |
+
alias_to_canonical=alias_to_canonical,
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
def _translation_terms(self, session: EpisodeSession) -> list[tuple[str, str]]:
|
| 487 |
+
terms: list[tuple[str, str]] = []
|
| 488 |
+
seen: set[str] = set()
|
| 489 |
+
for node in session.compiled.world.nodes:
|
| 490 |
+
source = session.compiled.node_command_names.get(node.id)
|
| 491 |
+
if source is None or source in seen:
|
| 492 |
+
continue
|
| 493 |
+
kind = "room" if node.type in {"location", "junction"} else node.type
|
| 494 |
+
seen.add(source)
|
| 495 |
+
terms.append((kind, source))
|
| 496 |
+
for item in session.compiled.world.items:
|
| 497 |
+
source = session.compiled.item_command_names.get(item.id)
|
| 498 |
+
if source is None or source in seen:
|
| 499 |
+
continue
|
| 500 |
+
seen.add(source)
|
| 501 |
+
terms.append(("item", source))
|
| 502 |
+
answer = session.compiled.correct_answer_normalized
|
| 503 |
+
if answer and answer not in seen:
|
| 504 |
+
terms.append(("answer", answer))
|
| 505 |
+
return sorted(terms, key=lambda item: (item[0], item[1]))
|
| 506 |
+
|
| 507 |
+
def _dedupe_alias(self, source: str, alias: str, used_aliases: set[str]) -> str:
|
| 508 |
+
for candidate in (alias, source):
|
| 509 |
+
if candidate and candidate not in used_aliases:
|
| 510 |
+
return candidate
|
| 511 |
+
suffix = 2
|
| 512 |
+
while True:
|
| 513 |
+
candidate = f"{source} {suffix}"
|
| 514 |
+
if candidate not in used_aliases and self._PARSER_SAFE_NAME_RE.fullmatch(candidate):
|
| 515 |
+
return candidate
|
| 516 |
+
suffix += 1
|
| 517 |
+
|
| 518 |
+
def _normalize_parser_safe_alias(self, value: str) -> str:
|
| 519 |
+
alias = self._normalize_object_text(value)
|
| 520 |
+
if not alias or not self._PARSER_SAFE_NAME_RE.fullmatch(alias):
|
| 521 |
+
return ""
|
| 522 |
+
return alias
|
| 523 |
+
|
| 524 |
+
def _translation_cache_key(self, session: EpisodeSession) -> str:
|
| 525 |
+
episode_id = getattr(session.compiled, "episode_id", "") or "session"
|
| 526 |
+
return f"{episode_id}:{session.compiled.game_file}"
|
| 527 |
+
|
| 528 |
+
def _translation_enabled(self) -> bool:
|
| 529 |
+
return self.translation_mode != "none"
|
| 530 |
+
|
| 531 |
+
@classmethod
|
| 532 |
+
def _preprocess_player_text(cls, text: str) -> str:
|
| 533 |
+
normalized = re.sub(r"\s+", " ", text.strip().lower())
|
| 534 |
+
replacements = (
|
| 535 |
+
("pick up ", "take "),
|
| 536 |
+
("grab ", "take "),
|
| 537 |
+
("using ", "with "),
|
| 538 |
+
("talk to ", "talk "),
|
| 539 |
+
("speak to ", "talk "),
|
| 540 |
+
)
|
| 541 |
+
for source, target in replacements:
|
| 542 |
+
normalized = normalized.replace(source, target)
|
| 543 |
+
|
| 544 |
+
prefixes = (
|
| 545 |
+
"please ",
|
| 546 |
+
"please, ",
|
| 547 |
+
"can you ",
|
| 548 |
+
"could you ",
|
| 549 |
+
"would you ",
|
| 550 |
+
"will you ",
|
| 551 |
+
"go ahead and ",
|
| 552 |
+
"i want to ",
|
| 553 |
+
"i'd like to ",
|
| 554 |
+
"try to ",
|
| 555 |
+
)
|
| 556 |
+
stripped = True
|
| 557 |
+
while stripped:
|
| 558 |
+
stripped = False
|
| 559 |
+
for prefix in prefixes:
|
| 560 |
+
if normalized.startswith(prefix):
|
| 561 |
+
normalized = normalized[len(prefix) :].strip()
|
| 562 |
+
stripped = True
|
| 563 |
+
|
| 564 |
+
normalized = cls._TRAILING_POLITENESS_RE.sub("", normalized).strip()
|
| 565 |
+
return normalized
|
| 566 |
+
|
| 567 |
+
@staticmethod
|
| 568 |
+
def _extract_direction_command(text: str, admissible: set[str]) -> str | None:
|
| 569 |
+
directions = [direction for direction in SUPPORTED_DIRECTIONS if re.search(rf"\b{direction}\b", text)]
|
| 570 |
+
if len(directions) != 1:
|
| 571 |
+
return None
|
| 572 |
+
if not re.search(r"\b(go|head|move|walk|run|travel|enter|step)\b", text):
|
| 573 |
+
return None
|
| 574 |
+
candidate = f"go {directions[0]}"
|
| 575 |
+
return candidate if candidate in admissible else None
|
| 576 |
+
|
| 577 |
+
@staticmethod
|
| 578 |
+
def _allow_unlisted_canonical(command: str) -> bool:
|
| 579 |
+
return GeminiInterfaceAdapter._is_canonical_command(command) and not GeminiInterfaceAdapter._contains_conversational_fluff(command)
|
| 580 |
+
|
| 581 |
+
@staticmethod
|
| 582 |
+
def _contains_conversational_fluff(command: str) -> bool:
|
| 583 |
+
return bool(
|
| 584 |
+
re.search(
|
| 585 |
+
r"\b(for me|please|thanks|thank you|could you|can you|would you|will you)\b",
|
| 586 |
+
command,
|
| 587 |
+
)
|
| 588 |
+
)
|
| 589 |
+
|
| 590 |
+
@staticmethod
|
| 591 |
+
def _normalize_generated_command(text: str) -> str:
|
| 592 |
+
normalized = re.sub(r"\s+", " ", text.strip().lower())
|
| 593 |
+
normalized = normalized.removeprefix("command: ").removeprefix("response: ").strip()
|
| 594 |
+
normalized = normalized.rstrip(".!?")
|
| 595 |
+
if normalized in SUPPORTED_DIRECTIONS:
|
| 596 |
+
return "go " + normalized
|
| 597 |
+
if normalized.startswith("talk to "):
|
| 598 |
+
return "talk " + GeminiInterfaceAdapter._normalize_object_text(normalized[8:].strip())
|
| 599 |
+
if normalized.startswith("speak to "):
|
| 600 |
+
return "talk " + GeminiInterfaceAdapter._normalize_object_text(normalized[9:].strip())
|
| 601 |
+
if normalized.startswith("answer "):
|
| 602 |
+
return "submit " + normalized[7:].strip()
|
| 603 |
+
if normalized.startswith("say "):
|
| 604 |
+
return "submit " + normalized[4:].strip().strip("\"'")
|
| 605 |
+
if normalized.startswith("combine ") and " and " in normalized:
|
| 606 |
+
item_a, item_b = normalized[8:].split(" and ", 1)
|
| 607 |
+
return "combine " + GeminiInterfaceAdapter._normalize_object_text(item_a) + " with " + GeminiInterfaceAdapter._normalize_object_text(item_b)
|
| 608 |
+
if normalized.startswith("unlock ") and " with " in normalized:
|
| 609 |
+
target, key = normalized[7:].split(" with ", 1)
|
| 610 |
+
return "unlock " + GeminiInterfaceAdapter._normalize_object_text(target) + " with " + GeminiInterfaceAdapter._normalize_object_text(key)
|
| 611 |
+
if normalized.startswith("use ") and " on " in normalized:
|
| 612 |
+
item, target = normalized[4:].split(" on ", 1)
|
| 613 |
+
return "use " + GeminiInterfaceAdapter._normalize_object_text(item) + " on " + GeminiInterfaceAdapter._normalize_object_text(target)
|
| 614 |
+
if normalized.startswith("give ") and " to " in normalized:
|
| 615 |
+
item, target = normalized[5:].split(" to ", 1)
|
| 616 |
+
return "give " + GeminiInterfaceAdapter._normalize_object_text(item) + " to " + GeminiInterfaceAdapter._normalize_object_text(target)
|
| 617 |
+
if normalized.startswith("combine ") and " with " in normalized:
|
| 618 |
+
item_a, item_b = normalized[8:].split(" with ", 1)
|
| 619 |
+
return "combine " + GeminiInterfaceAdapter._normalize_object_text(item_a) + " with " + GeminiInterfaceAdapter._normalize_object_text(item_b)
|
| 620 |
+
if normalized.startswith(("open ", "read ", "talk ", "take ", "examine ")):
|
| 621 |
+
verb, obj = normalized.split(" ", 1)
|
| 622 |
+
return verb + " " + GeminiInterfaceAdapter._normalize_object_text(obj)
|
| 623 |
+
return normalized
|
| 624 |
+
|
| 625 |
+
@staticmethod
|
| 626 |
+
def _normalize_object_text(text: str) -> str:
|
| 627 |
+
object_text = GeminiInterfaceAdapter._ARTICLE_RE.sub(" ", text)
|
| 628 |
+
return re.sub(r"\s+", " ", object_text).strip().lower()
|
| 629 |
+
|
| 630 |
+
@staticmethod
|
| 631 |
+
def _is_canonical_command(command: str) -> bool:
|
| 632 |
+
if command in {"look", "inventory", "wait"}:
|
| 633 |
+
return True
|
| 634 |
+
if command.startswith("go "):
|
| 635 |
+
return command[3:] in SUPPORTED_DIRECTIONS
|
| 636 |
+
if command.startswith(("open ", "read ", "talk ", "submit ")):
|
| 637 |
+
return bool(command.split(maxsplit=1)[1].strip())
|
| 638 |
+
if command.startswith("use "):
|
| 639 |
+
return " on " in command and all(part.strip() for part in command[4:].split(" on ", 1))
|
| 640 |
+
if command.startswith("combine "):
|
| 641 |
+
return " with " in command and all(part.strip() for part in command[8:].split(" with ", 1))
|
| 642 |
+
if command.startswith("give "):
|
| 643 |
+
return " to " in command and all(part.strip() for part in command[5:].split(" to ", 1))
|
| 644 |
+
if command.startswith("take "):
|
| 645 |
+
return bool(command.split(maxsplit=1)[1].strip())
|
| 646 |
+
if command.startswith("unlock "):
|
| 647 |
+
if " with " not in command:
|
| 648 |
+
return False
|
| 649 |
+
door_text, key_text = command[7:].split(" with ", 1)
|
| 650 |
+
return bool(door_text.strip() and key_text.strip())
|
| 651 |
+
return False
|
| 652 |
+
|
| 653 |
+
@staticmethod
|
| 654 |
+
def _sanitize_command_response(text: str) -> str:
|
| 655 |
+
cleaned = text.strip().strip("`").strip().strip("\"'")
|
| 656 |
+
if not cleaned:
|
| 657 |
+
return ""
|
| 658 |
+
first_line = cleaned.splitlines()[0].strip()
|
| 659 |
+
if ":" in first_line:
|
| 660 |
+
prefix, suffix = first_line.split(":", 1)
|
| 661 |
+
if prefix.lower() in {"command", "response"}:
|
| 662 |
+
first_line = suffix.strip()
|
| 663 |
+
return re.sub(r"\s+", " ", first_line).strip().lower()
|
| 664 |
+
|
| 665 |
+
@staticmethod
|
| 666 |
+
def _sanitize_multiline_response(text: str) -> str:
|
| 667 |
+
cleaned = GeminiInterfaceAdapter._sanitize_json_response(text)
|
| 668 |
+
if not cleaned:
|
| 669 |
+
return ""
|
| 670 |
+
lines: list[str] = []
|
| 671 |
+
blank_run = 0
|
| 672 |
+
for raw_line in cleaned.splitlines():
|
| 673 |
+
line = raw_line.strip()
|
| 674 |
+
if not line:
|
| 675 |
+
blank_run += 1
|
| 676 |
+
if blank_run <= 1:
|
| 677 |
+
lines.append("")
|
| 678 |
+
continue
|
| 679 |
+
blank_run = 0
|
| 680 |
+
if ":" in line:
|
| 681 |
+
prefix, suffix = line.split(":", 1)
|
| 682 |
+
if prefix.lower() == "observation":
|
| 683 |
+
line = suffix.strip()
|
| 684 |
+
lines.append(line)
|
| 685 |
+
return "\n".join(lines).strip().strip("\"'")
|
| 686 |
+
|
| 687 |
+
@staticmethod
|
| 688 |
+
def _sanitize_json_response(text: str) -> str:
|
| 689 |
+
cleaned = text.strip()
|
| 690 |
+
if cleaned.startswith("```"):
|
| 691 |
+
cleaned = re.sub(r"^```(?:json|text)?\s*", "", cleaned)
|
| 692 |
+
cleaned = re.sub(r"\s*```$", "", cleaned)
|
| 693 |
+
return cleaned.strip()
|
| 694 |
+
|
| 695 |
+
@staticmethod
|
| 696 |
+
def _should_preserve_feedback(feedback: str, state: GameState | None) -> bool:
|
| 697 |
+
if '"' in feedback or "'" in feedback:
|
| 698 |
+
return True
|
| 699 |
+
if state is not None and (state.last_command or "").startswith("read"):
|
| 700 |
+
return True
|
| 701 |
+
return False
|
| 702 |
+
|
| 703 |
+
@staticmethod
|
| 704 |
+
def _resolve_admissible_command(candidate: str, admissible: set[str]) -> str | None:
|
| 705 |
+
if candidate in admissible:
|
| 706 |
+
return candidate
|
| 707 |
+
if " " not in candidate:
|
| 708 |
+
return None
|
| 709 |
+
verb, remainder = candidate.split(" ", 1)
|
| 710 |
+
candidate_tokens = [token for token in re.split(r"\s+", remainder) if token and token not in {"from", "with", "on", "to"}]
|
| 711 |
+
matches: list[tuple[int, str]] = []
|
| 712 |
+
for option in admissible:
|
| 713 |
+
if not option.startswith(verb + " "):
|
| 714 |
+
continue
|
| 715 |
+
option_tokens = [token for token in re.split(r"\s+", option[len(verb) + 1 :]) if token and token not in {"from", "with", "on", "to"}]
|
| 716 |
+
if candidate_tokens and all(token in option_tokens for token in candidate_tokens):
|
| 717 |
+
matches.append((len(option_tokens), option))
|
| 718 |
+
if not matches:
|
| 719 |
+
return None
|
| 720 |
+
matches.sort(key=lambda item: (item[0], item[1]))
|
| 721 |
+
return matches[0][1]
|
| 722 |
+
|
| 723 |
+
|
| 724 |
+
def sanitize_feedback_text(feedback: str) -> str:
|
| 725 |
+
lines = feedback.replace("\r\n", "\n").splitlines()
|
| 726 |
+
cleaned_lines: list[str] = []
|
| 727 |
+
for raw_line in lines:
|
| 728 |
+
line = raw_line.rstrip()
|
| 729 |
+
stripped = line.strip()
|
| 730 |
+
if not stripped:
|
| 731 |
+
cleaned_lines.append("")
|
| 732 |
+
continue
|
| 733 |
+
if _TEXTWORLD_PROMPT_LINE_RE.match(line):
|
| 734 |
+
continue
|
| 735 |
+
if stripped.startswith(">"):
|
| 736 |
+
continue
|
| 737 |
+
if _TEXTWORLD_META_LINE_RE.match(stripped):
|
| 738 |
+
continue
|
| 739 |
+
room_match = _TEXTWORLD_ROOM_HEADER_RE.match(stripped)
|
| 740 |
+
if room_match:
|
| 741 |
+
cleaned_lines.append(f"Location: {room_match.group('label').strip()}")
|
| 742 |
+
continue
|
| 743 |
+
if _is_probable_banner_line(stripped):
|
| 744 |
+
continue
|
| 745 |
+
cleaned_lines.append(stripped)
|
| 746 |
+
|
| 747 |
+
start_index = 0
|
| 748 |
+
for index, line in enumerate(cleaned_lines):
|
| 749 |
+
stripped = line.strip()
|
| 750 |
+
if not stripped:
|
| 751 |
+
continue
|
| 752 |
+
if stripped.startswith("Explore ") or stripped.startswith("Location: ") or not _is_probable_banner_line(stripped):
|
| 753 |
+
start_index = index
|
| 754 |
+
break
|
| 755 |
+
useful_lines = cleaned_lines[start_index:]
|
| 756 |
+
|
| 757 |
+
collapsed: list[str] = []
|
| 758 |
+
blank_run = 0
|
| 759 |
+
for line in useful_lines:
|
| 760 |
+
stripped = line.strip()
|
| 761 |
+
if not stripped:
|
| 762 |
+
blank_run += 1
|
| 763 |
+
if blank_run <= 1:
|
| 764 |
+
collapsed.append("")
|
| 765 |
+
continue
|
| 766 |
+
blank_run = 0
|
| 767 |
+
collapsed.append(stripped)
|
| 768 |
+
return "\n".join(collapsed).strip()
|
| 769 |
+
|
| 770 |
+
|
| 771 |
+
def enrich_feedback_text(feedback: str, session: EpisodeSession) -> str:
|
| 772 |
+
supplement_lines = _observation_context_lines(session)
|
| 773 |
+
if not supplement_lines:
|
| 774 |
+
return feedback.strip()
|
| 775 |
+
merged: list[str] = []
|
| 776 |
+
base = feedback.strip()
|
| 777 |
+
if base:
|
| 778 |
+
merged.append(base)
|
| 779 |
+
for line in supplement_lines:
|
| 780 |
+
if line not in base:
|
| 781 |
+
merged.append(line)
|
| 782 |
+
return "\n\n".join(merged).strip()
|
| 783 |
+
|
| 784 |
+
|
| 785 |
+
def _observation_context_lines(session: EpisodeSession) -> list[str]:
|
| 786 |
+
visible = _visible_entities(session)
|
| 787 |
+
exits = sorted(command[3:] for command in session.available_commands() if command.startswith("go "))
|
| 788 |
+
lines: list[str] = []
|
| 789 |
+
if visible:
|
| 790 |
+
lines.append("Visible here: " + ", ".join(visible))
|
| 791 |
+
if exits:
|
| 792 |
+
lines.append("Exits: " + ", ".join(exits))
|
| 793 |
+
return lines
|
| 794 |
+
|
| 795 |
+
|
| 796 |
+
def _visible_entities(session: EpisodeSession) -> list[str]:
|
| 797 |
+
visible: list[str] = []
|
| 798 |
+
seen: set[str] = set()
|
| 799 |
+
for node in session.compiled.world.nodes:
|
| 800 |
+
if getattr(node, "parent_id", None) != session.current_room_id:
|
| 801 |
+
continue
|
| 802 |
+
if node.type == "readable" and node.id not in session.revealed_readables:
|
| 803 |
+
continue
|
| 804 |
+
name = session.compiled.node_command_names.get(node.id)
|
| 805 |
+
if name and name not in seen:
|
| 806 |
+
seen.add(name)
|
| 807 |
+
visible.append(name)
|
| 808 |
+
for edge in session.compiled.world.edges:
|
| 809 |
+
if edge.from_node_id != session.current_room_id or not edge.door_node_id:
|
| 810 |
+
continue
|
| 811 |
+
name = session.compiled.node_command_names.get(edge.door_node_id)
|
| 812 |
+
if name and name not in seen:
|
| 813 |
+
seen.add(name)
|
| 814 |
+
visible.append(name)
|
| 815 |
+
for item in session.compiled.world.items:
|
| 816 |
+
if session.item_locations.get(item.id) != session.current_room_id:
|
| 817 |
+
continue
|
| 818 |
+
name = session.compiled.item_command_names.get(item.id)
|
| 819 |
+
if name and name not in seen:
|
| 820 |
+
seen.add(name)
|
| 821 |
+
visible.append(name)
|
| 822 |
+
return visible
|
| 823 |
+
|
| 824 |
+
|
| 825 |
+
def _is_probable_banner_line(line: str) -> bool:
|
| 826 |
+
if len(line) < 12:
|
| 827 |
+
return False
|
| 828 |
+
if line.startswith("Explore ") or line.startswith("Location: "):
|
| 829 |
+
return False
|
| 830 |
+
banner_chars = len(_TEXTWORLD_BANNER_CHAR_RE.findall(line))
|
| 831 |
+
return banner_chars >= max(4, len(line) // 6)
|
agents/master/logic.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import shutil
|
| 5 |
+
import textwrap
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
import textworld
|
| 9 |
+
from textworld.core import EnvInfos
|
| 10 |
+
from textworld.generator.data import LOGIC_DATA_PATH, TEXT_GRAMMARS_PATH
|
| 11 |
+
|
| 12 |
+
from .base import (
|
| 13 |
+
CUSTOM_GRAMMAR_DIR,
|
| 14 |
+
CUSTOM_LOGIC_DIR,
|
| 15 |
+
normalize_answer_text,
|
| 16 |
+
suppress_unsupported_game_warning,
|
| 17 |
+
)
|
| 18 |
+
from .schema import WorldDefinition
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def build_logic_dir(artifacts_dir: Path, world: WorldDefinition) -> Path:
|
| 22 |
+
logic_dir = artifacts_dir / "kb_logic"
|
| 23 |
+
logic_dir.mkdir(parents=True, exist_ok=True)
|
| 24 |
+
overrides = {path.name for path in CUSTOM_LOGIC_DIR.glob("*.twl")}
|
| 25 |
+
for builtin in Path(LOGIC_DATA_PATH).glob("*.twl"):
|
| 26 |
+
if builtin.name not in overrides:
|
| 27 |
+
shutil.copy(builtin, logic_dir / builtin.name)
|
| 28 |
+
for custom in CUSTOM_LOGIC_DIR.glob("*.twl"):
|
| 29 |
+
shutil.copy(custom, logic_dir / custom.name)
|
| 30 |
+
(logic_dir / "world_submit_overlay.twl").write_text(submission_overlay(world), encoding="utf-8")
|
| 31 |
+
return logic_dir
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def build_grammar_dir(artifacts_dir: Path) -> Path:
|
| 35 |
+
grammar_dir = artifacts_dir / "kb_grammar"
|
| 36 |
+
grammar_dir.mkdir(parents=True, exist_ok=True)
|
| 37 |
+
overrides = {path.name for path in CUSTOM_GRAMMAR_DIR.glob("*.twg")}
|
| 38 |
+
for builtin in Path(TEXT_GRAMMARS_PATH).glob("*.twg"):
|
| 39 |
+
if builtin.name not in overrides:
|
| 40 |
+
shutil.copy(builtin, grammar_dir / builtin.name)
|
| 41 |
+
for custom in CUSTOM_GRAMMAR_DIR.glob("*.twg"):
|
| 42 |
+
shutil.copy(custom, grammar_dir / custom.name)
|
| 43 |
+
return grammar_dir
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def submit_command_text(world: WorldDefinition) -> str:
|
| 47 |
+
return "submit " + normalize_answer_text(world.meta.win_condition.answer_string)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def submission_overlay(world: WorldDefinition) -> str:
|
| 51 |
+
answer = submit_command_text(world).replace('"', '\\"')
|
| 52 |
+
return textwrap.dedent(
|
| 53 |
+
f'''
|
| 54 |
+
type submission {{
|
| 55 |
+
rules {{
|
| 56 |
+
submit/final :: $at(P, r) & $at(npc, r) & $guardian(npc) & $consulted(npc) & $correct(answer, npc) -> solved(answer);
|
| 57 |
+
}}
|
| 58 |
+
reverse_rules {{
|
| 59 |
+
submit/final :: submit/final;
|
| 60 |
+
}}
|
| 61 |
+
inform7 {{
|
| 62 |
+
commands {{
|
| 63 |
+
submit/final :: "{answer}" :: "taking inventory";
|
| 64 |
+
}}
|
| 65 |
+
code :: """
|
| 66 |
+
Understand "{answer}" as taking inventory.
|
| 67 |
+
After taking inventory:
|
| 68 |
+
if the player's command matches the text "{answer}":
|
| 69 |
+
repeat with candidate running through answer-likes:
|
| 70 |
+
now candidate is solved;
|
| 71 |
+
""";
|
| 72 |
+
}}
|
| 73 |
+
}}
|
| 74 |
+
'''
|
| 75 |
+
).strip() + "\n"
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def write_artifacts(artifacts_dir: Path, world: WorldDefinition, walkthrough_commands: list[str]) -> None:
|
| 79 |
+
(artifacts_dir / "world_definition.normalized.json").write_text(world.model_dump_json(indent=2), encoding="utf-8")
|
| 80 |
+
(artifacts_dir / "walkthrough.json").write_text(json.dumps(walkthrough_commands, indent=2), encoding="utf-8")
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def solver_policy(game_file: str) -> list[str]:
|
| 84 |
+
with suppress_unsupported_game_warning():
|
| 85 |
+
env = textworld.start(game_file, request_infos=EnvInfos(policy_commands=True, extras=["walkthrough"]))
|
| 86 |
+
try:
|
| 87 |
+
state = env.reset()
|
| 88 |
+
finally:
|
| 89 |
+
close = getattr(env, "close", None)
|
| 90 |
+
if callable(close):
|
| 91 |
+
close()
|
| 92 |
+
return list(state.policy_commands or state.get("extra.walkthrough") or [])
|
agents/master/main.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
import sys
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
from .base import DMCompileError, DMInterfaceError
|
| 9 |
+
from .env import DMEnvironment
|
| 10 |
+
from .interface import DEFAULT_GEMINI_MODEL, GeminiInterfaceAdapter, SimpleInterfaceAdapter
|
| 11 |
+
from .play import ManualRunner, RandomAdmissibleRunner, WalkthroughRunner
|
| 12 |
+
from .sample import load_world, sample_world_definition
|
| 13 |
+
from .server import run_server
|
| 14 |
+
from .snapshots import DEFAULT_LIVE_DIR, LiveSnapshotWriter
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def main(argv: list[str] | None = None) -> int:
|
| 18 |
+
parser = argparse.ArgumentParser(description="Dungeon DM environment harness")
|
| 19 |
+
parser.add_argument("mode", choices=["validate", "play", "sample", "serve"], help="What to do.")
|
| 20 |
+
parser.add_argument("world", nargs="?", help="Path to a world-definition JSON file.")
|
| 21 |
+
parser.add_argument("--runner", choices=["walkthrough", "random", "manual"], default="walkthrough")
|
| 22 |
+
parser.add_argument("--interface", choices=["simple", "gemini"], default="simple")
|
| 23 |
+
parser.add_argument("--model", default=DEFAULT_GEMINI_MODEL)
|
| 24 |
+
parser.add_argument("--narrate", action="store_true", help="Narrate observations through Gemini.")
|
| 25 |
+
parser.add_argument("--live", action="store_true", help="Write live viewer snapshots while playing.")
|
| 26 |
+
parser.add_argument("--live-dir", type=Path, default=DEFAULT_LIVE_DIR)
|
| 27 |
+
parser.add_argument("--port", type=int, default=8000)
|
| 28 |
+
args = parser.parse_args(argv)
|
| 29 |
+
if args.mode == "serve":
|
| 30 |
+
run_server(port=args.port, live_dir=args.live_dir)
|
| 31 |
+
return 0
|
| 32 |
+
|
| 33 |
+
if args.mode == "sample":
|
| 34 |
+
print(json.dumps(sample_world_definition(), indent=2))
|
| 35 |
+
return 0
|
| 36 |
+
if not args.world:
|
| 37 |
+
parser.error("A world-definition JSON file is required for validate/play.")
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
adapter = SimpleInterfaceAdapter()
|
| 41 |
+
if args.interface == "gemini":
|
| 42 |
+
adapter = GeminiInterfaceAdapter(model=args.model, narrate_observations=args.narrate)
|
| 43 |
+
env = DMEnvironment(interface_adapter=adapter)
|
| 44 |
+
world = load_world(args.world)
|
| 45 |
+
if args.mode == "validate":
|
| 46 |
+
compiled = env.compile_world(world)
|
| 47 |
+
print(f"Compiled successfully: {compiled.game_file}")
|
| 48 |
+
print(f"Solver policy: {compiled.solver_policy}")
|
| 49 |
+
return 0
|
| 50 |
+
|
| 51 |
+
runner = {"manual": ManualRunner(), "random": RandomAdmissibleRunner(), "walkthrough": WalkthroughRunner()}[
|
| 52 |
+
args.runner
|
| 53 |
+
]
|
| 54 |
+
observer = LiveSnapshotWriter(live_dir=args.live_dir, runner_name=args.runner) if args.live else None
|
| 55 |
+
result = env.play(world, runner=runner, observer=observer)
|
| 56 |
+
if result.observation.compile_error is not None:
|
| 57 |
+
print(result.observation.compile_error, file=sys.stderr)
|
| 58 |
+
return 1
|
| 59 |
+
print(
|
| 60 |
+
json.dumps(
|
| 61 |
+
{
|
| 62 |
+
"reward": result.reward,
|
| 63 |
+
"done": result.done,
|
| 64 |
+
"observation": result.observation.model_dump(),
|
| 65 |
+
},
|
| 66 |
+
indent=2,
|
| 67 |
+
)
|
| 68 |
+
)
|
| 69 |
+
return 0
|
| 70 |
+
except (DMCompileError, DMInterfaceError, ValueError) as exc:
|
| 71 |
+
print(str(exc), file=sys.stderr)
|
| 72 |
+
return 1
|
agents/master/play.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import random
|
| 4 |
+
from typing import Iterable, Protocol, TYPE_CHECKING
|
| 5 |
+
|
| 6 |
+
from .base import DMInterfaceError
|
| 7 |
+
|
| 8 |
+
if TYPE_CHECKING:
|
| 9 |
+
from .session import EpisodeSession
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class EpisodeRunner(Protocol):
|
| 13 |
+
def run(self, session: EpisodeSession, max_steps: int) -> None:
|
| 14 |
+
...
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class WalkthroughRunner:
|
| 18 |
+
def __init__(self, commands: Iterable[str] | None = None) -> None:
|
| 19 |
+
self._commands = list(commands) if commands is not None else None
|
| 20 |
+
|
| 21 |
+
def run(self, session: EpisodeSession, max_steps: int) -> None:
|
| 22 |
+
commands = list(self._commands or session.compiled.solver_policy)
|
| 23 |
+
for command in commands:
|
| 24 |
+
if session.done or session.steps_taken >= max_steps:
|
| 25 |
+
return
|
| 26 |
+
session.step(command)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class CommandSequenceRunner:
|
| 30 |
+
def __init__(self, commands: Iterable[str]) -> None:
|
| 31 |
+
self._commands = list(commands)
|
| 32 |
+
|
| 33 |
+
def run(self, session: EpisodeSession, max_steps: int) -> None:
|
| 34 |
+
for command in self._commands:
|
| 35 |
+
if session.done or session.steps_taken >= max_steps:
|
| 36 |
+
return
|
| 37 |
+
session.step(command)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class RandomAdmissibleRunner:
|
| 41 |
+
def __init__(self, seed: int | None = None) -> None:
|
| 42 |
+
self._rng = random.Random(seed)
|
| 43 |
+
|
| 44 |
+
def run(self, session: EpisodeSession, max_steps: int) -> None:
|
| 45 |
+
while not session.done and session.steps_taken < max_steps:
|
| 46 |
+
options = session.available_commands()
|
| 47 |
+
if not options:
|
| 48 |
+
return
|
| 49 |
+
session.step(self._rng.choice(options))
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class ManualRunner:
|
| 53 |
+
def run(self, session: EpisodeSession, max_steps: int) -> None:
|
| 54 |
+
print(session.current_feedback())
|
| 55 |
+
while not session.done and session.steps_taken < max_steps:
|
| 56 |
+
print()
|
| 57 |
+
print(f"Step {session.steps_taken + 1}/{max_steps}")
|
| 58 |
+
command = input("> ").strip()
|
| 59 |
+
if command in {"quit", "exit"}:
|
| 60 |
+
return
|
| 61 |
+
try:
|
| 62 |
+
turn = session.step(command)
|
| 63 |
+
except DMInterfaceError:
|
| 64 |
+
print("I'm not sure what you mean. Try rephrasing that command.")
|
| 65 |
+
if session.available_commands():
|
| 66 |
+
print("Admissible:", ", ".join(session.available_commands()))
|
| 67 |
+
continue
|
| 68 |
+
print(turn.observation)
|
| 69 |
+
if session.available_commands():
|
| 70 |
+
print("Admissible:", ", ".join(session.available_commands()))
|
agents/master/policy.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Protocol
|
| 4 |
+
|
| 5 |
+
from pydantic import Field
|
| 6 |
+
|
| 7 |
+
from agents.shared.llm_client import StructuredModelClient
|
| 8 |
+
from agents.shared.model_schema import StrictModel
|
| 9 |
+
|
| 10 |
+
from .schema import WorldDefinition
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class DungeonMasterPolicyError(RuntimeError):
|
| 14 |
+
pass
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class DungeonMasterPolicy(Protocol):
|
| 18 |
+
def generate_world(
|
| 19 |
+
self,
|
| 20 |
+
*,
|
| 21 |
+
target_ratio: float,
|
| 22 |
+
repair_context: "DMRepairContext | None" = None,
|
| 23 |
+
) -> WorldDefinition:
|
| 24 |
+
...
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class DMRepairContext(StrictModel):
|
| 28 |
+
attempt_number: int
|
| 29 |
+
error_message: str
|
| 30 |
+
previous_candidate_json: str | None = None
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class WinConditionCandidate(StrictModel):
|
| 34 |
+
type: str
|
| 35 |
+
target_npc_id: str
|
| 36 |
+
answer_string: str
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class WorldMetaCandidate(StrictModel):
|
| 40 |
+
title: str
|
| 41 |
+
difficulty_target: float
|
| 42 |
+
start_node_id: str
|
| 43 |
+
win_condition: WinConditionCandidate
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class WorldNodeCandidate(StrictModel):
|
| 47 |
+
id: str
|
| 48 |
+
type: str
|
| 49 |
+
label: str
|
| 50 |
+
description: str
|
| 51 |
+
parent_id: str | None = None
|
| 52 |
+
open: bool | None = None
|
| 53 |
+
locked: bool | None = None
|
| 54 |
+
lock_key_id: str | None = None
|
| 55 |
+
clue_id: str | None = None
|
| 56 |
+
requires_item_id: str | None = None
|
| 57 |
+
consumes_item: bool | None = None
|
| 58 |
+
text_content: str | None = None
|
| 59 |
+
reveals_item_id: str | None = None
|
| 60 |
+
reveals_readable_id: str | None = None
|
| 61 |
+
gives_item_id: str | None = None
|
| 62 |
+
gives_clue_id: str | None = None
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class EdgeCandidate(StrictModel):
|
| 66 |
+
id: str
|
| 67 |
+
from_node_id: str
|
| 68 |
+
to_node_id: str
|
| 69 |
+
direction: str
|
| 70 |
+
type: str
|
| 71 |
+
required_item_id: str | None = None
|
| 72 |
+
door_node_id: str | None = None
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class ItemCandidate(StrictModel):
|
| 76 |
+
id: str
|
| 77 |
+
label: str
|
| 78 |
+
description: str
|
| 79 |
+
subtype: str
|
| 80 |
+
start_node_id: str | None = None
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class ClueCandidate(StrictModel):
|
| 84 |
+
id: str
|
| 85 |
+
text: str
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class RecipeCandidate(StrictModel):
|
| 89 |
+
id: str
|
| 90 |
+
input_item_ids: list[str]
|
| 91 |
+
output_item_id: str
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class QuestStepCandidate(StrictModel):
|
| 95 |
+
step_id: str
|
| 96 |
+
description: str
|
| 97 |
+
requires_step_ids: list[str] = Field(default_factory=list)
|
| 98 |
+
action: str
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class WorldDefinitionCandidate(StrictModel):
|
| 102 |
+
meta: WorldMetaCandidate
|
| 103 |
+
nodes: list[WorldNodeCandidate]
|
| 104 |
+
edges: list[EdgeCandidate]
|
| 105 |
+
items: list[ItemCandidate]
|
| 106 |
+
clues: list[ClueCandidate]
|
| 107 |
+
recipes: list[RecipeCandidate] = Field(default_factory=list)
|
| 108 |
+
quest_chain: list[QuestStepCandidate]
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class DungeonMasterLLMPolicy:
|
| 112 |
+
def __init__(
|
| 113 |
+
self,
|
| 114 |
+
client: StructuredModelClient,
|
| 115 |
+
*,
|
| 116 |
+
model_name: str,
|
| 117 |
+
temperature: float = 0.0,
|
| 118 |
+
max_output_tokens: int = 8192,
|
| 119 |
+
) -> None:
|
| 120 |
+
self.client = client
|
| 121 |
+
self.model_name = model_name
|
| 122 |
+
self.temperature = temperature
|
| 123 |
+
self.max_output_tokens = max_output_tokens
|
| 124 |
+
|
| 125 |
+
def generate_world(
|
| 126 |
+
self,
|
| 127 |
+
*,
|
| 128 |
+
target_ratio: float,
|
| 129 |
+
repair_context: DMRepairContext | None = None,
|
| 130 |
+
) -> WorldDefinition:
|
| 131 |
+
from .prompt import build_dm_world_messages
|
| 132 |
+
|
| 133 |
+
try:
|
| 134 |
+
candidate = self.client.generate_structured(
|
| 135 |
+
build_dm_world_messages(target_ratio=target_ratio, repair_context=repair_context),
|
| 136 |
+
WorldDefinitionCandidate,
|
| 137 |
+
model_name=self.model_name,
|
| 138 |
+
temperature=self.temperature,
|
| 139 |
+
max_output_tokens=self.max_output_tokens,
|
| 140 |
+
)
|
| 141 |
+
return WorldDefinition.model_validate(candidate.model_dump(mode="json", exclude_none=True))
|
| 142 |
+
except Exception as exc:
|
| 143 |
+
raise DungeonMasterPolicyError(self._normalize_error(exc)) from exc
|
| 144 |
+
|
| 145 |
+
@staticmethod
|
| 146 |
+
def _normalize_error(exc: Exception) -> str:
|
| 147 |
+
return " ".join(str(exc).split()) or exc.__class__.__name__
|
agents/master/prompt.py
ADDED
|
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from typing import TYPE_CHECKING
|
| 5 |
+
|
| 6 |
+
from agents.shared.model_schema import ModelMessage
|
| 7 |
+
|
| 8 |
+
from .sample import sample_world_definition
|
| 9 |
+
|
| 10 |
+
if TYPE_CHECKING:
|
| 11 |
+
from .policy import DMRepairContext
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
DM_WORLD_SYSTEM_PROMPT = """You are the dungeon master policy for a structured text adventure generator.
|
| 15 |
+
|
| 16 |
+
Return exactly one valid WorldDefinition JSON object as minified JSON on a single line.
|
| 17 |
+
Do not use markdown fences, indentation, comments, or extra prose.
|
| 18 |
+
|
| 19 |
+
World requirements:
|
| 20 |
+
- Build a fair, solvable mystery dungeon with 4 to 6 rooms.
|
| 21 |
+
- Use only the supported schema fields and node types.
|
| 22 |
+
- Use ids in snake_case.
|
| 23 |
+
- Set meta.difficulty_target equal to the requested target ratio.
|
| 24 |
+
- The win condition must be deduce with a short lowercase answer string.
|
| 25 |
+
- The final answer must never be leaked directly in clue text.
|
| 26 |
+
- The world must be mechanically consistent: all references must point to real ids and every puzzle chain must be completable.
|
| 27 |
+
- Do not add unsupported fields to node variants. In particular, location, junction, and door nodes must not include `parent_id`.
|
| 28 |
+
- Every readable must include `text_content`.
|
| 29 |
+
- Keep the world compact enough to fit in one response: short labels, short descriptions, and a concise quest chain.
|
| 30 |
+
|
| 31 |
+
Supported mechanics:
|
| 32 |
+
- Containers and doors can be opened.
|
| 33 |
+
- Locked doors require a real key item.
|
| 34 |
+
- Readables may require an item before they become legible.
|
| 35 |
+
- Fixtures may reveal an item or a readable after a correct use action.
|
| 36 |
+
- NPCs may trade one required item for one item or one clue.
|
| 37 |
+
- Recipes combine exactly two items into one output item.
|
| 38 |
+
- Navigation uses only passage and locked_passage edges.
|
| 39 |
+
|
| 40 |
+
Quest-chain rules:
|
| 41 |
+
- Every quest action must be one of:
|
| 42 |
+
open(node_id)
|
| 43 |
+
take(item_id,source_node_id)
|
| 44 |
+
unlock(door_id,key_id)
|
| 45 |
+
go(room_id)
|
| 46 |
+
read(readable_id)
|
| 47 |
+
use(item_id,target_node_id)
|
| 48 |
+
combine(item_a_id,item_b_id)
|
| 49 |
+
give(item_id,npc_id)
|
| 50 |
+
talk(npc_id)
|
| 51 |
+
submit("answer")
|
| 52 |
+
- Do not invent unsupported actions such as inspect(), search(), solve(), explore(), or win().
|
| 53 |
+
- The quest chain must be topologically valid and correspond to a real solvable playthrough.
|
| 54 |
+
- Every quest step object must use exactly these keys: step_id, description, requires_step_ids, action.
|
| 55 |
+
- Use requires_step_ids (plural) even for one dependency. Never use requires_step_id.
|
| 56 |
+
- Include exactly 3 clues that narrow the answer without stating it directly.
|
| 57 |
+
- Include a guardian NPC for the final submission.
|
| 58 |
+
- Every clue id in clues[] must have exactly one real source: either one readable.clue_id or one non-guardian npc.gives_clue_id.
|
| 59 |
+
- Do not include unused clue ids and do not leave readables without clue_id.
|
| 60 |
+
- Clue text and readable text must never contain the exact answer_string.
|
| 61 |
+
- Every room-to-room connection must include the reverse edge explicitly.
|
| 62 |
+
- Every locked_passage pair must reference a real door node id that already exists in nodes[].
|
| 63 |
+
- Any item used by required_item_id or lock_key_id must have subtype key.
|
| 64 |
+
- Keep descriptions and clue texts short, concrete, and under 14 words when possible.
|
| 65 |
+
- Prefer 4 rooms, 9 to 11 nodes, 4 to 5 items, 3 clues, 0 recipes, and 6 to 9 quest steps.
|
| 66 |
+
- Use the shortest valid quest chain that still supports the target difficulty.
|
| 67 |
+
- meta must include title, difficulty_target, start_node_id, and win_condition.
|
| 68 |
+
- item objects use subtype, never type.
|
| 69 |
+
- clue objects use id and text, never clue_id.
|
| 70 |
+
- fixture objects use reveals_item_id or reveals_readable_id.
|
| 71 |
+
- NPC trade objects use requires_item_id plus gives_item_id or gives_clue_id.
|
| 72 |
+
|
| 73 |
+
Reliability matters more than novelty. Stay close to the reference world's mechanical bundle unless repair feedback requires a different fix.
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
_DM_WORLD_USER_PROMPTS = (
|
| 77 |
+
(
|
| 78 |
+
"Generate one full WorldDefinition JSON object as minified one-line JSON.\n"
|
| 79 |
+
"Requested target ratio: {target_ratio}\n\n"
|
| 80 |
+
"Hard output requirements:\n"
|
| 81 |
+
"- Required top-level fields: meta, nodes, edges, items, clues, recipes, quest_chain.\n"
|
| 82 |
+
"- Supported node types: location, junction, container, door, readable, fixture, npc.\n"
|
| 83 |
+
"- Supported edge types: passage, locked_passage.\n"
|
| 84 |
+
"- Supported item subtypes: key, puzzle.\n"
|
| 85 |
+
"- Every locked_passage must reference a real door_node_id and a real required_item_id.\n"
|
| 86 |
+
"- Every locked door must have a matching lock_key_id.\n"
|
| 87 |
+
"- Every fixture must have requires_item_id and reveal at most one item or one readable.\n"
|
| 88 |
+
"- Location, junction, and door nodes must not include parent_id.\n"
|
| 89 |
+
"- Every readable must include text_content.\n"
|
| 90 |
+
"- Every non-guardian NPC trade must require a real item.\n"
|
| 91 |
+
"- Use 6 to 9 quest steps unless a shorter valid chain is clearly enough.\n"
|
| 92 |
+
"- meta must include title, start_node_id, and win_condition.\n"
|
| 93 |
+
"- items use subtype, not type.\n"
|
| 94 |
+
"- clues use id, not clue_id.\n"
|
| 95 |
+
"- every clue id must have exactly one readable or non-guardian npc source.\n"
|
| 96 |
+
"- fixtures use reveals_item_id or reveals_readable_id.\n"
|
| 97 |
+
"- NPC trades use requires_item_id plus gives_item_id or gives_clue_id.\n"
|
| 98 |
+
"- every locked_passage must reference a real door node id and a key item.\n"
|
| 99 |
+
"- The final answer must stay implicit until the player gathers clues and speaks to the guardian.\n\n"
|
| 100 |
+
"Compact structural snippets to mimic exactly:\n"
|
| 101 |
+
"meta={meta_example_json}\n"
|
| 102 |
+
"item={item_example_json}\n"
|
| 103 |
+
"clue={clue_example_json}\n"
|
| 104 |
+
"fixture={fixture_example_json}\n"
|
| 105 |
+
"npc={npc_example_json}\n"
|
| 106 |
+
"quest_step={quest_step_example_json}\n"
|
| 107 |
+
"edge_pair={edge_pair_example_json}\n"
|
| 108 |
+
"readable={readable_example_json}\n"
|
| 109 |
+
),
|
| 110 |
+
(
|
| 111 |
+
"Produce a compact but fully valid WorldDefinition JSON object as minified one-line JSON.\n"
|
| 112 |
+
"Target difficulty ratio: {target_ratio}\n\n"
|
| 113 |
+
"Mechanical constraints:\n"
|
| 114 |
+
"- Output minified JSON only on one line.\n"
|
| 115 |
+
"- Keep the graph solvable and internally consistent.\n"
|
| 116 |
+
"- Keep all ids in snake_case and all references real.\n"
|
| 117 |
+
"- Preserve the supported node, edge, and item types exactly.\n"
|
| 118 |
+
"- Do not add unsupported fields to node variants.\n"
|
| 119 |
+
"- Every readable must include text_content.\n"
|
| 120 |
+
"- The world must require clue gathering before the guardian submission.\n"
|
| 121 |
+
"- Use exactly 3 clues.\n"
|
| 122 |
+
"- Every clue id must appear exactly once in a readable.clue_id or npc.gives_clue_id.\n"
|
| 123 |
+
"- Every edge pair must include both directions explicitly.\n"
|
| 124 |
+
"- Every locked_passage must reference a real door node id already present in nodes[].\n"
|
| 125 |
+
"- Any required_item_id on a locked_passage must be a key item.\n"
|
| 126 |
+
"- Quest steps must use requires_step_ids (plural).\n\n"
|
| 127 |
+
"Exact meta example:\n{meta_example_json}\n"
|
| 128 |
+
"Exact item example:\n{item_example_json}\n"
|
| 129 |
+
"Exact clue example:\n{clue_example_json}\n"
|
| 130 |
+
"Exact fixture example:\n{fixture_example_json}\n"
|
| 131 |
+
"Exact NPC example:\n{npc_example_json}\n"
|
| 132 |
+
"Exact quest step example:\n{quest_step_example_json}\n"
|
| 133 |
+
"Exact bidirectional edge example:\n{edge_pair_example_json}\n"
|
| 134 |
+
"Exact readable example:\n{readable_example_json}\n"
|
| 135 |
+
),
|
| 136 |
+
(
|
| 137 |
+
"Return one original WorldDefinition JSON object for a mystery dungeon as minified one-line JSON.\n"
|
| 138 |
+
"Requested target ratio: {target_ratio}\n\n"
|
| 139 |
+
"Checklist:\n"
|
| 140 |
+
"- 4 to 6 rooms.\n"
|
| 141 |
+
"- 3 to 5 clues.\n"
|
| 142 |
+
"- A real guardian NPC for the final answer.\n"
|
| 143 |
+
"- A quest chain that compiles into a real walkthrough.\n"
|
| 144 |
+
"- No unsupported extra fields and no missing required fields like readable.text_content.\n"
|
| 145 |
+
"- No unsupported mechanics, no unsupported actions, no prose.\n"
|
| 146 |
+
"- Use requires_step_ids (plural), not requires_step_id.\n"
|
| 147 |
+
"- Use exactly 3 clues and explicit reverse edges.\n"
|
| 148 |
+
"- Every clue id must have exactly one source and no clue may be orphaned.\n"
|
| 149 |
+
"- Every locked_passage must use an existing door node id and a key item.\n"
|
| 150 |
+
"- Prefer 6 to 9 quest steps, not long walkthroughs.\n\n"
|
| 151 |
+
"Mini schema examples:\n"
|
| 152 |
+
"meta={meta_example_json}\n"
|
| 153 |
+
"item={item_example_json}\n"
|
| 154 |
+
"clue={clue_example_json}\n"
|
| 155 |
+
"fixture={fixture_example_json}\n"
|
| 156 |
+
"npc={npc_example_json}\n"
|
| 157 |
+
"quest_step={quest_step_example_json}\n"
|
| 158 |
+
"edge_pair={edge_pair_example_json}\n"
|
| 159 |
+
"readable={readable_example_json}\n"
|
| 160 |
+
),
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
_DM_META_EXAMPLE = {
|
| 164 |
+
"title": "The Ember Vault",
|
| 165 |
+
"difficulty_target": 1.75,
|
| 166 |
+
"start_node_id": "foyer",
|
| 167 |
+
"win_condition": {
|
| 168 |
+
"type": "deduce",
|
| 169 |
+
"target_npc_id": "stone_guardian",
|
| 170 |
+
"answer_string": "vesna",
|
| 171 |
+
},
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
_DM_ITEM_EXAMPLE = {
|
| 175 |
+
"id": "brass_key",
|
| 176 |
+
"subtype": "key",
|
| 177 |
+
"start_node_id": "entry_chest",
|
| 178 |
+
"label": "Brass Key",
|
| 179 |
+
"description": "short key description",
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
_DM_CLUE_EXAMPLE = {
|
| 183 |
+
"id": "initial_clue",
|
| 184 |
+
"text": "short clue text",
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
_DM_FIXTURE_EXAMPLE = {
|
| 188 |
+
"id": "stone_well",
|
| 189 |
+
"type": "fixture",
|
| 190 |
+
"parent_id": "courtyard",
|
| 191 |
+
"requires_item_id": "full_map",
|
| 192 |
+
"consumes_item": False,
|
| 193 |
+
"reveals_item_id": None,
|
| 194 |
+
"reveals_readable_id": "water_plaque",
|
| 195 |
+
"label": "Stone Well",
|
| 196 |
+
"description": "short fixture description",
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
_DM_NPC_EXAMPLE = {
|
| 200 |
+
"id": "cartographer",
|
| 201 |
+
"type": "npc",
|
| 202 |
+
"parent_id": "gallery",
|
| 203 |
+
"requires_item_id": "full_map",
|
| 204 |
+
"gives_item_id": "lens",
|
| 205 |
+
"gives_clue_id": None,
|
| 206 |
+
"label": "Cartographer",
|
| 207 |
+
"description": "short npc description",
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
_DM_QUEST_STEP_EXAMPLE = {
|
| 211 |
+
"step_id": "open_entry_chest",
|
| 212 |
+
"description": "open the chest",
|
| 213 |
+
"requires_step_ids": [],
|
| 214 |
+
"action": "open(entry_chest)",
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
_DM_EDGE_PAIR_EXAMPLE = [
|
| 218 |
+
{
|
| 219 |
+
"id": "foyer_east",
|
| 220 |
+
"from_node_id": "foyer",
|
| 221 |
+
"to_node_id": "workshop",
|
| 222 |
+
"direction": "east",
|
| 223 |
+
"type": "locked_passage",
|
| 224 |
+
"required_item_id": "brass_key",
|
| 225 |
+
"door_node_id": "iron_door",
|
| 226 |
+
},
|
| 227 |
+
{
|
| 228 |
+
"id": "workshop_west",
|
| 229 |
+
"from_node_id": "workshop",
|
| 230 |
+
"to_node_id": "foyer",
|
| 231 |
+
"direction": "west",
|
| 232 |
+
"type": "locked_passage",
|
| 233 |
+
"required_item_id": "brass_key",
|
| 234 |
+
"door_node_id": "iron_door",
|
| 235 |
+
},
|
| 236 |
+
]
|
| 237 |
+
|
| 238 |
+
_DM_READABLE_EXAMPLE = {
|
| 239 |
+
"id": "ash_mural",
|
| 240 |
+
"type": "readable",
|
| 241 |
+
"parent_id": "workshop",
|
| 242 |
+
"clue_id": "initial_clue",
|
| 243 |
+
"requires_item_id": "torch",
|
| 244 |
+
"consumes_item": False,
|
| 245 |
+
"label": "Ash Mural",
|
| 246 |
+
"description": "short readable description",
|
| 247 |
+
"text_content": "short readable text",
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def _compress_reference_world_for_prompt(reference_world: dict[str, Any]) -> dict[str, Any]:
|
| 252 |
+
return {
|
| 253 |
+
"meta": reference_world.get("meta", {}),
|
| 254 |
+
"nodes": [
|
| 255 |
+
_compress_world_node(node)
|
| 256 |
+
for node in reference_world.get("nodes", [])
|
| 257 |
+
if isinstance(node, dict)
|
| 258 |
+
],
|
| 259 |
+
"edges": [
|
| 260 |
+
{
|
| 261 |
+
key: edge[key]
|
| 262 |
+
for key in ("id", "from_node_id", "to_node_id", "direction", "type", "required_item_id", "door_node_id")
|
| 263 |
+
if key in edge
|
| 264 |
+
}
|
| 265 |
+
for edge in reference_world.get("edges", [])
|
| 266 |
+
if isinstance(edge, dict)
|
| 267 |
+
],
|
| 268 |
+
"items": [
|
| 269 |
+
{
|
| 270 |
+
**{
|
| 271 |
+
key: item[key]
|
| 272 |
+
for key in ("id", "subtype", "start_node_id")
|
| 273 |
+
if key in item
|
| 274 |
+
},
|
| 275 |
+
"label": str(item.get("label") or item.get("id") or "item"),
|
| 276 |
+
"description": "short item description",
|
| 277 |
+
}
|
| 278 |
+
for item in reference_world.get("items", [])
|
| 279 |
+
if isinstance(item, dict)
|
| 280 |
+
],
|
| 281 |
+
"clues": [
|
| 282 |
+
{"id": clue["id"], "text": "short clue text"}
|
| 283 |
+
for clue in reference_world.get("clues", [])
|
| 284 |
+
if isinstance(clue, dict) and "id" in clue
|
| 285 |
+
],
|
| 286 |
+
"recipes": [
|
| 287 |
+
{
|
| 288 |
+
key: recipe[key]
|
| 289 |
+
for key in ("id", "input_item_ids", "output_item_id")
|
| 290 |
+
if key in recipe
|
| 291 |
+
}
|
| 292 |
+
for recipe in reference_world.get("recipes", [])
|
| 293 |
+
if isinstance(recipe, dict)
|
| 294 |
+
],
|
| 295 |
+
"quest_chain": [
|
| 296 |
+
{
|
| 297 |
+
**{
|
| 298 |
+
key: step[key]
|
| 299 |
+
for key in ("step_id", "requires_step_ids", "action")
|
| 300 |
+
if key in step
|
| 301 |
+
},
|
| 302 |
+
"description": "short quest step",
|
| 303 |
+
}
|
| 304 |
+
for step in reference_world.get("quest_chain", [])
|
| 305 |
+
if isinstance(step, dict)
|
| 306 |
+
],
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def _compress_world_node(node: dict[str, Any]) -> dict[str, Any]:
|
| 311 |
+
compressed = {
|
| 312 |
+
key: node[key]
|
| 313 |
+
for key in (
|
| 314 |
+
"id",
|
| 315 |
+
"type",
|
| 316 |
+
"parent_id",
|
| 317 |
+
"open",
|
| 318 |
+
"locked",
|
| 319 |
+
"lock_key_id",
|
| 320 |
+
"clue_id",
|
| 321 |
+
"requires_item_id",
|
| 322 |
+
"consumes_item",
|
| 323 |
+
"reveals_item_id",
|
| 324 |
+
"reveals_readable_id",
|
| 325 |
+
"gives_item_id",
|
| 326 |
+
"gives_clue_id",
|
| 327 |
+
)
|
| 328 |
+
if key in node
|
| 329 |
+
}
|
| 330 |
+
compressed["label"] = str(node.get("label") or node.get("id") or node.get("type") or "node")
|
| 331 |
+
compressed["description"] = f"short {str(node.get('type') or 'node')} description"
|
| 332 |
+
if node.get("type") == "readable":
|
| 333 |
+
compressed["text_content"] = "short readable text"
|
| 334 |
+
return compressed
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def build_dm_world_messages(
|
| 338 |
+
*,
|
| 339 |
+
target_ratio: float,
|
| 340 |
+
repair_context: "DMRepairContext | None" = None,
|
| 341 |
+
reference_world: dict[str, Any] | None = None,
|
| 342 |
+
prompt_style: int = 0,
|
| 343 |
+
) -> list[ModelMessage]:
|
| 344 |
+
exemplar_world = reference_world or sample_world_definition()
|
| 345 |
+
structural_exemplar = _compress_reference_world_for_prompt(exemplar_world)
|
| 346 |
+
template = _DM_WORLD_USER_PROMPTS[prompt_style % len(_DM_WORLD_USER_PROMPTS)]
|
| 347 |
+
prompt = template.format(
|
| 348 |
+
target_ratio=target_ratio,
|
| 349 |
+
reference_world_json=json.dumps(structural_exemplar, separators=(",", ":")),
|
| 350 |
+
meta_example_json=json.dumps(_DM_META_EXAMPLE, separators=(",", ":")),
|
| 351 |
+
item_example_json=json.dumps(_DM_ITEM_EXAMPLE, separators=(",", ":")),
|
| 352 |
+
clue_example_json=json.dumps(_DM_CLUE_EXAMPLE, separators=(",", ":")),
|
| 353 |
+
fixture_example_json=json.dumps(_DM_FIXTURE_EXAMPLE, separators=(",", ":")),
|
| 354 |
+
npc_example_json=json.dumps(_DM_NPC_EXAMPLE, separators=(",", ":")),
|
| 355 |
+
quest_step_example_json=json.dumps(_DM_QUEST_STEP_EXAMPLE, separators=(",", ":")),
|
| 356 |
+
edge_pair_example_json=json.dumps(_DM_EDGE_PAIR_EXAMPLE, separators=(",", ":")),
|
| 357 |
+
readable_example_json=json.dumps(_DM_READABLE_EXAMPLE, separators=(",", ":")),
|
| 358 |
+
)
|
| 359 |
+
if repair_context is not None:
|
| 360 |
+
prompt += (
|
| 361 |
+
"\nThe previous WorldDefinition failed schema validation or compilation.\n"
|
| 362 |
+
f"Repair attempt: {repair_context.attempt_number}\n"
|
| 363 |
+
f"Normalized error: {repair_context.error_message}\n"
|
| 364 |
+
"Return a fully corrected WorldDefinition only.\n"
|
| 365 |
+
)
|
| 366 |
+
if repair_context.previous_candidate_json:
|
| 367 |
+
prompt += f"Previous invalid WorldDefinition JSON:\n{repair_context.previous_candidate_json}\n"
|
| 368 |
+
return [
|
| 369 |
+
ModelMessage(role="system", content=DM_WORLD_SYSTEM_PROMPT),
|
| 370 |
+
ModelMessage(role="user", content=prompt),
|
| 371 |
+
]
|
agents/master/quest.py
ADDED
|
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import re
|
| 4 |
+
from collections import defaultdict, deque
|
| 5 |
+
|
| 6 |
+
from .base import (
|
| 7 |
+
COMBINE_RE,
|
| 8 |
+
DMCompileError,
|
| 9 |
+
GIVE_RE,
|
| 10 |
+
GO_RE,
|
| 11 |
+
INVENTORY_ID,
|
| 12 |
+
OPEN_RE,
|
| 13 |
+
READ_RE,
|
| 14 |
+
STORED_ID,
|
| 15 |
+
SUBMIT_RE,
|
| 16 |
+
TALK_RE,
|
| 17 |
+
TAKE_RE,
|
| 18 |
+
UNLOCK_RE,
|
| 19 |
+
USE_RE,
|
| 20 |
+
normalize_answer_text,
|
| 21 |
+
)
|
| 22 |
+
from .graph import door_room_mapping, hidden_readable_ids, recipe_mapping, use_effect_mapping
|
| 23 |
+
from .schema import (
|
| 24 |
+
CombineAction,
|
| 25 |
+
ContainerNode,
|
| 26 |
+
FixtureNode,
|
| 27 |
+
GiveAction,
|
| 28 |
+
GoAction,
|
| 29 |
+
Item,
|
| 30 |
+
NpcNode,
|
| 31 |
+
OpenAction,
|
| 32 |
+
QuestAction,
|
| 33 |
+
QuestStep,
|
| 34 |
+
ReadAction,
|
| 35 |
+
ReadableNode,
|
| 36 |
+
SimulationState,
|
| 37 |
+
SubmitAction,
|
| 38 |
+
TalkAction,
|
| 39 |
+
TakeAction,
|
| 40 |
+
UnlockAction,
|
| 41 |
+
UseAction,
|
| 42 |
+
WorldDefinition,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def topological_linearize(steps: list[QuestStep]) -> list[QuestStep]:
|
| 47 |
+
by_id = {step.step_id: step for step in steps}
|
| 48 |
+
for step in steps:
|
| 49 |
+
for dependency in step.requires_step_ids:
|
| 50 |
+
if dependency not in by_id:
|
| 51 |
+
raise DMCompileError(f"Quest step '{step.step_id}' depends on unknown step '{dependency}'.")
|
| 52 |
+
|
| 53 |
+
visiting: set[str] = set()
|
| 54 |
+
visited: set[str] = set()
|
| 55 |
+
|
| 56 |
+
def visit(step_id: str) -> None:
|
| 57 |
+
if step_id in visited:
|
| 58 |
+
return
|
| 59 |
+
if step_id in visiting:
|
| 60 |
+
raise DMCompileError("quest_chain contains a cycle.")
|
| 61 |
+
visiting.add(step_id)
|
| 62 |
+
for dependency in by_id[step_id].requires_step_ids:
|
| 63 |
+
visit(dependency)
|
| 64 |
+
visiting.remove(step_id)
|
| 65 |
+
visited.add(step_id)
|
| 66 |
+
|
| 67 |
+
for step in steps:
|
| 68 |
+
visit(step.step_id)
|
| 69 |
+
|
| 70 |
+
seen: set[str] = set()
|
| 71 |
+
for step in steps:
|
| 72 |
+
missing = [dependency for dependency in step.requires_step_ids if dependency not in seen]
|
| 73 |
+
if missing:
|
| 74 |
+
raise DMCompileError(
|
| 75 |
+
f"Quest step '{step.step_id}' appears before its required steps: {', '.join(sorted(missing))}."
|
| 76 |
+
)
|
| 77 |
+
seen.add(step.step_id)
|
| 78 |
+
return steps
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def parse_quest_action(text: str) -> QuestAction:
|
| 82 |
+
compact = re.sub(r"\s+", "", text)
|
| 83 |
+
if match := GO_RE.fullmatch(compact):
|
| 84 |
+
return GoAction(target_node_id=match.group("target"))
|
| 85 |
+
if match := OPEN_RE.fullmatch(compact):
|
| 86 |
+
return OpenAction(target_node_id=match.group("target"))
|
| 87 |
+
if match := UNLOCK_RE.fullmatch(compact):
|
| 88 |
+
return UnlockAction(door_id=match.group("door"), key_id=match.group("key"))
|
| 89 |
+
if match := TAKE_RE.fullmatch(compact):
|
| 90 |
+
return TakeAction(item_id=match.group("item"), source_node_id=match.group("source"))
|
| 91 |
+
if match := READ_RE.fullmatch(compact):
|
| 92 |
+
return ReadAction(target_node_id=match.group("target"))
|
| 93 |
+
if match := USE_RE.fullmatch(compact):
|
| 94 |
+
return UseAction(item_id=match.group("item"), target_node_id=match.group("target"))
|
| 95 |
+
if match := COMBINE_RE.fullmatch(compact):
|
| 96 |
+
return CombineAction(item_a_id=match.group("item_a"), item_b_id=match.group("item_b"))
|
| 97 |
+
if match := GIVE_RE.fullmatch(compact):
|
| 98 |
+
return GiveAction(item_id=match.group("item"), npc_id=match.group("npc"))
|
| 99 |
+
if match := TALK_RE.fullmatch(compact):
|
| 100 |
+
return TalkAction(target_node_id=match.group("target"))
|
| 101 |
+
if match := SUBMIT_RE.fullmatch(text.strip()):
|
| 102 |
+
return SubmitAction(answer_text=match.group("answer"))
|
| 103 |
+
raise DMCompileError(f"Unsupported quest action DSL '{text}'.")
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def simulate_walkthrough(
|
| 107 |
+
world: WorldDefinition,
|
| 108 |
+
actions: list[QuestAction],
|
| 109 |
+
entity_names: dict[str, str],
|
| 110 |
+
) -> list[str]:
|
| 111 |
+
node_by_id = {node.id: node for node in world.nodes}
|
| 112 |
+
item_by_id = {item.id: item for item in world.items}
|
| 113 |
+
edge_by_target = {(edge.from_node_id, edge.to_node_id): edge for edge in world.edges}
|
| 114 |
+
door_rooms = door_room_mapping(world)
|
| 115 |
+
hidden_readables = hidden_readable_ids(world)
|
| 116 |
+
use_effects = use_effect_mapping(world)
|
| 117 |
+
recipes = recipe_mapping(world)
|
| 118 |
+
clue_ids = {clue.id for clue in world.clues}
|
| 119 |
+
|
| 120 |
+
state = SimulationState(
|
| 121 |
+
current_room_id=world.meta.start_node_id,
|
| 122 |
+
item_locations={item.id: item.start_node_id or STORED_ID for item in world.items},
|
| 123 |
+
visited_nodes={world.meta.start_node_id},
|
| 124 |
+
revealed_readables={node.id for node in world.nodes if node.type == "readable" and node.id not in hidden_readables},
|
| 125 |
+
)
|
| 126 |
+
for node in world.nodes:
|
| 127 |
+
if node.type in {"container", "door"}:
|
| 128 |
+
if node.open:
|
| 129 |
+
state.open_nodes.add(node.id)
|
| 130 |
+
if node.locked:
|
| 131 |
+
state.locked_nodes.add(node.id)
|
| 132 |
+
|
| 133 |
+
commands: list[str] = []
|
| 134 |
+
for action in actions:
|
| 135 |
+
if isinstance(action, GoAction):
|
| 136 |
+
_apply_go(action, edge_by_target, state, commands)
|
| 137 |
+
elif isinstance(action, OpenAction):
|
| 138 |
+
_apply_open(action, node_by_id, door_rooms, state, entity_names, commands)
|
| 139 |
+
elif isinstance(action, UnlockAction):
|
| 140 |
+
_apply_unlock(action, node_by_id, item_by_id, door_rooms, state, entity_names, commands)
|
| 141 |
+
elif isinstance(action, TakeAction):
|
| 142 |
+
_apply_take(action, node_by_id, item_by_id, state, entity_names, commands)
|
| 143 |
+
elif isinstance(action, ReadAction):
|
| 144 |
+
_apply_read(action, node_by_id, state, entity_names, commands)
|
| 145 |
+
elif isinstance(action, UseAction):
|
| 146 |
+
_apply_use(action, node_by_id, state, entity_names, commands, use_effects)
|
| 147 |
+
elif isinstance(action, CombineAction):
|
| 148 |
+
_apply_combine(action, state, entity_names, commands, recipes)
|
| 149 |
+
elif isinstance(action, GiveAction):
|
| 150 |
+
_apply_give(action, node_by_id, state, entity_names, commands)
|
| 151 |
+
elif isinstance(action, TalkAction):
|
| 152 |
+
_apply_talk(action, node_by_id, state, entity_names, commands)
|
| 153 |
+
elif isinstance(action, SubmitAction):
|
| 154 |
+
_apply_submit(action, world, node_by_id, state, commands, clue_ids)
|
| 155 |
+
else: # pragma: no cover
|
| 156 |
+
raise AssertionError(f"Unhandled quest action {action!r}")
|
| 157 |
+
|
| 158 |
+
return commands
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def _apply_go(
|
| 162 |
+
action: GoAction,
|
| 163 |
+
edge_by_target: dict[tuple[str, str], object],
|
| 164 |
+
state: SimulationState,
|
| 165 |
+
commands: list[str],
|
| 166 |
+
) -> None:
|
| 167 |
+
edge = edge_by_target.get((state.current_room_id, action.target_node_id))
|
| 168 |
+
if edge is None:
|
| 169 |
+
raise DMCompileError(
|
| 170 |
+
f"Quest moves from '{state.current_room_id}' to non-adjacent room '{action.target_node_id}'."
|
| 171 |
+
)
|
| 172 |
+
if edge.door_node_id and edge.door_node_id not in state.open_nodes:
|
| 173 |
+
raise DMCompileError(f"Quest moves through closed door '{edge.door_node_id}'.")
|
| 174 |
+
if edge.type == "locked_passage" and edge.door_node_id in state.locked_nodes:
|
| 175 |
+
raise DMCompileError(f"Quest moves through locked door '{edge.door_node_id}'.")
|
| 176 |
+
state.current_room_id = edge.to_node_id
|
| 177 |
+
state.visited_nodes.add(edge.to_node_id)
|
| 178 |
+
commands.append(f"go {edge.direction}")
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def _apply_open(
|
| 182 |
+
action: OpenAction,
|
| 183 |
+
node_by_id: dict[str, object],
|
| 184 |
+
door_rooms: dict[str, frozenset[str]],
|
| 185 |
+
state: SimulationState,
|
| 186 |
+
entity_names: dict[str, str],
|
| 187 |
+
commands: list[str],
|
| 188 |
+
) -> None:
|
| 189 |
+
node = node_by_id.get(action.target_node_id)
|
| 190 |
+
if node is None or node.type not in {"container", "door"}:
|
| 191 |
+
raise DMCompileError(f"open(...) targets unknown lockable '{action.target_node_id}'.")
|
| 192 |
+
if node.id in state.locked_nodes:
|
| 193 |
+
raise DMCompileError(f"Quest opens locked '{node.id}' before unlocking it.")
|
| 194 |
+
if node.type == "door":
|
| 195 |
+
if state.current_room_id not in door_rooms.get(node.id, frozenset()):
|
| 196 |
+
raise DMCompileError(f"Door '{node.id}' is not reachable from room '{state.current_room_id}'.")
|
| 197 |
+
else:
|
| 198 |
+
_require_parent_room(node.parent_id, node.id, state.current_room_id)
|
| 199 |
+
state.open_nodes.add(node.id)
|
| 200 |
+
state.visited_nodes.add(node.id)
|
| 201 |
+
commands.append(f"open {entity_names[node.id]}")
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def _apply_unlock(
|
| 205 |
+
action: UnlockAction,
|
| 206 |
+
node_by_id: dict[str, object],
|
| 207 |
+
item_by_id: dict[str, Item],
|
| 208 |
+
door_rooms: dict[str, frozenset[str]],
|
| 209 |
+
state: SimulationState,
|
| 210 |
+
entity_names: dict[str, str],
|
| 211 |
+
commands: list[str],
|
| 212 |
+
) -> None:
|
| 213 |
+
if action.key_id not in item_by_id:
|
| 214 |
+
raise DMCompileError(f"Quest references unknown key '{action.key_id}'.")
|
| 215 |
+
if action.key_id not in state.inventory:
|
| 216 |
+
raise DMCompileError(f"Quest unlocks '{action.door_id}' without key '{action.key_id}'.")
|
| 217 |
+
node = node_by_id.get(action.door_id)
|
| 218 |
+
if node is None or node.type not in {"door", "container"}:
|
| 219 |
+
raise DMCompileError(f"unlock(...) targets unknown lockable '{action.door_id}'.")
|
| 220 |
+
if node.lock_key_id != action.key_id:
|
| 221 |
+
raise DMCompileError(f"'{node.id}' does not match key '{action.key_id}'.")
|
| 222 |
+
if node.type == "door":
|
| 223 |
+
if state.current_room_id not in door_rooms.get(node.id, frozenset()):
|
| 224 |
+
raise DMCompileError(f"Door '{node.id}' is not reachable from room '{state.current_room_id}'.")
|
| 225 |
+
else:
|
| 226 |
+
_require_parent_room(node.parent_id, node.id, state.current_room_id)
|
| 227 |
+
state.locked_nodes.discard(node.id)
|
| 228 |
+
state.visited_nodes.add(node.id)
|
| 229 |
+
commands.append(f"unlock {entity_names[node.id]} with {entity_names[action.key_id]}")
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def _apply_take(
|
| 233 |
+
action: TakeAction,
|
| 234 |
+
node_by_id: dict[str, object],
|
| 235 |
+
item_by_id: dict[str, Item],
|
| 236 |
+
state: SimulationState,
|
| 237 |
+
entity_names: dict[str, str],
|
| 238 |
+
commands: list[str],
|
| 239 |
+
) -> None:
|
| 240 |
+
item = item_by_id.get(action.item_id)
|
| 241 |
+
if item is None:
|
| 242 |
+
raise DMCompileError(f"Quest references unknown item '{action.item_id}'.")
|
| 243 |
+
actual_location = state.item_locations.get(item.id)
|
| 244 |
+
if actual_location != action.source_node_id:
|
| 245 |
+
raise DMCompileError(
|
| 246 |
+
f"Quest expects item '{item.id}' in '{action.source_node_id}', but it is in '{actual_location}'."
|
| 247 |
+
)
|
| 248 |
+
if action.source_node_id == state.current_room_id:
|
| 249 |
+
command = f"take {entity_names[item.id]}"
|
| 250 |
+
else:
|
| 251 |
+
source = node_by_id.get(action.source_node_id)
|
| 252 |
+
if source is None or not isinstance(source, ContainerNode):
|
| 253 |
+
raise DMCompileError(f"Quest cannot take '{item.id}' from '{action.source_node_id}'.")
|
| 254 |
+
_require_parent_room(source.parent_id, source.id, state.current_room_id)
|
| 255 |
+
if source.id not in state.open_nodes:
|
| 256 |
+
raise DMCompileError(f"Quest takes from closed container '{source.id}'.")
|
| 257 |
+
command = f"take {entity_names[item.id]} from {entity_names[source.id]}"
|
| 258 |
+
state.inventory.add(item.id)
|
| 259 |
+
state.item_locations[item.id] = INVENTORY_ID
|
| 260 |
+
state.visited_nodes.add(item.id)
|
| 261 |
+
commands.append(command)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def _apply_read(
|
| 265 |
+
action: ReadAction,
|
| 266 |
+
node_by_id: dict[str, object],
|
| 267 |
+
state: SimulationState,
|
| 268 |
+
entity_names: dict[str, str],
|
| 269 |
+
commands: list[str],
|
| 270 |
+
) -> None:
|
| 271 |
+
node = _typed_node(node_by_id, action.target_node_id, ReadableNode, "read")
|
| 272 |
+
_require_parent_room(node.parent_id, node.id, state.current_room_id)
|
| 273 |
+
if node.id not in state.revealed_readables:
|
| 274 |
+
raise DMCompileError(f"Readable '{node.id}' has not been revealed yet.")
|
| 275 |
+
if node.requires_item_id and node.id not in state.prepared_readables:
|
| 276 |
+
raise DMCompileError(f"Readable '{node.id}' still requires item '{node.requires_item_id}'.")
|
| 277 |
+
state.discovered_clues.add(node.clue_id)
|
| 278 |
+
state.visited_nodes.add(node.id)
|
| 279 |
+
commands.append(f"read {entity_names[node.id]}")
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def _apply_use(
|
| 283 |
+
action: UseAction,
|
| 284 |
+
node_by_id: dict[str, object],
|
| 285 |
+
state: SimulationState,
|
| 286 |
+
entity_names: dict[str, str],
|
| 287 |
+
commands: list[str],
|
| 288 |
+
use_effects: dict[str, object],
|
| 289 |
+
) -> None:
|
| 290 |
+
effect = use_effects.get(action.target_node_id)
|
| 291 |
+
if effect is None:
|
| 292 |
+
raise DMCompileError(f"use(...) targets unknown use-effect node '{action.target_node_id}'.")
|
| 293 |
+
if effect.required_item_id != action.item_id:
|
| 294 |
+
raise DMCompileError(f"'{action.target_node_id}' does not accept item '{action.item_id}'.")
|
| 295 |
+
if action.item_id not in state.inventory:
|
| 296 |
+
raise DMCompileError(f"Quest uses item '{action.item_id}' before taking it.")
|
| 297 |
+
node = node_by_id.get(action.target_node_id)
|
| 298 |
+
if node is None or node.type not in {"readable", "fixture"}:
|
| 299 |
+
raise DMCompileError(f"use(...) targets unsupported node '{action.target_node_id}'.")
|
| 300 |
+
_require_parent_room(node.parent_id, node.id, state.current_room_id)
|
| 301 |
+
if isinstance(node, ReadableNode) and node.id not in state.revealed_readables:
|
| 302 |
+
raise DMCompileError(f"Readable '{node.id}' has not been revealed yet.")
|
| 303 |
+
|
| 304 |
+
if effect.consumes_item:
|
| 305 |
+
state.inventory.remove(action.item_id)
|
| 306 |
+
state.item_locations[action.item_id] = None
|
| 307 |
+
if effect.clue_id:
|
| 308 |
+
state.prepared_readables.add(node.id)
|
| 309 |
+
state.discovered_clues.add(effect.clue_id)
|
| 310 |
+
if effect.reveals_item_id:
|
| 311 |
+
state.item_locations[effect.reveals_item_id] = state.current_room_id
|
| 312 |
+
if effect.reveals_readable_id:
|
| 313 |
+
state.revealed_readables.add(effect.reveals_readable_id)
|
| 314 |
+
if isinstance(node, FixtureNode):
|
| 315 |
+
state.used_fixtures.add(node.id)
|
| 316 |
+
state.visited_nodes.add(node.id)
|
| 317 |
+
commands.append(f"use {entity_names[action.item_id]} on {entity_names[node.id]}")
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
def _apply_combine(
|
| 321 |
+
action: CombineAction,
|
| 322 |
+
state: SimulationState,
|
| 323 |
+
entity_names: dict[str, str],
|
| 324 |
+
commands: list[str],
|
| 325 |
+
recipes: dict[frozenset[str], str],
|
| 326 |
+
) -> None:
|
| 327 |
+
recipe_key = frozenset({action.item_a_id, action.item_b_id})
|
| 328 |
+
output_item_id = recipes.get(recipe_key)
|
| 329 |
+
if output_item_id is None:
|
| 330 |
+
raise DMCompileError(f"No recipe combines '{action.item_a_id}' with '{action.item_b_id}'.")
|
| 331 |
+
if action.item_a_id not in state.inventory or action.item_b_id not in state.inventory:
|
| 332 |
+
raise DMCompileError("Quest combines items before both are in inventory.")
|
| 333 |
+
state.inventory.remove(action.item_a_id)
|
| 334 |
+
state.inventory.remove(action.item_b_id)
|
| 335 |
+
state.item_locations[action.item_a_id] = None
|
| 336 |
+
state.item_locations[action.item_b_id] = None
|
| 337 |
+
state.inventory.add(output_item_id)
|
| 338 |
+
state.item_locations[output_item_id] = INVENTORY_ID
|
| 339 |
+
state.produced_items.add(output_item_id)
|
| 340 |
+
state.visited_nodes.add(output_item_id)
|
| 341 |
+
commands.append(f"combine {entity_names[action.item_a_id]} with {entity_names[action.item_b_id]}")
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def _apply_give(
|
| 345 |
+
action: GiveAction,
|
| 346 |
+
node_by_id: dict[str, object],
|
| 347 |
+
state: SimulationState,
|
| 348 |
+
entity_names: dict[str, str],
|
| 349 |
+
commands: list[str],
|
| 350 |
+
) -> None:
|
| 351 |
+
npc = _typed_node(node_by_id, action.npc_id, NpcNode, "give")
|
| 352 |
+
_require_parent_room(npc.parent_id, npc.id, state.current_room_id)
|
| 353 |
+
if action.item_id not in state.inventory:
|
| 354 |
+
raise DMCompileError(f"Quest gives '{action.item_id}' before taking it.")
|
| 355 |
+
if npc.requires_item_id != action.item_id:
|
| 356 |
+
raise DMCompileError(f"NPC '{npc.id}' does not want '{action.item_id}'.")
|
| 357 |
+
if npc.id in state.satisfied_npcs:
|
| 358 |
+
raise DMCompileError(f"Quest trades with NPC '{npc.id}' more than once.")
|
| 359 |
+
state.inventory.remove(action.item_id)
|
| 360 |
+
state.item_locations[action.item_id] = None
|
| 361 |
+
if npc.gives_item_id:
|
| 362 |
+
state.inventory.add(npc.gives_item_id)
|
| 363 |
+
state.item_locations[npc.gives_item_id] = INVENTORY_ID
|
| 364 |
+
state.produced_items.add(npc.gives_item_id)
|
| 365 |
+
if npc.gives_clue_id:
|
| 366 |
+
state.discovered_clues.add(npc.gives_clue_id)
|
| 367 |
+
state.satisfied_npcs.add(npc.id)
|
| 368 |
+
state.visited_nodes.add(npc.id)
|
| 369 |
+
commands.append(f"give {entity_names[action.item_id]} to {entity_names[npc.id]}")
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def _apply_talk(
|
| 373 |
+
action: TalkAction,
|
| 374 |
+
node_by_id: dict[str, object],
|
| 375 |
+
state: SimulationState,
|
| 376 |
+
entity_names: dict[str, str],
|
| 377 |
+
commands: list[str],
|
| 378 |
+
) -> None:
|
| 379 |
+
npc = _typed_node(node_by_id, action.target_node_id, NpcNode, "talk")
|
| 380 |
+
_require_parent_room(npc.parent_id, npc.id, state.current_room_id)
|
| 381 |
+
state.consulted_npcs.add(npc.id)
|
| 382 |
+
state.visited_nodes.add(npc.id)
|
| 383 |
+
commands.append(f"talk {entity_names[npc.id]}")
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def _apply_submit(
|
| 387 |
+
action: SubmitAction,
|
| 388 |
+
world: WorldDefinition,
|
| 389 |
+
node_by_id: dict[str, object],
|
| 390 |
+
state: SimulationState,
|
| 391 |
+
commands: list[str],
|
| 392 |
+
clue_ids: set[str],
|
| 393 |
+
) -> None:
|
| 394 |
+
guardian_id = world.meta.win_condition.target_npc_id
|
| 395 |
+
guardian = _typed_node(node_by_id, guardian_id, NpcNode, "submit")
|
| 396 |
+
_require_parent_room(guardian.parent_id, guardian.id, state.current_room_id)
|
| 397 |
+
if guardian.id not in state.consulted_npcs:
|
| 398 |
+
raise DMCompileError("Quest submits before talking to the guardian.")
|
| 399 |
+
if state.discovered_clues != clue_ids:
|
| 400 |
+
missing = sorted(clue_ids - state.discovered_clues)
|
| 401 |
+
raise DMCompileError(f"Quest submits before all clues are discovered: {missing}")
|
| 402 |
+
if normalize_answer_text(action.answer_text) != normalize_answer_text(world.meta.win_condition.answer_string):
|
| 403 |
+
raise DMCompileError("The final submit step must match win_condition.answer_string.")
|
| 404 |
+
commands.append("submit " + normalize_answer_text(action.answer_text))
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def _typed_node(node_by_id: dict[str, object], node_id: str, expected: type, label: str):
|
| 408 |
+
node = node_by_id.get(node_id)
|
| 409 |
+
if node is None or not isinstance(node, expected):
|
| 410 |
+
raise DMCompileError(f"{label}(...) targets unknown {expected.__name__.lower()} '{node_id}'.")
|
| 411 |
+
return node
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
def _require_parent_room(parent_id: str, node_id: str, current_room_id: str) -> None:
|
| 415 |
+
if parent_id != current_room_id:
|
| 416 |
+
raise DMCompileError(
|
| 417 |
+
f"Quest interacts with '{node_id}' from room '{current_room_id}', but it lives in '{parent_id}'."
|
| 418 |
+
)
|
agents/master/sample.py
ADDED
|
@@ -0,0 +1,499 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import random
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass(frozen=True)
|
| 11 |
+
class WorldTheme:
|
| 12 |
+
title: str
|
| 13 |
+
answer: str
|
| 14 |
+
foyer_label: str
|
| 15 |
+
foyer_description: str
|
| 16 |
+
shrine_label: str
|
| 17 |
+
shrine_description: str
|
| 18 |
+
workshop_label: str
|
| 19 |
+
workshop_description: str
|
| 20 |
+
courtyard_label: str
|
| 21 |
+
courtyard_description: str
|
| 22 |
+
gallery_label: str
|
| 23 |
+
gallery_description: str
|
| 24 |
+
entry_chest_label: str
|
| 25 |
+
entry_chest_description: str
|
| 26 |
+
iron_door_label: str
|
| 27 |
+
iron_door_description: str
|
| 28 |
+
ash_mural_label: str
|
| 29 |
+
ash_mural_description: str
|
| 30 |
+
ash_mural_text: str
|
| 31 |
+
iron_chest_label: str
|
| 32 |
+
iron_chest_description: str
|
| 33 |
+
stone_well_label: str
|
| 34 |
+
stone_well_description: str
|
| 35 |
+
water_plaque_label: str
|
| 36 |
+
water_plaque_description: str
|
| 37 |
+
water_plaque_text: str
|
| 38 |
+
cartographer_label: str
|
| 39 |
+
cartographer_description: str
|
| 40 |
+
faded_letter_label: str
|
| 41 |
+
faded_letter_description: str
|
| 42 |
+
faded_letter_text: str
|
| 43 |
+
stone_guardian_label: str
|
| 44 |
+
stone_guardian_description: str
|
| 45 |
+
brass_key_label: str
|
| 46 |
+
brass_key_description: str
|
| 47 |
+
torch_label: str
|
| 48 |
+
torch_description: str
|
| 49 |
+
torn_map_left_label: str
|
| 50 |
+
torn_map_left_description: str
|
| 51 |
+
torn_map_right_label: str
|
| 52 |
+
torn_map_right_description: str
|
| 53 |
+
full_map_label: str
|
| 54 |
+
full_map_description: str
|
| 55 |
+
lens_label: str
|
| 56 |
+
lens_description: str
|
| 57 |
+
initial_clue_text: str
|
| 58 |
+
river_clue_text: str
|
| 59 |
+
waterwarden_clue_text: str
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
_WORLD_THEMES: tuple[WorldTheme, ...] = (
|
| 63 |
+
WorldTheme(
|
| 64 |
+
title="The River Ward",
|
| 65 |
+
answer="mira",
|
| 66 |
+
foyer_label="Foyer",
|
| 67 |
+
foyer_description="A drafty entry hall with passages north, south, east, and west.",
|
| 68 |
+
shrine_label="Shrine",
|
| 69 |
+
shrine_description="An open shrine watched by a silent stone guardian.",
|
| 70 |
+
workshop_label="Workshop",
|
| 71 |
+
workshop_description="An ash-streaked workshop lit by a guttering lamp.",
|
| 72 |
+
courtyard_label="Courtyard",
|
| 73 |
+
courtyard_description="Rainwater gathers around a cracked stone well.",
|
| 74 |
+
gallery_label="Gallery",
|
| 75 |
+
gallery_description="Portraits of the wardens hang above a long dust-covered table.",
|
| 76 |
+
entry_chest_label="Entry Chest",
|
| 77 |
+
entry_chest_description="A squat travel chest sits beside the door.",
|
| 78 |
+
iron_door_label="Iron Door",
|
| 79 |
+
iron_door_description="A blackened iron door seals the workshop.",
|
| 80 |
+
ash_mural_label="Ash Mural",
|
| 81 |
+
ash_mural_description="An ash-dark mural is impossible to make out with the naked eye.",
|
| 82 |
+
ash_mural_text="The mural preserves one line: the betrayer's name begins with M.",
|
| 83 |
+
iron_chest_label="Iron Chest",
|
| 84 |
+
iron_chest_description="A soot-stained iron chest is tucked under a bench.",
|
| 85 |
+
stone_well_label="Stone Well",
|
| 86 |
+
stone_well_description="Etchings circle the well's rim, but they only align from the proper vantage.",
|
| 87 |
+
water_plaque_label="Water Plaque",
|
| 88 |
+
water_plaque_description="A bronze plaque slides out from the well masonry.",
|
| 89 |
+
water_plaque_text="The betrayer lived closest to the river gate.",
|
| 90 |
+
cartographer_label="Cartographer",
|
| 91 |
+
cartographer_description="The cartographer studies the walls and waits for a completed survey.",
|
| 92 |
+
faded_letter_label="Faded Letter",
|
| 93 |
+
faded_letter_description="A faded letter is still too blurred to decipher.",
|
| 94 |
+
faded_letter_text="Of the wardens, only Mira kept quarters beside the water.",
|
| 95 |
+
stone_guardian_label="Stone Guardian",
|
| 96 |
+
stone_guardian_description="The guardian asks for the betrayer's name once you are ready.",
|
| 97 |
+
brass_key_label="Brass Key",
|
| 98 |
+
brass_key_description="A brass key with soot in its teeth.",
|
| 99 |
+
torch_label="Torch",
|
| 100 |
+
torch_description="A pitch torch with a steady flame.",
|
| 101 |
+
torn_map_left_label="Torn Map Left",
|
| 102 |
+
torn_map_left_description="The left half of a survey map.",
|
| 103 |
+
torn_map_right_label="Torn Map Right",
|
| 104 |
+
torn_map_right_description="The right half of a survey map.",
|
| 105 |
+
full_map_label="Full Map",
|
| 106 |
+
full_map_description="A restored map of the ward.",
|
| 107 |
+
lens_label="Lens",
|
| 108 |
+
lens_description="A polished lens in a brass frame.",
|
| 109 |
+
initial_clue_text="The betrayer's name begins with M.",
|
| 110 |
+
river_clue_text="The betrayer lived closest to the river gate.",
|
| 111 |
+
waterwarden_clue_text="Of the wardens, only Mira kept quarters beside the water.",
|
| 112 |
+
),
|
| 113 |
+
WorldTheme(
|
| 114 |
+
title="The Ember Vault",
|
| 115 |
+
answer="vesna",
|
| 116 |
+
foyer_label="Receiving Hall",
|
| 117 |
+
foyer_description="A warm stone hall lined with soot and copper hooks.",
|
| 118 |
+
shrine_label="Crucible Shrine",
|
| 119 |
+
shrine_description="A brass sentinel stands before a furnace-bright altar.",
|
| 120 |
+
workshop_label="Forge Annex",
|
| 121 |
+
workshop_description="Bellows creak above benches powdered with black ash.",
|
| 122 |
+
courtyard_label="Quench Yard",
|
| 123 |
+
courtyard_description="A cracked basin gathers rain beside the old quench line.",
|
| 124 |
+
gallery_label="Ledger Hall",
|
| 125 |
+
gallery_description="Burned account books rest beneath portraits of furnace wardens.",
|
| 126 |
+
entry_chest_label="Courier Trunk",
|
| 127 |
+
entry_chest_description="A courier trunk waits under a soot-marked peg rail.",
|
| 128 |
+
iron_door_label="Furnace Door",
|
| 129 |
+
iron_door_description="A scorched iron door blocks the annex.",
|
| 130 |
+
ash_mural_label="Cinder Frieze",
|
| 131 |
+
ash_mural_description="A smoke-dark frieze only sharpens under moving flame.",
|
| 132 |
+
ash_mural_text="A surviving line says the betrayer's name begins with V.",
|
| 133 |
+
iron_chest_label="Coal Locker",
|
| 134 |
+
iron_chest_description="A riveted locker is wedged beneath a slagged bench.",
|
| 135 |
+
stone_well_label="Quench Basin",
|
| 136 |
+
stone_well_description="Marks on the basin align only when seen with the full survey.",
|
| 137 |
+
water_plaque_label="Cooling Plaque",
|
| 138 |
+
water_plaque_description="A brass plate rises from a seam in the basin stone.",
|
| 139 |
+
water_plaque_text="The betrayer worked closest to the quench trench.",
|
| 140 |
+
cartographer_label="Quartermaster",
|
| 141 |
+
cartographer_description="The quartermaster trades only for a complete furnace survey.",
|
| 142 |
+
faded_letter_label="Scorched Ledger",
|
| 143 |
+
faded_letter_description="Heat has blurred the ink into copper-colored streaks.",
|
| 144 |
+
faded_letter_text="Only Vesna kept the cooling ledgers beside the trench.",
|
| 145 |
+
stone_guardian_label="Brass Sentinel",
|
| 146 |
+
stone_guardian_description="The sentinel requests the betrayer's name when the case is ready.",
|
| 147 |
+
brass_key_label="Copper Key",
|
| 148 |
+
brass_key_description="A copper key with furnace grit packed in the cuts.",
|
| 149 |
+
torch_label="Coal Torch",
|
| 150 |
+
torch_description="A coal torch that burns with a steady orange core.",
|
| 151 |
+
torn_map_left_label="Smelter Map Left",
|
| 152 |
+
torn_map_left_description="The left half of a furnace survey.",
|
| 153 |
+
torn_map_right_label="Smelter Map Right",
|
| 154 |
+
torn_map_right_description="The right half of a furnace survey.",
|
| 155 |
+
full_map_label="Furnace Survey",
|
| 156 |
+
full_map_description="A restored survey of the ember vault.",
|
| 157 |
+
lens_label="Gauge Lens",
|
| 158 |
+
lens_description="A thick gauge lens set in a brass ring.",
|
| 159 |
+
initial_clue_text="The betrayer's name begins with V.",
|
| 160 |
+
river_clue_text="The betrayer worked closest to the quench trench.",
|
| 161 |
+
waterwarden_clue_text="Only Vesna kept the cooling ledgers beside the trench.",
|
| 162 |
+
),
|
| 163 |
+
WorldTheme(
|
| 164 |
+
title="The Astral Archive",
|
| 165 |
+
answer="selene",
|
| 166 |
+
foyer_label="Entry Rotunda",
|
| 167 |
+
foyer_description="A quiet rotunda opens toward stacked corridors and a dim observatory stair.",
|
| 168 |
+
shrine_label="Moon Chapel",
|
| 169 |
+
shrine_description="A silver warden stands beneath a ceiling of cold stars.",
|
| 170 |
+
workshop_label="Chart Room",
|
| 171 |
+
workshop_description="Tables of brass instruments glint in powdery moon dust.",
|
| 172 |
+
courtyard_label="Star Court",
|
| 173 |
+
courtyard_description="A dry fountain mirrors the constellations in chipped stone.",
|
| 174 |
+
gallery_label="Catalog Hall",
|
| 175 |
+
gallery_description="Glass cases hold the names of long-dead archivists.",
|
| 176 |
+
entry_chest_label="Porter's Case",
|
| 177 |
+
entry_chest_description="A leather case rests under the chart hooks.",
|
| 178 |
+
iron_door_label="Star Door",
|
| 179 |
+
iron_door_description="A ribbed iron door seals the chart room.",
|
| 180 |
+
ash_mural_label="Night Chart",
|
| 181 |
+
ash_mural_description="The chart is unreadable until lit from the proper angle.",
|
| 182 |
+
ash_mural_text="One surviving note says the betrayer's name begins with S.",
|
| 183 |
+
iron_chest_label="Index Chest",
|
| 184 |
+
iron_chest_description="A narrow chest sits below a shelf of cracked lenses.",
|
| 185 |
+
stone_well_label="Dry Fountain",
|
| 186 |
+
stone_well_description="Its star marks align only when the full survey is restored.",
|
| 187 |
+
water_plaque_label="Star Plaque",
|
| 188 |
+
water_plaque_description="A silver plaque slides free from the fountain rim.",
|
| 189 |
+
water_plaque_text="The betrayer slept nearest the eastern telescope.",
|
| 190 |
+
cartographer_label="Archivist",
|
| 191 |
+
cartographer_description="The archivist will trade for a complete celestial survey.",
|
| 192 |
+
faded_letter_label="Blurred Index",
|
| 193 |
+
faded_letter_description="The index script is too faint without magnification.",
|
| 194 |
+
faded_letter_text="Among the archivists, only Selene kept quarters by the east telescope.",
|
| 195 |
+
stone_guardian_label="Silver Warden",
|
| 196 |
+
stone_guardian_description="The warden will hear the accusation once you have evidence.",
|
| 197 |
+
brass_key_label="Star Key",
|
| 198 |
+
brass_key_description="A slim key engraved with a crescent notch.",
|
| 199 |
+
torch_label="Lamp Wand",
|
| 200 |
+
torch_description="A narrow lamp wand with a clean blue flame.",
|
| 201 |
+
torn_map_left_label="Celestial Map Left",
|
| 202 |
+
torn_map_left_description="The left half of a star survey.",
|
| 203 |
+
torn_map_right_label="Celestial Map Right",
|
| 204 |
+
torn_map_right_description="The right half of a star survey.",
|
| 205 |
+
full_map_label="Celestial Survey",
|
| 206 |
+
full_map_description="A restored survey of the astral archive.",
|
| 207 |
+
lens_label="Astrolabe Lens",
|
| 208 |
+
lens_description="A polished lens mounted in silver wire.",
|
| 209 |
+
initial_clue_text="The betrayer's name begins with S.",
|
| 210 |
+
river_clue_text="The betrayer slept nearest the eastern telescope.",
|
| 211 |
+
waterwarden_clue_text="Among the archivists, only Selene kept quarters by the east telescope.",
|
| 212 |
+
),
|
| 213 |
+
WorldTheme(
|
| 214 |
+
title="The Glass Conservatory",
|
| 215 |
+
answer="liora",
|
| 216 |
+
foyer_label="Gate House",
|
| 217 |
+
foyer_description="A humid gate house opens onto vine-choked passages.",
|
| 218 |
+
shrine_label="Bloom Shrine",
|
| 219 |
+
shrine_description="A mossy guardian waits among chipped planters.",
|
| 220 |
+
workshop_label="Potting Room",
|
| 221 |
+
workshop_description="Clay dust and root knives cover the worktables.",
|
| 222 |
+
courtyard_label="Glass Court",
|
| 223 |
+
courtyard_description="A cracked basin sits beneath panes webbed with ivy.",
|
| 224 |
+
gallery_label="Seed Gallery",
|
| 225 |
+
gallery_description="Pressed flowers hang beside records of vanished caretakers.",
|
| 226 |
+
entry_chest_label="Garden Chest",
|
| 227 |
+
entry_chest_description="A cedar chest is tucked beside the rain cloaks.",
|
| 228 |
+
iron_door_label="Greenhouse Door",
|
| 229 |
+
iron_door_description="A warped iron door blocks the potting room.",
|
| 230 |
+
ash_mural_label="Vine Panel",
|
| 231 |
+
ash_mural_description="The panel's scratches only read clearly under a steady flame.",
|
| 232 |
+
ash_mural_text="A scratched line says the betrayer's name begins with L.",
|
| 233 |
+
iron_chest_label="Tool Locker",
|
| 234 |
+
iron_chest_description="A damp locker crouches under a potting bench.",
|
| 235 |
+
stone_well_label="Ivy Basin",
|
| 236 |
+
stone_well_description="The etched rings align only when the full garden survey is in hand.",
|
| 237 |
+
water_plaque_label="Root Plaque",
|
| 238 |
+
water_plaque_description="A greened plaque slides from the basin wall.",
|
| 239 |
+
water_plaque_text="The betrayer tended the beds nearest the rain cistern.",
|
| 240 |
+
cartographer_label="Head Gardener",
|
| 241 |
+
cartographer_description="The gardener will barter only for a complete bed map.",
|
| 242 |
+
faded_letter_label="Watered Note",
|
| 243 |
+
faded_letter_description="The note is blurred by old rain and fertilizer.",
|
| 244 |
+
faded_letter_text="Only Liora kept the cistern ledgers beside the rain beds.",
|
| 245 |
+
stone_guardian_label="Moss Guardian",
|
| 246 |
+
stone_guardian_description="The guardian listens when you are ready to name the betrayer.",
|
| 247 |
+
brass_key_label="Trellis Key",
|
| 248 |
+
brass_key_description="A greened key shaped like a curling vine.",
|
| 249 |
+
torch_label="Glass Lantern",
|
| 250 |
+
torch_description="A glass-sided lantern with a bright white flame.",
|
| 251 |
+
torn_map_left_label="Bed Map Left",
|
| 252 |
+
torn_map_left_description="The left half of a conservatory plan.",
|
| 253 |
+
torn_map_right_label="Bed Map Right",
|
| 254 |
+
torn_map_right_description="The right half of a conservatory plan.",
|
| 255 |
+
full_map_label="Bed Survey",
|
| 256 |
+
full_map_description="A restored survey of the conservatory beds.",
|
| 257 |
+
lens_label="Prism Lens",
|
| 258 |
+
lens_description="A prism lens wrapped in tarnished copper.",
|
| 259 |
+
initial_clue_text="The betrayer's name begins with L.",
|
| 260 |
+
river_clue_text="The betrayer tended the beds nearest the rain cistern.",
|
| 261 |
+
waterwarden_clue_text="Only Liora kept the cistern ledgers beside the rain beds.",
|
| 262 |
+
),
|
| 263 |
+
WorldTheme(
|
| 264 |
+
title="The Salt Bastion",
|
| 265 |
+
answer="corin",
|
| 266 |
+
foyer_label="Watch Hall",
|
| 267 |
+
foyer_description="A salt-stung hall opens toward barracks, chapel, and the sea court.",
|
| 268 |
+
shrine_label="Tide Chapel",
|
| 269 |
+
shrine_description="A stone warden keeps watch over a shrine of ropes and shells.",
|
| 270 |
+
workshop_label="Signal Room",
|
| 271 |
+
workshop_description="Lantern hooks sway above benches dusted with salt ash.",
|
| 272 |
+
courtyard_label="Sea Court",
|
| 273 |
+
courtyard_description="A dry cistern sits beneath walls pitted by ocean wind.",
|
| 274 |
+
gallery_label="Roll Hall",
|
| 275 |
+
gallery_description="Roster boards hang beneath portraits of old coast captains.",
|
| 276 |
+
entry_chest_label="Harbor Chest",
|
| 277 |
+
entry_chest_description="A travel chest sits beside a rack of oilskins.",
|
| 278 |
+
iron_door_label="Beacon Door",
|
| 279 |
+
iron_door_description="A rusted iron door bars the signal room.",
|
| 280 |
+
ash_mural_label="Signal Board",
|
| 281 |
+
ash_mural_description="Salt haze hides the markings until a lamp is raised close.",
|
| 282 |
+
ash_mural_text="A surviving mark says the betrayer's name begins with C.",
|
| 283 |
+
iron_chest_label="Tar Locker",
|
| 284 |
+
iron_chest_description="A tar-black locker hides below a signal bench.",
|
| 285 |
+
stone_well_label="Dry Cistern",
|
| 286 |
+
stone_well_description="Its carved rings make sense only with the restored coast survey.",
|
| 287 |
+
water_plaque_label="Harbor Plaque",
|
| 288 |
+
water_plaque_description="A plaque rises from a crack in the cistern lip.",
|
| 289 |
+
water_plaque_text="The betrayer bunked nearest the harbor chain.",
|
| 290 |
+
cartographer_label="Harbor Clerk",
|
| 291 |
+
cartographer_description="The clerk trades only for a complete bastion survey.",
|
| 292 |
+
faded_letter_label="Salted Roll",
|
| 293 |
+
faded_letter_description="Salt has crusted over the roster names.",
|
| 294 |
+
faded_letter_text="Only Corin kept the harbor ledgers beside the chain gate.",
|
| 295 |
+
stone_guardian_label="Stone Warden",
|
| 296 |
+
stone_guardian_description="The warden asks for the betrayer's name when the proof is ready.",
|
| 297 |
+
brass_key_label="Anchor Key",
|
| 298 |
+
brass_key_description="A heavy key stamped with a worn anchor.",
|
| 299 |
+
torch_label="Signal Lamp",
|
| 300 |
+
torch_description="A shuttered lamp with a disciplined yellow flame.",
|
| 301 |
+
torn_map_left_label="Coast Map Left",
|
| 302 |
+
torn_map_left_description="The left half of a bastion survey.",
|
| 303 |
+
torn_map_right_label="Coast Map Right",
|
| 304 |
+
torn_map_right_description="The right half of a bastion survey.",
|
| 305 |
+
full_map_label="Coast Survey",
|
| 306 |
+
full_map_description="A restored survey of the salt bastion.",
|
| 307 |
+
lens_label="Captain's Lens",
|
| 308 |
+
lens_description="A salt-clear lens held in a bronze ring.",
|
| 309 |
+
initial_clue_text="The betrayer's name begins with C.",
|
| 310 |
+
river_clue_text="The betrayer bunked nearest the harbor chain.",
|
| 311 |
+
waterwarden_clue_text="Only Corin kept the harbor ledgers beside the chain gate.",
|
| 312 |
+
),
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def sample_world_definition(seed: int | None = None, difficulty_target: float = 1.5) -> dict[str, Any]:
|
| 317 |
+
theme = _select_theme(seed)
|
| 318 |
+
return _build_world(theme, difficulty_target=difficulty_target)
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def load_world(path: str) -> dict[str, Any]:
|
| 322 |
+
return json.loads(Path(path).read_text(encoding="utf-8"))
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def _select_theme(seed: int | None) -> WorldTheme:
|
| 326 |
+
if seed is None:
|
| 327 |
+
return _WORLD_THEMES[0]
|
| 328 |
+
rng = random.Random(seed)
|
| 329 |
+
return _WORLD_THEMES[rng.randrange(len(_WORLD_THEMES))]
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def _build_world(theme: WorldTheme, *, difficulty_target: float) -> dict[str, Any]:
|
| 333 |
+
return {
|
| 334 |
+
"meta": {
|
| 335 |
+
"title": theme.title,
|
| 336 |
+
"difficulty_target": difficulty_target,
|
| 337 |
+
"start_node_id": "foyer",
|
| 338 |
+
"win_condition": {
|
| 339 |
+
"type": "deduce",
|
| 340 |
+
"target_npc_id": "stone_guardian",
|
| 341 |
+
"answer_string": theme.answer,
|
| 342 |
+
},
|
| 343 |
+
},
|
| 344 |
+
"nodes": [
|
| 345 |
+
{"id": "foyer", "type": "location", "label": theme.foyer_label, "description": theme.foyer_description},
|
| 346 |
+
{"id": "shrine", "type": "location", "label": theme.shrine_label, "description": theme.shrine_description},
|
| 347 |
+
{"id": "workshop", "type": "location", "label": theme.workshop_label, "description": theme.workshop_description},
|
| 348 |
+
{"id": "courtyard", "type": "location", "label": theme.courtyard_label, "description": theme.courtyard_description},
|
| 349 |
+
{"id": "gallery", "type": "location", "label": theme.gallery_label, "description": theme.gallery_description},
|
| 350 |
+
{
|
| 351 |
+
"id": "entry_chest",
|
| 352 |
+
"type": "container",
|
| 353 |
+
"label": theme.entry_chest_label,
|
| 354 |
+
"description": theme.entry_chest_description,
|
| 355 |
+
"parent_id": "foyer",
|
| 356 |
+
"open": False,
|
| 357 |
+
"locked": False,
|
| 358 |
+
"lock_key_id": None,
|
| 359 |
+
},
|
| 360 |
+
{
|
| 361 |
+
"id": "iron_door",
|
| 362 |
+
"type": "door",
|
| 363 |
+
"label": theme.iron_door_label,
|
| 364 |
+
"description": theme.iron_door_description,
|
| 365 |
+
"open": False,
|
| 366 |
+
"locked": True,
|
| 367 |
+
"lock_key_id": "brass_key",
|
| 368 |
+
},
|
| 369 |
+
{
|
| 370 |
+
"id": "ash_mural",
|
| 371 |
+
"type": "readable",
|
| 372 |
+
"label": theme.ash_mural_label,
|
| 373 |
+
"description": theme.ash_mural_description,
|
| 374 |
+
"parent_id": "workshop",
|
| 375 |
+
"clue_id": "initial_clue",
|
| 376 |
+
"requires_item_id": "torch",
|
| 377 |
+
"consumes_item": False,
|
| 378 |
+
"text_content": theme.ash_mural_text,
|
| 379 |
+
},
|
| 380 |
+
{
|
| 381 |
+
"id": "iron_chest",
|
| 382 |
+
"type": "container",
|
| 383 |
+
"label": theme.iron_chest_label,
|
| 384 |
+
"description": theme.iron_chest_description,
|
| 385 |
+
"parent_id": "workshop",
|
| 386 |
+
"open": False,
|
| 387 |
+
"locked": False,
|
| 388 |
+
"lock_key_id": None,
|
| 389 |
+
},
|
| 390 |
+
{
|
| 391 |
+
"id": "stone_well",
|
| 392 |
+
"type": "fixture",
|
| 393 |
+
"label": theme.stone_well_label,
|
| 394 |
+
"description": theme.stone_well_description,
|
| 395 |
+
"parent_id": "courtyard",
|
| 396 |
+
"requires_item_id": "full_map",
|
| 397 |
+
"reveals_item_id": None,
|
| 398 |
+
"reveals_readable_id": "water_plaque",
|
| 399 |
+
"consumes_item": False,
|
| 400 |
+
},
|
| 401 |
+
{
|
| 402 |
+
"id": "water_plaque",
|
| 403 |
+
"type": "readable",
|
| 404 |
+
"label": theme.water_plaque_label,
|
| 405 |
+
"description": theme.water_plaque_description,
|
| 406 |
+
"parent_id": "courtyard",
|
| 407 |
+
"clue_id": "river_clue",
|
| 408 |
+
"requires_item_id": None,
|
| 409 |
+
"consumes_item": False,
|
| 410 |
+
"text_content": theme.water_plaque_text,
|
| 411 |
+
},
|
| 412 |
+
{
|
| 413 |
+
"id": "cartographer",
|
| 414 |
+
"type": "npc",
|
| 415 |
+
"label": theme.cartographer_label,
|
| 416 |
+
"description": theme.cartographer_description,
|
| 417 |
+
"parent_id": "gallery",
|
| 418 |
+
"requires_item_id": "full_map",
|
| 419 |
+
"gives_item_id": "lens",
|
| 420 |
+
"gives_clue_id": None,
|
| 421 |
+
},
|
| 422 |
+
{
|
| 423 |
+
"id": "faded_letter",
|
| 424 |
+
"type": "readable",
|
| 425 |
+
"label": theme.faded_letter_label,
|
| 426 |
+
"description": theme.faded_letter_description,
|
| 427 |
+
"parent_id": "gallery",
|
| 428 |
+
"clue_id": "waterwarden_clue",
|
| 429 |
+
"requires_item_id": "lens",
|
| 430 |
+
"consumes_item": False,
|
| 431 |
+
"text_content": theme.faded_letter_text,
|
| 432 |
+
},
|
| 433 |
+
{
|
| 434 |
+
"id": "stone_guardian",
|
| 435 |
+
"type": "npc",
|
| 436 |
+
"label": theme.stone_guardian_label,
|
| 437 |
+
"description": theme.stone_guardian_description,
|
| 438 |
+
"parent_id": "shrine",
|
| 439 |
+
"requires_item_id": None,
|
| 440 |
+
"gives_item_id": None,
|
| 441 |
+
"gives_clue_id": None,
|
| 442 |
+
},
|
| 443 |
+
],
|
| 444 |
+
"edges": [
|
| 445 |
+
{"id": "foyer_north", "from_node_id": "foyer", "to_node_id": "shrine", "direction": "north", "type": "passage", "required_item_id": None, "door_node_id": None},
|
| 446 |
+
{"id": "shrine_south", "from_node_id": "shrine", "to_node_id": "foyer", "direction": "south", "type": "passage", "required_item_id": None, "door_node_id": None},
|
| 447 |
+
{"id": "foyer_east", "from_node_id": "foyer", "to_node_id": "workshop", "direction": "east", "type": "locked_passage", "required_item_id": "brass_key", "door_node_id": "iron_door"},
|
| 448 |
+
{"id": "workshop_west", "from_node_id": "workshop", "to_node_id": "foyer", "direction": "west", "type": "locked_passage", "required_item_id": "brass_key", "door_node_id": "iron_door"},
|
| 449 |
+
{"id": "foyer_west", "from_node_id": "foyer", "to_node_id": "courtyard", "direction": "west", "type": "passage", "required_item_id": None, "door_node_id": None},
|
| 450 |
+
{"id": "courtyard_east", "from_node_id": "courtyard", "to_node_id": "foyer", "direction": "east", "type": "passage", "required_item_id": None, "door_node_id": None},
|
| 451 |
+
{"id": "foyer_south", "from_node_id": "foyer", "to_node_id": "gallery", "direction": "south", "type": "passage", "required_item_id": None, "door_node_id": None},
|
| 452 |
+
{"id": "gallery_north", "from_node_id": "gallery", "to_node_id": "foyer", "direction": "north", "type": "passage", "required_item_id": None, "door_node_id": None},
|
| 453 |
+
],
|
| 454 |
+
"items": [
|
| 455 |
+
{"id": "brass_key", "label": theme.brass_key_label, "description": theme.brass_key_description, "subtype": "key", "start_node_id": "entry_chest"},
|
| 456 |
+
{"id": "torch", "label": theme.torch_label, "description": theme.torch_description, "subtype": "puzzle", "start_node_id": "workshop"},
|
| 457 |
+
{"id": "torn_map_left", "label": theme.torn_map_left_label, "description": theme.torn_map_left_description, "subtype": "puzzle", "start_node_id": "iron_chest"},
|
| 458 |
+
{"id": "torn_map_right", "label": theme.torn_map_right_label, "description": theme.torn_map_right_description, "subtype": "puzzle", "start_node_id": "courtyard"},
|
| 459 |
+
{"id": "full_map", "label": theme.full_map_label, "description": theme.full_map_description, "subtype": "puzzle", "start_node_id": None},
|
| 460 |
+
{"id": "lens", "label": theme.lens_label, "description": theme.lens_description, "subtype": "puzzle", "start_node_id": None},
|
| 461 |
+
],
|
| 462 |
+
"clues": [
|
| 463 |
+
{"id": "initial_clue", "text": theme.initial_clue_text},
|
| 464 |
+
{"id": "river_clue", "text": theme.river_clue_text},
|
| 465 |
+
{"id": "waterwarden_clue", "text": theme.waterwarden_clue_text},
|
| 466 |
+
],
|
| 467 |
+
"recipes": [
|
| 468 |
+
{
|
| 469 |
+
"id": "restore_map",
|
| 470 |
+
"input_item_ids": ["torn_map_left", "torn_map_right"],
|
| 471 |
+
"output_item_id": "full_map",
|
| 472 |
+
}
|
| 473 |
+
],
|
| 474 |
+
"quest_chain": [
|
| 475 |
+
{"step_id": "open_entry_chest", "description": f"Open the {theme.entry_chest_label.lower()}.", "requires_step_ids": [], "action": "open(entry_chest)"},
|
| 476 |
+
{"step_id": "take_brass_key", "description": f"Take the {theme.brass_key_label.lower()}.", "requires_step_ids": ["open_entry_chest"], "action": "take(brass_key,entry_chest)"},
|
| 477 |
+
{"step_id": "unlock_workshop", "description": f"Unlock the {theme.iron_door_label.lower()}.", "requires_step_ids": ["take_brass_key"], "action": "unlock(iron_door,brass_key)"},
|
| 478 |
+
{"step_id": "open_workshop", "description": f"Open the {theme.iron_door_label.lower()}.", "requires_step_ids": ["unlock_workshop"], "action": "open(iron_door)"},
|
| 479 |
+
{"step_id": "go_workshop", "description": f"Enter the {theme.workshop_label.lower()}.", "requires_step_ids": ["open_workshop"], "action": "go(workshop)"},
|
| 480 |
+
{"step_id": "take_torch", "description": f"Take the {theme.torch_label.lower()}.", "requires_step_ids": ["go_workshop"], "action": "take(torch,workshop)"},
|
| 481 |
+
{"step_id": "use_torch_on_mural", "description": f"Use the {theme.torch_label.lower()} on the {theme.ash_mural_label.lower()}.", "requires_step_ids": ["take_torch"], "action": "use(torch,ash_mural)"},
|
| 482 |
+
{"step_id": "open_iron_chest", "description": f"Open the {theme.iron_chest_label.lower()}.", "requires_step_ids": ["go_workshop"], "action": "open(iron_chest)"},
|
| 483 |
+
{"step_id": "take_left_map", "description": f"Take the {theme.torn_map_left_label.lower()}.", "requires_step_ids": ["open_iron_chest"], "action": "take(torn_map_left,iron_chest)"},
|
| 484 |
+
{"step_id": "return_foyer", "description": f"Return to the {theme.foyer_label.lower()}.", "requires_step_ids": ["take_left_map"], "action": "go(foyer)"},
|
| 485 |
+
{"step_id": "go_courtyard", "description": f"Head to the {theme.courtyard_label.lower()}.", "requires_step_ids": ["return_foyer"], "action": "go(courtyard)"},
|
| 486 |
+
{"step_id": "take_right_map", "description": f"Take the {theme.torn_map_right_label.lower()}.", "requires_step_ids": ["go_courtyard"], "action": "take(torn_map_right,courtyard)"},
|
| 487 |
+
{"step_id": "combine_map", "description": f"Restore the {theme.full_map_label.lower()}.", "requires_step_ids": ["take_right_map"], "action": "combine(torn_map_left,torn_map_right)"},
|
| 488 |
+
{"step_id": "use_map_on_well", "description": f"Use the {theme.full_map_label.lower()} on the {theme.stone_well_label.lower()}.", "requires_step_ids": ["combine_map"], "action": "use(full_map,stone_well)"},
|
| 489 |
+
{"step_id": "read_plaque", "description": f"Read the {theme.water_plaque_label.lower()}.", "requires_step_ids": ["use_map_on_well"], "action": "read(water_plaque)"},
|
| 490 |
+
{"step_id": "go_foyer_again", "description": f"Go back to the {theme.foyer_label.lower()}.", "requires_step_ids": ["read_plaque"], "action": "go(foyer)"},
|
| 491 |
+
{"step_id": "go_gallery", "description": f"Head to the {theme.gallery_label.lower()}.", "requires_step_ids": ["go_foyer_again"], "action": "go(gallery)"},
|
| 492 |
+
{"step_id": "give_map", "description": f"Give the map to the {theme.cartographer_label.lower()}.", "requires_step_ids": ["go_gallery"], "action": "give(full_map,cartographer)"},
|
| 493 |
+
{"step_id": "use_lens_on_letter", "description": f"Use the {theme.lens_label.lower()} on the {theme.faded_letter_label.lower()}.", "requires_step_ids": ["give_map"], "action": "use(lens,faded_letter)"},
|
| 494 |
+
{"step_id": "return_foyer_final", "description": f"Return to the {theme.foyer_label.lower()} again.", "requires_step_ids": ["use_lens_on_letter"], "action": "go(foyer)"},
|
| 495 |
+
{"step_id": "go_shrine", "description": f"Go to the {theme.shrine_label.lower()}.", "requires_step_ids": ["return_foyer_final"], "action": "go(shrine)"},
|
| 496 |
+
{"step_id": "talk_guardian", "description": f"Speak to the {theme.stone_guardian_label.lower()}.", "requires_step_ids": ["go_shrine"], "action": "talk(stone_guardian)"},
|
| 497 |
+
{"step_id": "submit_answer", "description": "Submit the betrayer's name.", "requires_step_ids": ["talk_guardian"], "action": f'submit("{theme.answer}")'},
|
| 498 |
+
],
|
| 499 |
+
}
|
agents/master/schema.py
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Annotated, Literal, TypeAlias
|
| 6 |
+
|
| 7 |
+
from pydantic import BaseModel, ConfigDict, Field
|
| 8 |
+
|
| 9 |
+
from agents.shared.openenv_compat import Action, Observation, State
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class StrictModel(BaseModel):
|
| 13 |
+
model_config = ConfigDict(extra="forbid")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class WorldMeta(StrictModel):
|
| 17 |
+
title: str
|
| 18 |
+
difficulty_target: float
|
| 19 |
+
start_node_id: str
|
| 20 |
+
win_condition: "WinCondition"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class WinCondition(StrictModel):
|
| 24 |
+
type: Literal["deduce"]
|
| 25 |
+
target_npc_id: str
|
| 26 |
+
answer_string: str
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class BaseNode(StrictModel):
|
| 30 |
+
id: str
|
| 31 |
+
label: str
|
| 32 |
+
description: str
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class LocationNode(BaseNode):
|
| 36 |
+
type: Literal["location"]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class JunctionNode(BaseNode):
|
| 40 |
+
type: Literal["junction"]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class ContainerNode(BaseNode):
|
| 44 |
+
type: Literal["container"]
|
| 45 |
+
parent_id: str
|
| 46 |
+
open: bool = False
|
| 47 |
+
locked: bool = False
|
| 48 |
+
lock_key_id: str | None = None
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class DoorNode(BaseNode):
|
| 52 |
+
type: Literal["door"]
|
| 53 |
+
open: bool = False
|
| 54 |
+
locked: bool = False
|
| 55 |
+
lock_key_id: str | None = None
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class ReadableNode(BaseNode):
|
| 59 |
+
type: Literal["readable"]
|
| 60 |
+
parent_id: str
|
| 61 |
+
clue_id: str
|
| 62 |
+
requires_item_id: str | None = None
|
| 63 |
+
consumes_item: bool = False
|
| 64 |
+
text_content: str
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class FixtureNode(BaseNode):
|
| 68 |
+
type: Literal["fixture"]
|
| 69 |
+
parent_id: str
|
| 70 |
+
requires_item_id: str
|
| 71 |
+
reveals_item_id: str | None = None
|
| 72 |
+
reveals_readable_id: str | None = None
|
| 73 |
+
consumes_item: bool = False
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class NpcNode(BaseNode):
|
| 77 |
+
type: Literal["npc"]
|
| 78 |
+
parent_id: str
|
| 79 |
+
requires_item_id: str | None = None
|
| 80 |
+
gives_item_id: str | None = None
|
| 81 |
+
gives_clue_id: str | None = None
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
WorldNode: TypeAlias = Annotated[
|
| 85 |
+
LocationNode | JunctionNode | ContainerNode | DoorNode | ReadableNode | FixtureNode | NpcNode,
|
| 86 |
+
Field(discriminator="type"),
|
| 87 |
+
]
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class Edge(StrictModel):
|
| 91 |
+
id: str
|
| 92 |
+
from_node_id: str
|
| 93 |
+
to_node_id: str
|
| 94 |
+
direction: Literal["north", "south", "east", "west", "up", "down", "in", "out"]
|
| 95 |
+
type: Literal["passage", "locked_passage"]
|
| 96 |
+
required_item_id: str | None = None
|
| 97 |
+
door_node_id: str | None = None
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class Item(StrictModel):
|
| 101 |
+
id: str
|
| 102 |
+
label: str
|
| 103 |
+
description: str
|
| 104 |
+
subtype: Literal["key", "puzzle"]
|
| 105 |
+
start_node_id: str | None = None
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class Clue(StrictModel):
|
| 109 |
+
id: str
|
| 110 |
+
text: str
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class Recipe(StrictModel):
|
| 114 |
+
id: str
|
| 115 |
+
input_item_ids: list[str] = Field(min_length=2, max_length=2)
|
| 116 |
+
output_item_id: str
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class QuestStep(StrictModel):
|
| 120 |
+
step_id: str
|
| 121 |
+
description: str
|
| 122 |
+
requires_step_ids: list[str] = Field(default_factory=list)
|
| 123 |
+
action: str
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class WorldDefinition(StrictModel):
|
| 127 |
+
meta: WorldMeta
|
| 128 |
+
nodes: list[WorldNode]
|
| 129 |
+
edges: list[Edge]
|
| 130 |
+
items: list[Item]
|
| 131 |
+
clues: list[Clue]
|
| 132 |
+
recipes: list[Recipe] = Field(default_factory=list)
|
| 133 |
+
quest_chain: list[QuestStep]
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class DMAction(Action):
|
| 137 |
+
world_definition: WorldDefinition
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class Turn(StrictModel):
|
| 141 |
+
step: int
|
| 142 |
+
player_action: str
|
| 143 |
+
textworld_command: str
|
| 144 |
+
observation: str
|
| 145 |
+
game_state_delta: dict[str, object]
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class DMFeedback(StrictModel):
|
| 149 |
+
unreachable_nodes: list[str]
|
| 150 |
+
unused_items: list[str]
|
| 151 |
+
clues_missed: list[str]
|
| 152 |
+
mean_steps_per_room: float
|
| 153 |
+
invalid_command_count: int = 0
|
| 154 |
+
wrong_submit_count: int = 0
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class DMRewardBreakdown(StrictModel):
|
| 158 |
+
reward_mode: Literal["gaussian_target_ratio", "compile_failure_penalty"] = "gaussian_target_ratio"
|
| 159 |
+
player_won: bool
|
| 160 |
+
raw_ratio: float | None = None
|
| 161 |
+
clamped_ratio: float | None = None
|
| 162 |
+
target_ratio: float
|
| 163 |
+
target_ratio_delta: float | None = None
|
| 164 |
+
efficiency_score: float | None = None
|
| 165 |
+
quality_score: float = 0.0
|
| 166 |
+
reward: float
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class DMObservation(Observation):
|
| 170 |
+
episode_transcript: list[Turn] = Field(default_factory=list)
|
| 171 |
+
player_won: bool | None = None
|
| 172 |
+
steps_taken: int | None = None
|
| 173 |
+
min_steps: int | None = None
|
| 174 |
+
ratio: float | None = None
|
| 175 |
+
compile_error: str | None = None
|
| 176 |
+
feedback: DMFeedback | None = None
|
| 177 |
+
reward_breakdown: DMRewardBreakdown | None = None
|
| 178 |
+
target_ratio_used: float | None = None
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class DMState(State):
|
| 182 |
+
current_world: WorldDefinition | None = None
|
| 183 |
+
compile_status: Literal["valid", "invalid", "pending"] = "pending"
|
| 184 |
+
episode_status: Literal["running", "complete", "failed"] = "running"
|
| 185 |
+
cumulative_success_rate: float = 0.0
|
| 186 |
+
target_ratio: float = 0.0
|
| 187 |
+
difficulty_hint: float | None = None
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
@dataclass(frozen=True)
|
| 191 |
+
class GoAction:
|
| 192 |
+
target_node_id: str
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
@dataclass(frozen=True)
|
| 196 |
+
class OpenAction:
|
| 197 |
+
target_node_id: str
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
@dataclass(frozen=True)
|
| 201 |
+
class UnlockAction:
|
| 202 |
+
door_id: str
|
| 203 |
+
key_id: str
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
@dataclass(frozen=True)
|
| 207 |
+
class TakeAction:
|
| 208 |
+
item_id: str
|
| 209 |
+
source_node_id: str
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
@dataclass(frozen=True)
|
| 213 |
+
class ReadAction:
|
| 214 |
+
target_node_id: str
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
@dataclass(frozen=True)
|
| 218 |
+
class UseAction:
|
| 219 |
+
item_id: str
|
| 220 |
+
target_node_id: str
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
@dataclass(frozen=True)
|
| 224 |
+
class CombineAction:
|
| 225 |
+
item_a_id: str
|
| 226 |
+
item_b_id: str
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
@dataclass(frozen=True)
|
| 230 |
+
class GiveAction:
|
| 231 |
+
item_id: str
|
| 232 |
+
npc_id: str
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
@dataclass(frozen=True)
|
| 236 |
+
class TalkAction:
|
| 237 |
+
target_node_id: str
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
@dataclass(frozen=True)
|
| 241 |
+
class SubmitAction:
|
| 242 |
+
answer_text: str
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
QuestAction = (
|
| 246 |
+
GoAction
|
| 247 |
+
| OpenAction
|
| 248 |
+
| UnlockAction
|
| 249 |
+
| TakeAction
|
| 250 |
+
| ReadAction
|
| 251 |
+
| UseAction
|
| 252 |
+
| CombineAction
|
| 253 |
+
| GiveAction
|
| 254 |
+
| TalkAction
|
| 255 |
+
| SubmitAction
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
@dataclass(frozen=True)
|
| 260 |
+
class NpcTrade:
|
| 261 |
+
required_item_id: str
|
| 262 |
+
gives_item_id: str | None
|
| 263 |
+
gives_clue_id: str | None
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
@dataclass(frozen=True)
|
| 267 |
+
class UseEffect:
|
| 268 |
+
required_item_id: str
|
| 269 |
+
clue_id: str | None = None
|
| 270 |
+
reveals_item_id: str | None = None
|
| 271 |
+
reveals_readable_id: str | None = None
|
| 272 |
+
consumes_item: bool = False
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
@dataclass
|
| 276 |
+
class CompiledWorld:
|
| 277 |
+
episode_id: str
|
| 278 |
+
world: WorldDefinition
|
| 279 |
+
artifacts_dir: Path
|
| 280 |
+
game_file: Path
|
| 281 |
+
walkthrough_commands: list[str]
|
| 282 |
+
solver_policy: list[str]
|
| 283 |
+
correct_answer_normalized: str
|
| 284 |
+
correct_submit_command: str
|
| 285 |
+
guardian_id: str
|
| 286 |
+
guardian_room_id: str
|
| 287 |
+
room_name_to_id: dict[str, str]
|
| 288 |
+
node_command_names: dict[str, str]
|
| 289 |
+
item_command_names: dict[str, str]
|
| 290 |
+
item_start_locations: dict[str, str | None]
|
| 291 |
+
clue_text_by_id: dict[str, str]
|
| 292 |
+
readable_clue_by_id: dict[str, str]
|
| 293 |
+
npc_trade_map: dict[str, NpcTrade]
|
| 294 |
+
recipe_map: dict[frozenset[str], str]
|
| 295 |
+
use_effects: dict[str, UseEffect]
|
| 296 |
+
produced_item_ids: set[str]
|
| 297 |
+
room_edges_by_target: dict[tuple[str, str], Edge]
|
| 298 |
+
room_edges_by_direction: dict[tuple[str, str], Edge]
|
| 299 |
+
door_rooms: dict[str, frozenset[str]]
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
@dataclass
|
| 303 |
+
class SimulationState:
|
| 304 |
+
current_room_id: str
|
| 305 |
+
inventory: set[str] = field(default_factory=set)
|
| 306 |
+
item_locations: dict[str, str | None] = field(default_factory=dict)
|
| 307 |
+
open_nodes: set[str] = field(default_factory=set)
|
| 308 |
+
locked_nodes: set[str] = field(default_factory=set)
|
| 309 |
+
discovered_clues: set[str] = field(default_factory=set)
|
| 310 |
+
consulted_npcs: set[str] = field(default_factory=set)
|
| 311 |
+
satisfied_npcs: set[str] = field(default_factory=set)
|
| 312 |
+
revealed_readables: set[str] = field(default_factory=set)
|
| 313 |
+
prepared_readables: set[str] = field(default_factory=set)
|
| 314 |
+
used_fixtures: set[str] = field(default_factory=set)
|
| 315 |
+
produced_items: set[str] = field(default_factory=set)
|
| 316 |
+
visited_nodes: set[str] = field(default_factory=set)
|
agents/master/server.py
ADDED
|
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import mimetypes
|
| 5 |
+
import threading
|
| 6 |
+
from http import HTTPStatus
|
| 7 |
+
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any
|
| 10 |
+
from urllib.parse import urlparse
|
| 11 |
+
|
| 12 |
+
from .base import DMCompileError, DMInterfaceError
|
| 13 |
+
from .build import WorldCompiler
|
| 14 |
+
from .interface import GeminiInterfaceAdapter, SimpleInterfaceAdapter
|
| 15 |
+
from .schema import CompiledWorld, WorldDefinition
|
| 16 |
+
from .session import EpisodeSession
|
| 17 |
+
from .snapshots import (
|
| 18 |
+
DEFAULT_LIVE_DIR,
|
| 19 |
+
STATE_FILENAME,
|
| 20 |
+
WORLD_FILENAME,
|
| 21 |
+
LiveCurrentRoom,
|
| 22 |
+
LiveMetrics,
|
| 23 |
+
LiveRuntime,
|
| 24 |
+
LiveStateSnapshot,
|
| 25 |
+
load_live_payload,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
WEB_DIST_DIR = Path(__file__).resolve().parents[2] / "www" / "dist"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class GameSessionManager:
|
| 33 |
+
"""Thread-safe container for an interactive play session."""
|
| 34 |
+
|
| 35 |
+
def __init__(self, live_dir: Path, use_gemini: bool = False) -> None:
|
| 36 |
+
self._lock = threading.Lock()
|
| 37 |
+
self._session: EpisodeSession | None = None
|
| 38 |
+
self._compiled: CompiledWorld | None = None
|
| 39 |
+
self._compiler = WorldCompiler()
|
| 40 |
+
self._live_dir = live_dir
|
| 41 |
+
self._use_gemini = use_gemini
|
| 42 |
+
self._clear_stale_files()
|
| 43 |
+
|
| 44 |
+
def _clear_stale_files(self) -> None:
|
| 45 |
+
"""Remove leftover state/world JSON from a previous session."""
|
| 46 |
+
for fname in (STATE_FILENAME, WORLD_FILENAME):
|
| 47 |
+
path = self._live_dir / fname
|
| 48 |
+
path.unlink(missing_ok=True)
|
| 49 |
+
|
| 50 |
+
def start(self, world_input: WorldDefinition | dict[str, Any]) -> dict[str, Any]:
|
| 51 |
+
with self._lock:
|
| 52 |
+
if self._session is not None:
|
| 53 |
+
self._session.close()
|
| 54 |
+
compiled = self._compiler.compile(world_input)
|
| 55 |
+
adapter = self._make_adapter()
|
| 56 |
+
session = EpisodeSession(compiled, interface_adapter=adapter)
|
| 57 |
+
self._compiled = compiled
|
| 58 |
+
self._session = session
|
| 59 |
+
self._write_world(compiled.world)
|
| 60 |
+
self._write_state("running")
|
| 61 |
+
return {
|
| 62 |
+
"ok": True,
|
| 63 |
+
"episode_id": compiled.episode_id,
|
| 64 |
+
"observation": session.current_feedback(),
|
| 65 |
+
"available_commands": session.available_commands(),
|
| 66 |
+
"room": self._room_info(session),
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
def reset(self) -> dict[str, Any]:
|
| 70 |
+
with self._lock:
|
| 71 |
+
if self._session is not None:
|
| 72 |
+
self._session.close()
|
| 73 |
+
self._session = None
|
| 74 |
+
self._compiled = None
|
| 75 |
+
self._clear_stale_files()
|
| 76 |
+
return {"ok": True}
|
| 77 |
+
|
| 78 |
+
def command(self, raw_command: str) -> dict[str, Any]:
|
| 79 |
+
with self._lock:
|
| 80 |
+
session = self._session
|
| 81 |
+
if session is None:
|
| 82 |
+
return {"ok": False, "error": "No active session. POST /api/start first."}
|
| 83 |
+
if session.done:
|
| 84 |
+
return {
|
| 85 |
+
"ok": False,
|
| 86 |
+
"error": "Episode is complete.",
|
| 87 |
+
"done": True,
|
| 88 |
+
"player_won": session.player_won,
|
| 89 |
+
}
|
| 90 |
+
try:
|
| 91 |
+
turn = session.step(raw_command)
|
| 92 |
+
except (DMInterfaceError, RuntimeError) as exc:
|
| 93 |
+
return {"ok": False, "error": str(exc)}
|
| 94 |
+
|
| 95 |
+
status = "complete" if session.done and session.player_won else (
|
| 96 |
+
"failed" if session.done else "running"
|
| 97 |
+
)
|
| 98 |
+
self._write_state(status)
|
| 99 |
+
return {
|
| 100 |
+
"ok": True,
|
| 101 |
+
"step": turn.step,
|
| 102 |
+
"command": turn.textworld_command,
|
| 103 |
+
"observation": turn.observation,
|
| 104 |
+
"done": session.done,
|
| 105 |
+
"player_won": session.player_won,
|
| 106 |
+
"available_commands": [] if session.done else session.available_commands(),
|
| 107 |
+
"room": self._room_info(session),
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
def get_state_payload(self) -> dict[str, Any] | None:
|
| 111 |
+
with self._lock:
|
| 112 |
+
session = self._session
|
| 113 |
+
compiled = self._compiled
|
| 114 |
+
if session is None or compiled is None:
|
| 115 |
+
return None
|
| 116 |
+
return self._snapshot(session, compiled).model_dump()
|
| 117 |
+
|
| 118 |
+
def _make_adapter(self) -> SimpleInterfaceAdapter | GeminiInterfaceAdapter:
|
| 119 |
+
if self._use_gemini:
|
| 120 |
+
try:
|
| 121 |
+
return GeminiInterfaceAdapter(narrate_observations=True)
|
| 122 |
+
except DMInterfaceError:
|
| 123 |
+
pass
|
| 124 |
+
return SimpleInterfaceAdapter()
|
| 125 |
+
|
| 126 |
+
def _write_world(self, world: WorldDefinition) -> None:
|
| 127 |
+
self._write_json(WORLD_FILENAME, world.model_dump_json(indent=2))
|
| 128 |
+
|
| 129 |
+
def _write_state(self, status: str) -> None:
|
| 130 |
+
session = self._session
|
| 131 |
+
compiled = self._compiled
|
| 132 |
+
if session is None or compiled is None:
|
| 133 |
+
return
|
| 134 |
+
snapshot = self._snapshot(session, compiled, status=status)
|
| 135 |
+
self._write_json(STATE_FILENAME, snapshot.model_dump_json(indent=2))
|
| 136 |
+
|
| 137 |
+
def _snapshot(
|
| 138 |
+
self,
|
| 139 |
+
session: EpisodeSession,
|
| 140 |
+
compiled: CompiledWorld,
|
| 141 |
+
status: str | None = None,
|
| 142 |
+
) -> LiveStateSnapshot:
|
| 143 |
+
from datetime import datetime, timezone
|
| 144 |
+
|
| 145 |
+
room_ids = {
|
| 146 |
+
node.id for node in compiled.world.nodes if node.type in {"location", "junction"}
|
| 147 |
+
}
|
| 148 |
+
commands = [] if session.done else session.available_commands()
|
| 149 |
+
|
| 150 |
+
if status is None:
|
| 151 |
+
if session.done:
|
| 152 |
+
status = "complete" if session.player_won else "failed"
|
| 153 |
+
else:
|
| 154 |
+
status = "running"
|
| 155 |
+
|
| 156 |
+
return LiveStateSnapshot(
|
| 157 |
+
episode_id=compiled.episode_id,
|
| 158 |
+
status=status,
|
| 159 |
+
updated_at=datetime.now(timezone.utc).isoformat(),
|
| 160 |
+
title=compiled.world.meta.title,
|
| 161 |
+
transcript=list(session.transcript),
|
| 162 |
+
metrics=LiveMetrics(
|
| 163 |
+
steps_taken=session.steps_taken,
|
| 164 |
+
min_steps=len(compiled.solver_policy),
|
| 165 |
+
ratio=session.steps_taken / len(compiled.solver_policy) if compiled.solver_policy else None,
|
| 166 |
+
player_won=session.player_won if session.done else None,
|
| 167 |
+
),
|
| 168 |
+
runtime=LiveRuntime(
|
| 169 |
+
current_room_id=session.current_room_id,
|
| 170 |
+
inventory_item_ids=sorted(session.inventory),
|
| 171 |
+
discovered_clue_ids=sorted(session.discovered_clues),
|
| 172 |
+
traded_npc_ids=sorted(session.traded_npcs),
|
| 173 |
+
visited_room_ids=sorted(room_ids & session.visited_nodes),
|
| 174 |
+
available_commands=commands,
|
| 175 |
+
invalid_command_count=session.invalid_command_count,
|
| 176 |
+
wrong_submit_count=session.wrong_submit_count,
|
| 177 |
+
open_node_ids=sorted(session.open_nodes),
|
| 178 |
+
locked_node_ids=sorted(session.locked_nodes),
|
| 179 |
+
),
|
| 180 |
+
current_room=self._current_room_snapshot(session),
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
@staticmethod
|
| 184 |
+
def _current_room_snapshot(session: EpisodeSession) -> LiveCurrentRoom | None:
|
| 185 |
+
node_by_id = {node.id: node for node in session.compiled.world.nodes}
|
| 186 |
+
room = node_by_id.get(session.current_room_id)
|
| 187 |
+
if room is None:
|
| 188 |
+
return None
|
| 189 |
+
visible_nodes = [
|
| 190 |
+
node.id
|
| 191 |
+
for node in session.compiled.world.nodes
|
| 192 |
+
if getattr(node, "parent_id", None) == session.current_room_id
|
| 193 |
+
and (node.type != "readable" or node.id in session.revealed_readables)
|
| 194 |
+
]
|
| 195 |
+
visible_nodes.extend(
|
| 196 |
+
sorted(
|
| 197 |
+
door_id
|
| 198 |
+
for door_id, rooms in session.compiled.door_rooms.items()
|
| 199 |
+
if session.current_room_id in rooms
|
| 200 |
+
)
|
| 201 |
+
)
|
| 202 |
+
visible_items = sorted(
|
| 203 |
+
item_id
|
| 204 |
+
for item_id, location in session.item_locations.items()
|
| 205 |
+
if location == session.current_room_id
|
| 206 |
+
)
|
| 207 |
+
return LiveCurrentRoom(
|
| 208 |
+
id=room.id,
|
| 209 |
+
label=room.label,
|
| 210 |
+
description=room.description,
|
| 211 |
+
visible_node_ids=sorted(set(visible_nodes)),
|
| 212 |
+
visible_item_ids=visible_items,
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
@staticmethod
|
| 216 |
+
def _room_info(session: EpisodeSession) -> dict[str, Any]:
|
| 217 |
+
node_by_id = {node.id: node for node in session.compiled.world.nodes}
|
| 218 |
+
room = node_by_id.get(session.current_room_id)
|
| 219 |
+
return {
|
| 220 |
+
"id": session.current_room_id,
|
| 221 |
+
"label": room.label if room else session.current_room_id,
|
| 222 |
+
"description": room.description if room else "",
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
def _write_json(self, filename: str, payload: str) -> None:
|
| 226 |
+
self._live_dir.mkdir(parents=True, exist_ok=True)
|
| 227 |
+
path = self._live_dir / filename
|
| 228 |
+
tmp_path = path.with_suffix(path.suffix + ".tmp")
|
| 229 |
+
tmp_path.write_text(payload + "\n", encoding="utf-8")
|
| 230 |
+
tmp_path.replace(path)
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def create_server(
|
| 234 |
+
*,
|
| 235 |
+
live_dir: Path | None = None,
|
| 236 |
+
host: str = "127.0.0.1",
|
| 237 |
+
port: int = 8000,
|
| 238 |
+
use_gemini: bool = False,
|
| 239 |
+
) -> ThreadingHTTPServer:
|
| 240 |
+
resolved_live_dir = live_dir or DEFAULT_LIVE_DIR
|
| 241 |
+
game = GameSessionManager(resolved_live_dir, use_gemini=use_gemini)
|
| 242 |
+
|
| 243 |
+
class LiveViewerHandler(BaseHTTPRequestHandler):
|
| 244 |
+
server_version = "AgentsMasterLive/1.0"
|
| 245 |
+
|
| 246 |
+
def do_GET(self) -> None: # noqa: N802
|
| 247 |
+
path = urlparse(self.path).path
|
| 248 |
+
if path == "/api/state":
|
| 249 |
+
self._serve_live_file(STATE_FILENAME)
|
| 250 |
+
return
|
| 251 |
+
if path == "/api/world":
|
| 252 |
+
self._serve_live_file(WORLD_FILENAME)
|
| 253 |
+
return
|
| 254 |
+
if path == "/":
|
| 255 |
+
self._serve_index()
|
| 256 |
+
return
|
| 257 |
+
if path == "/favicon.ico":
|
| 258 |
+
self.send_response(HTTPStatus.NO_CONTENT)
|
| 259 |
+
self.end_headers()
|
| 260 |
+
return
|
| 261 |
+
if self._serve_web_file(path):
|
| 262 |
+
return
|
| 263 |
+
if WEB_DIST_DIR.exists() and Path(path).suffix == "":
|
| 264 |
+
self._serve_index()
|
| 265 |
+
return
|
| 266 |
+
self._respond(HTTPStatus.NOT_FOUND, b"Not found\n", "text/plain; charset=utf-8")
|
| 267 |
+
|
| 268 |
+
def do_POST(self) -> None: # noqa: N802
|
| 269 |
+
path = urlparse(self.path).path
|
| 270 |
+
body = self._read_body()
|
| 271 |
+
|
| 272 |
+
if path == "/api/reset":
|
| 273 |
+
result = game.reset()
|
| 274 |
+
self._json_respond(HTTPStatus.OK, result)
|
| 275 |
+
return
|
| 276 |
+
|
| 277 |
+
if path == "/api/start":
|
| 278 |
+
try:
|
| 279 |
+
world_input = json.loads(body) if body else None
|
| 280 |
+
if world_input is None:
|
| 281 |
+
self._json_respond(HTTPStatus.BAD_REQUEST, {"ok": False, "error": "Missing JSON body."})
|
| 282 |
+
return
|
| 283 |
+
result = game.start(world_input)
|
| 284 |
+
self._json_respond(HTTPStatus.OK, result)
|
| 285 |
+
except (DMCompileError, ValueError, json.JSONDecodeError) as exc:
|
| 286 |
+
self._json_respond(HTTPStatus.BAD_REQUEST, {"ok": False, "error": str(exc)})
|
| 287 |
+
return
|
| 288 |
+
|
| 289 |
+
if path == "/api/command":
|
| 290 |
+
try:
|
| 291 |
+
data = json.loads(body) if body else {}
|
| 292 |
+
command = data.get("command", "").strip()
|
| 293 |
+
if not command:
|
| 294 |
+
self._json_respond(HTTPStatus.BAD_REQUEST, {"ok": False, "error": "Missing 'command' field."})
|
| 295 |
+
return
|
| 296 |
+
result = game.command(command)
|
| 297 |
+
self._json_respond(HTTPStatus.OK, result)
|
| 298 |
+
except json.JSONDecodeError as exc:
|
| 299 |
+
self._json_respond(HTTPStatus.BAD_REQUEST, {"ok": False, "error": str(exc)})
|
| 300 |
+
return
|
| 301 |
+
|
| 302 |
+
self._respond(HTTPStatus.NOT_FOUND, b"Not found\n", "text/plain; charset=utf-8")
|
| 303 |
+
|
| 304 |
+
def log_message(self, format: str, *args: object) -> None: # noqa: A003
|
| 305 |
+
del format, args
|
| 306 |
+
|
| 307 |
+
def _read_body(self) -> bytes:
|
| 308 |
+
length = int(self.headers.get("Content-Length", 0))
|
| 309 |
+
return self.rfile.read(length) if length > 0 else b""
|
| 310 |
+
|
| 311 |
+
def _serve_index(self) -> None:
|
| 312 |
+
index_path = WEB_DIST_DIR / "index.html"
|
| 313 |
+
if index_path.is_file():
|
| 314 |
+
self._respond(HTTPStatus.OK, index_path.read_bytes(), "text/html; charset=utf-8")
|
| 315 |
+
else:
|
| 316 |
+
from .templates import render_index
|
| 317 |
+
self._respond(HTTPStatus.OK, render_index().encode("utf-8"), "text/html; charset=utf-8")
|
| 318 |
+
|
| 319 |
+
def _serve_live_file(self, filename: str) -> None:
|
| 320 |
+
payload = load_live_payload(resolved_live_dir, filename)
|
| 321 |
+
if payload is None:
|
| 322 |
+
self.send_response(HTTPStatus.NO_CONTENT)
|
| 323 |
+
self.send_header("Cache-Control", "no-store")
|
| 324 |
+
self.end_headers()
|
| 325 |
+
return
|
| 326 |
+
self._respond(
|
| 327 |
+
HTTPStatus.OK, payload, "application/json; charset=utf-8",
|
| 328 |
+
extra_headers={"Cache-Control": "no-store"},
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
def _serve_web_file(self, path: str) -> bool:
|
| 332 |
+
candidate = (WEB_DIST_DIR / path.lstrip("/")).resolve()
|
| 333 |
+
try:
|
| 334 |
+
candidate.relative_to(WEB_DIST_DIR.resolve())
|
| 335 |
+
except ValueError:
|
| 336 |
+
return False
|
| 337 |
+
if not candidate.is_file():
|
| 338 |
+
return False
|
| 339 |
+
content_type = mimetypes.guess_type(candidate.name)[0] or "application/octet-stream"
|
| 340 |
+
self._respond(HTTPStatus.OK, candidate.read_bytes(), content_type)
|
| 341 |
+
return True
|
| 342 |
+
|
| 343 |
+
def _json_respond(self, status: HTTPStatus, data: dict[str, Any]) -> None:
|
| 344 |
+
payload = json.dumps(data).encode("utf-8")
|
| 345 |
+
self._respond(status, payload, "application/json; charset=utf-8",
|
| 346 |
+
extra_headers={"Cache-Control": "no-store"})
|
| 347 |
+
|
| 348 |
+
def _respond(
|
| 349 |
+
self, status: HTTPStatus, payload: bytes, content_type: str,
|
| 350 |
+
*, extra_headers: dict[str, str] | None = None,
|
| 351 |
+
) -> None:
|
| 352 |
+
self.send_response(status)
|
| 353 |
+
self.send_header("Content-Type", content_type)
|
| 354 |
+
self.send_header("Content-Length", str(len(payload)))
|
| 355 |
+
if extra_headers:
|
| 356 |
+
for key, value in extra_headers.items():
|
| 357 |
+
self.send_header(key, value)
|
| 358 |
+
self.end_headers()
|
| 359 |
+
self.wfile.write(payload)
|
| 360 |
+
|
| 361 |
+
return ThreadingHTTPServer((host, port), LiveViewerHandler)
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def run_server(*, port: int = 8000, live_dir: Path | None = None, host: str = "127.0.0.1", use_gemini: bool = False) -> None:
|
| 365 |
+
server = create_server(live_dir=live_dir, host=host, port=port, use_gemini=use_gemini)
|
| 366 |
+
print(f"Serving live viewer on http://{host}:{server.server_address[1]}")
|
| 367 |
+
try:
|
| 368 |
+
server.serve_forever()
|
| 369 |
+
finally:
|
| 370 |
+
server.server_close()
|
agents/master/session.py
ADDED
|
@@ -0,0 +1,484 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import textwrap
|
| 5 |
+
from collections import deque
|
| 6 |
+
from typing import TYPE_CHECKING, Any, Callable
|
| 7 |
+
|
| 8 |
+
import textworld
|
| 9 |
+
from textworld.core import EnvInfos, GameState
|
| 10 |
+
|
| 11 |
+
from .base import INVENTORY_ID, normalize_answer_text, suppress_unsupported_game_warning
|
| 12 |
+
from .interface import InterfaceAdapter, SimpleInterfaceAdapter
|
| 13 |
+
from .schema import CompiledWorld, Turn
|
| 14 |
+
|
| 15 |
+
if TYPE_CHECKING:
|
| 16 |
+
TurnListener = Callable[["EpisodeSession", Turn], None]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class EpisodeSession:
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
compiled: CompiledWorld,
|
| 23 |
+
interface_adapter: InterfaceAdapter = SimpleInterfaceAdapter(),
|
| 24 |
+
turn_listener: "TurnListener | None" = None,
|
| 25 |
+
) -> None:
|
| 26 |
+
if interface_adapter is None:
|
| 27 |
+
raise ValueError("interface_adapter must not be None.")
|
| 28 |
+
self.compiled = compiled
|
| 29 |
+
self.interface_adapter = interface_adapter
|
| 30 |
+
self.turn_listener = turn_listener
|
| 31 |
+
with suppress_unsupported_game_warning():
|
| 32 |
+
self.env = textworld.start(str(compiled.game_file), request_infos=self._requested_infos())
|
| 33 |
+
self.state = self.env.reset()
|
| 34 |
+
self._closed = False
|
| 35 |
+
self.done = False
|
| 36 |
+
self.player_won = False
|
| 37 |
+
self.steps_taken = 0
|
| 38 |
+
self.invalid_command_count = 0
|
| 39 |
+
self.wrong_submit_count = 0
|
| 40 |
+
self.used_items: set[str] = set()
|
| 41 |
+
self.discovered_clues: set[str] = set()
|
| 42 |
+
self.consulted_npcs: set[str] = set()
|
| 43 |
+
self.traded_npcs: set[str] = set()
|
| 44 |
+
self.prepared_readables: set[str] = set()
|
| 45 |
+
self.completed_recipe_outputs: set[str] = set()
|
| 46 |
+
self.completed_use_targets: set[str] = set()
|
| 47 |
+
self.unlocked_doors: set[str] = set()
|
| 48 |
+
self.consulted_guardian = False
|
| 49 |
+
self.hidden_readables = {
|
| 50 |
+
effect.reveals_readable_id for effect in compiled.use_effects.values() if effect.reveals_readable_id
|
| 51 |
+
}
|
| 52 |
+
self.revealed_readables = {
|
| 53 |
+
node.id for node in compiled.world.nodes if node.type == "readable" and node.id not in self.hidden_readables
|
| 54 |
+
}
|
| 55 |
+
self.item_locations = dict(compiled.item_start_locations)
|
| 56 |
+
self.inventory = {item_id for item_id, location in self.item_locations.items() if location == INVENTORY_ID}
|
| 57 |
+
self.open_nodes = {
|
| 58 |
+
node.id for node in compiled.world.nodes if node.type in {"container", "door"} and getattr(node, "open", False)
|
| 59 |
+
}
|
| 60 |
+
self.locked_nodes = {
|
| 61 |
+
node.id for node in compiled.world.nodes if node.type in {"container", "door"} and getattr(node, "locked", False)
|
| 62 |
+
}
|
| 63 |
+
self.current_room_id = compiled.world.meta.start_node_id
|
| 64 |
+
self.visited_nodes: set[str] = {self.current_room_id}
|
| 65 |
+
self.transcript: list[Turn] = []
|
| 66 |
+
self.recent_normalized_commands: deque[str] = deque(maxlen=3)
|
| 67 |
+
self._node_by_id = {node.id: node for node in compiled.world.nodes}
|
| 68 |
+
self._label_by_id = {node.id: node.label for node in compiled.world.nodes}
|
| 69 |
+
self._label_by_id.update({item.id: item.label for item in compiled.world.items})
|
| 70 |
+
self._item_name_to_id = {name: item_id for item_id, name in compiled.item_command_names.items()}
|
| 71 |
+
self.last_state_fingerprint = self.state_fingerprint()
|
| 72 |
+
|
| 73 |
+
@staticmethod
|
| 74 |
+
def _requested_infos() -> EnvInfos:
|
| 75 |
+
return EnvInfos(
|
| 76 |
+
feedback=True,
|
| 77 |
+
description=True,
|
| 78 |
+
inventory=True,
|
| 79 |
+
location=True,
|
| 80 |
+
facts=False,
|
| 81 |
+
won=True,
|
| 82 |
+
lost=True,
|
| 83 |
+
score=True,
|
| 84 |
+
moves=True,
|
| 85 |
+
last_action=True,
|
| 86 |
+
last_command=True,
|
| 87 |
+
admissible_commands=True,
|
| 88 |
+
policy_commands=True,
|
| 89 |
+
extras=["walkthrough"],
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
def available_commands(self) -> list[str]:
|
| 93 |
+
commands = set(self.state.admissible_commands or [])
|
| 94 |
+
commands.update(self._custom_commands())
|
| 95 |
+
return sorted(commands)
|
| 96 |
+
|
| 97 |
+
def current_feedback(self) -> str:
|
| 98 |
+
return self.interface_adapter.render_observation(self.state.feedback or "", self.state, self)
|
| 99 |
+
|
| 100 |
+
def state_fingerprint(self) -> str:
|
| 101 |
+
return json.dumps(
|
| 102 |
+
{
|
| 103 |
+
"room": self.current_room_id,
|
| 104 |
+
"inventory": sorted(self.inventory),
|
| 105 |
+
"clues": sorted(self.discovered_clues),
|
| 106 |
+
"opened": sorted(self.open_nodes),
|
| 107 |
+
"traded": sorted(self.traded_npcs),
|
| 108 |
+
"use_targets": sorted(self.completed_use_targets),
|
| 109 |
+
"recipe_outputs": sorted(self.completed_recipe_outputs),
|
| 110 |
+
},
|
| 111 |
+
sort_keys=True,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
def node_id_for_command_name(self, command_name: str, node_types: set[str] | None = None) -> str | None:
|
| 115 |
+
for node in self.compiled.world.nodes:
|
| 116 |
+
safe_name = self.compiled.node_command_names.get(node.id)
|
| 117 |
+
if safe_name != command_name:
|
| 118 |
+
continue
|
| 119 |
+
if node_types is None or node.type in node_types:
|
| 120 |
+
return node.id
|
| 121 |
+
return None
|
| 122 |
+
|
| 123 |
+
def step(self, raw_command: str) -> Turn:
|
| 124 |
+
if self.done:
|
| 125 |
+
raise RuntimeError("Episode is already complete.")
|
| 126 |
+
|
| 127 |
+
lowered = self.interface_adapter.translate_command(raw_command, self).lower().strip()
|
| 128 |
+
if turn := self._handle_submit(raw_command, lowered):
|
| 129 |
+
return turn
|
| 130 |
+
if self._is_wrapper_command(lowered):
|
| 131 |
+
return self._step_wrapper(raw_command, lowered)
|
| 132 |
+
return self._step_env(raw_command, lowered)
|
| 133 |
+
|
| 134 |
+
def _handle_submit(self, raw_command: str, lowered: str) -> Turn | None:
|
| 135 |
+
if not lowered.startswith("submit "):
|
| 136 |
+
return None
|
| 137 |
+
answer = normalize_answer_text(lowered[7:])
|
| 138 |
+
if self.current_room_id != self.compiled.guardian_room_id or self.compiled.guardian_id not in self.consulted_npcs:
|
| 139 |
+
return self._wrapper_only_turn(
|
| 140 |
+
raw_command,
|
| 141 |
+
lowered,
|
| 142 |
+
"The guardian has not asked for your answer yet.",
|
| 143 |
+
{"wrapper": "submit_rejected", "reason": "guardian_not_ready"},
|
| 144 |
+
)
|
| 145 |
+
required_clues = set(self.compiled.clue_text_by_id)
|
| 146 |
+
if self.discovered_clues != required_clues:
|
| 147 |
+
return self._wrapper_only_turn(
|
| 148 |
+
raw_command,
|
| 149 |
+
lowered,
|
| 150 |
+
"The guardian waits. You have not gathered enough evidence yet.",
|
| 151 |
+
{
|
| 152 |
+
"wrapper": "submit_rejected",
|
| 153 |
+
"reason": "missing_clues",
|
| 154 |
+
"missing_clues": sorted(required_clues - self.discovered_clues),
|
| 155 |
+
},
|
| 156 |
+
)
|
| 157 |
+
if answer != self.compiled.correct_answer_normalized:
|
| 158 |
+
self.wrong_submit_count += 1
|
| 159 |
+
return self._wrapper_only_turn(
|
| 160 |
+
raw_command,
|
| 161 |
+
lowered,
|
| 162 |
+
"The guardian shakes their head. That answer is wrong.",
|
| 163 |
+
{"wrapper": "submit_rejected", "reason": "wrong_answer", "submitted": answer},
|
| 164 |
+
)
|
| 165 |
+
self.steps_taken += 1
|
| 166 |
+
self.done = True
|
| 167 |
+
self.player_won = True
|
| 168 |
+
turn = Turn(
|
| 169 |
+
step=self.steps_taken,
|
| 170 |
+
player_action=raw_command,
|
| 171 |
+
textworld_command=self.compiled.correct_submit_command,
|
| 172 |
+
observation="The guardian weighs your answer, then nods.\n\nThe dungeon yields. You solved it.",
|
| 173 |
+
game_state_delta={"wrapper": "submit_forwarded", "won": True, "location": self.current_room_id},
|
| 174 |
+
)
|
| 175 |
+
return self._record_turn(turn)
|
| 176 |
+
|
| 177 |
+
def _step_env(self, raw_command: str, lowered: str) -> Turn:
|
| 178 |
+
previous = self.state
|
| 179 |
+
admissible = set(previous.admissible_commands or [])
|
| 180 |
+
self.state, _, env_done = self.env.step(lowered)
|
| 181 |
+
self.steps_taken += 1
|
| 182 |
+
succeeded = lowered in admissible
|
| 183 |
+
if not succeeded:
|
| 184 |
+
self.invalid_command_count += 1
|
| 185 |
+
else:
|
| 186 |
+
self._apply_env_side_effects(lowered)
|
| 187 |
+
self.done = bool(env_done or self.state.won)
|
| 188 |
+
observation = self.interface_adapter.render_observation(self.state.feedback or "", self.state, self)
|
| 189 |
+
turn = Turn(
|
| 190 |
+
step=self.steps_taken,
|
| 191 |
+
player_action=raw_command,
|
| 192 |
+
textworld_command=lowered,
|
| 193 |
+
observation=observation,
|
| 194 |
+
game_state_delta=self._compute_delta(previous, self.state, succeeded, self.current_room_id),
|
| 195 |
+
)
|
| 196 |
+
return self._record_turn(turn)
|
| 197 |
+
|
| 198 |
+
def _step_wrapper(self, raw_command: str, lowered: str) -> Turn:
|
| 199 |
+
observation, delta = self._apply_wrapper_command(lowered)
|
| 200 |
+
self.steps_taken += 1
|
| 201 |
+
if delta.get("succeeded") is False:
|
| 202 |
+
self.invalid_command_count += 1
|
| 203 |
+
delta.setdefault("location", self.current_room_id)
|
| 204 |
+
rendered = self.interface_adapter.render_observation(observation, self.state, self)
|
| 205 |
+
turn = Turn(
|
| 206 |
+
step=self.steps_taken,
|
| 207 |
+
player_action=raw_command,
|
| 208 |
+
textworld_command=lowered,
|
| 209 |
+
observation=rendered,
|
| 210 |
+
game_state_delta=delta,
|
| 211 |
+
)
|
| 212 |
+
return self._record_turn(turn)
|
| 213 |
+
|
| 214 |
+
def _apply_env_side_effects(self, command: str) -> None:
|
| 215 |
+
if command.startswith("go "):
|
| 216 |
+
direction = command[3:].strip()
|
| 217 |
+
edge = self.compiled.room_edges_by_direction.get((self.current_room_id, direction))
|
| 218 |
+
if edge is not None:
|
| 219 |
+
self.current_room_id = edge.to_node_id
|
| 220 |
+
self.visited_nodes.add(edge.to_node_id)
|
| 221 |
+
return
|
| 222 |
+
if command.startswith("open "):
|
| 223 |
+
node_id = self.node_id_for_command_name(command[5:].strip(), node_types={"container", "door"})
|
| 224 |
+
if node_id:
|
| 225 |
+
self.open_nodes.add(node_id)
|
| 226 |
+
self.visited_nodes.add(node_id)
|
| 227 |
+
return
|
| 228 |
+
if command.startswith("unlock ") and " with " in command:
|
| 229 |
+
target_name, key_name = command[7:].split(" with ", 1)
|
| 230 |
+
target_id = self.node_id_for_command_name(target_name.strip(), node_types={"container", "door"})
|
| 231 |
+
if target_id:
|
| 232 |
+
self.locked_nodes.discard(target_id)
|
| 233 |
+
if self._node_by_id[target_id].type == "door":
|
| 234 |
+
self.unlocked_doors.add(target_id)
|
| 235 |
+
self.visited_nodes.add(target_id)
|
| 236 |
+
self._mark_item_by_name(key_name.strip())
|
| 237 |
+
return
|
| 238 |
+
if command.startswith("take "):
|
| 239 |
+
item_name = command[5:].split(" from ", 1)[0].strip()
|
| 240 |
+
item_id = self._item_name_to_id.get(item_name)
|
| 241 |
+
if item_id:
|
| 242 |
+
self.inventory.add(item_id)
|
| 243 |
+
self.item_locations[item_id] = INVENTORY_ID
|
| 244 |
+
self.used_items.add(item_id)
|
| 245 |
+
self.visited_nodes.add(item_id)
|
| 246 |
+
|
| 247 |
+
def _apply_wrapper_command(self, command: str) -> tuple[str, dict[str, Any]]:
|
| 248 |
+
if command.startswith("read "):
|
| 249 |
+
return self._apply_read(command)
|
| 250 |
+
if command.startswith("talk "):
|
| 251 |
+
return self._apply_talk(command)
|
| 252 |
+
if command.startswith("use ") and " on " in command:
|
| 253 |
+
return self._apply_use(command)
|
| 254 |
+
if command.startswith("combine ") and " with " in command:
|
| 255 |
+
return self._apply_combine(command)
|
| 256 |
+
if command.startswith("give ") and " to " in command:
|
| 257 |
+
return self._apply_give(command)
|
| 258 |
+
raise RuntimeError(f"Unsupported wrapper command '{command}'.")
|
| 259 |
+
|
| 260 |
+
def _apply_read(self, command: str) -> tuple[str, dict[str, Any]]:
|
| 261 |
+
readable_id = self.node_id_for_command_name(command[5:].strip(), node_types={"readable"})
|
| 262 |
+
if not readable_id or readable_id not in self.revealed_readables:
|
| 263 |
+
return self._fail("You can't read that right now.", command)
|
| 264 |
+
node = self._node_by_id[readable_id]
|
| 265 |
+
if node.parent_id != self.current_room_id:
|
| 266 |
+
return self._fail("You are too far away to read that.", command)
|
| 267 |
+
if node.requires_item_id and readable_id not in self.prepared_readables:
|
| 268 |
+
return self._fail("You still need the right tool before the text becomes legible.", command)
|
| 269 |
+
clue_id = self.compiled.readable_clue_by_id[readable_id]
|
| 270 |
+
self.discovered_clues.add(clue_id)
|
| 271 |
+
self.visited_nodes.add(readable_id)
|
| 272 |
+
return self._success(
|
| 273 |
+
textwrap.dedent(
|
| 274 |
+
f"""
|
| 275 |
+
{node.description}
|
| 276 |
+
|
| 277 |
+
"{self.compiled.clue_text_by_id[clue_id]}"
|
| 278 |
+
"""
|
| 279 |
+
).strip(),
|
| 280 |
+
command,
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
def _apply_talk(self, command: str) -> tuple[str, dict[str, Any]]:
|
| 284 |
+
npc_id = self.node_id_for_command_name(command[5:].strip(), node_types={"npc"})
|
| 285 |
+
if not npc_id:
|
| 286 |
+
return self._fail("You can't talk to that right now.", command)
|
| 287 |
+
node = self._node_by_id[npc_id]
|
| 288 |
+
if node.parent_id != self.current_room_id:
|
| 289 |
+
return self._fail("You are too far away to talk to that.", command)
|
| 290 |
+
self.consulted_npcs.add(npc_id)
|
| 291 |
+
if npc_id == self.compiled.guardian_id:
|
| 292 |
+
self.consulted_guardian = True
|
| 293 |
+
self.visited_nodes.add(npc_id)
|
| 294 |
+
return self._success(node.description, command)
|
| 295 |
+
|
| 296 |
+
def _apply_use(self, command: str) -> tuple[str, dict[str, Any]]:
|
| 297 |
+
item_name, target_name = command[4:].split(" on ", 1)
|
| 298 |
+
item_id = self._item_name_to_id.get(item_name.strip())
|
| 299 |
+
target_id = self.node_id_for_command_name(target_name.strip(), node_types={"readable", "fixture"})
|
| 300 |
+
if not item_id or item_id not in self.inventory:
|
| 301 |
+
return self._fail("You don't have the item needed for that.", command)
|
| 302 |
+
if not target_id:
|
| 303 |
+
return self._fail("You can't use that here.", command)
|
| 304 |
+
target = self._node_by_id[target_id]
|
| 305 |
+
if target.parent_id != self.current_room_id:
|
| 306 |
+
return self._fail("That target is not within reach.", command)
|
| 307 |
+
effect = self.compiled.use_effects.get(target_id)
|
| 308 |
+
if effect is None or effect.required_item_id != item_id:
|
| 309 |
+
return self._fail("That item doesn't seem to work there.", command)
|
| 310 |
+
if effect.consumes_item:
|
| 311 |
+
self.inventory.discard(item_id)
|
| 312 |
+
self.item_locations[item_id] = None
|
| 313 |
+
self.used_items.add(item_id)
|
| 314 |
+
self.visited_nodes.add(target_id)
|
| 315 |
+
self.completed_use_targets.add(target_id)
|
| 316 |
+
if effect.clue_id:
|
| 317 |
+
self.prepared_readables.add(target_id)
|
| 318 |
+
self.discovered_clues.add(effect.clue_id)
|
| 319 |
+
return self._success(
|
| 320 |
+
textwrap.dedent(
|
| 321 |
+
f"""
|
| 322 |
+
{target.description}
|
| 323 |
+
|
| 324 |
+
"{self.compiled.clue_text_by_id[effect.clue_id]}"
|
| 325 |
+
"""
|
| 326 |
+
).strip(),
|
| 327 |
+
command,
|
| 328 |
+
)
|
| 329 |
+
if effect.reveals_readable_id:
|
| 330 |
+
self.revealed_readables.add(effect.reveals_readable_id)
|
| 331 |
+
return self._success(f"The {self._label_by_id[effect.reveals_readable_id]} is revealed.", command)
|
| 332 |
+
if effect.reveals_item_id:
|
| 333 |
+
self.item_locations[effect.reveals_item_id] = self.current_room_id
|
| 334 |
+
return self._success(f"The {self._label_by_id[effect.reveals_item_id]} is revealed.", command)
|
| 335 |
+
return self._fail("Nothing happens.", command)
|
| 336 |
+
|
| 337 |
+
def _apply_combine(self, command: str) -> tuple[str, dict[str, Any]]:
|
| 338 |
+
item_a_name, item_b_name = command[8:].split(" with ", 1)
|
| 339 |
+
item_a_id = self._item_name_to_id.get(item_a_name.strip())
|
| 340 |
+
item_b_id = self._item_name_to_id.get(item_b_name.strip())
|
| 341 |
+
if not item_a_id or not item_b_id or item_a_id not in self.inventory or item_b_id not in self.inventory:
|
| 342 |
+
return self._fail("You do not have both pieces required to combine those.", command)
|
| 343 |
+
output_id = self.compiled.recipe_map.get(frozenset({item_a_id, item_b_id}))
|
| 344 |
+
if not output_id:
|
| 345 |
+
return self._fail("Those items do not fit together.", command)
|
| 346 |
+
self.inventory.discard(item_a_id)
|
| 347 |
+
self.inventory.discard(item_b_id)
|
| 348 |
+
self.item_locations[item_a_id] = None
|
| 349 |
+
self.item_locations[item_b_id] = None
|
| 350 |
+
self.inventory.add(output_id)
|
| 351 |
+
self.item_locations[output_id] = INVENTORY_ID
|
| 352 |
+
self.used_items.update({item_a_id, item_b_id, output_id})
|
| 353 |
+
self.completed_recipe_outputs.add(output_id)
|
| 354 |
+
self.visited_nodes.add(output_id)
|
| 355 |
+
return self._success(f"You assemble the {self._label_by_id[output_id]}.", command)
|
| 356 |
+
|
| 357 |
+
def _apply_give(self, command: str) -> tuple[str, dict[str, Any]]:
|
| 358 |
+
item_name, npc_name = command[5:].split(" to ", 1)
|
| 359 |
+
item_id = self._item_name_to_id.get(item_name.strip())
|
| 360 |
+
npc_id = self.node_id_for_command_name(npc_name.strip(), node_types={"npc"})
|
| 361 |
+
if not item_id or item_id not in self.inventory:
|
| 362 |
+
return self._fail("You do not have that item to give.", command)
|
| 363 |
+
if not npc_id:
|
| 364 |
+
return self._fail("There is no one here by that name.", command)
|
| 365 |
+
npc = self._node_by_id[npc_id]
|
| 366 |
+
if npc.parent_id != self.current_room_id:
|
| 367 |
+
return self._fail("That person is not here.", command)
|
| 368 |
+
trade = self.compiled.npc_trade_map.get(npc_id)
|
| 369 |
+
if trade is None or trade.required_item_id != item_id:
|
| 370 |
+
return self._fail("They are not interested in that item.", command)
|
| 371 |
+
if npc_id in self.traded_npcs:
|
| 372 |
+
return self._fail("That trade has already been completed.", command)
|
| 373 |
+
self.inventory.discard(item_id)
|
| 374 |
+
self.item_locations[item_id] = None
|
| 375 |
+
self.used_items.add(item_id)
|
| 376 |
+
self.traded_npcs.add(npc_id)
|
| 377 |
+
if trade.gives_item_id:
|
| 378 |
+
self.inventory.add(trade.gives_item_id)
|
| 379 |
+
self.item_locations[trade.gives_item_id] = INVENTORY_ID
|
| 380 |
+
self.used_items.add(trade.gives_item_id)
|
| 381 |
+
return self._success(f"You receive the {self._label_by_id[trade.gives_item_id]}.", command)
|
| 382 |
+
if trade.gives_clue_id:
|
| 383 |
+
self.discovered_clues.add(trade.gives_clue_id)
|
| 384 |
+
return self._success(f'"{self.compiled.clue_text_by_id[trade.gives_clue_id]}"', command)
|
| 385 |
+
return self._fail("Nothing comes of the trade.", command)
|
| 386 |
+
|
| 387 |
+
def _custom_commands(self) -> set[str]:
|
| 388 |
+
commands: set[str] = set()
|
| 389 |
+
for node in self.compiled.world.nodes:
|
| 390 |
+
if node.type == "npc" and node.parent_id == self.current_room_id:
|
| 391 |
+
commands.add(f"talk {self.compiled.node_command_names[node.id]}")
|
| 392 |
+
trade = self.compiled.npc_trade_map.get(node.id)
|
| 393 |
+
if trade and node.id not in self.traded_npcs and trade.required_item_id in self.inventory:
|
| 394 |
+
commands.add(
|
| 395 |
+
f"give {self.compiled.item_command_names[trade.required_item_id]} to {self.compiled.node_command_names[node.id]}"
|
| 396 |
+
)
|
| 397 |
+
elif node.type == "readable" and node.parent_id == self.current_room_id and node.id in self.revealed_readables:
|
| 398 |
+
if not node.requires_item_id or node.id in self.prepared_readables:
|
| 399 |
+
commands.add(f"read {self.compiled.node_command_names[node.id]}")
|
| 400 |
+
elif node.type == "fixture" and node.parent_id == self.current_room_id:
|
| 401 |
+
effect = self.compiled.use_effects.get(node.id)
|
| 402 |
+
if effect and effect.required_item_id in self.inventory:
|
| 403 |
+
commands.add(
|
| 404 |
+
f"use {self.compiled.item_command_names[effect.required_item_id]} on {self.compiled.node_command_names[node.id]}"
|
| 405 |
+
)
|
| 406 |
+
for readable_id, effect in self.compiled.use_effects.items():
|
| 407 |
+
node = self._node_by_id.get(readable_id)
|
| 408 |
+
if node and node.type == "readable" and node.parent_id == self.current_room_id and effect.required_item_id in self.inventory:
|
| 409 |
+
commands.add(
|
| 410 |
+
f"use {self.compiled.item_command_names[effect.required_item_id]} on {self.compiled.node_command_names[readable_id]}"
|
| 411 |
+
)
|
| 412 |
+
for recipe_inputs, output_id in self.compiled.recipe_map.items():
|
| 413 |
+
del output_id
|
| 414 |
+
item_ids = sorted(recipe_inputs)
|
| 415 |
+
if all(item_id in self.inventory for item_id in item_ids):
|
| 416 |
+
commands.add(
|
| 417 |
+
f"combine {self.compiled.item_command_names[item_ids[0]]} with {self.compiled.item_command_names[item_ids[1]]}"
|
| 418 |
+
)
|
| 419 |
+
commands.add(
|
| 420 |
+
f"combine {self.compiled.item_command_names[item_ids[1]]} with {self.compiled.item_command_names[item_ids[0]]}"
|
| 421 |
+
)
|
| 422 |
+
return commands
|
| 423 |
+
|
| 424 |
+
def _is_wrapper_command(self, command: str) -> bool:
|
| 425 |
+
return any(
|
| 426 |
+
command.startswith(prefix)
|
| 427 |
+
for prefix in ("read ", "talk ", "use ", "combine ", "give ")
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
def _mark_item_by_name(self, name: str) -> None:
|
| 431 |
+
item_id = self._item_name_to_id.get(name)
|
| 432 |
+
if item_id:
|
| 433 |
+
self.used_items.add(item_id)
|
| 434 |
+
|
| 435 |
+
def _success(self, observation: str, command: str) -> tuple[str, dict[str, Any]]:
|
| 436 |
+
return observation, {"wrapper": "custom", "command": command, "succeeded": True, "location": self.current_room_id}
|
| 437 |
+
|
| 438 |
+
def _fail(self, observation: str, command: str) -> tuple[str, dict[str, Any]]:
|
| 439 |
+
return observation, {"wrapper": "custom", "command": command, "succeeded": False, "location": self.current_room_id}
|
| 440 |
+
|
| 441 |
+
@staticmethod
|
| 442 |
+
def _compute_delta(previous: GameState, current: GameState, succeeded: bool, fallback_location: str | None) -> dict[str, Any]:
|
| 443 |
+
return {
|
| 444 |
+
"added_facts": [],
|
| 445 |
+
"removed_facts": [],
|
| 446 |
+
"location": current.location or fallback_location,
|
| 447 |
+
"score": current.score,
|
| 448 |
+
"won": current.won,
|
| 449 |
+
"lost": current.lost,
|
| 450 |
+
"succeeded": succeeded,
|
| 451 |
+
}
|
| 452 |
+
|
| 453 |
+
def _wrapper_only_turn(
|
| 454 |
+
self,
|
| 455 |
+
raw_command: str,
|
| 456 |
+
translated: str,
|
| 457 |
+
observation: str,
|
| 458 |
+
delta: dict[str, Any],
|
| 459 |
+
) -> Turn:
|
| 460 |
+
self.steps_taken += 1
|
| 461 |
+
delta.setdefault("location", self.current_room_id)
|
| 462 |
+
turn = Turn(
|
| 463 |
+
step=self.steps_taken,
|
| 464 |
+
player_action=raw_command,
|
| 465 |
+
textworld_command=translated,
|
| 466 |
+
observation=observation,
|
| 467 |
+
game_state_delta=delta,
|
| 468 |
+
)
|
| 469 |
+
return self._record_turn(turn)
|
| 470 |
+
|
| 471 |
+
def _record_turn(self, turn: Turn) -> Turn:
|
| 472 |
+
self.transcript.append(turn)
|
| 473 |
+
self.last_state_fingerprint = self.state_fingerprint()
|
| 474 |
+
if self.turn_listener is not None:
|
| 475 |
+
self.turn_listener(self, turn)
|
| 476 |
+
return turn
|
| 477 |
+
|
| 478 |
+
def close(self) -> None:
|
| 479 |
+
if self._closed:
|
| 480 |
+
return
|
| 481 |
+
close = getattr(self.env, "close", None)
|
| 482 |
+
if callable(close):
|
| 483 |
+
close()
|
| 484 |
+
self._closed = True
|
agents/master/snapshots.py
ADDED
|
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from datetime import datetime, timezone
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import TYPE_CHECKING, Any, Protocol
|
| 7 |
+
|
| 8 |
+
from pydantic import Field
|
| 9 |
+
|
| 10 |
+
from .base import ARTIFACTS_ROOT
|
| 11 |
+
from .schema import CompiledWorld, DMFeedback, DMObservation, StrictModel, Turn, WorldDefinition
|
| 12 |
+
|
| 13 |
+
if TYPE_CHECKING:
|
| 14 |
+
from .session import EpisodeSession
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
STATE_FILENAME = "state.json"
|
| 18 |
+
WORLD_FILENAME = "world.json"
|
| 19 |
+
LIVE_SCHEMA_VERSION = 1
|
| 20 |
+
DEFAULT_LIVE_DIR = ARTIFACTS_ROOT / "live"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class LiveMetrics(StrictModel):
|
| 24 |
+
steps_taken: int = 0
|
| 25 |
+
min_steps: int | None = None
|
| 26 |
+
ratio: float | None = None
|
| 27 |
+
reward: float | None = None
|
| 28 |
+
player_won: bool | None = None
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class LiveRuntime(StrictModel):
|
| 32 |
+
current_room_id: str | None = None
|
| 33 |
+
inventory_item_ids: list[str] = Field(default_factory=list)
|
| 34 |
+
discovered_clue_ids: list[str] = Field(default_factory=list)
|
| 35 |
+
traded_npc_ids: list[str] = Field(default_factory=list)
|
| 36 |
+
visited_room_ids: list[str] = Field(default_factory=list)
|
| 37 |
+
available_commands: list[str] = Field(default_factory=list)
|
| 38 |
+
invalid_command_count: int = 0
|
| 39 |
+
wrong_submit_count: int = 0
|
| 40 |
+
open_node_ids: list[str] = Field(default_factory=list)
|
| 41 |
+
locked_node_ids: list[str] = Field(default_factory=list)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class LiveCurrentRoom(StrictModel):
|
| 45 |
+
id: str | None = None
|
| 46 |
+
label: str | None = None
|
| 47 |
+
description: str | None = None
|
| 48 |
+
visible_node_ids: list[str] = Field(default_factory=list)
|
| 49 |
+
visible_item_ids: list[str] = Field(default_factory=list)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class LiveStateSnapshot(StrictModel):
|
| 53 |
+
schema_version: int = LIVE_SCHEMA_VERSION
|
| 54 |
+
episode_id: str
|
| 55 |
+
status: str
|
| 56 |
+
updated_at: str
|
| 57 |
+
title: str | None = None
|
| 58 |
+
runner: str | None = None
|
| 59 |
+
error: str | None = None
|
| 60 |
+
transcript: list[Turn] = Field(default_factory=list)
|
| 61 |
+
metrics: LiveMetrics = Field(default_factory=LiveMetrics)
|
| 62 |
+
feedback: DMFeedback | None = None
|
| 63 |
+
runtime: LiveRuntime = Field(default_factory=LiveRuntime)
|
| 64 |
+
current_room: LiveCurrentRoom | None = None
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class LiveObserver(Protocol):
|
| 68 |
+
def on_run_start(self, episode_id: str, world_input: WorldDefinition | dict[str, Any]) -> None:
|
| 69 |
+
...
|
| 70 |
+
|
| 71 |
+
def on_compile_success(self, compiled: CompiledWorld, session: EpisodeSession) -> None:
|
| 72 |
+
...
|
| 73 |
+
|
| 74 |
+
def on_turn(self, session: EpisodeSession, turn: Turn) -> None:
|
| 75 |
+
...
|
| 76 |
+
|
| 77 |
+
def on_complete(self, compiled: CompiledWorld, session: EpisodeSession, observation: DMObservation) -> None:
|
| 78 |
+
...
|
| 79 |
+
|
| 80 |
+
def on_error(
|
| 81 |
+
self,
|
| 82 |
+
*,
|
| 83 |
+
episode_id: str,
|
| 84 |
+
error: str,
|
| 85 |
+
world_input: WorldDefinition | dict[str, Any],
|
| 86 |
+
compiled: CompiledWorld | None = None,
|
| 87 |
+
session: EpisodeSession | None = None,
|
| 88 |
+
) -> None:
|
| 89 |
+
...
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class LiveSnapshotWriter:
|
| 93 |
+
def __init__(self, live_dir: Path | None = None, runner_name: str | None = None) -> None:
|
| 94 |
+
self.live_dir = live_dir or DEFAULT_LIVE_DIR
|
| 95 |
+
self.runner_name = runner_name
|
| 96 |
+
self.live_dir.mkdir(parents=True, exist_ok=True)
|
| 97 |
+
|
| 98 |
+
def on_run_start(self, episode_id: str, world_input: WorldDefinition | dict[str, Any]) -> None:
|
| 99 |
+
self._remove_world()
|
| 100 |
+
snapshot = LiveStateSnapshot(
|
| 101 |
+
episode_id=episode_id,
|
| 102 |
+
status="compiling",
|
| 103 |
+
updated_at=self._timestamp(),
|
| 104 |
+
title=self._extract_title(world_input),
|
| 105 |
+
runner=self.runner_name,
|
| 106 |
+
)
|
| 107 |
+
self._write_state_snapshot(snapshot)
|
| 108 |
+
|
| 109 |
+
def on_compile_success(self, compiled: CompiledWorld, session: EpisodeSession) -> None:
|
| 110 |
+
self._write_world(compiled.world)
|
| 111 |
+
snapshot = LiveStateSnapshot(
|
| 112 |
+
episode_id=compiled.episode_id,
|
| 113 |
+
status="running",
|
| 114 |
+
updated_at=self._timestamp(),
|
| 115 |
+
title=compiled.world.meta.title,
|
| 116 |
+
runner=self.runner_name,
|
| 117 |
+
metrics=self._metrics(min_steps=len(compiled.solver_policy), steps_taken=session.steps_taken),
|
| 118 |
+
runtime=self._runtime(session),
|
| 119 |
+
current_room=self._current_room(session),
|
| 120 |
+
)
|
| 121 |
+
self._write_state_snapshot(snapshot)
|
| 122 |
+
|
| 123 |
+
def on_turn(self, session: EpisodeSession, turn: Turn) -> None:
|
| 124 |
+
del turn
|
| 125 |
+
snapshot = LiveStateSnapshot(
|
| 126 |
+
episode_id=session.compiled.episode_id,
|
| 127 |
+
status="running",
|
| 128 |
+
updated_at=self._timestamp(),
|
| 129 |
+
title=session.compiled.world.meta.title,
|
| 130 |
+
runner=self.runner_name,
|
| 131 |
+
transcript=list(session.transcript),
|
| 132 |
+
metrics=self._metrics(
|
| 133 |
+
min_steps=len(session.compiled.solver_policy),
|
| 134 |
+
steps_taken=session.steps_taken,
|
| 135 |
+
),
|
| 136 |
+
runtime=self._runtime(session),
|
| 137 |
+
current_room=self._current_room(session),
|
| 138 |
+
)
|
| 139 |
+
self._write_state_snapshot(snapshot)
|
| 140 |
+
|
| 141 |
+
def on_complete(self, compiled: CompiledWorld, session: EpisodeSession, observation: DMObservation) -> None:
|
| 142 |
+
status = "complete" if observation.player_won else "failed"
|
| 143 |
+
snapshot = LiveStateSnapshot(
|
| 144 |
+
episode_id=compiled.episode_id,
|
| 145 |
+
status=status,
|
| 146 |
+
updated_at=self._timestamp(),
|
| 147 |
+
title=compiled.world.meta.title,
|
| 148 |
+
runner=self.runner_name,
|
| 149 |
+
transcript=list(session.transcript),
|
| 150 |
+
metrics=self._metrics(
|
| 151 |
+
min_steps=observation.min_steps,
|
| 152 |
+
steps_taken=observation.steps_taken or session.steps_taken,
|
| 153 |
+
ratio=observation.ratio,
|
| 154 |
+
reward=observation.reward,
|
| 155 |
+
player_won=observation.player_won,
|
| 156 |
+
),
|
| 157 |
+
feedback=observation.feedback,
|
| 158 |
+
runtime=self._runtime(session),
|
| 159 |
+
current_room=self._current_room(session),
|
| 160 |
+
)
|
| 161 |
+
self._write_state_snapshot(snapshot)
|
| 162 |
+
|
| 163 |
+
def on_error(
|
| 164 |
+
self,
|
| 165 |
+
*,
|
| 166 |
+
episode_id: str,
|
| 167 |
+
error: str,
|
| 168 |
+
world_input: WorldDefinition | dict[str, Any],
|
| 169 |
+
compiled: CompiledWorld | None = None,
|
| 170 |
+
session: EpisodeSession | None = None,
|
| 171 |
+
) -> None:
|
| 172 |
+
title = compiled.world.meta.title if compiled is not None else self._extract_title(world_input)
|
| 173 |
+
snapshot = LiveStateSnapshot(
|
| 174 |
+
episode_id=episode_id,
|
| 175 |
+
status="compile_error",
|
| 176 |
+
updated_at=self._timestamp(),
|
| 177 |
+
title=title,
|
| 178 |
+
runner=self.runner_name,
|
| 179 |
+
error=error,
|
| 180 |
+
transcript=list(session.transcript) if session is not None else [],
|
| 181 |
+
metrics=self._metrics(
|
| 182 |
+
min_steps=len(compiled.solver_policy) if compiled is not None else None,
|
| 183 |
+
steps_taken=session.steps_taken if session is not None else 0,
|
| 184 |
+
),
|
| 185 |
+
runtime=self._runtime(session),
|
| 186 |
+
current_room=self._current_room(session),
|
| 187 |
+
)
|
| 188 |
+
self._write_state_snapshot(snapshot)
|
| 189 |
+
|
| 190 |
+
def _write_world(self, world: WorldDefinition) -> None:
|
| 191 |
+
self._write_json(self.live_dir / WORLD_FILENAME, world.model_dump_json(indent=2))
|
| 192 |
+
|
| 193 |
+
def _write_state_snapshot(self, snapshot: LiveStateSnapshot) -> None:
|
| 194 |
+
self._write_json(self.live_dir / STATE_FILENAME, snapshot.model_dump_json(indent=2))
|
| 195 |
+
|
| 196 |
+
def _write_json(self, path: Path, payload: str) -> None:
|
| 197 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 198 |
+
tmp_path = path.with_suffix(path.suffix + ".tmp")
|
| 199 |
+
tmp_path.write_text(payload + "\n", encoding="utf-8")
|
| 200 |
+
tmp_path.replace(path)
|
| 201 |
+
|
| 202 |
+
def _remove_world(self) -> None:
|
| 203 |
+
world_path = self.live_dir / WORLD_FILENAME
|
| 204 |
+
if world_path.exists():
|
| 205 |
+
world_path.unlink()
|
| 206 |
+
|
| 207 |
+
@staticmethod
|
| 208 |
+
def _timestamp() -> str:
|
| 209 |
+
return datetime.now(timezone.utc).isoformat()
|
| 210 |
+
|
| 211 |
+
@staticmethod
|
| 212 |
+
def _extract_title(world_input: WorldDefinition | dict[str, Any]) -> str | None:
|
| 213 |
+
if isinstance(world_input, WorldDefinition):
|
| 214 |
+
return world_input.meta.title
|
| 215 |
+
meta = world_input.get("meta") if isinstance(world_input, dict) else None
|
| 216 |
+
title = meta.get("title") if isinstance(meta, dict) else None
|
| 217 |
+
return title if isinstance(title, str) else None
|
| 218 |
+
|
| 219 |
+
@staticmethod
|
| 220 |
+
def _metrics(
|
| 221 |
+
*,
|
| 222 |
+
min_steps: int | None,
|
| 223 |
+
steps_taken: int,
|
| 224 |
+
ratio: float | None = None,
|
| 225 |
+
reward: float | None = None,
|
| 226 |
+
player_won: bool | None = None,
|
| 227 |
+
) -> LiveMetrics:
|
| 228 |
+
computed_ratio = ratio
|
| 229 |
+
if computed_ratio is None and min_steps:
|
| 230 |
+
computed_ratio = steps_taken / min_steps
|
| 231 |
+
return LiveMetrics(
|
| 232 |
+
steps_taken=steps_taken,
|
| 233 |
+
min_steps=min_steps,
|
| 234 |
+
ratio=computed_ratio,
|
| 235 |
+
reward=reward,
|
| 236 |
+
player_won=player_won,
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
@staticmethod
|
| 240 |
+
def _runtime(session: EpisodeSession | None) -> LiveRuntime:
|
| 241 |
+
if session is None:
|
| 242 |
+
return LiveRuntime()
|
| 243 |
+
room_ids = {
|
| 244 |
+
node.id
|
| 245 |
+
for node in session.compiled.world.nodes
|
| 246 |
+
if node.type in {"location", "junction"}
|
| 247 |
+
}
|
| 248 |
+
commands = [] if session.done else session.available_commands()
|
| 249 |
+
return LiveRuntime(
|
| 250 |
+
current_room_id=session.current_room_id,
|
| 251 |
+
inventory_item_ids=sorted(session.inventory),
|
| 252 |
+
discovered_clue_ids=sorted(session.discovered_clues),
|
| 253 |
+
traded_npc_ids=sorted(session.traded_npcs),
|
| 254 |
+
visited_room_ids=sorted(room_ids & session.visited_nodes),
|
| 255 |
+
available_commands=commands,
|
| 256 |
+
invalid_command_count=session.invalid_command_count,
|
| 257 |
+
wrong_submit_count=session.wrong_submit_count,
|
| 258 |
+
open_node_ids=sorted(session.open_nodes),
|
| 259 |
+
locked_node_ids=sorted(session.locked_nodes),
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
@staticmethod
|
| 263 |
+
def _current_room(session: EpisodeSession | None) -> LiveCurrentRoom | None:
|
| 264 |
+
if session is None:
|
| 265 |
+
return None
|
| 266 |
+
node_by_id = {node.id: node for node in session.compiled.world.nodes}
|
| 267 |
+
room = node_by_id.get(session.current_room_id)
|
| 268 |
+
if room is None:
|
| 269 |
+
return None
|
| 270 |
+
visible_nodes = [
|
| 271 |
+
node.id
|
| 272 |
+
for node in session.compiled.world.nodes
|
| 273 |
+
if getattr(node, "parent_id", None) == session.current_room_id
|
| 274 |
+
and (node.type != "readable" or node.id in session.revealed_readables)
|
| 275 |
+
]
|
| 276 |
+
visible_nodes.extend(
|
| 277 |
+
sorted(
|
| 278 |
+
door_id
|
| 279 |
+
for door_id, rooms in session.compiled.door_rooms.items()
|
| 280 |
+
if session.current_room_id in rooms
|
| 281 |
+
)
|
| 282 |
+
)
|
| 283 |
+
visible_items = sorted(
|
| 284 |
+
item_id
|
| 285 |
+
for item_id, location in session.item_locations.items()
|
| 286 |
+
if location == session.current_room_id
|
| 287 |
+
)
|
| 288 |
+
return LiveCurrentRoom(
|
| 289 |
+
id=room.id,
|
| 290 |
+
label=room.label,
|
| 291 |
+
description=room.description,
|
| 292 |
+
visible_node_ids=sorted(set(visible_nodes)),
|
| 293 |
+
visible_item_ids=visible_items,
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def load_live_payload(live_dir: Path, filename: str) -> bytes | None:
|
| 298 |
+
path = live_dir / filename
|
| 299 |
+
if not path.exists():
|
| 300 |
+
return None
|
| 301 |
+
return path.read_bytes()
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def load_live_state(live_dir: Path) -> dict[str, Any] | None:
|
| 305 |
+
payload = load_live_payload(live_dir, STATE_FILENAME)
|
| 306 |
+
if payload is None:
|
| 307 |
+
return None
|
| 308 |
+
return json.loads(payload)
|
agents/master/templates.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
DIST_INDEX = Path(__file__).resolve().parents[2] / "www" / "dist" / "index.html"
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def render_index() -> str:
|
| 10 |
+
if DIST_INDEX.is_file():
|
| 11 |
+
return DIST_INDEX.read_text(encoding="utf-8")
|
| 12 |
+
return """<!doctype html>
|
| 13 |
+
<html lang="en">
|
| 14 |
+
<head>
|
| 15 |
+
<meta charset="utf-8">
|
| 16 |
+
<meta name="viewport" content="width=device-width, initial-scale=1">
|
| 17 |
+
<title>Viewer Not Built</title>
|
| 18 |
+
<style>
|
| 19 |
+
body {
|
| 20 |
+
margin: 0;
|
| 21 |
+
min-height: 100vh;
|
| 22 |
+
display: grid;
|
| 23 |
+
place-items: center;
|
| 24 |
+
font-family: ui-monospace, Menlo, Monaco, monospace;
|
| 25 |
+
background: #0a0d14;
|
| 26 |
+
color: #e7ebf2;
|
| 27 |
+
}
|
| 28 |
+
main {
|
| 29 |
+
max-width: 48rem;
|
| 30 |
+
padding: 2rem;
|
| 31 |
+
}
|
| 32 |
+
code {
|
| 33 |
+
color: #ffd86b;
|
| 34 |
+
}
|
| 35 |
+
</style>
|
| 36 |
+
</head>
|
| 37 |
+
<body>
|
| 38 |
+
<main>
|
| 39 |
+
<h1>Frontend build not found.</h1>
|
| 40 |
+
<p>Run <code>npm run dev</code> for the Vite app or <code>npm run build</code> to let the Python server serve the built site from <code>www/dist</code>.</p>
|
| 41 |
+
</main>
|
| 42 |
+
</body>
|
| 43 |
+
</html>
|
| 44 |
+
"""
|
agents/openenv_server/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""OpenEnv HTTP server entrypoints for dungeon environments."""
|
| 2 |
+
|
agents/openenv_server/__main__.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import uvicorn
|
| 7 |
+
from openenv.core.env_server import create_fastapi_app
|
| 8 |
+
|
| 9 |
+
from agents.hero.env import HeroEnvironment
|
| 10 |
+
from agents.hero.schema import HeroObservation, HeroServerAction
|
| 11 |
+
from agents.master.env import DMEnvironment
|
| 12 |
+
from agents.master.sample import load_world
|
| 13 |
+
from agents.master.schema import DMAction, DMObservation
|
| 14 |
+
from agents.shared.runtime import build_interface_adapter, resolve_interface_config
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def main(argv: list[str] | None = None) -> int:
|
| 18 |
+
parser = argparse.ArgumentParser(description="Serve dungeon environments over OpenEnv HTTP/WebSocket APIs.")
|
| 19 |
+
parser.add_argument("role", choices=["dm", "hero"])
|
| 20 |
+
parser.add_argument("--host", default="127.0.0.1")
|
| 21 |
+
parser.add_argument("--port", type=int)
|
| 22 |
+
parser.add_argument("--world", type=Path, help="Optional world definition JSON for hero serving.")
|
| 23 |
+
parser.add_argument("--artifacts-root", type=Path)
|
| 24 |
+
parser.add_argument("--max-concurrent-envs", type=int, default=1)
|
| 25 |
+
parser.add_argument("--interface-provider", choices=["strict", "simple", "gemini"])
|
| 26 |
+
parser.add_argument("--interface-model")
|
| 27 |
+
parser.add_argument("--interface-narrate", action="store_true")
|
| 28 |
+
parser.add_argument(
|
| 29 |
+
"--translate-corporate-env",
|
| 30 |
+
action="store_true",
|
| 31 |
+
help="Rewrite hero-facing observations into a corporate app metaphor and map translated commands back through Gemini.",
|
| 32 |
+
)
|
| 33 |
+
args = parser.parse_args(argv)
|
| 34 |
+
|
| 35 |
+
interface_config = resolve_interface_config(
|
| 36 |
+
provider=args.interface_provider,
|
| 37 |
+
model_name=args.interface_model,
|
| 38 |
+
narrate_observations=args.interface_narrate,
|
| 39 |
+
translation_mode="corporate_app" if args.translate_corporate_env else None,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
if args.role == "dm":
|
| 43 |
+
env_factory = lambda: DMEnvironment(
|
| 44 |
+
artifacts_root=args.artifacts_root,
|
| 45 |
+
interface_adapter=build_interface_adapter(interface_config),
|
| 46 |
+
)
|
| 47 |
+
action_cls = DMAction
|
| 48 |
+
observation_cls = DMObservation
|
| 49 |
+
default_port = 8001
|
| 50 |
+
else:
|
| 51 |
+
world_input = load_world(str(args.world)) if args.world is not None else None
|
| 52 |
+
env_factory = lambda: HeroEnvironment(
|
| 53 |
+
artifacts_root=args.artifacts_root,
|
| 54 |
+
world_input=world_input,
|
| 55 |
+
interface_adapter=build_interface_adapter(interface_config),
|
| 56 |
+
)
|
| 57 |
+
action_cls = HeroServerAction
|
| 58 |
+
observation_cls = HeroObservation
|
| 59 |
+
default_port = 8002
|
| 60 |
+
|
| 61 |
+
app = create_fastapi_app(
|
| 62 |
+
env=env_factory,
|
| 63 |
+
action_cls=action_cls,
|
| 64 |
+
observation_cls=observation_cls,
|
| 65 |
+
max_concurrent_envs=args.max_concurrent_envs,
|
| 66 |
+
)
|
| 67 |
+
uvicorn.run(app, host=args.host, port=args.port or default_port)
|
| 68 |
+
return 0
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
if __name__ == "__main__":
|
| 72 |
+
raise SystemExit(main())
|
agents/shared/__init__.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared helpers for agent environments and model adapters."""
|
| 2 |
+
|
| 3 |
+
from .llm_client import (
|
| 4 |
+
DEFAULT_HF_DM_MODEL,
|
| 5 |
+
DEFAULT_HF_HERO_MODEL,
|
| 6 |
+
GeminiStructuredClient,
|
| 7 |
+
HuggingFaceStructuredClient,
|
| 8 |
+
StructuredModelClient,
|
| 9 |
+
)
|
| 10 |
+
from .model_schema import ModelMessage
|
| 11 |
+
from .openenv_compat import OPENENV_AVAILABLE
|
| 12 |
+
from .runtime import (
|
| 13 |
+
DEFAULT_INTERFACE_MODEL,
|
| 14 |
+
DEFAULT_INTERFACE_PROVIDER,
|
| 15 |
+
DEFAULT_INTERFACE_TRANSLATION_MODE,
|
| 16 |
+
InterfaceConfig,
|
| 17 |
+
InterfaceTranslationMode,
|
| 18 |
+
StructuredClientConfig,
|
| 19 |
+
build_interface_adapter,
|
| 20 |
+
create_structured_client,
|
| 21 |
+
resolve_interface_config,
|
| 22 |
+
resolve_structured_client_config,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
__all__ = [
|
| 26 |
+
"build_interface_adapter",
|
| 27 |
+
"create_structured_client",
|
| 28 |
+
"DEFAULT_HF_DM_MODEL",
|
| 29 |
+
"DEFAULT_HF_HERO_MODEL",
|
| 30 |
+
"DEFAULT_INTERFACE_MODEL",
|
| 31 |
+
"DEFAULT_INTERFACE_PROVIDER",
|
| 32 |
+
"DEFAULT_INTERFACE_TRANSLATION_MODE",
|
| 33 |
+
"GeminiStructuredClient",
|
| 34 |
+
"HuggingFaceStructuredClient",
|
| 35 |
+
"InterfaceConfig",
|
| 36 |
+
"InterfaceTranslationMode",
|
| 37 |
+
"ModelMessage",
|
| 38 |
+
"OPENENV_AVAILABLE",
|
| 39 |
+
"resolve_interface_config",
|
| 40 |
+
"resolve_structured_client_config",
|
| 41 |
+
"StructuredModelClient",
|
| 42 |
+
"StructuredClientConfig",
|
| 43 |
+
]
|
agents/shared/llm_client.py
ADDED
|
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Any, Protocol, TypeVar
|
| 7 |
+
|
| 8 |
+
from dotenv import load_dotenv
|
| 9 |
+
from google import genai
|
| 10 |
+
from google.genai import types
|
| 11 |
+
from pydantic import BaseModel
|
| 12 |
+
|
| 13 |
+
from .model_schema import ModelMessage
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
from trl.chat_template_utils import qwen3_chat_template
|
| 17 |
+
except Exception: # pragma: no cover - optional runtime dependency
|
| 18 |
+
qwen3_chat_template = None # type: ignore[assignment]
|
| 19 |
+
|
| 20 |
+
ResponseModelT = TypeVar("ResponseModelT", bound=BaseModel)
|
| 21 |
+
|
| 22 |
+
DEFAULT_GEMINI_DM_MODEL = "gemini-2.5-flash"
|
| 23 |
+
DEFAULT_GEMINI_HERO_MODEL = "gemini-2.5-flash"
|
| 24 |
+
DEFAULT_HF_DM_MODEL = "Qwen/Qwen3-32B"
|
| 25 |
+
DEFAULT_HF_HERO_MODEL = "Qwen/Qwen3-32B"
|
| 26 |
+
PROVIDER_GEMINI = "gemini"
|
| 27 |
+
PROVIDER_HF_LOCAL = "hf_local"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class StructuredModelClient(Protocol):
|
| 31 |
+
def generate_structured(
|
| 32 |
+
self,
|
| 33 |
+
messages: list[ModelMessage],
|
| 34 |
+
response_model: type[ResponseModelT],
|
| 35 |
+
*,
|
| 36 |
+
model_name: str,
|
| 37 |
+
temperature: float,
|
| 38 |
+
max_output_tokens: int,
|
| 39 |
+
) -> ResponseModelT:
|
| 40 |
+
...
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class GeminiStructuredClient:
|
| 44 |
+
def __init__(self, api_key: str | None = None) -> None:
|
| 45 |
+
self._client = self._create_client(api_key)
|
| 46 |
+
|
| 47 |
+
def generate_structured(
|
| 48 |
+
self,
|
| 49 |
+
messages: list[ModelMessage],
|
| 50 |
+
response_model: type[ResponseModelT],
|
| 51 |
+
*,
|
| 52 |
+
model_name: str,
|
| 53 |
+
temperature: float,
|
| 54 |
+
max_output_tokens: int,
|
| 55 |
+
) -> ResponseModelT:
|
| 56 |
+
failures: list[str] = []
|
| 57 |
+
strategies = (
|
| 58 |
+
self._generate_with_response_schema,
|
| 59 |
+
self._generate_with_json_mode,
|
| 60 |
+
self._generate_with_prompt_only,
|
| 61 |
+
)
|
| 62 |
+
for strategy in strategies:
|
| 63 |
+
try:
|
| 64 |
+
return strategy(
|
| 65 |
+
messages,
|
| 66 |
+
response_model,
|
| 67 |
+
model_name=model_name,
|
| 68 |
+
temperature=temperature,
|
| 69 |
+
max_output_tokens=max_output_tokens,
|
| 70 |
+
)
|
| 71 |
+
except Exception as exc:
|
| 72 |
+
failures.append(f"{strategy.__name__}: {self._normalize_error(exc)}")
|
| 73 |
+
raise RuntimeError("Gemini structured generation failed. " + " | ".join(failures))
|
| 74 |
+
|
| 75 |
+
def _generate_with_response_schema(
|
| 76 |
+
self,
|
| 77 |
+
messages: list[ModelMessage],
|
| 78 |
+
response_model: type[ResponseModelT],
|
| 79 |
+
*,
|
| 80 |
+
model_name: str,
|
| 81 |
+
temperature: float,
|
| 82 |
+
max_output_tokens: int,
|
| 83 |
+
) -> ResponseModelT:
|
| 84 |
+
system_instruction, contents = self._split_messages(messages)
|
| 85 |
+
response = self._client.models.generate_content(
|
| 86 |
+
model=model_name,
|
| 87 |
+
contents=contents,
|
| 88 |
+
config=types.GenerateContentConfig(
|
| 89 |
+
system_instruction=system_instruction,
|
| 90 |
+
temperature=temperature,
|
| 91 |
+
max_output_tokens=max_output_tokens,
|
| 92 |
+
response_mime_type="application/json",
|
| 93 |
+
response_schema=response_model,
|
| 94 |
+
candidate_count=1,
|
| 95 |
+
),
|
| 96 |
+
)
|
| 97 |
+
parsed = getattr(response, "parsed", None)
|
| 98 |
+
if parsed is not None:
|
| 99 |
+
return response_model.model_validate(parsed)
|
| 100 |
+
text = getattr(response, "text", None)
|
| 101 |
+
if isinstance(text, str) and text.strip():
|
| 102 |
+
return response_model.model_validate_json(text)
|
| 103 |
+
raise RuntimeError("Gemini returned an empty structured response.")
|
| 104 |
+
|
| 105 |
+
def _generate_with_json_mode(
|
| 106 |
+
self,
|
| 107 |
+
messages: list[ModelMessage],
|
| 108 |
+
response_model: type[ResponseModelT],
|
| 109 |
+
*,
|
| 110 |
+
model_name: str,
|
| 111 |
+
temperature: float,
|
| 112 |
+
max_output_tokens: int,
|
| 113 |
+
) -> ResponseModelT:
|
| 114 |
+
prompt = self._json_prompt(messages, response_model)
|
| 115 |
+
response = self._client.models.generate_content(
|
| 116 |
+
model=model_name,
|
| 117 |
+
contents=prompt,
|
| 118 |
+
config=types.GenerateContentConfig(
|
| 119 |
+
temperature=temperature,
|
| 120 |
+
max_output_tokens=max_output_tokens,
|
| 121 |
+
response_mime_type="application/json",
|
| 122 |
+
candidate_count=1,
|
| 123 |
+
),
|
| 124 |
+
)
|
| 125 |
+
text = getattr(response, "text", None)
|
| 126 |
+
if not isinstance(text, str) or not text.strip():
|
| 127 |
+
raise RuntimeError("Gemini returned an empty JSON-mode response.")
|
| 128 |
+
return response_model.model_validate_json(text)
|
| 129 |
+
|
| 130 |
+
def _generate_with_prompt_only(
|
| 131 |
+
self,
|
| 132 |
+
messages: list[ModelMessage],
|
| 133 |
+
response_model: type[ResponseModelT],
|
| 134 |
+
*,
|
| 135 |
+
model_name: str,
|
| 136 |
+
temperature: float,
|
| 137 |
+
max_output_tokens: int,
|
| 138 |
+
) -> ResponseModelT:
|
| 139 |
+
prompt = self._json_prompt(messages, response_model)
|
| 140 |
+
response = self._client.models.generate_content(
|
| 141 |
+
model=model_name,
|
| 142 |
+
contents=prompt,
|
| 143 |
+
config=types.GenerateContentConfig(
|
| 144 |
+
temperature=temperature,
|
| 145 |
+
max_output_tokens=max_output_tokens,
|
| 146 |
+
candidate_count=1,
|
| 147 |
+
),
|
| 148 |
+
)
|
| 149 |
+
text = getattr(response, "text", None)
|
| 150 |
+
if not isinstance(text, str) or not text.strip():
|
| 151 |
+
raise RuntimeError("Gemini returned an empty prompt-only response.")
|
| 152 |
+
return response_model.model_validate_json(self._extract_json_object(text))
|
| 153 |
+
|
| 154 |
+
def _create_client(self, api_key: str | None) -> genai.Client:
|
| 155 |
+
load_dotenv(self._repo_root() / ".env", override=False)
|
| 156 |
+
key = api_key or os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY")
|
| 157 |
+
if not key:
|
| 158 |
+
raise RuntimeError("Missing GEMINI_API_KEY or GOOGLE_API_KEY.")
|
| 159 |
+
return genai.Client(api_key=key)
|
| 160 |
+
|
| 161 |
+
@staticmethod
|
| 162 |
+
def _repo_root() -> Path:
|
| 163 |
+
return Path(__file__).resolve().parents[2]
|
| 164 |
+
|
| 165 |
+
@staticmethod
|
| 166 |
+
def _split_messages(messages: list[ModelMessage]) -> tuple[str | None, list[str]]:
|
| 167 |
+
system_parts: list[str] = []
|
| 168 |
+
content_parts: list[str] = []
|
| 169 |
+
for message in messages:
|
| 170 |
+
if message.role == "system":
|
| 171 |
+
system_parts.append(message.content)
|
| 172 |
+
continue
|
| 173 |
+
content_parts.append(f"{message.role.upper()}:\n{message.content}")
|
| 174 |
+
system_instruction = "\n\n".join(system_parts) if system_parts else None
|
| 175 |
+
contents = ["\n\n".join(content_parts)] if content_parts else [""]
|
| 176 |
+
return system_instruction, contents
|
| 177 |
+
|
| 178 |
+
@staticmethod
|
| 179 |
+
def _json_prompt(
|
| 180 |
+
messages: list[ModelMessage],
|
| 181 |
+
response_model: type[ResponseModelT],
|
| 182 |
+
) -> str:
|
| 183 |
+
message_blocks = [f"{message.role.upper()}:\n{message.content}" for message in messages]
|
| 184 |
+
schema = _schema_prompt_snippet(response_model)
|
| 185 |
+
conversation = "\n\n".join(message_blocks)
|
| 186 |
+
return (
|
| 187 |
+
"Return exactly one valid JSON object and nothing else.\n"
|
| 188 |
+
"Do not use markdown fences.\n"
|
| 189 |
+
"Use compact JSON with no commentary.\n"
|
| 190 |
+
f"JSON Schema:\n{schema}\n\n"
|
| 191 |
+
f"Conversation:\n{conversation}\n"
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
@staticmethod
|
| 195 |
+
def _extract_json_object(text: str) -> str:
|
| 196 |
+
cleaned = text.strip()
|
| 197 |
+
if cleaned.startswith("```"):
|
| 198 |
+
cleaned = cleaned.strip("`")
|
| 199 |
+
if cleaned.startswith("json"):
|
| 200 |
+
cleaned = cleaned[4:].lstrip()
|
| 201 |
+
start = cleaned.find("{")
|
| 202 |
+
end = cleaned.rfind("}")
|
| 203 |
+
if start == -1 or end == -1 or end < start:
|
| 204 |
+
raise RuntimeError("Gemini response did not contain a JSON object.")
|
| 205 |
+
return cleaned[start : end + 1]
|
| 206 |
+
|
| 207 |
+
@staticmethod
|
| 208 |
+
def _normalize_error(exc: Exception) -> str:
|
| 209 |
+
return " ".join(str(exc).split()) or exc.__class__.__name__
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
class HuggingFaceStructuredClient:
|
| 213 |
+
def __init__(
|
| 214 |
+
self,
|
| 215 |
+
*,
|
| 216 |
+
adapter_path: str | None = None,
|
| 217 |
+
cache_dir: str | None = None,
|
| 218 |
+
load_in_4bit: bool = True,
|
| 219 |
+
trust_remote_code: bool = False,
|
| 220 |
+
device_map: str | None = "auto",
|
| 221 |
+
) -> None:
|
| 222 |
+
self.adapter_path = adapter_path
|
| 223 |
+
self.cache_dir = cache_dir
|
| 224 |
+
self.load_in_4bit = load_in_4bit
|
| 225 |
+
self.trust_remote_code = trust_remote_code
|
| 226 |
+
self.device_map = device_map
|
| 227 |
+
self._loaded_model_name: str | None = None
|
| 228 |
+
self._model: Any | None = None
|
| 229 |
+
self._tokenizer: Any | None = None
|
| 230 |
+
|
| 231 |
+
def generate_structured(
|
| 232 |
+
self,
|
| 233 |
+
messages: list[ModelMessage],
|
| 234 |
+
response_model: type[ResponseModelT],
|
| 235 |
+
*,
|
| 236 |
+
model_name: str,
|
| 237 |
+
temperature: float,
|
| 238 |
+
max_output_tokens: int,
|
| 239 |
+
) -> ResponseModelT:
|
| 240 |
+
tokenizer, model = self._ensure_model(model_name)
|
| 241 |
+
prompt = self._hf_prompt(messages, response_model)
|
| 242 |
+
rendered = self._render_prompt(tokenizer, prompt)
|
| 243 |
+
tokenized = tokenizer(rendered, return_tensors="pt")
|
| 244 |
+
tokenized = {key: value.to(model.device) for key, value in tokenized.items()}
|
| 245 |
+
generate_kwargs: dict[str, Any] = {
|
| 246 |
+
"max_new_tokens": max_output_tokens,
|
| 247 |
+
"do_sample": temperature > 0.0,
|
| 248 |
+
"temperature": max(temperature, 1e-5) if temperature > 0.0 else None,
|
| 249 |
+
"pad_token_id": getattr(tokenizer, "pad_token_id", None) or getattr(tokenizer, "eos_token_id", None),
|
| 250 |
+
"eos_token_id": getattr(tokenizer, "eos_token_id", None),
|
| 251 |
+
}
|
| 252 |
+
generate_kwargs = {key: value for key, value in generate_kwargs.items() if value is not None}
|
| 253 |
+
|
| 254 |
+
import torch
|
| 255 |
+
|
| 256 |
+
with torch.inference_mode():
|
| 257 |
+
output_ids = model.generate(**tokenized, **generate_kwargs)
|
| 258 |
+
prompt_length = tokenized["input_ids"].shape[1]
|
| 259 |
+
completion_ids = output_ids[0][prompt_length:]
|
| 260 |
+
text = tokenizer.decode(completion_ids, skip_special_tokens=True)
|
| 261 |
+
if not text.strip():
|
| 262 |
+
raise RuntimeError("Hugging Face model returned an empty response.")
|
| 263 |
+
return response_model.model_validate_json(self._extract_json_object(text))
|
| 264 |
+
|
| 265 |
+
def _ensure_model(self, model_name: str) -> tuple[Any, Any]:
|
| 266 |
+
if self._model is not None and self._tokenizer is not None and self._loaded_model_name == model_name:
|
| 267 |
+
return self._tokenizer, self._model
|
| 268 |
+
|
| 269 |
+
load_dotenv(self._repo_root() / ".env", override=False)
|
| 270 |
+
|
| 271 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 272 |
+
|
| 273 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 274 |
+
model_name,
|
| 275 |
+
cache_dir=self.cache_dir,
|
| 276 |
+
trust_remote_code=self.trust_remote_code,
|
| 277 |
+
token=_hf_token(),
|
| 278 |
+
)
|
| 279 |
+
if tokenizer.pad_token is None:
|
| 280 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 281 |
+
tokenizer = self._canonicalize_chat_template(tokenizer)
|
| 282 |
+
|
| 283 |
+
model_kwargs: dict[str, Any] = {
|
| 284 |
+
"cache_dir": self.cache_dir,
|
| 285 |
+
"trust_remote_code": self.trust_remote_code,
|
| 286 |
+
"token": _hf_token(),
|
| 287 |
+
}
|
| 288 |
+
model_kwargs.update(_hf_model_init_kwargs(load_in_4bit=self.load_in_4bit, device_map=self.device_map))
|
| 289 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
|
| 290 |
+
if self.adapter_path:
|
| 291 |
+
from peft import PeftModel
|
| 292 |
+
|
| 293 |
+
model = PeftModel.from_pretrained(model, self.adapter_path, is_trainable=False)
|
| 294 |
+
model.eval()
|
| 295 |
+
self._loaded_model_name = model_name
|
| 296 |
+
self._model = model
|
| 297 |
+
self._tokenizer = tokenizer
|
| 298 |
+
return tokenizer, model
|
| 299 |
+
|
| 300 |
+
@staticmethod
|
| 301 |
+
def _repo_root() -> Path:
|
| 302 |
+
return Path(__file__).resolve().parents[2]
|
| 303 |
+
|
| 304 |
+
@staticmethod
|
| 305 |
+
def _render_prompt(tokenizer: Any, prompt: str) -> str:
|
| 306 |
+
if hasattr(tokenizer, "apply_chat_template"):
|
| 307 |
+
chat_template_kwargs = HuggingFaceStructuredClient._chat_template_kwargs(tokenizer)
|
| 308 |
+
return tokenizer.apply_chat_template(
|
| 309 |
+
[
|
| 310 |
+
{"role": "system", "content": "Return exactly one valid JSON object and nothing else."},
|
| 311 |
+
{"role": "user", "content": prompt},
|
| 312 |
+
],
|
| 313 |
+
tokenize=False,
|
| 314 |
+
add_generation_prompt=True,
|
| 315 |
+
**chat_template_kwargs,
|
| 316 |
+
)
|
| 317 |
+
return prompt
|
| 318 |
+
|
| 319 |
+
@staticmethod
|
| 320 |
+
def _canonicalize_chat_template(tokenizer: Any) -> Any:
|
| 321 |
+
chat_template = getattr(tokenizer, "chat_template", "") or ""
|
| 322 |
+
if qwen3_chat_template is None:
|
| 323 |
+
return tokenizer
|
| 324 |
+
if "<|im_start|>" not in chat_template or "<|im_end|>" not in chat_template:
|
| 325 |
+
return tokenizer
|
| 326 |
+
tokenizer.chat_template = qwen3_chat_template
|
| 327 |
+
return tokenizer
|
| 328 |
+
|
| 329 |
+
@staticmethod
|
| 330 |
+
def _chat_template_kwargs(tokenizer: Any) -> dict[str, Any]:
|
| 331 |
+
if not hasattr(tokenizer, "apply_chat_template"):
|
| 332 |
+
return {}
|
| 333 |
+
try:
|
| 334 |
+
tokenizer.apply_chat_template(
|
| 335 |
+
[{"role": "user", "content": "ping"}],
|
| 336 |
+
tokenize=False,
|
| 337 |
+
add_generation_prompt=True,
|
| 338 |
+
enable_thinking=False,
|
| 339 |
+
)
|
| 340 |
+
except Exception:
|
| 341 |
+
return {}
|
| 342 |
+
return {"enable_thinking": False}
|
| 343 |
+
|
| 344 |
+
@staticmethod
|
| 345 |
+
def _hf_prompt(
|
| 346 |
+
messages: list[ModelMessage],
|
| 347 |
+
response_model: type[ResponseModelT],
|
| 348 |
+
) -> str:
|
| 349 |
+
schema = _schema_prompt_snippet(response_model)
|
| 350 |
+
conversation = "\n\n".join(f"{message.role.upper()}:\n{message.content}" for message in messages)
|
| 351 |
+
return (
|
| 352 |
+
"Respond with exactly one compact JSON object and no other text.\n"
|
| 353 |
+
"Do not use markdown fences.\n"
|
| 354 |
+
f"JSON Schema:\n{schema}\n\n"
|
| 355 |
+
f"Conversation:\n{conversation}\n"
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
@staticmethod
|
| 359 |
+
def _extract_json_object(text: str) -> str:
|
| 360 |
+
cleaned = text.strip()
|
| 361 |
+
if cleaned.startswith("```"):
|
| 362 |
+
cleaned = cleaned.strip("`")
|
| 363 |
+
if cleaned.startswith("json"):
|
| 364 |
+
cleaned = cleaned[4:].lstrip()
|
| 365 |
+
start = cleaned.find("{")
|
| 366 |
+
end = cleaned.rfind("}")
|
| 367 |
+
if start == -1 or end == -1 or end < start:
|
| 368 |
+
raise RuntimeError("Hugging Face response did not contain a JSON object.")
|
| 369 |
+
return cleaned[start : end + 1]
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def _schema_prompt_snippet(response_model: type[ResponseModelT]) -> str:
|
| 373 |
+
schema = response_model.model_json_schema()
|
| 374 |
+
serialized = json.dumps(schema, separators=(",", ":"))
|
| 375 |
+
if len(serialized) <= 4000:
|
| 376 |
+
return serialized
|
| 377 |
+
summarized = {
|
| 378 |
+
"title": schema.get("title", response_model.__name__),
|
| 379 |
+
"type": schema.get("type", "object"),
|
| 380 |
+
"required": schema.get("required", []),
|
| 381 |
+
"properties": {
|
| 382 |
+
key: {
|
| 383 |
+
field_name: value
|
| 384 |
+
for field_name, value in property_schema.items()
|
| 385 |
+
if field_name in {"type", "title", "enum", "items", "required", "$ref", "description"}
|
| 386 |
+
}
|
| 387 |
+
for key, property_schema in schema.get("properties", {}).items()
|
| 388 |
+
},
|
| 389 |
+
"defs": sorted(schema.get("$defs", {}).keys()),
|
| 390 |
+
}
|
| 391 |
+
return json.dumps(summarized, separators=(",", ":"))
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
def _hf_model_init_kwargs(*, load_in_4bit: bool, device_map: str | None) -> dict[str, Any]:
|
| 395 |
+
import torch
|
| 396 |
+
|
| 397 |
+
kwargs: dict[str, Any] = {
|
| 398 |
+
"torch_dtype": torch.bfloat16 if torch.cuda.is_available() else torch.float32,
|
| 399 |
+
}
|
| 400 |
+
if device_map is not None and torch.cuda.is_available():
|
| 401 |
+
kwargs["device_map"] = device_map
|
| 402 |
+
if load_in_4bit and torch.cuda.is_available():
|
| 403 |
+
from transformers import BitsAndBytesConfig
|
| 404 |
+
|
| 405 |
+
kwargs["quantization_config"] = BitsAndBytesConfig(
|
| 406 |
+
load_in_4bit=True,
|
| 407 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 408 |
+
bnb_4bit_quant_type="nf4",
|
| 409 |
+
bnb_4bit_use_double_quant=True,
|
| 410 |
+
)
|
| 411 |
+
return kwargs
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
def _hf_token() -> str | None:
|
| 415 |
+
return os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
|
agents/shared/model_schema.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Literal
|
| 4 |
+
|
| 5 |
+
from pydantic import BaseModel, ConfigDict
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class StrictModel(BaseModel):
|
| 9 |
+
model_config = ConfigDict(extra="forbid")
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ModelMessage(StrictModel):
|
| 13 |
+
role: Literal["system", "user", "assistant"]
|
| 14 |
+
content: str
|
agents/shared/openenv_compat.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Any, Generic, Optional, TypeVar
|
| 5 |
+
|
| 6 |
+
from pydantic import BaseModel, ConfigDict, Field
|
| 7 |
+
|
| 8 |
+
ObsT = TypeVar("ObsT")
|
| 9 |
+
ActT = TypeVar("ActT")
|
| 10 |
+
StateT = TypeVar("StateT")
|
| 11 |
+
|
| 12 |
+
try: # pragma: no cover - exercised when openenv-core is installed
|
| 13 |
+
from openenv.core.client_types import StepResult as OpenEnvStepResult
|
| 14 |
+
from openenv.core.env_server.interfaces import Environment as OpenEnvEnvironment
|
| 15 |
+
from openenv.core.env_server.types import (
|
| 16 |
+
Action as OpenEnvAction,
|
| 17 |
+
EnvironmentMetadata as OpenEnvEnvironmentMetadata,
|
| 18 |
+
Observation as OpenEnvObservation,
|
| 19 |
+
State as OpenEnvState,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
OPENENV_AVAILABLE = True
|
| 23 |
+
except ImportError: # pragma: no cover - lightweight fallback for local imports/tests
|
| 24 |
+
OPENENV_AVAILABLE = False
|
| 25 |
+
|
| 26 |
+
class Action(BaseModel):
|
| 27 |
+
model_config = ConfigDict(
|
| 28 |
+
extra="forbid",
|
| 29 |
+
validate_assignment=True,
|
| 30 |
+
arbitrary_types_allowed=True,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
| 34 |
+
|
| 35 |
+
class Observation(BaseModel):
|
| 36 |
+
model_config = ConfigDict(
|
| 37 |
+
extra="forbid",
|
| 38 |
+
validate_assignment=True,
|
| 39 |
+
arbitrary_types_allowed=True,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
done: bool = False
|
| 43 |
+
reward: bool | int | float | None = None
|
| 44 |
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
| 45 |
+
|
| 46 |
+
class State(BaseModel):
|
| 47 |
+
model_config = ConfigDict(
|
| 48 |
+
extra="allow",
|
| 49 |
+
validate_assignment=True,
|
| 50 |
+
arbitrary_types_allowed=True,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
episode_id: str | None = None
|
| 54 |
+
step_count: int = 0
|
| 55 |
+
|
| 56 |
+
class EnvironmentMetadata(BaseModel):
|
| 57 |
+
model_config = ConfigDict(extra="forbid")
|
| 58 |
+
|
| 59 |
+
name: str
|
| 60 |
+
description: str
|
| 61 |
+
version: str | None = None
|
| 62 |
+
|
| 63 |
+
@dataclass
|
| 64 |
+
class StepResult(Generic[ObsT]):
|
| 65 |
+
observation: ObsT
|
| 66 |
+
reward: Optional[float] = None
|
| 67 |
+
done: bool = False
|
| 68 |
+
|
| 69 |
+
class Environment(Generic[ActT, ObsT, StateT]):
|
| 70 |
+
SUPPORTS_CONCURRENT_SESSIONS: bool = False
|
| 71 |
+
|
| 72 |
+
def __init__(self, transform: Any | None = None) -> None:
|
| 73 |
+
self.transform = transform
|
| 74 |
+
|
| 75 |
+
def reset(
|
| 76 |
+
self,
|
| 77 |
+
seed: Optional[int] = None,
|
| 78 |
+
episode_id: Optional[str] = None,
|
| 79 |
+
**kwargs: Any,
|
| 80 |
+
) -> ObsT:
|
| 81 |
+
raise NotImplementedError
|
| 82 |
+
|
| 83 |
+
def step(
|
| 84 |
+
self,
|
| 85 |
+
action: ActT,
|
| 86 |
+
timeout_s: Optional[float] = None,
|
| 87 |
+
**kwargs: Any,
|
| 88 |
+
) -> ObsT:
|
| 89 |
+
raise NotImplementedError
|
| 90 |
+
|
| 91 |
+
@property
|
| 92 |
+
def state(self) -> StateT:
|
| 93 |
+
raise NotImplementedError
|
| 94 |
+
|
| 95 |
+
def get_metadata(self) -> EnvironmentMetadata:
|
| 96 |
+
return EnvironmentMetadata(
|
| 97 |
+
name=self.__class__.__name__,
|
| 98 |
+
description=f"{self.__class__.__name__} environment",
|
| 99 |
+
version="1.0.0",
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
def _apply_transform(self, observation: ObsT) -> ObsT:
|
| 103 |
+
return observation if self.transform is None else self.transform(observation)
|
| 104 |
+
|
| 105 |
+
def close(self) -> None:
|
| 106 |
+
return None
|
| 107 |
+
|
| 108 |
+
else:
|
| 109 |
+
Action = OpenEnvAction
|
| 110 |
+
Observation = OpenEnvObservation
|
| 111 |
+
State = OpenEnvState
|
| 112 |
+
Environment = OpenEnvEnvironment
|
| 113 |
+
EnvironmentMetadata = OpenEnvEnvironmentMetadata
|
| 114 |
+
StepResult = OpenEnvStepResult
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def build_step_result(observation: ObsT) -> StepResult[ObsT]:
|
| 118 |
+
reward = getattr(observation, "reward", None)
|
| 119 |
+
if reward is not None:
|
| 120 |
+
reward = float(reward)
|
| 121 |
+
return StepResult(
|
| 122 |
+
observation=observation,
|
| 123 |
+
reward=reward,
|
| 124 |
+
done=bool(getattr(observation, "done", False)),
|
| 125 |
+
)
|
agents/shared/runtime.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Literal
|
| 6 |
+
|
| 7 |
+
from .llm_client import (
|
| 8 |
+
DEFAULT_GEMINI_DM_MODEL,
|
| 9 |
+
DEFAULT_GEMINI_HERO_MODEL,
|
| 10 |
+
DEFAULT_HF_DM_MODEL,
|
| 11 |
+
DEFAULT_HF_HERO_MODEL,
|
| 12 |
+
GeminiStructuredClient,
|
| 13 |
+
HuggingFaceStructuredClient,
|
| 14 |
+
PROVIDER_GEMINI,
|
| 15 |
+
PROVIDER_HF_LOCAL,
|
| 16 |
+
StructuredModelClient,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
StructuredProvider = Literal["gemini", "hf_local"]
|
| 20 |
+
InterfaceProvider = Literal["strict", "simple", "gemini"]
|
| 21 |
+
InterfaceTranslationMode = Literal["none", "corporate_app"]
|
| 22 |
+
RoleName = Literal["dm", "hero"]
|
| 23 |
+
|
| 24 |
+
DEFAULT_INTERFACE_PROVIDER: InterfaceProvider = "strict"
|
| 25 |
+
DEFAULT_INTERFACE_MODEL = "gemini-2.5-flash-lite"
|
| 26 |
+
DEFAULT_INTERFACE_TRANSLATION_MODE: InterfaceTranslationMode = "none"
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass(frozen=True)
|
| 30 |
+
class StructuredClientConfig:
|
| 31 |
+
role: RoleName
|
| 32 |
+
provider: StructuredProvider
|
| 33 |
+
model_name: str
|
| 34 |
+
adapter_path: str | None = None
|
| 35 |
+
cache_dir: str | None = None
|
| 36 |
+
load_in_4bit: bool = True
|
| 37 |
+
trust_remote_code: bool = False
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@dataclass(frozen=True)
|
| 41 |
+
class InterfaceConfig:
|
| 42 |
+
provider: InterfaceProvider
|
| 43 |
+
model_name: str = DEFAULT_INTERFACE_MODEL
|
| 44 |
+
narrate_observations: bool = False
|
| 45 |
+
translation_mode: InterfaceTranslationMode = DEFAULT_INTERFACE_TRANSLATION_MODE
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def resolve_structured_client_config(
|
| 49 |
+
role: RoleName,
|
| 50 |
+
*,
|
| 51 |
+
provider: StructuredProvider | None = None,
|
| 52 |
+
model_name: str | None = None,
|
| 53 |
+
adapter_path: str | None = None,
|
| 54 |
+
) -> StructuredClientConfig:
|
| 55 |
+
env_prefix = f"DND_{role.upper()}"
|
| 56 |
+
resolved_provider = provider or _structured_provider_from_env(os.getenv(f"{env_prefix}_PROVIDER")) or PROVIDER_GEMINI
|
| 57 |
+
if resolved_provider == PROVIDER_HF_LOCAL:
|
| 58 |
+
default_model = DEFAULT_HF_DM_MODEL if role == "dm" else DEFAULT_HF_HERO_MODEL
|
| 59 |
+
else:
|
| 60 |
+
default_model = DEFAULT_GEMINI_DM_MODEL if role == "dm" else DEFAULT_GEMINI_HERO_MODEL
|
| 61 |
+
return StructuredClientConfig(
|
| 62 |
+
role=role,
|
| 63 |
+
provider=resolved_provider,
|
| 64 |
+
model_name=model_name or os.getenv(f"{env_prefix}_MODEL") or default_model,
|
| 65 |
+
adapter_path=adapter_path or os.getenv(f"{env_prefix}_ADAPTER_PATH"),
|
| 66 |
+
cache_dir=os.getenv("HF_HOME"),
|
| 67 |
+
load_in_4bit=_env_bool("DND_LOAD_IN_4BIT", default=True),
|
| 68 |
+
trust_remote_code=_env_bool("DND_TRUST_REMOTE_CODE", default=False),
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def create_structured_client(config: StructuredClientConfig) -> StructuredModelClient:
|
| 73 |
+
if config.provider == PROVIDER_GEMINI:
|
| 74 |
+
return GeminiStructuredClient()
|
| 75 |
+
if config.provider == PROVIDER_HF_LOCAL:
|
| 76 |
+
return HuggingFaceStructuredClient(
|
| 77 |
+
adapter_path=config.adapter_path,
|
| 78 |
+
cache_dir=config.cache_dir,
|
| 79 |
+
load_in_4bit=config.load_in_4bit,
|
| 80 |
+
trust_remote_code=config.trust_remote_code,
|
| 81 |
+
)
|
| 82 |
+
raise ValueError(f"Unsupported structured provider: {config.provider}")
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def resolve_interface_config(
|
| 86 |
+
*,
|
| 87 |
+
provider: InterfaceProvider | None = None,
|
| 88 |
+
model_name: str | None = None,
|
| 89 |
+
narrate_observations: bool | None = None,
|
| 90 |
+
translation_mode: InterfaceTranslationMode | None = None,
|
| 91 |
+
) -> InterfaceConfig:
|
| 92 |
+
resolved_translation = (
|
| 93 |
+
translation_mode
|
| 94 |
+
or _interface_translation_mode_from_env(os.getenv("DND_INTERFACE_TRANSLATION_MODE"))
|
| 95 |
+
or DEFAULT_INTERFACE_TRANSLATION_MODE
|
| 96 |
+
)
|
| 97 |
+
resolved_provider = provider or _interface_provider_from_env(os.getenv("DND_INTERFACE_PROVIDER"))
|
| 98 |
+
if resolved_provider is None:
|
| 99 |
+
resolved_provider = "gemini" if resolved_translation != "none" else DEFAULT_INTERFACE_PROVIDER
|
| 100 |
+
resolved_narrate = narrate_observations
|
| 101 |
+
if resolved_narrate is None:
|
| 102 |
+
resolved_narrate = _env_bool("DND_INTERFACE_NARRATE", default=False)
|
| 103 |
+
if resolved_translation != "none" and resolved_provider != "gemini":
|
| 104 |
+
raise ValueError("Interface translation mode requires the Gemini interface provider.")
|
| 105 |
+
return InterfaceConfig(
|
| 106 |
+
provider=resolved_provider,
|
| 107 |
+
model_name=model_name or os.getenv("DND_INTERFACE_MODEL") or DEFAULT_INTERFACE_MODEL,
|
| 108 |
+
narrate_observations=resolved_narrate,
|
| 109 |
+
translation_mode=resolved_translation,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def build_interface_adapter(config: InterfaceConfig):
|
| 114 |
+
from agents.master.interface import GeminiInterfaceAdapter, SimpleInterfaceAdapter, StrictCliInterfaceAdapter
|
| 115 |
+
|
| 116 |
+
if config.provider == "strict":
|
| 117 |
+
return StrictCliInterfaceAdapter()
|
| 118 |
+
if config.provider == "simple":
|
| 119 |
+
return SimpleInterfaceAdapter()
|
| 120 |
+
if config.provider == "gemini":
|
| 121 |
+
return GeminiInterfaceAdapter(
|
| 122 |
+
model=config.model_name,
|
| 123 |
+
narrate_observations=config.narrate_observations,
|
| 124 |
+
translation_mode=config.translation_mode,
|
| 125 |
+
)
|
| 126 |
+
raise ValueError(f"Unsupported interface provider: {config.provider}")
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def _structured_provider_from_env(value: str | None) -> StructuredProvider | None:
|
| 130 |
+
if value is None:
|
| 131 |
+
return None
|
| 132 |
+
normalized = value.strip().lower()
|
| 133 |
+
if normalized not in {PROVIDER_GEMINI, PROVIDER_HF_LOCAL}:
|
| 134 |
+
raise ValueError(f"Unsupported structured provider value: {value}")
|
| 135 |
+
return normalized # type: ignore[return-value]
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def _interface_provider_from_env(value: str | None) -> InterfaceProvider | None:
|
| 139 |
+
if value is None:
|
| 140 |
+
return None
|
| 141 |
+
normalized = value.strip().lower()
|
| 142 |
+
if normalized not in {"strict", "simple", "gemini"}:
|
| 143 |
+
raise ValueError(f"Unsupported interface provider value: {value}")
|
| 144 |
+
return normalized # type: ignore[return-value]
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def _interface_translation_mode_from_env(value: str | None) -> InterfaceTranslationMode | None:
|
| 148 |
+
if value is None:
|
| 149 |
+
return None
|
| 150 |
+
normalized = value.strip().lower()
|
| 151 |
+
if normalized not in {"none", "corporate_app"}:
|
| 152 |
+
raise ValueError(f"Unsupported interface translation mode value: {value}")
|
| 153 |
+
return normalized # type: ignore[return-value]
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def _env_bool(name: str, *, default: bool) -> bool:
|
| 157 |
+
raw = os.getenv(name)
|
| 158 |
+
if raw is None:
|
| 159 |
+
return default
|
| 160 |
+
normalized = raw.strip().lower()
|
| 161 |
+
if normalized in {"1", "true", "yes", "on"}:
|
| 162 |
+
return True
|
| 163 |
+
if normalized in {"0", "false", "no", "off"}:
|
| 164 |
+
return False
|
| 165 |
+
raise ValueError(f"Environment variable {name} must be a boolean value, got {raw!r}")
|
agents/spaces/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hugging Face Space wrapper apps for the dungeon environments."""
|
| 2 |
+
|
| 3 |
+
from .dm_space import LatestWorldOutputStore, SpaceDMEnvironment, create_app as create_dm_space_app
|
| 4 |
+
from .hero_space import SpaceHeroEnvironment, UploadedWorldStore, create_app as create_hero_space_app
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"LatestWorldOutputStore",
|
| 8 |
+
"SpaceDMEnvironment",
|
| 9 |
+
"SpaceHeroEnvironment",
|
| 10 |
+
"UploadedWorldStore",
|
| 11 |
+
"create_dm_space_app",
|
| 12 |
+
"create_hero_space_app",
|
| 13 |
+
]
|
agents/spaces/dm_space.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from datetime import datetime, timezone
|
| 5 |
+
from html import escape
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from threading import Lock
|
| 8 |
+
from typing import Any, Callable
|
| 9 |
+
|
| 10 |
+
from fastapi import FastAPI, HTTPException
|
| 11 |
+
from fastapi.responses import FileResponse, HTMLResponse
|
| 12 |
+
import uvicorn
|
| 13 |
+
|
| 14 |
+
from agents.master.env import DMEnvironment
|
| 15 |
+
from agents.master.schema import CompiledWorld, DMAction, DMObservation, WorldDefinition
|
| 16 |
+
from agents.shared.openenv_compat import StepResult
|
| 17 |
+
from agents.shared.runtime import build_interface_adapter, resolve_interface_config
|
| 18 |
+
|
| 19 |
+
DEFAULT_ARTIFACTS_ROOT = Path("/tmp/dnd_dm_artifacts")
|
| 20 |
+
DEFAULT_HOST = "0.0.0.0"
|
| 21 |
+
DEFAULT_PORT = 8000
|
| 22 |
+
DEFAULT_MAX_CONCURRENT_ENVS = 1
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass(frozen=True)
|
| 26 |
+
class LatestWorldSnapshot:
|
| 27 |
+
episode_id: str
|
| 28 |
+
title: str
|
| 29 |
+
path: Path
|
| 30 |
+
updated_at: str
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class LatestWorldOutputStore:
|
| 34 |
+
def __init__(self) -> None:
|
| 35 |
+
self._lock = Lock()
|
| 36 |
+
self._snapshot: LatestWorldSnapshot | None = None
|
| 37 |
+
|
| 38 |
+
def record(self, compiled: CompiledWorld) -> None:
|
| 39 |
+
path = compiled.artifacts_dir / "world_definition.normalized.json"
|
| 40 |
+
if not path.is_file():
|
| 41 |
+
return
|
| 42 |
+
snapshot = LatestWorldSnapshot(
|
| 43 |
+
episode_id=compiled.episode_id,
|
| 44 |
+
title=compiled.world.meta.title,
|
| 45 |
+
path=path,
|
| 46 |
+
updated_at=datetime.now(timezone.utc).isoformat(),
|
| 47 |
+
)
|
| 48 |
+
with self._lock:
|
| 49 |
+
self._snapshot = snapshot
|
| 50 |
+
|
| 51 |
+
def latest_path(self) -> Path | None:
|
| 52 |
+
snapshot = self.snapshot()
|
| 53 |
+
return None if snapshot is None else snapshot.path
|
| 54 |
+
|
| 55 |
+
def snapshot(self) -> LatestWorldSnapshot | None:
|
| 56 |
+
with self._lock:
|
| 57 |
+
return self._snapshot
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class SpaceDMEnvironment(DMEnvironment):
|
| 61 |
+
def __init__(self, *, world_output_store: LatestWorldOutputStore, **kwargs: Any) -> None:
|
| 62 |
+
super().__init__(**kwargs)
|
| 63 |
+
self._world_output_store = world_output_store
|
| 64 |
+
|
| 65 |
+
def step( # type: ignore[override]
|
| 66 |
+
self,
|
| 67 |
+
action: DMAction | WorldDefinition | dict[str, Any],
|
| 68 |
+
runner: Any | None = None,
|
| 69 |
+
observer: Any | None = None,
|
| 70 |
+
timeout_s: float | None = None,
|
| 71 |
+
) -> StepResult[DMObservation]:
|
| 72 |
+
result = super().step(action, runner=runner, observer=observer, timeout_s=timeout_s)
|
| 73 |
+
observation = result.observation
|
| 74 |
+
if observation.compile_error is None and self.last_compiled_world is not None:
|
| 75 |
+
self._world_output_store.record(self.last_compiled_world)
|
| 76 |
+
return result
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def create_app(
|
| 80 |
+
*,
|
| 81 |
+
openenv_app_factory: Callable[..., Any] | None = None,
|
| 82 |
+
world_output_store: LatestWorldOutputStore | None = None,
|
| 83 |
+
artifacts_root: Path = DEFAULT_ARTIFACTS_ROOT,
|
| 84 |
+
max_concurrent_envs: int = DEFAULT_MAX_CONCURRENT_ENVS,
|
| 85 |
+
) -> FastAPI:
|
| 86 |
+
if openenv_app_factory is None:
|
| 87 |
+
from openenv.core.env_server import create_fastapi_app as openenv_app_factory
|
| 88 |
+
|
| 89 |
+
store = world_output_store or LatestWorldOutputStore()
|
| 90 |
+
interface_adapter = build_interface_adapter(resolve_interface_config(provider="strict"))
|
| 91 |
+
|
| 92 |
+
env_app = openenv_app_factory(
|
| 93 |
+
env=lambda: SpaceDMEnvironment(
|
| 94 |
+
artifacts_root=artifacts_root,
|
| 95 |
+
interface_adapter=interface_adapter,
|
| 96 |
+
world_output_store=store,
|
| 97 |
+
),
|
| 98 |
+
action_cls=DMAction,
|
| 99 |
+
observation_cls=DMObservation,
|
| 100 |
+
max_concurrent_envs=max_concurrent_envs,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
app = FastAPI(title="DND-DM")
|
| 104 |
+
app.state.world_output_store = store
|
| 105 |
+
app.mount("/env", env_app)
|
| 106 |
+
|
| 107 |
+
@app.get("/", response_class=HTMLResponse)
|
| 108 |
+
def index() -> str:
|
| 109 |
+
return _render_index(store.snapshot())
|
| 110 |
+
|
| 111 |
+
@app.get("/healthz")
|
| 112 |
+
def healthz() -> dict[str, bool]:
|
| 113 |
+
return {"ok": True}
|
| 114 |
+
|
| 115 |
+
@app.get("/world-output/latest")
|
| 116 |
+
def latest_world_output() -> FileResponse:
|
| 117 |
+
path = store.latest_path()
|
| 118 |
+
if path is None or not path.is_file():
|
| 119 |
+
raise HTTPException(status_code=404, detail="No successful normalized world output is available yet.")
|
| 120 |
+
return FileResponse(
|
| 121 |
+
path,
|
| 122 |
+
media_type="application/json",
|
| 123 |
+
filename="world_definition.normalized.json",
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
return app
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def _render_index(snapshot: LatestWorldSnapshot | None) -> str:
|
| 130 |
+
latest_html = (
|
| 131 |
+
"<p>No successful normalized world output has been recorded yet.</p>"
|
| 132 |
+
if snapshot is None
|
| 133 |
+
else (
|
| 134 |
+
"<p>"
|
| 135 |
+
f"Latest world: <strong>{escape(snapshot.title)}</strong> "
|
| 136 |
+
f"(episode <code>{escape(snapshot.episode_id)}</code>, updated {escape(snapshot.updated_at)}). "
|
| 137 |
+
'<a href="/world-output/latest">Download normalized world JSON</a>.'
|
| 138 |
+
"</p>"
|
| 139 |
+
)
|
| 140 |
+
)
|
| 141 |
+
return f"""<!doctype html>
|
| 142 |
+
<html lang="en">
|
| 143 |
+
<head>
|
| 144 |
+
<meta charset="utf-8">
|
| 145 |
+
<meta name="viewport" content="width=device-width, initial-scale=1">
|
| 146 |
+
<title>DND-DM</title>
|
| 147 |
+
<style>
|
| 148 |
+
body {{
|
| 149 |
+
font-family: "IBM Plex Sans", "Helvetica Neue", sans-serif;
|
| 150 |
+
margin: 0;
|
| 151 |
+
background: #f4efe5;
|
| 152 |
+
color: #1b1a17;
|
| 153 |
+
}}
|
| 154 |
+
main {{
|
| 155 |
+
max-width: 760px;
|
| 156 |
+
margin: 0 auto;
|
| 157 |
+
padding: 48px 24px 64px;
|
| 158 |
+
}}
|
| 159 |
+
a {{ color: #0b5c78; }}
|
| 160 |
+
code {{
|
| 161 |
+
background: rgba(11, 92, 120, 0.08);
|
| 162 |
+
padding: 0.15rem 0.35rem;
|
| 163 |
+
border-radius: 0.3rem;
|
| 164 |
+
}}
|
| 165 |
+
.panel {{
|
| 166 |
+
border: 1px solid rgba(27, 26, 23, 0.12);
|
| 167 |
+
background: rgba(255, 255, 255, 0.72);
|
| 168 |
+
border-radius: 18px;
|
| 169 |
+
padding: 20px 22px;
|
| 170 |
+
margin-top: 18px;
|
| 171 |
+
}}
|
| 172 |
+
</style>
|
| 173 |
+
</head>
|
| 174 |
+
<body>
|
| 175 |
+
<main>
|
| 176 |
+
<h1>DND-DM</h1>
|
| 177 |
+
<p>This Space hosts the dungeon DM OpenEnv environment as a CPU-only evaluator.</p>
|
| 178 |
+
<div class="panel">
|
| 179 |
+
<p>The OpenEnv API is mounted at <a href="/env"><code>/env</code></a>.</p>
|
| 180 |
+
<p>The DM evaluates submitted world definitions and writes the latest normalized JSON artifact for manual handoff to <code>DND-Hero</code>.</p>
|
| 181 |
+
{latest_html}
|
| 182 |
+
</div>
|
| 183 |
+
</main>
|
| 184 |
+
</body>
|
| 185 |
+
</html>"""
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def main() -> int:
|
| 189 |
+
uvicorn.run("agents.spaces.dm_space:create_app", factory=True, host=DEFAULT_HOST, port=DEFAULT_PORT)
|
| 190 |
+
return 0
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
if __name__ == "__main__":
|
| 194 |
+
raise SystemExit(main())
|
agents/spaces/hero_space.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from copy import deepcopy
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from datetime import datetime, timezone
|
| 6 |
+
from html import escape
|
| 7 |
+
import json
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from threading import Lock
|
| 10 |
+
from typing import Any, Callable
|
| 11 |
+
|
| 12 |
+
from fastapi import FastAPI, File, HTTPException, Request, UploadFile
|
| 13 |
+
from fastapi.responses import HTMLResponse, JSONResponse, Response
|
| 14 |
+
import uvicorn
|
| 15 |
+
|
| 16 |
+
from agents.hero.env import HeroEnvironment
|
| 17 |
+
from agents.hero.schema import HeroObservation, HeroServerAction
|
| 18 |
+
from agents.master.check import DMCompileError, validate_and_normalize
|
| 19 |
+
from agents.master.schema import WorldDefinition
|
| 20 |
+
from agents.shared.runtime import build_interface_adapter, resolve_interface_config
|
| 21 |
+
|
| 22 |
+
DEFAULT_ARTIFACTS_ROOT = Path("/tmp/dnd_hero_artifacts")
|
| 23 |
+
DEFAULT_HOST = "0.0.0.0"
|
| 24 |
+
DEFAULT_PORT = 8000
|
| 25 |
+
DEFAULT_MAX_CONCURRENT_ENVS = 1
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass(frozen=True)
|
| 29 |
+
class UploadedWorldSnapshot:
|
| 30 |
+
world_input: dict[str, Any]
|
| 31 |
+
title: str
|
| 32 |
+
size_bytes: int
|
| 33 |
+
updated_at: str
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class UploadedWorldStore:
|
| 37 |
+
def __init__(self) -> None:
|
| 38 |
+
self._lock = Lock()
|
| 39 |
+
self._snapshot: UploadedWorldSnapshot | None = None
|
| 40 |
+
|
| 41 |
+
def set_world(self, world: WorldDefinition | dict[str, Any]) -> UploadedWorldSnapshot:
|
| 42 |
+
if isinstance(world, dict):
|
| 43 |
+
world = validate_and_normalize(world)
|
| 44 |
+
world_input = world.model_dump(mode="json")
|
| 45 |
+
snapshot = UploadedWorldSnapshot(
|
| 46 |
+
world_input=world_input,
|
| 47 |
+
title=world.meta.title,
|
| 48 |
+
size_bytes=len(json.dumps(world_input).encode("utf-8")),
|
| 49 |
+
updated_at=datetime.now(timezone.utc).isoformat(),
|
| 50 |
+
)
|
| 51 |
+
with self._lock:
|
| 52 |
+
self._snapshot = snapshot
|
| 53 |
+
return snapshot
|
| 54 |
+
|
| 55 |
+
def clear(self) -> None:
|
| 56 |
+
with self._lock:
|
| 57 |
+
self._snapshot = None
|
| 58 |
+
|
| 59 |
+
def current_world(self) -> dict[str, Any] | None:
|
| 60 |
+
snapshot = self.snapshot()
|
| 61 |
+
return None if snapshot is None else deepcopy(snapshot.world_input)
|
| 62 |
+
|
| 63 |
+
def snapshot(self) -> UploadedWorldSnapshot | None:
|
| 64 |
+
with self._lock:
|
| 65 |
+
return self._snapshot
|
| 66 |
+
|
| 67 |
+
def metadata(self) -> dict[str, Any]:
|
| 68 |
+
snapshot = self.snapshot()
|
| 69 |
+
if snapshot is None:
|
| 70 |
+
return {"configured": False}
|
| 71 |
+
return {
|
| 72 |
+
"configured": True,
|
| 73 |
+
"title": snapshot.title,
|
| 74 |
+
"size_bytes": snapshot.size_bytes,
|
| 75 |
+
"updated_at": snapshot.updated_at,
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class SpaceHeroEnvironment(HeroEnvironment):
|
| 80 |
+
def __init__(self, *, uploaded_world_store: UploadedWorldStore, **kwargs: Any) -> None:
|
| 81 |
+
super().__init__(**kwargs)
|
| 82 |
+
self._uploaded_world_store = uploaded_world_store
|
| 83 |
+
|
| 84 |
+
def reset( # type: ignore[override]
|
| 85 |
+
self,
|
| 86 |
+
world_input: Any | None = None,
|
| 87 |
+
*,
|
| 88 |
+
seed: int | None = None,
|
| 89 |
+
episode_id: str | None = None,
|
| 90 |
+
max_game_steps: int | None = None,
|
| 91 |
+
max_tool_calls: int | None = None,
|
| 92 |
+
scratchpad_max_chars: int | None = None,
|
| 93 |
+
debug: bool | None = None,
|
| 94 |
+
) -> HeroObservation:
|
| 95 |
+
selected_world_input = world_input
|
| 96 |
+
if selected_world_input is None:
|
| 97 |
+
selected_world_input = self._uploaded_world_store.current_world()
|
| 98 |
+
if selected_world_input is None:
|
| 99 |
+
raise ValueError(
|
| 100 |
+
"Upload a world JSON to /world-input or pass world_input explicitly before resetting DND-Hero."
|
| 101 |
+
)
|
| 102 |
+
return super().reset(
|
| 103 |
+
selected_world_input,
|
| 104 |
+
seed=seed,
|
| 105 |
+
episode_id=episode_id,
|
| 106 |
+
max_game_steps=max_game_steps,
|
| 107 |
+
max_tool_calls=max_tool_calls,
|
| 108 |
+
scratchpad_max_chars=scratchpad_max_chars,
|
| 109 |
+
debug=debug,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def create_app(
|
| 114 |
+
*,
|
| 115 |
+
openenv_app_factory: Callable[..., Any] | None = None,
|
| 116 |
+
uploaded_world_store: UploadedWorldStore | None = None,
|
| 117 |
+
artifacts_root: Path = DEFAULT_ARTIFACTS_ROOT,
|
| 118 |
+
max_concurrent_envs: int = DEFAULT_MAX_CONCURRENT_ENVS,
|
| 119 |
+
) -> FastAPI:
|
| 120 |
+
if openenv_app_factory is None:
|
| 121 |
+
from openenv.core.env_server import create_fastapi_app as openenv_app_factory
|
| 122 |
+
|
| 123 |
+
store = uploaded_world_store or UploadedWorldStore()
|
| 124 |
+
interface_adapter = build_interface_adapter(resolve_interface_config(provider="strict"))
|
| 125 |
+
|
| 126 |
+
env_app = openenv_app_factory(
|
| 127 |
+
env=lambda: SpaceHeroEnvironment(
|
| 128 |
+
artifacts_root=artifacts_root,
|
| 129 |
+
uploaded_world_store=store,
|
| 130 |
+
interface_adapter=interface_adapter,
|
| 131 |
+
),
|
| 132 |
+
action_cls=HeroServerAction,
|
| 133 |
+
observation_cls=HeroObservation,
|
| 134 |
+
max_concurrent_envs=max_concurrent_envs,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
app = FastAPI(title="DND-Hero")
|
| 138 |
+
app.state.uploaded_world_store = store
|
| 139 |
+
app.mount("/env", env_app)
|
| 140 |
+
|
| 141 |
+
@app.get("/", response_class=HTMLResponse)
|
| 142 |
+
def index() -> str:
|
| 143 |
+
return _render_index(store.metadata())
|
| 144 |
+
|
| 145 |
+
@app.get("/healthz")
|
| 146 |
+
def healthz() -> dict[str, bool]:
|
| 147 |
+
return {"ok": True}
|
| 148 |
+
|
| 149 |
+
@app.post("/world-input")
|
| 150 |
+
async def upload_world_input(
|
| 151 |
+
request: Request,
|
| 152 |
+
file: UploadFile | None = File(default=None),
|
| 153 |
+
) -> JSONResponse:
|
| 154 |
+
payload = await file.read() if file is not None else await request.body()
|
| 155 |
+
if not payload:
|
| 156 |
+
raise HTTPException(status_code=400, detail="Provide a world JSON file upload or a raw JSON request body.")
|
| 157 |
+
try:
|
| 158 |
+
raw_world = json.loads(payload.decode("utf-8"))
|
| 159 |
+
except UnicodeDecodeError as exc:
|
| 160 |
+
raise HTTPException(status_code=400, detail="World input must be UTF-8 JSON.") from exc
|
| 161 |
+
except json.JSONDecodeError as exc:
|
| 162 |
+
raise HTTPException(status_code=400, detail=f"Invalid JSON: {exc.msg}") from exc
|
| 163 |
+
if not isinstance(raw_world, dict):
|
| 164 |
+
raise HTTPException(status_code=400, detail="World input JSON must be an object.")
|
| 165 |
+
try:
|
| 166 |
+
world = validate_and_normalize(raw_world)
|
| 167 |
+
except DMCompileError as exc:
|
| 168 |
+
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
| 169 |
+
snapshot = store.set_world(world)
|
| 170 |
+
return JSONResponse(
|
| 171 |
+
{
|
| 172 |
+
"configured": True,
|
| 173 |
+
"title": snapshot.title,
|
| 174 |
+
"size_bytes": snapshot.size_bytes,
|
| 175 |
+
"updated_at": snapshot.updated_at,
|
| 176 |
+
}
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
@app.get("/world-input")
|
| 180 |
+
def world_input_metadata() -> JSONResponse:
|
| 181 |
+
return JSONResponse(store.metadata())
|
| 182 |
+
|
| 183 |
+
@app.delete("/world-input", status_code=204)
|
| 184 |
+
def clear_world_input() -> Response:
|
| 185 |
+
store.clear()
|
| 186 |
+
return Response(status_code=204)
|
| 187 |
+
|
| 188 |
+
return app
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def _render_index(metadata: dict[str, Any]) -> str:
|
| 192 |
+
current_world_html = (
|
| 193 |
+
"<p>No default world is uploaded yet.</p>"
|
| 194 |
+
if not metadata.get("configured")
|
| 195 |
+
else (
|
| 196 |
+
"<p>"
|
| 197 |
+
f"Current uploaded world: <strong>{escape(str(metadata['title']))}</strong> "
|
| 198 |
+
f"({escape(str(metadata['size_bytes']))} bytes, updated {escape(str(metadata['updated_at']))})."
|
| 199 |
+
"</p>"
|
| 200 |
+
)
|
| 201 |
+
)
|
| 202 |
+
return f"""<!doctype html>
|
| 203 |
+
<html lang="en">
|
| 204 |
+
<head>
|
| 205 |
+
<meta charset="utf-8">
|
| 206 |
+
<meta name="viewport" content="width=device-width, initial-scale=1">
|
| 207 |
+
<title>DND-Hero</title>
|
| 208 |
+
<style>
|
| 209 |
+
body {{
|
| 210 |
+
font-family: "IBM Plex Sans", "Helvetica Neue", sans-serif;
|
| 211 |
+
margin: 0;
|
| 212 |
+
background: #eef4eb;
|
| 213 |
+
color: #182118;
|
| 214 |
+
}}
|
| 215 |
+
main {{
|
| 216 |
+
max-width: 760px;
|
| 217 |
+
margin: 0 auto;
|
| 218 |
+
padding: 48px 24px 64px;
|
| 219 |
+
}}
|
| 220 |
+
a {{ color: #146042; }}
|
| 221 |
+
code {{
|
| 222 |
+
background: rgba(20, 96, 66, 0.08);
|
| 223 |
+
padding: 0.15rem 0.35rem;
|
| 224 |
+
border-radius: 0.3rem;
|
| 225 |
+
}}
|
| 226 |
+
.panel {{
|
| 227 |
+
border: 1px solid rgba(24, 33, 24, 0.12);
|
| 228 |
+
background: rgba(255, 255, 255, 0.76);
|
| 229 |
+
border-radius: 18px;
|
| 230 |
+
padding: 20px 22px;
|
| 231 |
+
margin-top: 18px;
|
| 232 |
+
}}
|
| 233 |
+
input[type="file"] {{
|
| 234 |
+
display: block;
|
| 235 |
+
margin-bottom: 12px;
|
| 236 |
+
}}
|
| 237 |
+
button {{
|
| 238 |
+
background: #146042;
|
| 239 |
+
color: white;
|
| 240 |
+
border: 0;
|
| 241 |
+
border-radius: 999px;
|
| 242 |
+
padding: 0.7rem 1rem;
|
| 243 |
+
cursor: pointer;
|
| 244 |
+
}}
|
| 245 |
+
</style>
|
| 246 |
+
</head>
|
| 247 |
+
<body>
|
| 248 |
+
<main>
|
| 249 |
+
<h1>DND-Hero</h1>
|
| 250 |
+
<p>This Space hosts the dungeon Hero OpenEnv environment as a CPU-only evaluator.</p>
|
| 251 |
+
<div class="panel">
|
| 252 |
+
<p>The OpenEnv API is mounted at <a href="/env"><code>/env</code></a>.</p>
|
| 253 |
+
<p>Upload a normalized world-definition JSON file from <code>DND-DM</code> to make it the default world for future hero resets.</p>
|
| 254 |
+
{current_world_html}
|
| 255 |
+
<form action="/world-input" method="post" enctype="multipart/form-data">
|
| 256 |
+
<input type="file" name="file" accept="application/json,.json" required>
|
| 257 |
+
<button type="submit">Upload World JSON</button>
|
| 258 |
+
</form>
|
| 259 |
+
</div>
|
| 260 |
+
</main>
|
| 261 |
+
</body>
|
| 262 |
+
</html>"""
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def main() -> int:
|
| 266 |
+
uvicorn.run("agents.spaces.hero_space:create_app", factory=True, host=DEFAULT_HOST, port=DEFAULT_PORT)
|
| 267 |
+
return 0
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
if __name__ == "__main__":
|
| 271 |
+
raise SystemExit(main())
|
agents/train/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Training entrypoints for GRPO-based experiments."""
|
| 2 |
+
|
agents/train/__main__.py
ADDED
|
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
from dotenv import load_dotenv
|
| 9 |
+
|
| 10 |
+
from agents.shared.llm_client import DEFAULT_HF_DM_MODEL, DEFAULT_HF_HERO_MODEL
|
| 11 |
+
|
| 12 |
+
from .grpo import (
|
| 13 |
+
DMClosedLoopConfig,
|
| 14 |
+
GRPOLaunchConfig,
|
| 15 |
+
SUPPORTED_GRPO_LOSS_TYPES,
|
| 16 |
+
SUPPORTED_IMPORTANCE_SAMPLING_LEVELS,
|
| 17 |
+
build_dm_grpo_dataset,
|
| 18 |
+
build_hero_grpo_dataset,
|
| 19 |
+
run_dm_grpo,
|
| 20 |
+
run_hero_grpo,
|
| 21 |
+
)
|
| 22 |
+
from .joint import JointTrainingConfig, run_joint_training_loop
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def main(argv: list[str] | None = None) -> int:
|
| 26 |
+
_load_repo_dotenv()
|
| 27 |
+
parser = argparse.ArgumentParser(description="GRPO training harnesses for dungeon agents.")
|
| 28 |
+
subparsers = parser.add_subparsers(dest="command", required=True)
|
| 29 |
+
|
| 30 |
+
dm_parser = subparsers.add_parser("dm-grpo", help="Run GRPO for the dungeon-master generator.")
|
| 31 |
+
_add_common_args(dm_parser, default_model=DEFAULT_HF_DM_MODEL, default_output_dir="artifacts/grpo/dm")
|
| 32 |
+
dm_parser.add_argument("--target-ratio", type=float, action="append")
|
| 33 |
+
dm_parser.add_argument("--artifacts-root", type=Path)
|
| 34 |
+
dm_parser.add_argument("--hero-provider", choices=["gemini", "hf_local"])
|
| 35 |
+
dm_parser.add_argument("--hero-model")
|
| 36 |
+
dm_parser.add_argument("--hero-adapter-path")
|
| 37 |
+
dm_parser.add_argument("--interface-provider", choices=["strict", "simple", "gemini"])
|
| 38 |
+
dm_parser.add_argument("--interface-model")
|
| 39 |
+
dm_parser.add_argument("--interface-narrate", action="store_true")
|
| 40 |
+
dm_parser.add_argument(
|
| 41 |
+
"--translate-corporate-env",
|
| 42 |
+
action="store_true",
|
| 43 |
+
help="Rewrite hero-facing observations into a corporate app metaphor and map translated commands back through Gemini.",
|
| 44 |
+
)
|
| 45 |
+
dm_parser.add_argument("--hero-max-game-steps", type=int, default=40)
|
| 46 |
+
dm_parser.add_argument("--hero-max-tool-calls", type=int, default=80)
|
| 47 |
+
|
| 48 |
+
hero_parser = subparsers.add_parser("hero-grpo", help="Run GRPO for the hero tool-calling policy.")
|
| 49 |
+
_add_common_args(hero_parser, default_model=DEFAULT_HF_HERO_MODEL, default_output_dir="artifacts/grpo/hero")
|
| 50 |
+
hero_parser.add_argument("--world", type=Path)
|
| 51 |
+
hero_parser.add_argument("--artifacts-root", type=Path)
|
| 52 |
+
hero_parser.add_argument("--max-game-steps", type=int, default=40)
|
| 53 |
+
hero_parser.add_argument("--max-tool-calls", type=int, default=80)
|
| 54 |
+
hero_parser.add_argument("--max-tool-calling-iterations", type=int, default=32)
|
| 55 |
+
hero_parser.add_argument("--interface-provider", choices=["strict", "simple", "gemini"])
|
| 56 |
+
hero_parser.add_argument("--interface-model")
|
| 57 |
+
hero_parser.add_argument("--interface-narrate", action="store_true")
|
| 58 |
+
hero_parser.add_argument(
|
| 59 |
+
"--translate-corporate-env",
|
| 60 |
+
action="store_true",
|
| 61 |
+
help="Rewrite hero-facing observations into a corporate app metaphor and map translated commands back through Gemini.",
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
joint_parser = subparsers.add_parser("joint-loop", help="Alternate hero and DM GRPO phases with adapter carry-over.")
|
| 65 |
+
joint_parser.add_argument("--root-dir", type=Path, required=True)
|
| 66 |
+
joint_parser.add_argument("--cycles", type=int, default=1)
|
| 67 |
+
joint_parser.add_argument("--target-ratio", type=float, action="append")
|
| 68 |
+
joint_parser.add_argument("--hero-world", type=Path)
|
| 69 |
+
joint_parser.add_argument("--interface-provider", choices=["strict", "simple", "gemini"])
|
| 70 |
+
joint_parser.add_argument("--interface-model")
|
| 71 |
+
joint_parser.add_argument("--interface-narrate", action="store_true")
|
| 72 |
+
joint_parser.add_argument(
|
| 73 |
+
"--translate-corporate-env",
|
| 74 |
+
action="store_true",
|
| 75 |
+
help="Rewrite hero-facing observations into a corporate app metaphor and map translated commands back through Gemini.",
|
| 76 |
+
)
|
| 77 |
+
joint_parser.add_argument("--hero-max-game-steps", type=int, default=40)
|
| 78 |
+
joint_parser.add_argument("--hero-max-tool-calls", type=int, default=80)
|
| 79 |
+
joint_parser.add_argument("--hero-max-tool-calling-iterations", type=int, default=32)
|
| 80 |
+
_add_prefixed_common_args(
|
| 81 |
+
joint_parser,
|
| 82 |
+
prefix="hero",
|
| 83 |
+
default_model=DEFAULT_HF_HERO_MODEL,
|
| 84 |
+
default_max_steps=24,
|
| 85 |
+
default_num_prompts=16,
|
| 86 |
+
default_max_completion_length=512,
|
| 87 |
+
)
|
| 88 |
+
_add_prefixed_common_args(
|
| 89 |
+
joint_parser,
|
| 90 |
+
prefix="dm",
|
| 91 |
+
default_model=DEFAULT_HF_DM_MODEL,
|
| 92 |
+
default_max_steps=8,
|
| 93 |
+
default_num_prompts=16,
|
| 94 |
+
default_max_completion_length=2048,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
dataset_parser = subparsers.add_parser("smoke-dataset", help="Print smoke dataset rows for inspection.")
|
| 98 |
+
dataset_parser.add_argument("role", choices=["dm", "hero"])
|
| 99 |
+
dataset_parser.add_argument("--num-prompts", type=int, default=2)
|
| 100 |
+
dataset_parser.add_argument("--target-ratio", type=float, action="append")
|
| 101 |
+
dataset_parser.add_argument("--world", type=Path)
|
| 102 |
+
dataset_parser.add_argument("--max-game-steps", type=int, default=40)
|
| 103 |
+
dataset_parser.add_argument("--max-tool-calls", type=int, default=80)
|
| 104 |
+
|
| 105 |
+
args = parser.parse_args(argv)
|
| 106 |
+
|
| 107 |
+
if args.command == "smoke-dataset":
|
| 108 |
+
if args.role == "dm":
|
| 109 |
+
rows = build_dm_grpo_dataset(num_prompts=args.num_prompts, target_ratios=args.target_ratio)
|
| 110 |
+
else:
|
| 111 |
+
world_input = None if args.world is None else json.loads(args.world.read_text(encoding="utf-8"))
|
| 112 |
+
rows = build_hero_grpo_dataset(
|
| 113 |
+
num_prompts=args.num_prompts,
|
| 114 |
+
world_input=world_input,
|
| 115 |
+
max_game_steps=args.max_game_steps,
|
| 116 |
+
max_tool_calls=args.max_tool_calls,
|
| 117 |
+
)
|
| 118 |
+
print(json.dumps(rows, indent=2))
|
| 119 |
+
return 0
|
| 120 |
+
|
| 121 |
+
if args.command == "joint-loop":
|
| 122 |
+
hero_config = _build_prefixed_grpo_config(args, "hero", default_output_dir=args.root_dir / "hero")
|
| 123 |
+
dm_config = _build_prefixed_grpo_config(args, "dm", default_output_dir=args.root_dir / "dm")
|
| 124 |
+
run_joint_training_loop(
|
| 125 |
+
JointTrainingConfig(
|
| 126 |
+
root_dir=args.root_dir,
|
| 127 |
+
cycles=args.cycles,
|
| 128 |
+
hero_config=hero_config,
|
| 129 |
+
dm_config=dm_config,
|
| 130 |
+
target_ratios=args.target_ratio,
|
| 131 |
+
hero_world_path=args.hero_world,
|
| 132 |
+
interface_provider=args.interface_provider,
|
| 133 |
+
interface_model=args.interface_model,
|
| 134 |
+
interface_narrate=args.interface_narrate,
|
| 135 |
+
interface_translation_mode="corporate_app" if args.translate_corporate_env else None,
|
| 136 |
+
hero_max_game_steps=args.hero_max_game_steps,
|
| 137 |
+
hero_max_tool_calls=args.hero_max_tool_calls,
|
| 138 |
+
hero_max_tool_calling_iterations=args.hero_max_tool_calling_iterations,
|
| 139 |
+
)
|
| 140 |
+
)
|
| 141 |
+
return 0
|
| 142 |
+
|
| 143 |
+
config = GRPOLaunchConfig(
|
| 144 |
+
model_name=args.model,
|
| 145 |
+
output_dir=args.output_dir,
|
| 146 |
+
resume_adapter_path=args.resume_adapter_path,
|
| 147 |
+
max_steps=args.max_steps,
|
| 148 |
+
num_prompts=args.num_prompts,
|
| 149 |
+
learning_rate=args.learning_rate,
|
| 150 |
+
per_device_train_batch_size=args.per_device_train_batch_size,
|
| 151 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 152 |
+
num_generations=args.num_generations,
|
| 153 |
+
max_completion_length=args.max_completion_length,
|
| 154 |
+
logging_steps=args.logging_steps,
|
| 155 |
+
save_steps=args.save_steps,
|
| 156 |
+
seed=args.seed,
|
| 157 |
+
rank=args.rank,
|
| 158 |
+
alpha=args.alpha,
|
| 159 |
+
dropout=args.dropout,
|
| 160 |
+
temperature=args.temperature,
|
| 161 |
+
top_p=args.top_p,
|
| 162 |
+
top_k=args.top_k,
|
| 163 |
+
min_p=args.min_p,
|
| 164 |
+
repetition_penalty=args.repetition_penalty,
|
| 165 |
+
use_wandb=not args.no_wandb,
|
| 166 |
+
run_name=args.run_name,
|
| 167 |
+
trust_remote_code=args.trust_remote_code,
|
| 168 |
+
load_in_4bit=not args.no_4bit,
|
| 169 |
+
loss_type=args.loss_type,
|
| 170 |
+
importance_sampling_level=args.importance_sampling_level,
|
| 171 |
+
use_transformers_paged=args.use_transformers_paged,
|
| 172 |
+
cache_implementation=args.cache_implementation,
|
| 173 |
+
use_vllm=args.use_vllm,
|
| 174 |
+
vllm_mode=args.vllm_mode,
|
| 175 |
+
vllm_gpu_memory_utilization=args.vllm_gpu_memory_utilization,
|
| 176 |
+
vllm_enable_sleep_mode=not args.no_vllm_sleep_mode,
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
if args.command == "dm-grpo":
|
| 180 |
+
run_dm_grpo(
|
| 181 |
+
config,
|
| 182 |
+
target_ratios=args.target_ratio,
|
| 183 |
+
artifacts_root=args.artifacts_root,
|
| 184 |
+
closed_loop=DMClosedLoopConfig(
|
| 185 |
+
hero_provider=args.hero_provider,
|
| 186 |
+
hero_model=args.hero_model,
|
| 187 |
+
hero_adapter_path=args.hero_adapter_path,
|
| 188 |
+
interface_provider=args.interface_provider,
|
| 189 |
+
interface_model=args.interface_model,
|
| 190 |
+
interface_narrate=args.interface_narrate,
|
| 191 |
+
interface_translation_mode="corporate_app" if args.translate_corporate_env else None,
|
| 192 |
+
hero_max_game_steps=args.hero_max_game_steps,
|
| 193 |
+
hero_max_tool_calls=args.hero_max_tool_calls,
|
| 194 |
+
),
|
| 195 |
+
)
|
| 196 |
+
return 0
|
| 197 |
+
|
| 198 |
+
run_hero_grpo(
|
| 199 |
+
config,
|
| 200 |
+
world_path=args.world,
|
| 201 |
+
artifacts_root=args.artifacts_root,
|
| 202 |
+
interface_provider=args.interface_provider,
|
| 203 |
+
interface_model=args.interface_model,
|
| 204 |
+
interface_narrate=args.interface_narrate,
|
| 205 |
+
interface_translation_mode="corporate_app" if args.translate_corporate_env else None,
|
| 206 |
+
max_game_steps=args.max_game_steps,
|
| 207 |
+
max_tool_calls=args.max_tool_calls,
|
| 208 |
+
max_tool_calling_iterations=args.max_tool_calling_iterations,
|
| 209 |
+
)
|
| 210 |
+
return 0
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def _load_repo_dotenv() -> None:
|
| 214 |
+
load_dotenv(Path(__file__).resolve().parents[2] / ".env", override=False)
|
| 215 |
+
_normalize_wandb_env()
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def _normalize_wandb_env() -> None:
|
| 219 |
+
project = os.getenv("WANDB_PROJECT")
|
| 220 |
+
entity = os.getenv("WANDB_ENTITY")
|
| 221 |
+
if entity or not project or "/" not in project:
|
| 222 |
+
return
|
| 223 |
+
|
| 224 |
+
maybe_entity, maybe_project = project.split("/", 1)
|
| 225 |
+
if not maybe_entity or not maybe_project:
|
| 226 |
+
return
|
| 227 |
+
|
| 228 |
+
os.environ["WANDB_ENTITY"] = maybe_entity
|
| 229 |
+
os.environ["WANDB_PROJECT"] = maybe_project
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def _add_common_args(parser: argparse.ArgumentParser, *, default_model: str, default_output_dir: str) -> None:
|
| 233 |
+
parser.add_argument("--model", default=default_model)
|
| 234 |
+
parser.add_argument("--output-dir", type=Path, default=Path(default_output_dir))
|
| 235 |
+
parser.add_argument("--resume-adapter-path")
|
| 236 |
+
parser.add_argument("--run-name")
|
| 237 |
+
parser.add_argument("--max-steps", type=int, default=10)
|
| 238 |
+
parser.add_argument("--num-prompts", type=int, default=16)
|
| 239 |
+
parser.add_argument("--learning-rate", type=float, default=1e-5)
|
| 240 |
+
parser.add_argument("--per-device-train-batch-size", type=int, default=2)
|
| 241 |
+
parser.add_argument("--gradient-accumulation-steps", type=int, default=8)
|
| 242 |
+
parser.add_argument("--num-generations", type=int, default=2)
|
| 243 |
+
parser.add_argument("--max-completion-length", type=int, default=512)
|
| 244 |
+
parser.add_argument("--logging-steps", type=int, default=1)
|
| 245 |
+
parser.add_argument("--save-steps", type=int, default=10)
|
| 246 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 247 |
+
parser.add_argument("--rank", type=int, default=16)
|
| 248 |
+
parser.add_argument("--alpha", type=int, default=32)
|
| 249 |
+
parser.add_argument("--dropout", type=float, default=0.05)
|
| 250 |
+
parser.add_argument("--temperature", type=float, default=0.6)
|
| 251 |
+
parser.add_argument("--top-p", type=float, default=0.95)
|
| 252 |
+
parser.add_argument("--top-k", type=int, default=20)
|
| 253 |
+
parser.add_argument("--min-p", type=float)
|
| 254 |
+
parser.add_argument("--repetition-penalty", type=float, default=1.0)
|
| 255 |
+
parser.add_argument("--loss-type", choices=SUPPORTED_GRPO_LOSS_TYPES, default="dapo")
|
| 256 |
+
parser.add_argument(
|
| 257 |
+
"--importance-sampling-level",
|
| 258 |
+
choices=SUPPORTED_IMPORTANCE_SAMPLING_LEVELS,
|
| 259 |
+
default="token",
|
| 260 |
+
)
|
| 261 |
+
parser.add_argument("--use-transformers-paged", action="store_true")
|
| 262 |
+
parser.add_argument("--cache-implementation")
|
| 263 |
+
parser.add_argument("--use-vllm", action="store_true")
|
| 264 |
+
parser.add_argument("--vllm-mode", choices=["server", "colocate"], default="colocate")
|
| 265 |
+
parser.add_argument("--vllm-gpu-memory-utilization", type=float, default=0.2)
|
| 266 |
+
parser.add_argument("--no-vllm-sleep-mode", action="store_true")
|
| 267 |
+
parser.add_argument("--trust-remote-code", action="store_true")
|
| 268 |
+
parser.add_argument("--no-4bit", action="store_true")
|
| 269 |
+
parser.add_argument("--no-wandb", action="store_true")
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def _add_prefixed_common_args(
|
| 273 |
+
parser: argparse.ArgumentParser,
|
| 274 |
+
*,
|
| 275 |
+
prefix: str,
|
| 276 |
+
default_model: str,
|
| 277 |
+
default_max_steps: int,
|
| 278 |
+
default_num_prompts: int,
|
| 279 |
+
default_max_completion_length: int,
|
| 280 |
+
) -> None:
|
| 281 |
+
parser.add_argument(f"--{prefix}-model", default=default_model)
|
| 282 |
+
parser.add_argument(f"--{prefix}-resume-adapter-path")
|
| 283 |
+
parser.add_argument(f"--{prefix}-run-name")
|
| 284 |
+
parser.add_argument(f"--{prefix}-max-steps", type=int, default=default_max_steps)
|
| 285 |
+
parser.add_argument(f"--{prefix}-num-prompts", type=int, default=default_num_prompts)
|
| 286 |
+
parser.add_argument(f"--{prefix}-learning-rate", type=float, default=1e-5)
|
| 287 |
+
parser.add_argument(f"--{prefix}-per-device-train-batch-size", type=int, default=2)
|
| 288 |
+
parser.add_argument(f"--{prefix}-gradient-accumulation-steps", type=int, default=8)
|
| 289 |
+
parser.add_argument(f"--{prefix}-num-generations", type=int, default=2)
|
| 290 |
+
parser.add_argument(f"--{prefix}-max-completion-length", type=int, default=default_max_completion_length)
|
| 291 |
+
parser.add_argument(f"--{prefix}-logging-steps", type=int, default=1)
|
| 292 |
+
parser.add_argument(f"--{prefix}-save-steps", type=int, default=4)
|
| 293 |
+
parser.add_argument(f"--{prefix}-seed", type=int, default=42)
|
| 294 |
+
parser.add_argument(f"--{prefix}-rank", type=int, default=16)
|
| 295 |
+
parser.add_argument(f"--{prefix}-alpha", type=int, default=32)
|
| 296 |
+
parser.add_argument(f"--{prefix}-dropout", type=float, default=0.05)
|
| 297 |
+
parser.add_argument(f"--{prefix}-temperature", type=float, default=0.6)
|
| 298 |
+
parser.add_argument(f"--{prefix}-top-p", type=float, default=0.95)
|
| 299 |
+
parser.add_argument(f"--{prefix}-top-k", type=int, default=20)
|
| 300 |
+
parser.add_argument(f"--{prefix}-min-p", type=float)
|
| 301 |
+
parser.add_argument(f"--{prefix}-repetition-penalty", type=float, default=1.0)
|
| 302 |
+
parser.add_argument(f"--{prefix}-loss-type", choices=SUPPORTED_GRPO_LOSS_TYPES, default="dapo")
|
| 303 |
+
parser.add_argument(
|
| 304 |
+
f"--{prefix}-importance-sampling-level",
|
| 305 |
+
choices=SUPPORTED_IMPORTANCE_SAMPLING_LEVELS,
|
| 306 |
+
default="token",
|
| 307 |
+
)
|
| 308 |
+
parser.add_argument(f"--{prefix}-use-transformers-paged", action="store_true")
|
| 309 |
+
parser.add_argument(f"--{prefix}-cache-implementation")
|
| 310 |
+
parser.add_argument(f"--{prefix}-use-vllm", action="store_true")
|
| 311 |
+
parser.add_argument(f"--{prefix}-vllm-mode", choices=["server", "colocate"], default="colocate")
|
| 312 |
+
parser.add_argument(f"--{prefix}-vllm-gpu-memory-utilization", type=float, default=0.2)
|
| 313 |
+
parser.add_argument(f"--{prefix}-no-vllm-sleep-mode", action="store_true")
|
| 314 |
+
parser.add_argument(f"--{prefix}-trust-remote-code", action="store_true")
|
| 315 |
+
parser.add_argument(f"--{prefix}-no-4bit", action="store_true")
|
| 316 |
+
parser.add_argument(f"--{prefix}-no-wandb", action="store_true")
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def _build_prefixed_grpo_config(args: argparse.Namespace, prefix: str, *, default_output_dir: Path) -> GRPOLaunchConfig:
|
| 320 |
+
def value(name: str):
|
| 321 |
+
return getattr(args, f"{prefix}_{name}")
|
| 322 |
+
|
| 323 |
+
return GRPOLaunchConfig(
|
| 324 |
+
model_name=value("model"),
|
| 325 |
+
output_dir=default_output_dir,
|
| 326 |
+
resume_adapter_path=value("resume_adapter_path"),
|
| 327 |
+
max_steps=value("max_steps"),
|
| 328 |
+
num_prompts=value("num_prompts"),
|
| 329 |
+
learning_rate=value("learning_rate"),
|
| 330 |
+
per_device_train_batch_size=value("per_device_train_batch_size"),
|
| 331 |
+
gradient_accumulation_steps=value("gradient_accumulation_steps"),
|
| 332 |
+
num_generations=value("num_generations"),
|
| 333 |
+
max_completion_length=value("max_completion_length"),
|
| 334 |
+
logging_steps=value("logging_steps"),
|
| 335 |
+
save_steps=value("save_steps"),
|
| 336 |
+
seed=value("seed"),
|
| 337 |
+
rank=value("rank"),
|
| 338 |
+
alpha=value("alpha"),
|
| 339 |
+
dropout=value("dropout"),
|
| 340 |
+
temperature=value("temperature"),
|
| 341 |
+
top_p=value("top_p"),
|
| 342 |
+
top_k=value("top_k"),
|
| 343 |
+
min_p=value("min_p"),
|
| 344 |
+
repetition_penalty=value("repetition_penalty"),
|
| 345 |
+
use_wandb=not value("no_wandb"),
|
| 346 |
+
run_name=value("run_name"),
|
| 347 |
+
trust_remote_code=value("trust_remote_code"),
|
| 348 |
+
load_in_4bit=not value("no_4bit"),
|
| 349 |
+
loss_type=value("loss_type"),
|
| 350 |
+
importance_sampling_level=value("importance_sampling_level"),
|
| 351 |
+
use_transformers_paged=value("use_transformers_paged"),
|
| 352 |
+
cache_implementation=value("cache_implementation"),
|
| 353 |
+
use_vllm=value("use_vllm"),
|
| 354 |
+
vllm_mode=value("vllm_mode"),
|
| 355 |
+
vllm_gpu_memory_utilization=value("vllm_gpu_memory_utilization"),
|
| 356 |
+
vllm_enable_sleep_mode=not value("no_vllm_sleep_mode"),
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
if __name__ == "__main__":
|
| 361 |
+
raise SystemExit(main())
|
agents/train/grpo.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
agents/train/joint.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
from contextlib import contextmanager
|
| 6 |
+
from dataclasses import asdict, dataclass, replace
|
| 7 |
+
from datetime import UTC, datetime
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any, Callable, Iterator
|
| 10 |
+
|
| 11 |
+
from .grpo import DMClosedLoopConfig, GRPOLaunchConfig, run_dm_grpo, run_hero_grpo
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass(frozen=True)
|
| 15 |
+
class JointTrainingConfig:
|
| 16 |
+
root_dir: Path
|
| 17 |
+
cycles: int
|
| 18 |
+
hero_config: GRPOLaunchConfig
|
| 19 |
+
dm_config: GRPOLaunchConfig
|
| 20 |
+
target_ratios: list[float] | None = None
|
| 21 |
+
hero_world_path: Path | None = None
|
| 22 |
+
interface_provider: str | None = None
|
| 23 |
+
interface_model: str | None = None
|
| 24 |
+
interface_narrate: bool = False
|
| 25 |
+
interface_translation_mode: str | None = None
|
| 26 |
+
hero_max_game_steps: int = 40
|
| 27 |
+
hero_max_tool_calls: int = 80
|
| 28 |
+
hero_max_tool_calling_iterations: int = 32
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def run_joint_training_loop(config: JointTrainingConfig) -> Path:
|
| 32 |
+
if config.cycles < 1:
|
| 33 |
+
raise ValueError("cycles must be at least 1.")
|
| 34 |
+
|
| 35 |
+
config.root_dir.mkdir(parents=True, exist_ok=True)
|
| 36 |
+
latest_hero_adapter = _initial_adapter_path(config.hero_config.resume_adapter_path)
|
| 37 |
+
latest_dm_adapter = _initial_adapter_path(config.dm_config.resume_adapter_path)
|
| 38 |
+
phases: list[dict[str, Any]] = []
|
| 39 |
+
_write_manifest(config, phases, status="running")
|
| 40 |
+
|
| 41 |
+
try:
|
| 42 |
+
for cycle_index in range(config.cycles):
|
| 43 |
+
cycle_number = cycle_index + 1
|
| 44 |
+
cycle_dir = config.root_dir / f"cycle_{cycle_number:02d}"
|
| 45 |
+
hero_dir = cycle_dir / "hero"
|
| 46 |
+
dm_dir = cycle_dir / "dm"
|
| 47 |
+
|
| 48 |
+
hero_result = _run_or_resume_hero_phase(
|
| 49 |
+
config=config,
|
| 50 |
+
cycle_number=cycle_number,
|
| 51 |
+
output_dir=hero_dir,
|
| 52 |
+
resume_adapter_path=latest_hero_adapter,
|
| 53 |
+
phases=phases,
|
| 54 |
+
on_phase_state_change=lambda: _write_manifest(config, phases, status="running"),
|
| 55 |
+
)
|
| 56 |
+
latest_hero_adapter = hero_result
|
| 57 |
+
_write_manifest(config, phases, status="running")
|
| 58 |
+
|
| 59 |
+
dm_result = _run_or_resume_dm_phase(
|
| 60 |
+
config=config,
|
| 61 |
+
cycle_number=cycle_number,
|
| 62 |
+
output_dir=dm_dir,
|
| 63 |
+
resume_adapter_path=latest_dm_adapter,
|
| 64 |
+
hero_adapter_path=latest_hero_adapter,
|
| 65 |
+
phases=phases,
|
| 66 |
+
on_phase_state_change=lambda: _write_manifest(config, phases, status="running"),
|
| 67 |
+
)
|
| 68 |
+
latest_dm_adapter = dm_result
|
| 69 |
+
_write_manifest(config, phases, status="running")
|
| 70 |
+
except Exception as exc:
|
| 71 |
+
_write_manifest(config, phases, status="failed", error=str(exc))
|
| 72 |
+
raise
|
| 73 |
+
|
| 74 |
+
_write_manifest(
|
| 75 |
+
config,
|
| 76 |
+
phases,
|
| 77 |
+
status="completed",
|
| 78 |
+
latest_hero_adapter_path=str(latest_hero_adapter) if latest_hero_adapter is not None else None,
|
| 79 |
+
latest_dm_adapter_path=str(latest_dm_adapter) if latest_dm_adapter is not None else None,
|
| 80 |
+
)
|
| 81 |
+
return config.root_dir
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def _run_or_resume_hero_phase(
|
| 85 |
+
*,
|
| 86 |
+
config: JointTrainingConfig,
|
| 87 |
+
cycle_number: int,
|
| 88 |
+
output_dir: Path,
|
| 89 |
+
resume_adapter_path: Path | None,
|
| 90 |
+
phases: list[dict[str, Any]],
|
| 91 |
+
on_phase_state_change: Callable[[], None] | None = None,
|
| 92 |
+
) -> Path:
|
| 93 |
+
state_path = output_dir / "phase_state.json"
|
| 94 |
+
existing_state = _load_phase_state(state_path)
|
| 95 |
+
if existing_state is not None and existing_state.get("status") == "completed":
|
| 96 |
+
phases.append(existing_state)
|
| 97 |
+
return output_dir
|
| 98 |
+
|
| 99 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 100 |
+
run_name = config.hero_config.run_name or f"{config.root_dir.name}-hero-cycle-{cycle_number:02d}"
|
| 101 |
+
phase_state = {
|
| 102 |
+
"phase": "hero",
|
| 103 |
+
"cycle": cycle_number,
|
| 104 |
+
"status": "running",
|
| 105 |
+
"run_name": run_name,
|
| 106 |
+
"output_dir": str(output_dir),
|
| 107 |
+
"resume_adapter_path": None if resume_adapter_path is None else str(resume_adapter_path),
|
| 108 |
+
"started_at": _utc_now(),
|
| 109 |
+
}
|
| 110 |
+
phases.append(phase_state)
|
| 111 |
+
_write_json(state_path, phase_state)
|
| 112 |
+
if on_phase_state_change is not None:
|
| 113 |
+
on_phase_state_change()
|
| 114 |
+
|
| 115 |
+
phase_config = replace(
|
| 116 |
+
config.hero_config,
|
| 117 |
+
output_dir=output_dir,
|
| 118 |
+
run_name=run_name,
|
| 119 |
+
resume_adapter_path=None if resume_adapter_path is None else str(resume_adapter_path),
|
| 120 |
+
)
|
| 121 |
+
with _wandb_phase_env(group=config.root_dir.name, job_type="hero"):
|
| 122 |
+
run_hero_grpo(
|
| 123 |
+
phase_config,
|
| 124 |
+
world_path=config.hero_world_path,
|
| 125 |
+
artifacts_root=output_dir / "artifacts",
|
| 126 |
+
interface_provider=config.interface_provider,
|
| 127 |
+
interface_model=config.interface_model,
|
| 128 |
+
interface_narrate=config.interface_narrate,
|
| 129 |
+
interface_translation_mode=config.interface_translation_mode,
|
| 130 |
+
max_game_steps=config.hero_max_game_steps,
|
| 131 |
+
max_tool_calls=config.hero_max_tool_calls,
|
| 132 |
+
max_tool_calling_iterations=config.hero_max_tool_calling_iterations,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
phase_state["status"] = "completed"
|
| 136 |
+
phase_state["completed_at"] = _utc_now()
|
| 137 |
+
_write_json(state_path, phase_state)
|
| 138 |
+
return output_dir
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def _run_or_resume_dm_phase(
|
| 142 |
+
*,
|
| 143 |
+
config: JointTrainingConfig,
|
| 144 |
+
cycle_number: int,
|
| 145 |
+
output_dir: Path,
|
| 146 |
+
resume_adapter_path: Path | None,
|
| 147 |
+
hero_adapter_path: Path | None,
|
| 148 |
+
phases: list[dict[str, Any]],
|
| 149 |
+
on_phase_state_change: Callable[[], None] | None = None,
|
| 150 |
+
) -> Path:
|
| 151 |
+
if hero_adapter_path is None:
|
| 152 |
+
raise RuntimeError("DM phase requires a hero adapter path from a completed hero phase.")
|
| 153 |
+
|
| 154 |
+
state_path = output_dir / "phase_state.json"
|
| 155 |
+
existing_state = _load_phase_state(state_path)
|
| 156 |
+
if existing_state is not None and existing_state.get("status") == "completed":
|
| 157 |
+
phases.append(existing_state)
|
| 158 |
+
return output_dir
|
| 159 |
+
|
| 160 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 161 |
+
run_name = config.dm_config.run_name or f"{config.root_dir.name}-dm-cycle-{cycle_number:02d}"
|
| 162 |
+
phase_state = {
|
| 163 |
+
"phase": "dm",
|
| 164 |
+
"cycle": cycle_number,
|
| 165 |
+
"status": "running",
|
| 166 |
+
"run_name": run_name,
|
| 167 |
+
"output_dir": str(output_dir),
|
| 168 |
+
"resume_adapter_path": None if resume_adapter_path is None else str(resume_adapter_path),
|
| 169 |
+
"hero_adapter_path": str(hero_adapter_path),
|
| 170 |
+
"started_at": _utc_now(),
|
| 171 |
+
}
|
| 172 |
+
phases.append(phase_state)
|
| 173 |
+
_write_json(state_path, phase_state)
|
| 174 |
+
if on_phase_state_change is not None:
|
| 175 |
+
on_phase_state_change()
|
| 176 |
+
|
| 177 |
+
phase_config = replace(
|
| 178 |
+
config.dm_config,
|
| 179 |
+
output_dir=output_dir,
|
| 180 |
+
run_name=run_name,
|
| 181 |
+
resume_adapter_path=None if resume_adapter_path is None else str(resume_adapter_path),
|
| 182 |
+
)
|
| 183 |
+
closed_loop = DMClosedLoopConfig(
|
| 184 |
+
hero_provider="hf_local",
|
| 185 |
+
hero_model=config.hero_config.model_name,
|
| 186 |
+
hero_adapter_path=str(hero_adapter_path),
|
| 187 |
+
interface_provider=config.interface_provider,
|
| 188 |
+
interface_model=config.interface_model,
|
| 189 |
+
interface_narrate=config.interface_narrate,
|
| 190 |
+
interface_translation_mode=config.interface_translation_mode,
|
| 191 |
+
hero_max_game_steps=config.hero_max_game_steps,
|
| 192 |
+
hero_max_tool_calls=config.hero_max_tool_calls,
|
| 193 |
+
)
|
| 194 |
+
with _wandb_phase_env(group=config.root_dir.name, job_type="dm"):
|
| 195 |
+
run_dm_grpo(
|
| 196 |
+
phase_config,
|
| 197 |
+
target_ratios=config.target_ratios,
|
| 198 |
+
artifacts_root=output_dir / "artifacts",
|
| 199 |
+
closed_loop=closed_loop,
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
phase_state["status"] = "completed"
|
| 203 |
+
phase_state["completed_at"] = _utc_now()
|
| 204 |
+
_write_json(state_path, phase_state)
|
| 205 |
+
return output_dir
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def _write_manifest(
|
| 209 |
+
config: JointTrainingConfig,
|
| 210 |
+
phases: list[dict[str, Any]],
|
| 211 |
+
*,
|
| 212 |
+
status: str,
|
| 213 |
+
error: str | None = None,
|
| 214 |
+
latest_hero_adapter_path: str | None = None,
|
| 215 |
+
latest_dm_adapter_path: str | None = None,
|
| 216 |
+
) -> None:
|
| 217 |
+
payload = {
|
| 218 |
+
"status": status,
|
| 219 |
+
"updated_at": _utc_now(),
|
| 220 |
+
"error": error,
|
| 221 |
+
"latest_hero_adapter_path": latest_hero_adapter_path,
|
| 222 |
+
"latest_dm_adapter_path": latest_dm_adapter_path,
|
| 223 |
+
"config": _to_jsonable(asdict(config)),
|
| 224 |
+
"phases": phases,
|
| 225 |
+
}
|
| 226 |
+
_write_json(config.root_dir / "joint_state.json", payload)
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
@contextmanager
|
| 230 |
+
def _wandb_phase_env(*, group: str, job_type: str) -> Iterator[None]:
|
| 231 |
+
previous_group = os.getenv("WANDB_RUN_GROUP")
|
| 232 |
+
previous_job_type = os.getenv("WANDB_JOB_TYPE")
|
| 233 |
+
os.environ["WANDB_RUN_GROUP"] = group
|
| 234 |
+
os.environ["WANDB_JOB_TYPE"] = job_type
|
| 235 |
+
try:
|
| 236 |
+
yield
|
| 237 |
+
finally:
|
| 238 |
+
_restore_env("WANDB_RUN_GROUP", previous_group)
|
| 239 |
+
_restore_env("WANDB_JOB_TYPE", previous_job_type)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def _restore_env(name: str, value: str | None) -> None:
|
| 243 |
+
if value is None:
|
| 244 |
+
os.environ.pop(name, None)
|
| 245 |
+
else:
|
| 246 |
+
os.environ[name] = value
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def _load_phase_state(path: Path) -> dict[str, Any] | None:
|
| 250 |
+
if not path.exists():
|
| 251 |
+
return None
|
| 252 |
+
return json.loads(path.read_text(encoding="utf-8"))
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def _write_json(path: Path, payload: dict[str, Any]) -> None:
|
| 256 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 257 |
+
path.write_text(json.dumps(_to_jsonable(payload), indent=2, sort_keys=True) + "\n", encoding="utf-8")
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def _to_jsonable(value: Any) -> Any:
|
| 261 |
+
if isinstance(value, Path):
|
| 262 |
+
return str(value)
|
| 263 |
+
if isinstance(value, dict):
|
| 264 |
+
return {str(key): _to_jsonable(item) for key, item in value.items()}
|
| 265 |
+
if isinstance(value, list):
|
| 266 |
+
return [_to_jsonable(item) for item in value]
|
| 267 |
+
return value
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def _initial_adapter_path(raw_path: str | None) -> Path | None:
|
| 271 |
+
if raw_path is None:
|
| 272 |
+
return None
|
| 273 |
+
path = Path(raw_path)
|
| 274 |
+
return path if path.exists() else None
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def _utc_now() -> str:
|
| 278 |
+
return datetime.now(UTC).replace(microsecond=0).isoformat().replace("+00:00", "Z")
|
pyproject.toml
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=69", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "dnd-agents"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "Dungeon master and hero agent environments built on TextWorld and OpenEnv."
|
| 9 |
+
readme = "SPEC.md"
|
| 10 |
+
requires-python = ">=3.11,<3.12"
|
| 11 |
+
dependencies = [
|
| 12 |
+
"openenv-core==0.2.1",
|
| 13 |
+
"textworld==1.7.0",
|
| 14 |
+
"fastapi>=0.115,<1",
|
| 15 |
+
"uvicorn>=0.30,<1",
|
| 16 |
+
"pydantic>=2.12,<3",
|
| 17 |
+
"python-dotenv>=1.0,<2",
|
| 18 |
+
"python-multipart>=0.0.9,<1",
|
| 19 |
+
"google-genai>=1.0,<2",
|
| 20 |
+
"huggingface-hub>=1.6,<2",
|
| 21 |
+
"pytest>=8.0,<9",
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
[project.scripts]
|
| 25 |
+
dnd-master = "agents.master.main:main"
|
| 26 |
+
dnd-hero = "agents.hero.__main__:main"
|
| 27 |
+
dnd-loop = "agents.loop.__main__:main"
|
| 28 |
+
dnd-train = "agents.train.__main__:main"
|
| 29 |
+
dnd-openenv = "agents.openenv_server.__main__:main"
|
| 30 |
+
|
| 31 |
+
[project.optional-dependencies]
|
| 32 |
+
local-llm = [
|
| 33 |
+
"accelerate==1.13.0",
|
| 34 |
+
"bitsandbytes==0.49.2",
|
| 35 |
+
"huggingface-hub>=1.6,<2",
|
| 36 |
+
"peft==0.18.1",
|
| 37 |
+
"transformers==5.3.0",
|
| 38 |
+
"vllm==0.12.0; platform_system == 'Linux'",
|
| 39 |
+
]
|
| 40 |
+
train = [
|
| 41 |
+
"accelerate==1.13.0",
|
| 42 |
+
"bitsandbytes==0.49.2",
|
| 43 |
+
"datasets==4.6.1",
|
| 44 |
+
"huggingface-hub>=1.6,<2",
|
| 45 |
+
"jmespath>=1.0,<2",
|
| 46 |
+
"peft==0.18.1",
|
| 47 |
+
"transformers==5.3.0",
|
| 48 |
+
"trl==0.29.0",
|
| 49 |
+
"vllm==0.12.0; platform_system == 'Linux'",
|
| 50 |
+
"wandb==0.25.0",
|
| 51 |
+
]
|
| 52 |
+
|
| 53 |
+
[tool.setuptools.packages.find]
|
| 54 |
+
include = ["agents*"]
|
| 55 |
+
|
| 56 |
+
[tool.pytest.ini_options]
|
| 57 |
+
testpaths = ["tests"]
|
| 58 |
+
markers = [
|
| 59 |
+
"live: tests that call live external model APIs",
|
| 60 |
+
]
|
| 61 |
+
filterwarnings = [
|
| 62 |
+
"ignore:Game '.*' is not fully supported\\..*",
|
| 63 |
+
]
|