text-adventure-template / mcp_server.py
j0eyd's picture
Refine general policy, exploration memory, and README approach
aea1d92
import concurrent.futures
import hashlib
import json
import os
import re
import sys
from collections import defaultdict, deque
from pathlib import Path
# Make sure we can import games.zork_env across local/eval layouts.
_THIS = Path(__file__).resolve()
for candidate in (
_THIS.parent,
_THIS.parent.parent,
Path.cwd(),
Path("/workspace"),
Path("/home/joey.david/Agentic-zork"),
):
if (candidate / "games").is_dir():
sys.path.insert(0, str(candidate))
break
from fastmcp import FastMCP
from games.zork_env import TextAdventureEnv
mcp = FastMCP("Student Text Adventure Server")
_MOVE_ACTIONS = {
"north",
"south",
"east",
"west",
"up",
"down",
"enter",
"exit",
"n",
"s",
"e",
"w",
"u",
"d",
}
_NEGATIVE_PATTERNS = (
"you can't",
"you cannot",
"that's not",
"i don't",
"nothing happens",
"there is no",
"not here",
"not open",
"already",
"not see any place to go",
"not know which way to go",
"not allowed in field",
"get in big trouble",
"not in any thing",
)
class GameManager:
"""Tracks game state plus lightweight analytics for planning tools."""
def __init__(self) -> None:
self.env: TextAdventureEnv | None = None
self.state = None
self.game_name: str = ""
self.last_location: str = "Unknown"
self.last_location_id: int | None = None
self.history: deque[dict] = deque(maxlen=250)
self.action_history: deque[str] = deque(maxlen=320)
self.location_graph: dict[str, dict[str, str]] = defaultdict(dict)
self.location_actions: dict[str, dict[str, dict[str, int | str]]] = defaultdict(dict)
self.promising_actions: dict[str, set[str]] = defaultdict(set)
def initialize(self, game: str = "zork1") -> str:
self.game_name = game
self.env = TextAdventureEnv(game)
self.state = self.env.reset()
self.history.clear()
self.action_history.clear()
self.location_graph.clear()
self.location_actions.clear()
self.promising_actions.clear()
loc, loc_id = self._derive_location(
observation=self.state.observation if self.state else "",
previous_location="Unknown",
state_location=getattr(self.state, "location", ""),
)
self.last_location = loc
self.last_location_id = loc_id
return self._format_observation(
self.state.observation,
action="__reset__",
prev_location="Unknown",
changed_location=True,
progress="init",
)
def step(self, action: str) -> str:
if self.env is None:
self.initialize(os.environ.get("GAME", "zork1"))
normalized = self._normalize_action(action)
prev_location = self.last_location or self._current_location()
prev_location_id = self.last_location_id
prev_score = int(self.state.score if self.state else 0)
prev_obs = self.state.observation if self.state else ""
prev_sig = self._obs_signature(prev_obs)
self.state = self.env.step(normalized)
self.action_history.append(normalized)
current_location, current_location_id = self._derive_location(
observation=self.state.observation if self.state else "",
previous_location=prev_location,
state_location=getattr(self.state, "location", ""),
)
changed_location = False
if prev_location_id is not None and current_location_id is not None:
changed_location = (current_location_id != prev_location_id) or (current_location != prev_location)
else:
changed_location = current_location != prev_location
score_delta = int(self.state.score) - prev_score
obs_sig = self._obs_signature(self.state.observation)
unchanged_obs = obs_sig == prev_sig
negative = self._looks_negative(self.state.observation)
if normalized in _MOVE_ACTIONS and negative:
changed_location = False
elif normalized in _MOVE_ACTIONS and not negative and current_location == prev_location:
prev_anchor = self._observation_anchor(prev_obs)
curr_anchor = self._observation_anchor(self.state.observation)
if curr_anchor and curr_anchor != prev_anchor:
current_location = curr_anchor
changed_location = True
progress = "none"
if score_delta > 0:
progress = "score"
elif changed_location:
progress = "move"
elif not unchanged_obs and not negative:
progress = "state"
self._update_stats(
action=normalized,
prev_location=prev_location,
current_location=current_location,
changed_location=changed_location,
score_delta=score_delta,
progress=progress,
)
if normalized in _MOVE_ACTIONS and changed_location:
self.location_graph[prev_location][normalized] = current_location
self.last_location = current_location
self.last_location_id = current_location_id
self.history.append(
{
"action": normalized,
"from": prev_location,
"to": current_location,
"score": int(self.state.score),
"reward": int(self.state.reward),
"moves": int(self.state.moves),
"progress": progress,
"obs": self._short(self.state.observation, 160),
}
)
return self._format_observation(
self.state.observation,
action=normalized,
prev_location=prev_location,
changed_location=changed_location,
progress=progress,
)
def get_memory(self) -> str:
if not self.state:
return "Game not initialized."
location = self._current_location()
recent = list(self.history)[-8:]
loc_actions = self.location_actions.get(location, {})
action_lines = []
for action, stats in sorted(
loc_actions.items(),
key=lambda kv: (
int(kv[1].get("success", 0)),
int(kv[1].get("score_gain", 0)),
-int(kv[1].get("tries", 0)),
),
reverse=True,
)[:10]:
action_lines.append(
f"- {action}: tries={stats.get('tries', 0)}, success={stats.get('success', 0)}, "
f"score_gain={stats.get('score_gain', 0)}, stagnant={stats.get('stagnant', 0)}"
)
recent_lines = [
f"- {item['action']} -> {item['progress']} (score={item['score']}, moves={item['moves']})"
for item in recent
]
memory = {
"game": self.game_name,
"location": location,
"score": int(self.state.score),
"max_score": int(self.state.max_score),
"moves": int(self.state.moves),
"promising_here": sorted(self.promising_actions.get(location, set())),
"known_exits_here": self.location_graph.get(location, {}),
}
lines = [
json.dumps(memory, ensure_ascii=True),
"",
"Recent actions:",
*(recent_lines if recent_lines else ["- (none)"]),
"",
"Location action stats:",
*(action_lines if action_lines else ["- (none)"]),
"",
"Observation:",
self.state.observation,
]
return "\n".join(lines)
def get_inventory(self) -> str:
if not self.state:
return "Inventory: game not initialized."
items = self._clean_inventory(self.state.inventory)
if not items:
return "Inventory: empty"
return f"Inventory: {', '.join(items)}"
def get_map(self) -> str:
if not self.state:
return "Map unavailable: game not initialized."
if not self.location_graph:
return f"Map: no transitions recorded yet. Current location: {self._current_location()}"
lines = ["Exploration map:"]
for src, exits in sorted(self.location_graph.items()):
exit_str = ", ".join(f"{k}->{v}" for k, v in sorted(exits.items()))
lines.append(f"- {src}: {exit_str}")
lines.append(f"Current location: {self._current_location()}")
return "\n".join(lines)
def get_valid_actions(self, timeout_s: float = 4.0) -> list[str]:
if self.env is None:
self.initialize(os.environ.get("GAME", "zork1"))
pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
fut = pool.submit(self._probe_valid_actions)
try:
actions = fut.result(timeout=max(0.5, min(float(timeout_s), 8.0)))
except Exception:
actions = self._fallback_actions()
finally:
pool.shutdown(wait=False, cancel_futures=True)
if not actions:
return self._fallback_actions()
deduped: list[str] = []
seen = set()
for act in actions:
norm = self._normalize_action(act)
if not norm or norm in seen:
continue
seen.add(norm)
deduped.append(norm)
if len(deduped) >= 40:
break
return deduped or self._fallback_actions()
def _probe_valid_actions(self) -> list[str]:
if self.env is None:
return []
snapshot = None
try:
snapshot = self.env.save_state()
except Exception:
snapshot = None
try:
actions = self.env.get_valid_actions()
return list(actions) if isinstance(actions, list) else []
finally:
if snapshot is not None:
try:
self.env.load_state(snapshot)
except Exception:
pass
def _update_stats(
self,
action: str,
prev_location: str,
current_location: str,
changed_location: bool,
score_delta: int,
progress: str,
) -> None:
loc_stats = self.location_actions[prev_location].setdefault(
action,
{
"tries": 0,
"success": 0,
"score_gain": 0,
"moves_out": 0,
"stagnant": 0,
"last_to": "",
},
)
loc_stats["tries"] = int(loc_stats["tries"]) + 1
if progress in {"score", "move", "state"}:
loc_stats["success"] = int(loc_stats["success"]) + 1
if score_delta > 0:
loc_stats["score_gain"] = int(loc_stats["score_gain"]) + score_delta
self.promising_actions[prev_location].add(action)
if changed_location:
loc_stats["moves_out"] = int(loc_stats["moves_out"]) + 1
loc_stats["last_to"] = current_location
if progress == "none":
loc_stats["stagnant"] = int(loc_stats["stagnant"]) + 1
def _current_location(self) -> str:
if not self.state:
return "Unknown"
if self.last_location:
return self.last_location
loc, loc_id = self._derive_location(
observation=self.state.observation,
previous_location="Unknown",
state_location=getattr(self.state, "location", ""),
)
self.last_location = loc
self.last_location_id = loc_id
return loc
def _derive_location(self, observation: str, previous_location: str, state_location: str) -> tuple[str, int | None]:
state_name, loc_id = self._parse_state_location(state_location)
obs_name = self._extract_location(observation, previous_location=previous_location)
if state_name and obs_name:
state_base = state_name.split("#", 1)[0].strip().lower()
obs_base = obs_name.split("#", 1)[0].strip().lower()
if obs_base and obs_base != state_base:
return (f"{obs_name}#{loc_id}", loc_id) if loc_id is not None else (obs_name, None)
return state_name, loc_id
if state_name and state_name.lower().startswith("unknown"):
anchor = self._observation_anchor(observation)
if anchor:
return (f"{anchor}#{loc_id}", loc_id) if loc_id is not None else (anchor, None)
if state_name:
return state_name, loc_id
if obs_name:
return obs_name, None
return previous_location or "Unknown", None
@staticmethod
def _parse_state_location(raw_location: str) -> tuple[str, int | None]:
text = str(raw_location or "").strip()
if not text:
return "", None
obj_match = re.search(r"Obj(\d+):\s*([^\n]+?)\s+Parent", text)
if obj_match:
loc_id = int(obj_match.group(1))
label = obj_match.group(2).strip()
label = re.sub(r"\s+", " ", label)
label_map = {"West House": "West of House", "North House": "North of House"}
label = label_map.get(label, label)
return f"{label}#{loc_id}", loc_id
line = text.splitlines()[0].strip()
if line:
return line[:80], None
return "", None
@staticmethod
def _extract_location(observation: str, previous_location: str = "Unknown") -> str:
lines = [ln.strip() for ln in observation.splitlines() if ln.strip()]
if not lines:
return previous_location or "Unknown"
first = lines[0]
if first.endswith(("?", ".", "!", "...")):
return previous_location or "Unknown"
# Prefer title-like lines (room headings) over imperative/status text.
if len(first) <= 70:
words = re.findall(r"[A-Za-z][A-Za-z'-]*", first)
if 1 <= len(words) <= 6:
capitalized = sum(1 for w in words if w[0].isupper())
if capitalized >= max(1, len(words) - 1):
return first
# Some games produce a heading on line 2 after a short status line.
if len(lines) >= 2:
second = lines[1]
if len(second) <= 70 and not second.endswith(("?", ".", "!", "...")):
words = re.findall(r"[A-Za-z][A-Za-z'-]*", second)
if 1 <= len(words) <= 6:
capitalized = sum(1 for w in words if w[0].isupper())
if capitalized >= max(1, len(words) - 1):
return second
if len(first) <= 70 and first and first[0].isupper():
if not re.search(r"\b(is|are|look|looks|seem|seems|have|has)\b", first, flags=re.IGNORECASE):
return first
return previous_location or "Unknown"
@staticmethod
def _observation_anchor(observation: str) -> str:
lines = [ln.strip() for ln in observation.splitlines() if ln.strip()]
if not lines:
return ""
first = re.split(r"[.!?]", lines[0], maxsplit=1)[0].strip().lower()
first = re.sub(r"[^a-z0-9 ]+", " ", first)
first = re.sub(r"\s+", " ", first).strip()
if not first:
return ""
digest = hashlib.sha1(first.encode("utf-8")).hexdigest()[:8]
return f"ctx-{digest}"
@staticmethod
def _normalize_action(action: str) -> str:
action = " ".join((action or "look").strip().lower().split())
if not action:
return "look"
alias = {
"n": "north",
"s": "south",
"e": "east",
"w": "west",
"u": "up",
"d": "down",
"l": "look",
"i": "inventory",
"x": "examine",
}
if action in alias:
return alias[action]
if action.startswith("go "):
direction = action.split(" ", 1)[1].strip()
return alias.get(direction, direction)
return action
@staticmethod
def _obs_signature(text: str) -> str:
clean = re.sub(r"\s+", " ", (text or "").strip().lower())
clean = re.sub(r"\[meta\].*", "", clean)
return clean[:300]
@staticmethod
def _looks_negative(text: str) -> bool:
t = (text or "").lower()
return any(p in t for p in _NEGATIVE_PATTERNS)
@staticmethod
def _short(text: str, n: int) -> str:
t = " ".join((text or "").split())
return t if len(t) <= n else t[: n - 3] + "..."
def _format_observation(
self,
observation: str,
action: str,
prev_location: str,
changed_location: bool,
progress: str,
) -> str:
meta = {
"game": self.game_name,
"action": action,
"location": self.last_location or self._current_location(),
"prev_location": prev_location,
"changed_location": changed_location,
"score": int(self.state.score if self.state else 0),
"max_score": int(self.state.max_score if self.state else 0),
"moves": int(self.state.moves if self.state else 0),
"reward": int(self.state.reward if self.state else 0),
"done": bool(self.state.done if self.state else False),
"progress": progress,
}
return f"{observation}\n\n[META]{json.dumps(meta, ensure_ascii=True)}"
def _fallback_actions(self) -> list[str]:
base = ["look", "north", "south", "east", "west", "up", "down", "enter", "exit", "wait", "listen"]
if not self.state:
return base
verbs = ["examine", "search", "open", "take", "read", "use"]
objects = self._extract_objects(self.state.observation)[:8]
inv_items = self._clean_inventory(self.state.inventory)[:4]
extra: list[str] = []
for obj in objects:
short_obj = obj.split()[-1] if obj.split() else obj
for v in verbs[:4]:
extra.append(f"{v} {obj}")
if short_obj and short_obj != obj:
extra.append(f"{v} {short_obj}")
for obj in inv_items:
extra.append(f"examine {obj}")
extra.append(f"use {obj}")
out: list[str] = []
seen = set()
for act in base + extra:
norm = self._normalize_action(act)
if not norm or norm in seen or len(norm) > 48:
continue
seen.add(norm)
out.append(norm)
if len(out) >= 40:
break
return out
@staticmethod
def _extract_objects(observation: str) -> list[str]:
low = (observation or "").lower()
out = []
patterns = (
r"(?:there is|there are|you see|you can see)\s+(?:a|an|the|some)?\s*([a-z][a-z -]{1,28})",
r"\b(?:a|an|the)\s+([a-z][a-z -]{1,20})\s+(?:is|are|lies|sits|stands)\b",
r"\bin\s+([a-z]{4,20})\b",
)
for pat in patterns:
for m in re.finditer(pat, low):
cand = " ".join(m.group(1).split())
cand = re.split(r"\b(?:here|there|that|which|with|near|on|in|from|to)\b", cand, maxsplit=1)[0].strip()
if cand and len(cand) <= 30:
out.append(cand)
dedup = []
seen = set()
for item in out:
if item in seen:
continue
seen.add(item)
dedup.append(item)
return dedup
@staticmethod
def _clean_inventory(items: list[str]) -> list[str]:
cleaned: list[str] = []
for item in items or []:
s = str(item).strip()
if not s:
continue
s = re.sub(r"\s*\(.*?\)\s*", "", s)
if ":" in s:
s = s.split(":", 1)[1].strip()
cleaned.append(s)
deduped: list[str] = []
seen = set()
for item in cleaned:
key = item.lower()
if key in seen:
continue
seen.add(key)
deduped.append(item)
return deduped
_game = GameManager()
def get_game() -> GameManager:
global _game
if _game.env is None:
_game.initialize(os.environ.get("GAME", "zork1"))
return _game
@mcp.tool()
def play_action(action: str) -> str:
"""Execute one game action and return observation + machine-readable metadata."""
return get_game().step(action)
@mcp.tool()
def memory() -> str:
"""Return compact state and action-memory summary."""
return get_game().get_memory()
@mcp.tool()
def inventory() -> str:
"""Return current inventory without spending an in-game action."""
return get_game().get_inventory()
@mcp.tool()
def get_map() -> str:
"""Return discovered location transitions."""
return get_game().get_map()
@mcp.tool()
def get_valid_actions(timeout_s: float = 4.0) -> str:
"""Return likely valid actions with timeout protection."""
actions = get_game().get_valid_actions(timeout_s=timeout_s)
return json.dumps({"valid_actions": actions}, ensure_ascii=True)
if __name__ == "__main__":
mcp.run()