""" Example: MCP ReAct Agent A complete ReAct agent that uses MCP tools to play text adventure games. This is a working example students can learn from. """ import json import os import re from dataclasses import dataclass, field from typing import Optional from dotenv import load_dotenv from huggingface_hub import InferenceClient import hashlib from collections import defaultdict load_dotenv() # ============================================================================= # LLM Configuration - DO NOT MODIFY # ============================================================================= LLM_MODEL = "Qwen/Qwen2.5-72B-Instruct" _hf_token = os.getenv("HF_TOKEN") if not _hf_token: raise ValueError("HF_TOKEN not found. Set it in your .env file.") LLM_CLIENT = InferenceClient(token=_hf_token) def call_llm(prompt: str, system_prompt: str, seed: int, max_tokens: int = 300) -> str: """Call the LLM with the given prompt.""" messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}, ] response = LLM_CLIENT.chat.completions.create( model=LLM_MODEL, messages=messages, temperature=0.0, max_tokens=max_tokens, seed=seed, ) return response.choices[0].message.content @dataclass class RunResult: """Result of running the agent. Do not modify this class.""" final_score: int max_score: int moves: int locations_visited: set[str] game_completed: bool error: Optional[str] = None history: list[tuple[str, str, str]] = field(default_factory=list) # ============================================================================= # System Prompt # ============================================================================= SYSTEM_PROMPT = """You are an expert text adventure game player. Your goal is to explore, collect treasures, and maximize your score. AVAILABLE TOOLS (use these via MCP): 1. play_action - Execute game commands (north, take lamp, open mailbox, etc.) 2. memory - Get current game state, score, and recent history 3. get_map - See explored locations and connections 4. inventory - Check what you're carrying VALID GAME COMMANDS for play_action: - Movement: north, south, east, west, up, down, enter, exit - Objects: take , drop , open , close , examine - Light: turn on lamp, turn off lamp - Combat: attack with - Other: inventory, look, read , wait FORBIDDEN (will NOT work): check, inspect, search, grab, use, help RESPOND IN THIS EXACT FORMAT (no markdown): THOUGHT: TOOL: ARGS: Examples: THOUGHT: I need to see what's around me. TOOL: play_action ARGS: {"action": "look"} THOUGHT: Let me check my current state and score. TOOL: memory ARGS: {} THOUGHT: The mailbox might contain something useful. TOOL: play_action ARGS: {"action": "open mailbox"} STRATEGY: 1. Start by looking around and checking memory 2. Explore systematically - try all directions 3. Pick up useful items (lamp, sword, etc.) 4. Open containers (mailbox, window, etc.) 5. Use get_map to avoid getting lost 6. Turn on lamp before dark areas! DO NOT repeat the same action multiple times in a row.""" # ============================================================================= # Student Agent Implementation # ============================================================================= class StudentAgent: """ MCP ReAct Agent - A complete working example. This agent demonstrates: - ReAct loop (Thought -> Tool -> Observation) - Loop detection - Action validation - Score tracking via memory tool """ def __init__(self): """Initialize the agent state.""" self.history: list[dict] = [] self.recent_actions: list[str] = [] self.score: int = 0 # --- Context management memory --- # Keyed by (state_id, inv_sig) self.failed_strong = defaultdict(set) # actions that are nonsense here self.failed_soft = defaultdict(dict) # action -> last_step tried (cooldown) self.state_last_obs = {} # (state_id, inv_sig) -> normalized obs self.inv_sig: str = "" # current inventory signature self.prev_inv_sig: str = "" # previous signature to detect changes self.step: int = 0 # current step counter self.debug_context: bool = True # Whether to include context management info in the prompt (for transparency) self.soft_cooldown_steps = 30 # -- LLM judge self.judge_cache: dict[tuple[str, str, str], str] = {} self.use_llm_judge: bool = True self._last_llm_raw = "" self._last_llm_label = "" self._last_llm_cached = False async def run( self, client, game: str, max_steps: int, seed: int, verbose: bool = False, ) -> RunResult: """Run the agent for a game session.""" locations_visited = set() history = [] moves = 0 # Get list of available tools tools = await client.list_tools() tool_names = [t.name for t in tools] # Get initial observation result = await client.call_tool("play_action", {"action": "look"}) observation = self._extract_result(result) # Initialize inventory signature inv_res = await client.call_tool("inventory", {}) inv_text = self._extract_result(inv_res) self.inv_sig = self._inventory_signature(inv_text) self.prev_inv_sig = self.inv_sig # Track initial location location = observation.split("\n")[0] if observation else "Unknown" locations_visited.add(location) if verbose: print(f"\n{observation}") # Main ReAct loop for step in range(1, max_steps + 1): self.step = step # Refresh inventory periodically (cheap and very useful for gating) if step == 1 or step % 7 == 0: inv_res = await client.call_tool("inventory", {}) inv_text = self._extract_result(inv_res) self.inv_sig = self._inventory_signature(inv_text) # If inventory changed, we want to allow retry of gated actions everywhere if self.inv_sig != self.prev_inv_sig: # Clear ALL soft failures (gated actions may now be valid) self.failed_soft.clear() self.prev_inv_sig = self.inv_sig # Build prompt with context prompt = self._build_prompt(observation) # Call LLM for reasoning (use step-based seed for variety) response = call_llm(prompt, SYSTEM_PROMPT, seed + step) # Parse the response thought, tool_name, tool_args = self._parse_response(response, tool_names) if verbose: print(f"\n--- Step {step} ---") print(f"[THOUGHT] {thought}") print(f"[TOOL] {tool_name}({tool_args})") # Validate and fix common issues tool_name, tool_args = self._validate_tool_call(tool_name, tool_args, tool_names) tool_name, tool_args = self._apply_context_management(tool_name, tool_args, observation) # Loop detection if tool_name == "play_action": action = tool_args.get("action", "look") self.recent_actions.append(action) if len(self.recent_actions) > 5: self.recent_actions = self.recent_actions[-5:] # Detect loops - if same action 3 times, force "look" if len(self.recent_actions) >= 3 and len(set(self.recent_actions[-3:])) == 1: if verbose: print(f"[WARNING] Loop detected - forcing 'look'") tool_args = {"action": "look"} self.recent_actions.append("look") moves += 1 # Execute the tool try: prev_observation = observation # keep previous for failure detection result = await client.call_tool(tool_name, tool_args) observation = self._extract_result(result) # Update failure memory only for play_action if tool_name == "play_action": action = tool_args.get("action", "look") self._update_failure_memory(prev_observation, action, observation) self._log_context_state(prev_observation, action, new_observation=observation) if verbose: print(f"[RESULT] {observation[:200]}...") except Exception as e: observation = f"Error: {e}" if verbose: print(f"[ERROR] {e}") # Track location location = observation.split("\n")[0] if observation else "Unknown" locations_visited.add(location) # Update history self.history.append({ "step": step, "thought": thought, "tool": tool_name, "args": tool_args, "result": observation[:200] }) if len(self.history) > 10: self.history = self.history[-10:] # Track score from observation self._update_score(observation) # Record in result history history.append((thought, f"{tool_name}({tool_args})", observation[:100])) # Check for game over if self._is_game_over(observation): if verbose: print("\n*** GAME OVER ***") break return RunResult( final_score=self.score, max_score=350, moves=moves, locations_visited=locations_visited, game_completed=self._is_game_over(observation), history=history, ) def _build_prompt(self, observation: str) -> str: """Build the prompt for the LLM with context.""" parts = [] parts.append(f"Current Score: {self.score}") # Recent history if self.history: parts.append("\nRecent actions:") for entry in self.history[-8:]: action = entry.get("args", {}).get("action", entry["tool"]) result_short = entry["result"][:80] + "..." if len(entry["result"]) > 80 else entry["result"] parts.append(f" > {action} -> {result_short}") # Warn about repeated actions if self.recent_actions and len(set(self.recent_actions[-3:])) == 1: parts.append(f"\n[WARNING: You've been doing '{self.recent_actions[-1]}' repeatedly. TRY SOMETHING DIFFERENT!]") # Add context constraints to reduce repetition sid = self._state_id_from_observation(observation) key = (sid, self.inv_sig) forbidden = list(self.failed_strong[key])[:8] soft_forbidden = list(self.failed_soft[key].keys())[:8] if forbidden or soft_forbidden: parts.append("\nContext restrictions (DO NOT choose these actions here):") if forbidden: parts.append(" Strong banned: " + ", ".join(forbidden)) if soft_forbidden: parts.append(" Recently failed (wait before retry): " + ", ".join(soft_forbidden)) parts.append(f"\nCurrent situation:\n{observation}") parts.append("\nWhat do you do next?") return "\n".join(parts) def _parse_response(self, response: str, valid_tools: list[str]) -> tuple[str, str, dict]: """Parse the LLM response to extract thought, tool, and arguments.""" thought = "No reasoning provided" tool_name = "play_action" tool_args = {"action": "look"} lines = response.strip().split("\n") for line in lines: line_clean = line.strip() line_upper = line_clean.upper() if line_upper.startswith("THOUGHT:"): thought = line_clean.split(":", 1)[1].strip() elif line_upper.startswith("TOOL:"): raw_tool = line_clean.split(":", 1)[1].strip().lower() raw_tool = raw_tool.replace("**", "").replace("*", "").replace("`", "") raw_tool = raw_tool.split()[0] if raw_tool else "play_action" tool_name = raw_tool elif line_upper.startswith("ARGS:"): args_part = line_clean.split(":", 1)[1].strip() try: args_part = args_part.replace("'", '"') tool_args = json.loads(args_part) except json.JSONDecodeError: match = re.search(r'"action"\s*:\s*"([^"]+)"', args_part) if match: tool_args = {"action": match.group(1)} else: tool_args = {"action": "look"} return thought, tool_name, tool_args def _validate_tool_call(self, tool_name: str, tool_args: dict, valid_tools: list[str]) -> tuple[str, dict]: """Validate and fix common tool call issues.""" # Fix tool name if tool_name not in valid_tools: if tool_name in ["action", "do", "command"]: tool_name = "play_action" elif tool_name in ["map", "location"]: tool_name = "get_map" elif tool_name in ["mem", "state", "status"]: tool_name = "memory" elif tool_name in ["inv", "items"]: tool_name = "inventory" else: tool_name = "play_action" # Fix action verbs if tool_name == "play_action": action = tool_args.get("action", "look") invalid_verb_map = { "check": "examine", "inspect": "examine", "search": "look", "grab": "take", "pick": "take", "use": "examine", "investigate": "examine", } words = action.lower().split() if words and words[0] in invalid_verb_map: words[0] = invalid_verb_map[words[0]] action = " ".join(words) action = action.lower().strip() action = action.replace("**", "").replace("*", "").replace("`", "") action = " ".join(action.split()) tool_args["action"] = action return tool_name, tool_args def _extract_result(self, result) -> str: """Extract text from MCP tool result.""" if hasattr(result, 'content') and result.content: return result.content[0].text if isinstance(result, list) and result: return result[0].text if hasattr(result[0], 'text') else str(result[0]) return str(result) def _update_score(self, text: str) -> None: """Update score from game text.""" patterns = [ r'Score:\s*(\d+)', r'score[:\s]+(\d+)', r'\[Score:\s*(\d+)', ] for pattern in patterns: match = re.search(pattern, text, re.IGNORECASE) if match: self.score = max(self.score, int(match.group(1))) def _is_game_over(self, text: str) -> bool: """Check if the game is over.""" game_over_phrases = [ "game over", "you have died", "you are dead", "*** you have died ***", ] text_lower = text.lower() return any(phrase in text_lower for phrase in game_over_phrases) def _normalize_text(self, text: str) -> str: """Normalize text for comparison (remove numbers/punct, collapse spaces).""" s = (text or "").lower() s = re.sub(r"\d+", "0", s) s = re.sub(r"[^a-z0\s]", " ", s) s = re.sub(r"\s+", " ", s).strip() return s[:500] def _state_id_from_observation(self, observation: str) -> str: """Compute a stable-ish state id for the current room/context.""" lines = [l.strip() for l in (observation or "").splitlines() if l.strip()] head = lines[0].lower() if lines else "unknown" # hash helps reduce tiny variations norm = self._normalize_text(head) return hashlib.md5(norm.encode()).hexdigest() def _inventory_signature(self, inv_text: str) -> str: """Signature of inventory text; robust enough for gating retries.""" norm = self._normalize_text(inv_text) return hashlib.md5(norm.encode()).hexdigest() def _classify_failure(self, prev_obs: str, action: str, new_obs: str) -> str: """ Return: "none" | "strong" | "soft" strong = will never become valid (parser/no object) soft = gated (locked/dark/tool) -> retry when inventory/state changes """ # Optional fast rule: identical observation => likely failed (soft) if self._normalize_text(new_obs) == self._normalize_text(prev_obs): return "soft" if not getattr(self, "use_llm_judge", True): return "none" return self._llm_judge_failure(prev_obs, action, new_obs) def _llm_judge_failure(self, prev_obs: str, action: str, new_obs: str) -> str: """ LLM judge: returns "none" | "soft" | "strong" """ # Normalize for caching (prevents repeated LLM calls) prev_n = self._normalize_text(prev_obs) new_n = self._normalize_text(new_obs) act_n = " ".join((action or "").lower().split()) key = (prev_n, act_n, new_n) if key in self.judge_cache: self._last_llm_raw = "(cached)" self._last_llm_label = self.judge_cache[key] self._last_llm_cached = True return self.judge_cache[key] system = "You are a strict classifier for text-adventure command outcomes." prompt = f""" Classify whether the player's action FAILED, based on the before/after observations. Return EXACTLY one label: none | soft | strong Definitions: - strong: The command is invalid/unknown OR refers to something not present/visible OR impossible in principle. It will NOT become valid later just by having a different item. - soft: The command was understood but is blocked by a condition (locked, closed, too dark, need an item, must do something first, not possible yet). It COULD become valid later. - none: The action had an effect OR gave new useful information (state changed, moved, item changed, new description). BE STRICT: If the new observation is basically identical to the previous one AND no progress happened, prefer "soft". PREVIOUS_OBSERVATION: {prev_obs} ACTION: {action} NEW_OBSERVATION: {new_obs} """.strip() out = call_llm(prompt, system, seed=100000 + self.step, max_tokens=8) raw_out = out.strip().lower() label = "none" if "strong" in raw_out: label = "strong" elif "soft" in raw_out: label = "soft" elif raw_out in {"none", "soft", "strong"}: label = raw_out # Store for logging self._last_llm_raw = raw_out self._last_llm_label = label self._last_llm_cached = False self.judge_cache[key] = label return label def _apply_context_management(self,tool_name: str,tool_args: dict,observation: str,) -> tuple[str, dict]: """Prevent repeating failed actions in the same context (state + inv).""" if tool_name != "play_action": return tool_name, tool_args action = (tool_args.get("action") or "look").strip().lower() sid = self._state_id_from_observation(observation) key = (sid, self.inv_sig) # Strong blacklist if action in self.failed_strong[key]: return "play_action", {"action": self._fallback_action(observation)} # Soft blacklist with cooldown if action in self.failed_soft[key]: last = self.failed_soft[key][action] # cooldown: avoid retrying too soon if self.step - last < self.soft_cooldown_steps: return "play_action", {"action": self._fallback_action(observation)} # Prevent immediate repetition in same context if self.recent_actions and action == self.recent_actions[-1]: return "play_action", {"action": self._fallback_action(observation)} return tool_name, {"action": action} def _fallback_action(self, observation: str) -> str: """ Deterministic fallback when the chosen action is banned. Prefer exploration moves; otherwise look/inventory. """ # Prefer moves that haven't been tried recently move_candidates = ["north","south","east","west","up","down","n","s","e","w","u","d"] recent_set = set(self.recent_actions[-8:]) if self.recent_actions else set() for m in move_candidates: if m not in recent_set: return m # If stuck, refresh if "dark" in (observation or "").lower(): # lamp heuristic: often useful in Zork return "turn on lamp" return "look" def _update_failure_memory(self, prev_obs: str, action: str, new_obs: str) -> None: """Update strong/soft failed actions for this (state, inv) context.""" sid = self._state_id_from_observation(prev_obs) key = (sid, self.inv_sig) verdict = self._classify_failure(prev_obs, action, new_obs) if verdict == "strong": self.failed_strong[key].add(action) # also remove from soft if present if action in self.failed_soft[key]: del self.failed_soft[key][action] elif verdict == "soft": self.failed_soft[key][action] = self.step def _log_context_state(self, prev_observation: str, chosen_action: str = "", new_observation: str = ""): """Print debug info for context management (before-state bucket).""" if not self.debug_context: return sid_before = self._state_id_from_observation(prev_observation) sid_after = self._state_id_from_observation(new_observation) if new_observation else "" key_before = (sid_before, self.inv_sig) print("\n" + "=" * 60) print(f"[STEP] {self.step}") print(f"[STATE_ID_BEFORE] {sid_before}") if sid_after: print(f"[STATE_ID_AFTER] {sid_after}") print(f"[INV_SIG] {self.inv_sig[:8]}...") if chosen_action: print(f"[CHOSEN ACTION] {chosen_action}") # ---- LLM Judge Info ---- if getattr(self, "_last_llm_label", ""): print(f"[LLM_VERDICT] {self._last_llm_label}") print(f"[LLM_RAW_OUTPUT] {getattr(self, '_last_llm_raw', '')}") print(f"[LLM_FROM_CACHE] {getattr(self, '_last_llm_cached', False)}") # Strong failures (BEFORE bucket) strong = list(self.failed_strong[key_before]) print(f"[FAILED_STRONG_BEFORE] ({len(strong)})") for a in strong[:10]: print(f" - {a}") # Soft failures (BEFORE bucket) soft = self.failed_soft[key_before] print(f"[FAILED_SOFT_BEFORE] ({len(soft)})") for a, last_step in list(soft.items())[:10]: cooldown_left = max(0, self.soft_cooldown_steps - (self.step - last_step)) print(f" - {a} (retry in {cooldown_left} steps)") print(f"[RECENT_ACTIONS] {self.recent_actions[-10:]}") print("=" * 60 + "\n") # ============================================================================= # Local Testing # ============================================================================= async def test_agent(): """Test the agent locally.""" from fastmcp import Client agent = StudentAgent() async with Client("mcp_server.py") as client: result = await agent.run( client=client, game="zork1", max_steps=20, seed=42, verbose=True, ) print(f"\n{'=' * 50}") print(f"Final Score: {result.final_score}") print(f"Moves: {result.moves}") print(f"Locations: {len(result.locations_visited)}") if __name__ == "__main__": import asyncio asyncio.run(test_agent())