| """ |
| Improved MCP Agent for Text Adventures |
| |
| This agent moves beyond a pure ReAct loop by keeping structured, location-aware |
| state: |
| - New-location detection (Jericho-aware with LLM fallback) |
| - Valid action retrieval per location |
| - Per-location action logs with summarized outcomes |
| - Exploration bias to avoid stalling in one place |
| """ |
|
|
| import json |
| import os |
| import re |
| from dataclasses import dataclass, field |
| from typing import Any, Optional |
|
|
| from dotenv import load_dotenv |
| from huggingface_hub import InferenceClient |
|
|
| load_dotenv() |
|
|
| |
| |
| |
|
|
| LLM_MODEL = "Qwen/Qwen2.5-72B-Instruct" |
|
|
| _hf_token = os.getenv("HF_TOKEN") |
| if not _hf_token: |
| raise ValueError("HF_TOKEN not found. Set it in your .env file.") |
|
|
| LLM_CLIENT = InferenceClient(token=_hf_token) |
|
|
|
|
| def call_llm(prompt: str, system_prompt: str, seed: int, max_tokens: int = 220) -> str: |
| """Call the LLM with the given prompt.""" |
| messages = [ |
| {"role": "system", "content": system_prompt}, |
| {"role": "user", "content": prompt}, |
| ] |
|
|
| response = LLM_CLIENT.chat.completions.create( |
| model=LLM_MODEL, |
| messages=messages, |
| temperature=0.0, |
| max_tokens=max_tokens, |
| seed=seed, |
| ) |
|
|
| return response.choices[0].message.content |
|
|
|
|
| @dataclass |
| class RunResult: |
| """Result of running the agent. Do not modify this class.""" |
| final_score: int |
| max_score: int |
| moves: int |
| locations_visited: set[str] |
| game_completed: bool |
| error: Optional[str] = None |
| history: list[tuple[str, str, str]] = field(default_factory=list) |
|
|
|
|
| MOVEMENT_ACTIONS = { |
| "north", "south", "east", "west", "up", "down", "enter", "exit", |
| "n", "s", "e", "w", "u", "d", |
| } |
| DEFAULT_ACTIONS = [ |
| "look", "inventory", "north", "south", "east", "west", "up", "down", |
| ] |
| DANGEROUS_ACTION_PREFIXES = ( |
| "jump", |
| "leap", |
| "dive", |
| "suicide", |
| ) |
| DANGEROUS_ACTIONS = { |
| "kill self", |
| "attack self", |
| "hit self", |
| "hurt self", |
| "stab self", |
| "cut self", |
| } |
| STUCK_THRESHOLD = 4 |
| VALID_ACTION_REFRESH_STUCK_INTERVAL = 3 |
| LOCATION_LOG_LIMIT = 20 |
| PROMPT_ACTIONS_MIN = 20 |
| PROMPT_ACTIONS_MAX = 40 |
| PROMPT_ACTIONS_TARGET = 30 |
| ACTION_SYSTEM_PROMPT = """You are a strong text-adventure planner. |
| You receive structured context (valid actions, location action history, and stagnation level). |
| |
| Your task is to choose ONE next game command. |
| |
| Rules: |
| 1. Prefer untried actions supported by current observation and local history. |
| 2. If STUCK_LEVEL >= 4, strongly prioritize exploration (movement or a new untried action). |
| 3. Avoid repeating an action that already failed in the same location. |
| 4. Avoid obviously dangerous commands such as jump, leap, dive, or self-harm actions. |
| 5. Use concise text-adventure commands only. |
| 6. No markdown. |
| |
| Respond exactly as: |
| THOUGHT: <short reasoning> |
| ACTION: <single game command> |
| """ |
|
|
|
|
| LOCATION_CHANGE_SYSTEM_PROMPT = """Decide if the player entered a new location. |
| Return JSON only: |
| {"new_location": true/false, "location_label": "<best label>"} |
| """ |
|
|
|
|
| SUMMARY_SYSTEM_PROMPT = """Summarize an action outcome in one short line. |
| Focus on progress, obstacle, or reward. |
| No markdown. Max 16 words. |
| """ |
|
|
|
|
| class StudentAgent: |
| """ |
| Location-aware MCP agent. |
| |
| Main ideas: |
| - Uses Jericho location when available to detect room transitions |
| - Pulls valid actions when entering a room |
| - Maintains per-location action logs |
| - Uses exploration bias to avoid stalling |
| """ |
|
|
| def __init__(self): |
| self._reset_state() |
|
|
| def _reset_state(self) -> None: |
| self.history: list[dict] = [] |
| self.recent_actions: list[str] = [] |
|
|
| self.score: int = 0 |
| self.max_score: int = 0 |
|
|
| self.current_location_key: str = "unknown" |
| self.current_location_label: str = "Unknown" |
| self.current_display_location: str = "Unknown" |
| self.current_jericho_location: str = "Unknown" |
| self.steps_in_current_location: int = 0 |
|
|
| self.location_visit_counts: dict[str, int] = {} |
| self.location_action_log: dict[str, list[dict]] = {} |
| self.location_valid_actions: dict[str, list[str]] = {} |
| self.latest_context: dict[str, Any] = {} |
| self.latest_lookahead: dict[str, Any] = {} |
| self.refresh_valid_actions_next_step: bool = True |
|
|
| async def run( |
| self, |
| client, |
| game: str, |
| max_steps: int, |
| seed: int, |
| verbose: bool = False, |
| ) -> RunResult: |
| """Run the agent for a game session.""" |
| self._reset_state() |
| locations_visited: set[str] = set() |
| result_history: list[tuple[str, str, str]] = [] |
| moves = 0 |
|
|
| tools = await client.list_tools() |
| tool_names = [t.name for t in tools] |
|
|
| |
| initial_result = await client.call_tool("play_action", {"action": "look"}) |
| observation = self._extract_result(initial_result) |
| self._update_score(observation) |
| await self._refresh_context(client, tool_names, observation) |
|
|
| memory_text = "" |
| if "memory" in tool_names: |
| memory_text = await self._safe_tool_call(client, "memory", {}) |
| self._update_score(memory_text) |
| self._update_max_score(memory_text) |
|
|
| location_info = await self._get_location_info( |
| client=client, |
| observation=observation, |
| memory_text=memory_text, |
| tool_names=tool_names, |
| ) |
| self._set_current_location(location_info, entered_new_location=True) |
| locations_visited.add(self.current_location_key) |
|
|
| valid_actions = await self._fetch_valid_actions( |
| client, |
| self.current_location_key, |
| tool_names, |
| force_refresh=True, |
| ) |
| self.refresh_valid_actions_next_step = False |
|
|
| map_snapshot = "" |
| if "get_map" in tool_names: |
| map_snapshot = await self._safe_tool_call(client, "get_map", {}) |
|
|
| if verbose: |
| print(f"\n{observation}") |
| print( |
| f"[LOCATION] {self.current_location_label} " |
| f"(key={self.current_location_key}, display={self.current_display_location}, " |
| f"jericho={self.current_jericho_location})" |
| ) |
|
|
| for step in range(1, max_steps + 1): |
| local_log = self.location_action_log.get(self.current_location_key, []) |
| refresh_valid_actions = self._should_refresh_valid_actions( |
| location_key=self.current_location_key, |
| tool_names=tool_names, |
| ) |
| valid_actions = await self._fetch_valid_actions( |
| client, |
| self.current_location_key, |
| tool_names, |
| force_refresh=refresh_valid_actions, |
| ) |
| self.refresh_valid_actions_next_step = False |
| lookahead = self._disabled_lookahead() |
|
|
| prompt = self._build_prompt( |
| observation=observation, |
| memory_text=memory_text, |
| map_snapshot=map_snapshot, |
| valid_actions=valid_actions, |
| local_log=local_log, |
| step=step, |
| max_steps=max_steps, |
| ) |
|
|
| llm_response = call_llm( |
| prompt=prompt, |
| system_prompt=ACTION_SYSTEM_PROMPT, |
| seed=seed + step, |
| max_tokens=220, |
| ) |
| thought, action = self._parse_action_response(llm_response) |
|
|
| action = self._validate_action(action, valid_actions) |
| action = self._apply_exploration_bias( |
| action=action, |
| location_key=self.current_location_key, |
| valid_actions=valid_actions, |
| ) |
|
|
| |
| if len(self.recent_actions) >= 2 and all(a == action for a in self.recent_actions[-2:]): |
| alternatives = self._get_untried_actions( |
| self.current_location_key, |
| valid_actions, |
| ) |
| if alternatives: |
| action = alternatives[0] |
|
|
| action = self._avoid_risky_action( |
| action=action, |
| valid_actions=valid_actions, |
| lookahead=lookahead, |
| ) |
|
|
| prev_score = self.score |
| prev_observation = observation |
| prev_location_info = { |
| "location_key": self.current_location_key, |
| "display_location": self.current_display_location, |
| "jericho_location": self.current_jericho_location, |
| } |
|
|
| if verbose: |
| print(f"\n--- Step {step} ---") |
| print(f"[THOUGHT] {thought}") |
| print(f"[ACTION] {action}") |
|
|
| try: |
| step_result = await client.call_tool("play_action", {"action": action}) |
| observation = self._extract_result(step_result) |
| except Exception as e: |
| observation = f"Error: {e}" |
|
|
| self.recent_actions.append(action) |
| if len(self.recent_actions) > 6: |
| self.recent_actions = self.recent_actions[-6:] |
| moves += 1 |
|
|
| self._update_score(observation) |
| await self._refresh_context(client, tool_names, observation) |
|
|
| if "memory" in tool_names: |
| memory_text = await self._safe_tool_call(client, "memory", {}) |
| self._update_score(memory_text) |
| self._update_max_score(memory_text) |
|
|
| location_info = await self._get_location_info( |
| client=client, |
| observation=observation, |
| memory_text=memory_text, |
| tool_names=tool_names, |
| ) |
| entered_new_location, resolved_location = self._did_enter_new_location( |
| previous_location=prev_location_info["location_key"], |
| current_location=location_info["location_key"], |
| previous_observation=prev_observation, |
| current_observation=observation, |
| seed=seed + step, |
| ) |
|
|
| location_info["location_key"] = resolved_location |
| self._set_current_location(location_info, entered_new_location=entered_new_location) |
| locations_visited.add(self.current_location_key) |
| if entered_new_location: |
| self.refresh_valid_actions_next_step = True |
|
|
| if entered_new_location: |
| if "get_map" in tool_names: |
| map_snapshot = await self._safe_tool_call(client, "get_map", {}) |
| else: |
| if self.steps_in_current_location >= STUCK_THRESHOLD and "get_map" in tool_names: |
| map_snapshot = await self._safe_tool_call(client, "get_map", {}) |
|
|
| outcome_summary = self._summarize_outcome( |
| action=action, |
| previous_observation=prev_observation, |
| current_observation=observation, |
| seed=seed + step, |
| ) |
| score_delta = self.score - prev_score |
| self._append_location_log( |
| location_key=prev_location_info["location_key"], |
| action=action, |
| summary=outcome_summary, |
| score_delta=score_delta, |
| ) |
|
|
| self.history.append( |
| { |
| "step": step, |
| "location": prev_location_info["location_key"], |
| "thought": thought, |
| "action": action, |
| "summary": outcome_summary, |
| } |
| ) |
| if len(self.history) > 15: |
| self.history = self.history[-15:] |
|
|
| result_history.append((thought, action, observation[:100])) |
|
|
| if verbose: |
| print(f"[RESULT] {observation[:180]}...") |
| print( |
| f"[LOCATION] {self.current_location_label} " |
| f"(key={self.current_location_key}) | Stuck steps: {self.steps_in_current_location}" |
| ) |
| print(f"[SCORE] {self.score}/{self.max_score}") |
|
|
| if self._is_game_over(observation): |
| if verbose: |
| print("\n*** GAME OVER ***") |
| break |
|
|
| return RunResult( |
| final_score=self.score, |
| max_score=self.max_score, |
| moves=moves, |
| locations_visited=locations_visited, |
| game_completed=self._is_game_over(observation), |
| history=result_history, |
| ) |
|
|
| async def _safe_tool_call(self, client, tool_name: str, tool_args: dict) -> str: |
| """Call an MCP tool and always return text.""" |
| try: |
| result = await client.call_tool(tool_name, tool_args) |
| return self._extract_result(result) |
| except Exception as e: |
| return f"Error: {e}" |
|
|
| async def _refresh_context(self, client, tool_names: list[str], observation: str) -> None: |
| """Refresh structured context from MCP if available.""" |
| if "get_context" not in tool_names: |
| return |
|
|
| raw = await self._safe_tool_call(client, "get_context", {"limit": 80}) |
| payload = self._parse_json_payload(raw) |
| if not isinstance(payload, dict): |
| return |
|
|
| self.latest_context = payload |
| score = payload.get("score") |
| max_score = payload.get("max_score") |
| if isinstance(score, int): |
| self.score = score |
| if isinstance(max_score, int) and max_score > 0: |
| self.max_score = max_score |
|
|
| async def _get_location_info( |
| self, |
| client, |
| observation: str, |
| memory_text: str, |
| tool_names: list[str], |
| ) -> dict: |
| """Get current location using get_location tool, memory, and observation fallback.""" |
| display_location = "Unknown" |
| jericho_location = "Unknown" |
|
|
| |
| if self.latest_context: |
| display_location = str( |
| self.latest_context.get("display_location") |
| or self.latest_context.get("location") |
| or display_location |
| ) |
| jericho_location = str(self.latest_context.get("jericho_location", jericho_location)) |
|
|
| |
| if "get_location" in tool_names: |
| raw = await self._safe_tool_call(client, "get_location", {}) |
| payload = self._parse_json_payload(raw) |
| if isinstance(payload, dict): |
| display_location = str(payload.get("display_location", display_location) or display_location) |
| jericho_location = str(payload.get("jericho_location", jericho_location) or jericho_location) |
|
|
| parsed_memory = self._parse_memory(memory_text) |
|
|
| if display_location == "Unknown": |
| display_location = parsed_memory.get("display_location", "Unknown") |
| if jericho_location == "Unknown": |
| jericho_location = parsed_memory.get("jericho_location", "Unknown") |
|
|
| if display_location == "Unknown": |
| first_line = observation.strip().split("\n")[0] if observation else "Unknown" |
| display_location = first_line if first_line else "Unknown" |
|
|
| display_location = self._normalize_location(display_location) |
| jericho_location = self._normalize_location(jericho_location) |
| location_key = self._location_key(display_location, jericho_location) |
| location_label = self._humanize_location( |
| jericho_location if jericho_location != "unknown" else display_location |
| ) |
|
|
| return { |
| "display_location": display_location, |
| "jericho_location": jericho_location, |
| "location_key": location_key, |
| "location_label": location_label, |
| } |
|
|
| def _set_current_location(self, location_info: dict, entered_new_location: bool) -> None: |
| """Update current location state and stagnation counters.""" |
| self.current_display_location = location_info.get("display_location", "Unknown") |
| self.current_jericho_location = location_info.get("jericho_location", "Unknown") |
| self.current_location_key = location_info.get("location_key", "unknown") |
| self.current_location_label = location_info.get( |
| "location_label", |
| self._humanize_location(self.current_location_key), |
| ) |
|
|
| if entered_new_location: |
| self.steps_in_current_location = 0 |
| self.location_visit_counts[self.current_location_key] = ( |
| self.location_visit_counts.get(self.current_location_key, 0) + 1 |
| ) |
| else: |
| self.steps_in_current_location += 1 |
|
|
| async def _fetch_valid_actions( |
| self, |
| client, |
| location_key: str, |
| tool_names: list[str], |
| force_refresh: bool = False, |
| ) -> list[str]: |
| """Fetch and cache valid actions for a location.""" |
| if not force_refresh and location_key in self.location_valid_actions: |
| return self.location_valid_actions[location_key] |
|
|
| actions: list[str] = [] |
| context_actions = self.latest_context.get("valid_actions") |
| if isinstance(context_actions, list): |
| actions = [str(v) for v in context_actions] |
|
|
| if "get_valid_actions" in tool_names: |
| |
| if not actions: |
| raw = await self._safe_tool_call(client, "get_valid_actions", {"limit": 80}) |
| payload = self._parse_json_payload(raw) |
| if isinstance(payload, dict): |
| values = payload.get("valid_actions", []) |
| if isinstance(values, list): |
| actions = [str(v) for v in values] |
|
|
| if not actions: |
| actions = DEFAULT_ACTIONS.copy() |
|
|
| normalized: list[str] = [] |
| for action in actions: |
| norm = self._normalize_action(action) |
| if norm and norm not in normalized: |
| normalized.append(norm) |
|
|
| if not normalized: |
| normalized = DEFAULT_ACTIONS.copy() |
|
|
| self.location_valid_actions[location_key] = normalized |
| return normalized |
|
|
| def _should_refresh_valid_actions(self, location_key: str, tool_names: list[str]) -> bool: |
| """Refresh valid actions only when the cache is missing or likely stale.""" |
| if "get_valid_actions" not in tool_names: |
| return False |
| if self.refresh_valid_actions_next_step: |
| return True |
|
|
| cached = self.location_valid_actions.get(location_key, []) |
| if not cached: |
| return True |
|
|
| if not self._get_untried_actions(location_key, cached): |
| return True |
|
|
| if self.steps_in_current_location >= STUCK_THRESHOLD: |
| stuck_offset = self.steps_in_current_location - STUCK_THRESHOLD |
| if stuck_offset % VALID_ACTION_REFRESH_STUCK_INTERVAL == 0: |
| return True |
|
|
| return False |
|
|
| def _disabled_lookahead(self) -> dict[str, Any]: |
| """This variant disables lookahead entirely for Frotz stability.""" |
| payload = { |
| "enabled": False, |
| "risky_actions": [], |
| "entries": [], |
| } |
| self.latest_lookahead = payload |
| return payload |
|
|
| def _did_enter_new_location( |
| self, |
| previous_location: str, |
| current_location: str, |
| previous_observation: str, |
| current_observation: str, |
| seed: int, |
| ) -> tuple[bool, str]: |
| """ |
| Determine if the agent entered a new location. |
| |
| Priority: |
| 1) Jericho/location-key comparison (deterministic) |
| 2) LLM fallback when locations are unknown/ambiguous |
| """ |
| prev_norm = self._normalize_location(previous_location) |
| curr_norm = self._normalize_location(current_location) |
|
|
| if prev_norm != "unknown" and curr_norm != "unknown": |
| return prev_norm != curr_norm, curr_norm |
|
|
| |
| prompt = ( |
| f"Previous location key: {previous_location}\n" |
| f"Current location key: {current_location}\n\n" |
| f"Previous observation:\n{previous_observation[:1000]}\n\n" |
| f"Current observation:\n{current_observation[:1000]}\n" |
| ) |
| try: |
| response = call_llm( |
| prompt=prompt, |
| system_prompt=LOCATION_CHANGE_SYSTEM_PROMPT, |
| seed=seed, |
| max_tokens=120, |
| ) |
| payload = self._parse_json_payload(response) |
| if isinstance(payload, dict): |
| is_new = bool(payload.get("new_location", False)) |
| label = str(payload.get("location_label", current_location)).strip() |
| normalized_label = self._normalize_location(label) |
| if normalized_label != "unknown": |
| return is_new, normalized_label |
| except Exception: |
| pass |
|
|
| |
| prev_line = previous_observation.strip().split("\n")[0].strip().lower() if previous_observation else "" |
| curr_line = current_observation.strip().split("\n")[0].strip().lower() if current_observation else "" |
| if prev_line and curr_line and prev_line != curr_line: |
| fallback_key = curr_norm if curr_norm != "unknown" else self._normalize_location(curr_line) |
| return True, fallback_key |
|
|
| fallback_key = curr_norm if curr_norm != "unknown" else self._normalize_location(curr_line) |
| return False, fallback_key |
|
|
| def _summarize_outcome( |
| self, |
| action: str, |
| previous_observation: str, |
| current_observation: str, |
| seed: int, |
| ) -> str: |
| """Summarize action outcomes for local logs.""" |
| prompt = ( |
| f"Action: {action}\n\n" |
| f"Before:\n{previous_observation[:500]}\n\n" |
| f"After:\n{current_observation[:700]}\n" |
| ) |
| try: |
| summary = call_llm( |
| prompt=prompt, |
| system_prompt=SUMMARY_SYSTEM_PROMPT, |
| seed=seed, |
| max_tokens=64, |
| ) |
| summary = summary.strip().replace("\n", " ") |
| return summary[:180] if summary else "No clear effect." |
| except Exception: |
| cleaned = re.sub(r"\s+", " ", current_observation).strip() |
| if not cleaned: |
| return "No clear effect." |
| return cleaned[:120] |
|
|
| def _append_location_log(self, location_key: str, action: str, summary: str, score_delta: int) -> None: |
| """Append an action summary to the per-location log.""" |
| if location_key not in self.location_action_log: |
| self.location_action_log[location_key] = [] |
|
|
| self.location_action_log[location_key].append( |
| { |
| "action": self._normalize_action(action), |
| "summary": summary, |
| "score_delta": score_delta, |
| } |
| ) |
| if len(self.location_action_log[location_key]) > LOCATION_LOG_LIMIT: |
| self.location_action_log[location_key] = self.location_action_log[location_key][-LOCATION_LOG_LIMIT:] |
|
|
| def _build_prompt( |
| self, |
| observation: str, |
| memory_text: str, |
| map_snapshot: str, |
| valid_actions: list[str], |
| local_log: list[dict], |
| step: int, |
| max_steps: int, |
| ) -> str: |
| """Build a location-aware decision prompt.""" |
| untried_actions = self._get_untried_actions(self.current_location_key, valid_actions) |
| prompt_actions = self._select_actions_for_llm(self.current_location_key, valid_actions) |
| log_lines = [] |
| for entry in local_log[-6:]: |
| delta_prefix = "+" if entry.get("score_delta", 0) > 0 else "" |
| log_lines.append( |
| f"- {entry.get('action')} -> {entry.get('summary')} " |
| f"(score {delta_prefix}{entry.get('score_delta', 0)})" |
| ) |
|
|
| exploration_hint = ( |
| "HIGH_EXPLORATION_BIAS: yes" |
| if self.steps_in_current_location >= STUCK_THRESHOLD |
| else "HIGH_EXPLORATION_BIAS: no" |
| ) |
|
|
| parts = [ |
| f"Step: {step}/{max_steps}", |
| f"Score: {self.score}/{self.max_score}", |
| f"Location label: {self.current_location_label}", |
| f"Location key: {self.current_location_key}", |
| f"Display location: {self.current_display_location}", |
| f"Jericho location: {self.current_jericho_location}", |
| f"STUCK_LEVEL: {self.steps_in_current_location}", |
| exploration_hint, |
| "", |
| "Valid actions (selected):", |
| ", ".join(prompt_actions) if prompt_actions else "(none)", |
| f"(showing {len(prompt_actions)} of {len(valid_actions)} valid actions)", |
| "", |
| "Untried candidates:", |
| ", ".join(untried_actions[:12]) if untried_actions else "(none)", |
| "", |
| "Recent outcomes in this location:", |
| "\n".join(log_lines) if log_lines else "- none yet", |
| "", |
| "Current observation:", |
| observation[:1500], |
| "", |
| "Memory snapshot:", |
| memory_text[:900] if memory_text else "(memory unavailable)", |
| ] |
|
|
| if map_snapshot: |
| parts.extend(["", "Map snapshot:", map_snapshot[:900]]) |
|
|
| parts.extend( |
| [ |
| "", |
| "Pick exactly one next command.", |
| "If HIGH_EXPLORATION_BIAS is yes, avoid staying in place.", |
| ] |
| ) |
| return "\n".join(parts) |
|
|
| def _select_actions_for_llm(self, location_key: str, valid_actions: list[str]) -> list[str]: |
| """Select an informative subset of valid actions instead of first-N truncation.""" |
| normalized: list[str] = [] |
| seen: set[str] = set() |
| for raw in valid_actions: |
| action = self._normalize_action(raw) |
| if not action or action in seen: |
| continue |
| seen.add(action) |
| normalized.append(action) |
|
|
| if not normalized: |
| return DEFAULT_ACTIONS.copy() |
| if len(normalized) <= PROMPT_ACTIONS_MAX: |
| return normalized |
|
|
| target = min(PROMPT_ACTIONS_TARGET, len(normalized)) |
| target = max(min(target, PROMPT_ACTIONS_MAX), PROMPT_ACTIONS_MIN) |
|
|
| local_log = self.location_action_log.get(location_key, []) |
| tried_counts: dict[str, int] = {} |
| failed_counts: dict[str, int] = {} |
| fail_markers = [ |
| "can't", |
| "cannot", |
| "nothing happens", |
| "no effect", |
| "not possible", |
| "not here", |
| "unknown", |
| "blocked", |
| "empty", |
| ] |
| for entry in local_log: |
| action = self._normalize_action(str(entry.get("action", ""))) |
| if not action: |
| continue |
| tried_counts[action] = tried_counts.get(action, 0) + 1 |
|
|
| summary = str(entry.get("summary", "")).lower() |
| delta = int(entry.get("score_delta", 0)) |
| if delta <= 0 and any(marker in summary for marker in fail_markers): |
| failed_counts[action] = failed_counts.get(action, 0) + 1 |
|
|
| interaction_verbs = {"examine", "open", "read", "take", "push", "pull", "put", "turn", "attack"} |
| generic_actions = {"look", "inventory"} |
| stuck = self.steps_in_current_location >= STUCK_THRESHOLD |
|
|
| specific_bucket: list[str] = [] |
| movement_bucket: list[str] = [] |
| remaining_bucket: list[str] = [] |
| generic_bucket: list[str] = [] |
| blocked_bucket: list[str] = [] |
|
|
| for action in normalized: |
| verb = action.split()[0] |
| blocked = failed_counts.get(action, 0) >= 2 |
| if blocked: |
| blocked_bucket.append(action) |
| continue |
|
|
| if tried_counts.get(action, 0) == 0 and len(action.split()) >= 2 and verb in interaction_verbs: |
| specific_bucket.append(action) |
| continue |
|
|
| if verb in MOVEMENT_ACTIONS: |
| movement_bucket.append(action) |
| continue |
|
|
| if action in generic_actions: |
| generic_bucket.append(action) |
| continue |
|
|
| remaining_bucket.append(action) |
|
|
| def ordered(items: list[str]) -> list[str]: |
| return sorted(items, key=lambda a: (tried_counts.get(a, 0), failed_counts.get(a, 0), a)) |
|
|
| specific_bucket = ordered(specific_bucket) |
| movement_bucket = ordered(movement_bucket) |
| remaining_bucket = ordered(remaining_bucket) |
| generic_bucket = ordered(generic_bucket) |
| blocked_bucket = ordered(blocked_bucket) |
|
|
| if stuck: |
| ranked = movement_bucket + specific_bucket + remaining_bucket + generic_bucket + blocked_bucket |
| else: |
| ranked = specific_bucket + movement_bucket + remaining_bucket + generic_bucket + blocked_bucket |
|
|
| selected: list[str] = [] |
| used_verbs: set[str] = set() |
|
|
| def add(action: str) -> None: |
| if action and action not in selected: |
| selected.append(action) |
|
|
| |
| for action in ranked: |
| verb = action.split()[0] |
| if verb in used_verbs: |
| continue |
| add(action) |
| used_verbs.add(verb) |
| if len(selected) >= target: |
| return selected[:target] |
|
|
| |
| for action in ranked: |
| add(action) |
| if len(selected) >= target: |
| return selected[:target] |
|
|
| return selected[:target] |
|
|
| def _parse_action_response(self, response: str) -> tuple[str, str]: |
| """Parse LLM response for THOUGHT and ACTION.""" |
| thought = "No reasoning provided." |
| action = "look" |
|
|
| for line in response.strip().split("\n"): |
| clean = line.strip() |
| upper = clean.upper() |
| if upper.startswith("THOUGHT:"): |
| thought = clean.split(":", 1)[1].strip() |
| elif upper.startswith("ACTION:"): |
| action = clean.split(":", 1)[1].strip() |
|
|
| if not action or action == "look": |
| payload = self._parse_json_payload(response) |
| if isinstance(payload, dict): |
| if "thought" in payload: |
| thought = str(payload.get("thought", thought)) |
| if "action" in payload: |
| action = str(payload.get("action", action)) |
|
|
| return thought, action |
|
|
| def _validate_action(self, action: str, valid_actions: list[str]) -> str: |
| """Normalize action text and repair common invalid verbs.""" |
| action = self._normalize_action(action) |
| if not action: |
| return "look" |
|
|
| invalid_verb_map = { |
| "check": "examine", |
| "inspect": "examine", |
| "search": "look", |
| "grab": "take", |
| "pick": "take", |
| "use": "examine", |
| "investigate": "examine", |
| } |
|
|
| words = action.split() |
| if words and words[0] in invalid_verb_map: |
| words[0] = invalid_verb_map[words[0]] |
| action = " ".join(words) |
|
|
| if self._is_dangerous_action(action): |
| return self._fallback_safe_action(valid_actions) |
|
|
| if valid_actions: |
| normalized_valid = [self._normalize_action(v) for v in valid_actions] |
| if action in normalized_valid: |
| return action |
|
|
| |
| verb = action.split()[0] |
| if any(v.startswith(verb + " ") or v == verb for v in normalized_valid): |
| return action |
|
|
| |
| noisy = len(action.split()) > 8 or "?" in action |
| if noisy: |
| untried = self._get_untried_actions(self.current_location_key, normalized_valid) |
| return untried[0] if untried else normalized_valid[0] |
|
|
| return action |
|
|
| def _is_dangerous_action(self, action: str) -> bool: |
| """Block a small set of obviously risky commands.""" |
| action = self._normalize_action(action) |
| if not action: |
| return False |
| if action in DANGEROUS_ACTIONS: |
| return True |
| return any(action == prefix or action.startswith(prefix + " ") for prefix in DANGEROUS_ACTION_PREFIXES) |
|
|
| def _fallback_safe_action(self, valid_actions: list[str]) -> str: |
| """Choose a conservative fallback when an action is blocked.""" |
| normalized_valid = [ |
| self._normalize_action(v) |
| for v in valid_actions |
| if self._normalize_action(v) and not self._is_dangerous_action(self._normalize_action(v)) |
| ] |
| untried = self._get_untried_actions(self.current_location_key, normalized_valid) |
| if untried: |
| return untried[0] |
| if normalized_valid: |
| return normalized_valid[0] |
| return "look" |
|
|
| def _apply_exploration_bias( |
| self, |
| action: str, |
| location_key: str, |
| valid_actions: list[str], |
| ) -> str: |
| """ |
| Force exploration when stagnating in the same room. |
| """ |
| if self.steps_in_current_location < STUCK_THRESHOLD: |
| return action |
|
|
| tried_counts: dict[str, int] = {} |
| for entry in self.location_action_log.get(location_key, []): |
| tried_action = entry.get("action", "") |
| tried_counts[tried_action] = tried_counts.get(tried_action, 0) + 1 |
|
|
| if tried_counts.get(action, 0) >= 2: |
| untried_movement = [ |
| a for a in self._get_untried_actions(location_key, valid_actions) |
| if a.split()[0] in MOVEMENT_ACTIONS |
| ] |
| if untried_movement: |
| return untried_movement[0] |
|
|
| untried_any = self._get_untried_actions(location_key, valid_actions) |
| if untried_any: |
| return untried_any[0] |
|
|
| return action |
|
|
| def _avoid_risky_action( |
| self, |
| action: str, |
| valid_actions: list[str], |
| lookahead: Optional[dict[str, Any]], |
| ) -> str: |
| """Replace only actions flagged as lethal/game-over by lookahead.""" |
| if not lookahead or not lookahead.get("enabled"): |
| return action |
|
|
| risky_set = { |
| self._normalize_action(a) for a in lookahead.get("risky_actions", []) |
| if self._normalize_action(a) |
| } |
| action = self._normalize_action(action) |
| if not action or action not in risky_set: |
| return action |
|
|
| normalized_valid: list[str] = [] |
| seen: set[str] = set() |
| for raw in valid_actions: |
| candidate = self._normalize_action(raw) |
| if candidate and candidate not in seen and candidate not in risky_set: |
| seen.add(candidate) |
| normalized_valid.append(candidate) |
|
|
| untried_valid = self._get_untried_actions(self.current_location_key, normalized_valid) |
| if untried_valid: |
| return untried_valid[0] |
| if normalized_valid: |
| return normalized_valid[0] |
|
|
| return "look" |
|
|
| def _get_untried_actions(self, location_key: str, candidates: list[str]) -> list[str]: |
| """Return candidate actions not yet tried in this location.""" |
| tried = {entry.get("action", "") for entry in self.location_action_log.get(location_key, [])} |
| untried: list[str] = [] |
| for action in candidates: |
| norm = self._normalize_action(action) |
| if norm and norm not in tried and norm not in untried: |
| untried.append(norm) |
| return untried |
|
|
| def _parse_json_payload(self, text: str): |
| """Parse JSON from plain text or fenced block.""" |
| if not text: |
| return None |
|
|
| cleaned = text.strip() |
| if cleaned.startswith("```"): |
| cleaned = re.sub(r"^```(?:json)?\s*|\s*```$", "", cleaned, flags=re.DOTALL).strip() |
|
|
| try: |
| return json.loads(cleaned) |
| except json.JSONDecodeError: |
| pass |
|
|
| obj_match = re.search(r"\{.*\}", cleaned, flags=re.DOTALL) |
| if obj_match: |
| try: |
| return json.loads(obj_match.group(0)) |
| except json.JSONDecodeError: |
| pass |
|
|
| arr_match = re.search(r"\[.*\]", cleaned, flags=re.DOTALL) |
| if arr_match: |
| try: |
| return json.loads(arr_match.group(0)) |
| except json.JSONDecodeError: |
| pass |
|
|
| return None |
|
|
| def _parse_memory(self, memory_text: str) -> dict: |
| """Parse key fields from memory() output.""" |
| parsed = { |
| "display_location": "Unknown", |
| "jericho_location": "Unknown", |
| "score": None, |
| "max_score": None, |
| "moves": None, |
| } |
| if not memory_text: |
| return parsed |
|
|
| display_match = re.search(r"-\s*(?:Display\s+)?Location:\s*(.+)", memory_text) |
| jericho_match = re.search(r"-\s*Jericho\s+Location:\s*(.+)", memory_text) |
| score_match = re.search(r"-\s*Score:\s*(\d+)", memory_text) |
| max_score_match = re.search(r"-\s*Max\s+Score:\s*(\d+)", memory_text) |
| moves_match = re.search(r"-\s*Moves:\s*(\d+)", memory_text) |
|
|
| if display_match: |
| parsed["display_location"] = display_match.group(1).strip() |
| if jericho_match: |
| parsed["jericho_location"] = jericho_match.group(1).strip() |
| if score_match: |
| parsed["score"] = int(score_match.group(1)) |
| if max_score_match: |
| parsed["max_score"] = int(max_score_match.group(1)) |
| if moves_match: |
| parsed["moves"] = int(moves_match.group(1)) |
|
|
| return parsed |
|
|
| def _extract_result(self, result) -> str: |
| """Extract text from MCP tool result.""" |
| if hasattr(result, "content") and result.content: |
| return result.content[0].text |
| if isinstance(result, list) and result: |
| return result[0].text if hasattr(result[0], "text") else str(result[0]) |
| return str(result) |
|
|
| def _normalize_action(self, action: str) -> str: |
| """Normalize action formatting.""" |
| if not action: |
| return "" |
| action = action.lower().strip() |
| action = action.replace("**", "").replace("*", "").replace("`", "") |
| action = re.sub(r"\s+", " ", action) |
| return action |
|
|
| def _normalize_location(self, location: str) -> str: |
| """Normalize location keys.""" |
| if not location: |
| return "unknown" |
| value = location.strip().lower() |
| value = value.replace("`", "") |
|
|
| |
| |
| if re.match(r"^obj\d+\s*:", value): |
| value = value.split(":", 1)[1].strip() |
| value = re.sub(r"^obj\d+\b[:\s-]*", "", value).strip() |
| value = re.split(r"\s+parent\d+\b", value)[0].strip() |
| value = re.split(r"\s+sibling\d+\b", value)[0].strip() |
| value = re.split(r"\s+child\d+\b", value)[0].strip() |
| value = re.split(r"\s+attributes?\b", value)[0].strip() |
| value = re.split(r"\s+properties?\b", value)[0].strip() |
| value = re.sub(r"[\[\]\(\)\{\}]", " ", value) |
| value = re.sub(r"\s+", " ", value).strip() |
|
|
| if value in {"unknown", "none", "null", "n/a", ""}: |
| return "unknown" |
| return value |
|
|
| def _location_key(self, display_location: str, jericho_location: str) -> str: |
| """Prefer Jericho location as stable room identity.""" |
| jericho = self._normalize_location(jericho_location) |
| if jericho != "unknown": |
| return jericho |
| return self._normalize_location(display_location) |
|
|
| def _humanize_location(self, location: str) -> str: |
| """Convert normalized location keys to readable labels for logs.""" |
| normalized = self._normalize_location(location) |
| if normalized == "unknown": |
| return "Unknown" |
| return " ".join(word.capitalize() for word in normalized.split()) |
|
|
| def _extract_score_from_text(self, text: str) -> Optional[int]: |
| """Extract score from various text formats.""" |
| patterns = [ |
| r"\[Score:\s*(\d+)\s*\|", |
| r"\(Total:\s*(\d+)\)", |
| r"(?:^|\n)\s*-\s*Score:\s*(\d+)\b", |
| r"(?:^|\n)\s*Score:\s*(\d+)\b", |
| r"\bscore\s+is\s+(\d+)\b", |
| ] |
| for pattern in patterns: |
| match = re.search(pattern, text, flags=re.IGNORECASE) |
| if match: |
| return int(match.group(1)) |
| return None |
|
|
| def _update_score(self, text: str) -> None: |
| """Update current score from text when available.""" |
| score = self._extract_score_from_text(text) |
| if score is not None: |
| self.score = score |
|
|
| def _update_max_score(self, text: str) -> None: |
| """Update max score from text when available.""" |
| match = re.search(r"-\s*Max\s+Score:\s*(\d+)", text, flags=re.IGNORECASE) |
| if match: |
| self.max_score = int(match.group(1)) |
|
|
| def _is_game_over(self, text: str) -> bool: |
| """Check if the game is over.""" |
| game_over_phrases = [ |
| "game over", |
| "you have died", |
| "you are dead", |
| "*** you have died ***", |
| "*** the end ***", |
| ] |
| lowered = text.lower() |
| return any(phrase in lowered for phrase in game_over_phrases) |
|
|
|
|
| async def test_agent(): |
| """Test the agent locally.""" |
| from fastmcp import Client |
|
|
| agent = StudentAgent() |
|
|
| async with Client("mcp_server.py") as client: |
| result = await agent.run( |
| client=client, |
| game="zork1", |
| max_steps=20, |
| seed=42, |
| verbose=True, |
| ) |
|
|
| print(f"\n{'=' * 50}") |
| print(f"Final Score: {result.final_score}/{result.max_score}") |
| print(f"Moves: {result.moves}") |
| print(f"Locations: {len(result.locations_visited)}") |
|
|
|
|
| if __name__ == "__main__": |
| import asyncio |
|
|
| asyncio.run(test_agent()) |
|
|