| | """ |
| | Exploration-first hybrid ReAct agent (score + locations) for text adventures. |
| | |
| | Key points: |
| | - Deterministic policy driven by server status() JSON. |
| | - ReAct loop explicit each step: THOUGHT -> TOOL(play_action) -> OBSERVATION |
| | - Priority: |
| | A) Valid untried exits (Jericho-validated) + obs-boosted directions |
| | B) Bounded suggested_interactions (game-validated) |
| | C) BFS backtrack to nearest frontier (room with untried exits) |
| | D) Stuck recovery (look/inventory/examine noun) |
| | E) Optional single LLM fallback if HF_TOKEN is present (never required) |
| | |
| | - Uses peek_action (if available) to score a small candidate set quickly. |
| | - All verbose/debug output goes to stderr only. |
| | """ |
| |
|
| | import json |
| | import os |
| | import re |
| | import sys |
| | from collections import deque |
| | from dataclasses import dataclass, field |
| | from typing import Optional |
| |
|
| | from dotenv import load_dotenv |
| | from huggingface_hub import InferenceClient |
| |
|
| | load_dotenv() |
| |
|
| | |
| | |
| | |
| | LLM_MODEL = "Qwen/Qwen2.5-72B-Instruct" |
| | _hf_token = os.getenv("HF_TOKEN") |
| | LLM_CLIENT = InferenceClient(token=_hf_token) if _hf_token else None |
| |
|
| |
|
| | def call_llm(prompt: str, system_prompt: str, seed: int, max_tokens: int = 120) -> str: |
| | """LLM call used only as last-resort fallback (optional).""" |
| | if LLM_CLIENT is None: |
| | raise RuntimeError("HF_TOKEN missing => LLM unavailable") |
| | r = LLM_CLIENT.chat.completions.create( |
| | model=LLM_MODEL, |
| | messages=[ |
| | {"role": "system", "content": system_prompt}, |
| | {"role": "user", "content": prompt}, |
| | ], |
| | temperature=0.0, |
| | max_tokens=max_tokens, |
| | seed=seed, |
| | ) |
| | return r.choices[0].message.content or "" |
| |
|
| |
|
| | @dataclass |
| | class RunResult: |
| | final_score: int |
| | max_score: int |
| | moves: int |
| | locations_visited: set |
| | game_completed: bool |
| | error: Optional[str] = None |
| | history: list[tuple[str, str, str]] = field(default_factory=list) |
| |
|
| |
|
| | |
| | |
| | |
| | MAX_INTERACTIONS = 4 |
| | STUCK_THRESHOLD = 10 |
| | MEMORY_LEN = 20 |
| | PEEK_K = 6 |
| |
|
| | UNSAFE_STARTS = ( |
| | "burn ", "set fire", "ignite ", |
| | "attack ", "kill ", "hit ", "stab ", "shoot ", "punch ", "fight ", |
| | "destroy ", "break ", "smash ", |
| | "eat ", |
| | ) |
| |
|
| | DIR_WORD_RE = re.compile( |
| | r"\b(north(?:east|west)?|south(?:east|west)?|east|west|" |
| | r"northeast|northwest|southeast|southwest|up|down|in|out)\b", |
| | re.IGNORECASE, |
| | ) |
| |
|
| | DISAMBIG_RE = re.compile( |
| | r"which do you mean|do you mean|be more specific|what do you want", |
| | re.IGNORECASE, |
| | ) |
| | OPTION_RE = re.compile(r"\bthe\s+([a-z]+(?:\s+[a-z]+)?)", re.IGNORECASE) |
| |
|
| | LLM_SYSTEM = ( |
| | "You play a text adventure game. Propose ONE action (<= 5 words) that helps " |
| | "explore a new location or gain points. Reply with exactly one line:\n" |
| | "ACTION: <command>" |
| | ) |
| |
|
| |
|
| | class StudentAgent: |
| | def __init__(self) -> None: |
| | self.visited: set[int] = set() |
| | self.graph: dict[int, dict[str, int]] = {} |
| | self.loc_untried: dict[int, list[str]] = {} |
| | self.interactions_done: dict[int, int] = {} |
| | |
| | self.recent_memory = deque(maxlen=MEMORY_LEN) |
| | self.no_progress_steps = 0 |
| | self.llm_calls = 0 |
| | self.last_action = "" |
| |
|
| | |
| | |
| | |
| | async def run(self, client, game: str, max_steps: int, seed: int, verbose: bool = False) -> RunResult: |
| | history: list[tuple[str, str, str]] = [] |
| | moves_taken = 0 |
| | final_score = 0 |
| | max_score = 0 |
| | game_completed = False |
| | last_status = {} |
| |
|
| | tools = await client.list_tools() |
| | tool_names = {t.name for t in tools} |
| | has_peek = "peek_action" in tool_names |
| |
|
| | |
| | init_obs = await client.call_tool("play_action", {"action": "look"}) |
| | moves_taken += 1 |
| | self.last_action = "look" |
| | history.append(( |
| | "THOUGHT: Start by looking around to ground the state.", |
| | "TOOL: play_action ARGS: {'action': 'look'}", |
| | self._text(init_obs)[:160], |
| | )) |
| |
|
| | prev_score = 0 |
| | prev_loc = -1 |
| |
|
| | while moves_taken < max_steps: |
| | |
| | try: |
| | raw = await client.call_tool("status", {}) |
| | status = json.loads(self._text(raw)) |
| | last_status = status |
| | except Exception: |
| | status = last_status |
| |
|
| | if not status: |
| | |
| | thought = "THOUGHT: Status unavailable; use a safe action to recover." |
| | tool_call = "TOOL: play_action ARGS: {'action': 'look'}" |
| | res = await client.call_tool("play_action", {"action": "look"}) |
| | moves_taken += 1 |
| | obs_txt = self._text(res) |
| | history.append((thought, tool_call, obs_txt[:160])) |
| | continue |
| |
|
| | loc_id = int(status["loc_id"]) |
| | score = int(status.get("score", 0)) |
| | final_score = score |
| | max_score = int(status.get("max_score", max_score) or max_score) |
| | done = bool(status.get("done", False)) |
| |
|
| | self.visited.add(loc_id) |
| | self._merge_edges(loc_id, status.get("edges_here", {}) or {}) |
| | self.loc_untried[loc_id] = list(status.get("untried_directions", []) or []) |
| |
|
| | if score == prev_score and loc_id == prev_loc: |
| | self.no_progress_steps += 1 |
| | else: |
| | self.no_progress_steps = 0 |
| | prev_score, prev_loc = score, loc_id |
| |
|
| | if done: |
| | game_completed = True |
| | break |
| |
|
| | |
| | thought_reason, action = self._decide(status, seed) |
| |
|
| | |
| | if has_peek: |
| | action = await self._peek_pick(client, status, action) |
| |
|
| | action = self._sanitize_action(action) |
| |
|
| | |
| | thought = f"THOUGHT: {thought_reason}" |
| | tool_call = f"TOOL: play_action ARGS: {{'action': '{action}'}}" |
| |
|
| | |
| | res = await client.call_tool("play_action", {"action": action}) |
| | moves_taken += 1 |
| | obs2 = self._text(res) |
| |
|
| | |
| | self.recent_memory.append((action.lower().strip(), loc_id, score, obs2[:60])) |
| | self.last_action = action |
| |
|
| | if verbose: |
| | print( |
| | f"[step] loc={loc_id} score={score} stuck={self.no_progress_steps} -> {action!r}", |
| | file=sys.stderr, |
| | ) |
| |
|
| | history.append((thought, tool_call, obs2[:160])) |
| |
|
| | if self._is_game_over(obs2): |
| | game_completed = True |
| | break |
| |
|
| | |
| | try: |
| | raw = await client.call_tool("status", {}) |
| | st2 = json.loads(self._text(raw)) |
| | final_score = max(final_score, int(st2.get("score", 0))) |
| | max_score = max_score or int(st2.get("max_score", 0)) |
| | self.visited.add(int(st2["loc_id"])) |
| | except Exception: |
| | pass |
| |
|
| | return RunResult( |
| | final_score=final_score, |
| | max_score=max_score, |
| | moves=moves_taken, |
| | locations_visited=self.visited, |
| | game_completed=game_completed, |
| | history=history, |
| | ) |
| |
|
| | |
| | |
| | |
| | def _decide(self, status: dict, seed: int) -> tuple[str, str]: |
| | loc_id = int(status["loc_id"]) |
| | obs = status.get("last_observation", "") or "" |
| | outcomes = status.get("outcomes_here", {}) or {} |
| |
|
| | banned = {str(x).lower().strip() for x in (status.get("banned_actions_here", []) or [])} |
| | untried = status.get("untried_directions", []) or [] |
| | valid_exits = status.get("valid_exits", []) or [] |
| | suggested = status.get("suggested_interactions", []) or [] |
| |
|
| | |
| | if DISAMBIG_RE.search(obs): |
| | opt = self._extract_option(obs) |
| | if opt and not self._repeat_noop(opt, loc_id): |
| | return "Disambiguation requested by the game; answer with the first plausible option.", opt |
| |
|
| | |
| | untried_set = set(untried) |
| | obs_dirs = self._mentioned_dirs(obs) |
| |
|
| | for d in valid_exits: |
| | dl = d.lower().strip() |
| | if d in untried_set and dl not in banned and not self._repeat_noop(d, loc_id): |
| | return f"Take a valid untried exit to explore: {d}.", d |
| |
|
| | |
| | for d in obs_dirs: |
| | if d in untried_set and d.lower() not in banned and not self._repeat_noop(d, loc_id): |
| | return f"Direction mentioned in observation and untried; explore: {d}.", d |
| |
|
| | |
| | for d in untried: |
| | if d.lower() not in banned and not self._repeat_noop(d, loc_id): |
| | return f"No strong cue; systematically try untried direction: {d}.", d |
| |
|
| | |
| | n = self.interactions_done.get(loc_id, 0) |
| | if n < MAX_INTERACTIONS: |
| | for a in suggested: |
| | al = a.lower().strip() |
| | if al in banned: |
| | continue |
| | if any(al.startswith(x) for x in UNSAFE_STARTS): |
| | continue |
| | if a in outcomes: |
| | continue |
| | if self._repeat_noop(a, loc_id): |
| | continue |
| | self.interactions_done[loc_id] = n + 1 |
| | return f"Try a game-validated interaction in this room (#{n+1}): {a}.", a |
| |
|
| | |
| | avoid = self._oscillation_avoid() |
| | step_dir = self._bfs_step(loc_id, avoid) |
| | if step_dir: |
| | return "No local frontier; backtrack via BFS to nearest unexplored frontier.", step_dir |
| |
|
| | |
| | if self.no_progress_steps >= STUCK_THRESHOLD: |
| | for a in ("look", "inventory"): |
| | if not self._repeat_noop(a, loc_id): |
| | return "Stuck for many steps; run a safe recovery action.", a |
| | noun = self._extract_noun(obs) |
| | if noun and not self._repeat_noop(f"examine {noun}", loc_id): |
| | return "Stuck; examine a likely noun from the observation.", f"examine {noun}" |
| |
|
| | |
| | if LLM_CLIENT is not None: |
| | try: |
| | self.llm_calls += 1 |
| | prompt = self._llm_prompt(status) |
| | resp = call_llm(prompt, LLM_SYSTEM, seed + self.llm_calls) |
| | act = self._parse_llm(resp) |
| | if act and act.lower().strip() not in banned and not self._repeat_noop(act, loc_id): |
| | return "Heuristics exhausted; use one short LLM suggestion (optional fallback).", act |
| | except Exception: |
| | pass |
| |
|
| | return "Fallback to a safe neutral action.", "look" |
| |
|
| | async def _peek_pick(self, client, status: dict, current_action: str) -> str: |
| | """Use peek_action to score a small candidate set and pick best.""" |
| | loc_id = int(status["loc_id"]) |
| | score = int(status.get("score", 0)) |
| |
|
| | candidates = [] |
| | if current_action: |
| | candidates.append(current_action) |
| |
|
| | for d in (status.get("untried_directions", []) or [])[:4]: |
| | if d not in candidates: |
| | candidates.append(d) |
| | for a in (status.get("suggested_interactions", []) or [])[:4]: |
| | if a not in candidates: |
| | candidates.append(a) |
| |
|
| | candidates = candidates[:PEEK_K] |
| | best = current_action |
| | best_u = -10**18 |
| |
|
| | for a in candidates: |
| | try: |
| | raw = await client.call_tool("peek_action", {"action": a}) |
| | st = json.loads(self._text(raw)) |
| | new_score = int(st.get("score", score)) |
| | new_loc = int(st.get("loc_id", loc_id)) |
| | delta = max(0, new_score - score) |
| |
|
| | if new_loc != loc_id: |
| | moved_bonus = 600 if (new_loc not in self.visited) else 80 |
| | else: |
| | moved_bonus = 0 |
| |
|
| | repeat_pen = 120 if self._repeat_noop(a, loc_id) else 0 |
| | u = delta * 900 + moved_bonus - repeat_pen |
| |
|
| | if u > best_u: |
| | best_u = u |
| | best = a |
| | except Exception: |
| | continue |
| |
|
| | return best |
| |
|
| | |
| | |
| | |
| | def _merge_edges(self, loc_id: int, edges_here: dict) -> None: |
| | if not edges_here: |
| | return |
| | node = self.graph.setdefault(loc_id, {}) |
| | for d, nid in edges_here.items(): |
| | try: |
| | node[str(d)] = int(nid) |
| | except Exception: |
| | pass |
| |
|
| | def _oscillation_avoid(self) -> Optional[int]: |
| | locs = [x[1] for x in self.recent_memory] |
| | if len(locs) >= 4 and locs[-1] == locs[-3] and locs[-2] == locs[-4]: |
| | return locs[-2] |
| | return None |
| |
|
| | def _bfs_step(self, from_loc: int, avoid_loc: Optional[int]) -> Optional[str]: |
| | frontier = {lid for lid, u in self.loc_untried.items() if u and lid != from_loc} |
| | if not frontier: |
| | return None |
| |
|
| | q = deque() |
| | seen = {from_loc} |
| |
|
| | for d, nid in self.graph.get(from_loc, {}).items(): |
| | if nid not in seen and nid != avoid_loc: |
| | q.append((nid, d)) |
| | seen.add(nid) |
| |
|
| | while q: |
| | cur, first_dir = q.popleft() |
| | if cur in frontier: |
| | return first_dir |
| | for d, nid in self.graph.get(cur, {}).items(): |
| | if nid not in seen: |
| | seen.add(nid) |
| | q.append((nid, first_dir)) |
| | return None |
| |
|
| | |
| | |
| | |
| | def _repeat_noop(self, action: str, loc_id: int) -> bool: |
| | a = (action or "").lower().strip() |
| | return any(prev_a == a and prev_loc == loc_id for (prev_a, prev_loc, _sc, _o) in self.recent_memory) |
| |
|
| | def _mentioned_dirs(self, obs: str) -> list[str]: |
| | out = [] |
| | for m in DIR_WORD_RE.finditer(obs or ""): |
| | d = m.group(1).lower() |
| | if d not in out: |
| | out.append(d) |
| | return out |
| |
|
| | def _extract_option(self, obs: str) -> Optional[str]: |
| | m = OPTION_RE.search(obs or "") |
| | if m: |
| | return m.group(1).strip().lower() |
| | return None |
| |
|
| | def _extract_noun(self, obs: str) -> Optional[str]: |
| | m = re.search(r"\bthe\s+([a-z]{3,})\b", (obs or "").lower()) |
| | if m: |
| | noun = m.group(1) |
| | |
| | if noun not in {"north", "south", "east", "west", "up", "down", "in", "out"}: |
| | return noun |
| | return None |
| |
|
| | def _sanitize_action(self, a: str) -> str: |
| | a = (a or "").strip() |
| | a = re.sub(r"[`\"']", "", a) |
| | a = re.sub(r"\s+", " ", a).strip() |
| | words = a.split()[:6] |
| | return " ".join(words) if words else "look" |
| |
|
| | def _llm_prompt(self, status: dict) -> str: |
| | inv = ", ".join(status.get("inventory", [])) or "empty" |
| | tried = ", ".join(list((status.get("outcomes_here") or {}).keys())[:20]) or "none" |
| | banned = ", ".join(status.get("banned_actions_here", [])) or "none" |
| | return ( |
| | f"Location: {status.get('loc_name')} (id={status.get('loc_id')})\n" |
| | f"Score: {status.get('score')}/{status.get('max_score')} Moves: {status.get('moves')}\n" |
| | f"Inventory: {inv}\n" |
| | f"Untried dirs: {', '.join((status.get('untried_directions') or [])[:12])}\n" |
| | f"Tried here: {tried}\n" |
| | f"BANNED: {banned}\n\n" |
| | f"Observation:\n{(status.get('last_observation') or '')[:500]}\n" |
| | ) |
| |
|
| | def _parse_llm(self, resp: str) -> str: |
| | for line in (resp or "").splitlines(): |
| | line = line.strip() |
| | if not line: |
| | continue |
| | if line.upper().startswith("ACTION:"): |
| | line = line.split(":", 1)[1].strip() |
| | line = line.lower() |
| | m = re.match( |
| | r"^(?:go\s+)?(north(?:east|west)?|south(?:east|west)?|east|west|up|down|in|out)\b", |
| | line, |
| | ) |
| | if m: |
| | return m.group(1) |
| | return " ".join(line.split()[:5]) |
| | return "look" |
| |
|
| | def _is_game_over(self, text: str) -> bool: |
| | t = (text or "").lower() |
| | return any(x in t for x in ("game over", "you have died", "you are dead", "you have won")) |
| |
|
| | def _text(self, result) -> str: |
| | try: |
| | if hasattr(result, "content") and result.content: |
| | return result.content[0].text |
| | if isinstance(result, list) and result: |
| | return result[0].text |
| | except Exception: |
| | pass |
| | return str(result) |
| |
|
| |
|
| | |
| | async def _test() -> None: |
| | from fastmcp import Client |
| | from fastmcp.client.transports import StdioTransport |
| | import sys as _sys |
| | import os as _os |
| |
|
| | transport = StdioTransport( |
| | command=_sys.executable, |
| | args=[_os.path.join(_os.path.dirname(__file__), "mcp_server.py")], |
| | env={**_os.environ, "GAME": "lostpig"}, |
| | ) |
| | agent = StudentAgent() |
| | async with Client(transport) as client: |
| | res = await agent.run(client, game="lostpig", max_steps=30, seed=42, verbose=True) |
| | print( |
| | f"Score: {res.final_score}/{res.max_score} | Moves: {res.moves} | Locations: {len(res.locations_visited)}", |
| | file=sys.stderr, |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import asyncio |
| | asyncio.run(_test()) |