Spaces:
Sleeping
Sleeping
| """ | |
| Example: MCP ReAct Agent | |
| A complete ReAct agent that uses MCP tools to play text adventure games. | |
| This is a working example students can learn from. | |
| """ | |
| import json | |
| import os | |
| import re | |
| from dataclasses import dataclass, field | |
| from typing import Optional | |
| from dotenv import load_dotenv | |
| from huggingface_hub import InferenceClient | |
| import hashlib | |
| from collections import defaultdict | |
| load_dotenv() | |
| # ============================================================================= | |
| # LLM Configuration - DO NOT MODIFY | |
| # ============================================================================= | |
| LLM_MODEL = "Qwen/Qwen2.5-72B-Instruct" | |
| _hf_token = os.getenv("HF_TOKEN") | |
| if not _hf_token: | |
| raise ValueError("HF_TOKEN not found. Set it in your .env file.") | |
| LLM_CLIENT = InferenceClient(token=_hf_token) | |
| def call_llm(prompt: str, system_prompt: str, seed: int, max_tokens: int = 300) -> str: | |
| """Call the LLM with the given prompt.""" | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": prompt}, | |
| ] | |
| response = LLM_CLIENT.chat.completions.create( | |
| model=LLM_MODEL, | |
| messages=messages, | |
| temperature=0.0, | |
| max_tokens=max_tokens, | |
| seed=seed, | |
| ) | |
| return response.choices[0].message.content | |
| class RunResult: | |
| """Result of running the agent. Do not modify this class.""" | |
| final_score: int | |
| max_score: int | |
| moves: int | |
| locations_visited: set[str] | |
| game_completed: bool | |
| error: Optional[str] = None | |
| history: list[tuple[str, str, str]] = field(default_factory=list) | |
| # ============================================================================= | |
| # System Prompt | |
| # ============================================================================= | |
| SYSTEM_PROMPT = """You are an expert text adventure game player. Your goal is to explore, collect treasures, and maximize your score. | |
| AVAILABLE TOOLS (use these via MCP): | |
| 1. play_action - Execute game commands (north, take lamp, open mailbox, etc.) | |
| 2. memory - Get current game state, score, and recent history | |
| 3. get_map - See explored locations and connections | |
| 4. inventory - Check what you're carrying | |
| VALID GAME COMMANDS for play_action: | |
| - Movement: north, south, east, west, up, down, enter, exit | |
| - Objects: take <item>, drop <item>, open <thing>, close <thing>, examine <thing> | |
| - Light: turn on lamp, turn off lamp | |
| - Combat: attack <enemy> with <weapon> | |
| - Other: inventory, look, read <thing>, wait | |
| FORBIDDEN (will NOT work): check, inspect, search, grab, use, help | |
| RESPOND IN THIS EXACT FORMAT (no markdown): | |
| THOUGHT: <brief reasoning about what to do next> | |
| TOOL: <tool_name> | |
| ARGS: <JSON arguments> | |
| Examples: | |
| THOUGHT: I need to see what's around me. | |
| TOOL: play_action | |
| ARGS: {"action": "look"} | |
| THOUGHT: Let me check my current state and score. | |
| TOOL: memory | |
| ARGS: {} | |
| THOUGHT: The mailbox might contain something useful. | |
| TOOL: play_action | |
| ARGS: {"action": "open mailbox"} | |
| STRATEGY: | |
| 1. Start by looking around and checking memory | |
| 2. Explore systematically - try all directions | |
| 3. Pick up useful items (lamp, sword, etc.) | |
| 4. Open containers (mailbox, window, etc.) | |
| 5. Use get_map to avoid getting lost | |
| 6. Turn on lamp before dark areas! | |
| DO NOT repeat the same action multiple times in a row.""" | |
| # ============================================================================= | |
| # Student Agent Implementation | |
| # ============================================================================= | |
| class StudentAgent: | |
| """ | |
| MCP ReAct Agent - A complete working example. | |
| This agent demonstrates: | |
| - ReAct loop (Thought -> Tool -> Observation) | |
| - Loop detection | |
| - Action validation | |
| - Score tracking via memory tool | |
| """ | |
| def __init__(self): | |
| """Initialize the agent state.""" | |
| self.history: list[dict] = [] | |
| self.recent_actions: list[str] = [] | |
| self.score: int = 0 | |
| # --- Context management memory --- | |
| # Keyed by (state_id, inv_sig) | |
| self.failed_strong = defaultdict(set) # actions that are nonsense here | |
| self.failed_soft = defaultdict(dict) # action -> last_step tried (cooldown) | |
| self.state_last_obs = {} # (state_id, inv_sig) -> normalized obs | |
| self.inv_sig: str = "" # current inventory signature | |
| self.prev_inv_sig: str = "" # previous signature to detect changes | |
| self.step: int = 0 # current step counter | |
| self.debug_context: bool = True # Whether to include context management info in the prompt (for transparency) | |
| self.soft_cooldown_steps = 30 | |
| # -- LLM judge | |
| self.judge_cache: dict[tuple[str, str, str], str] = {} | |
| self.use_llm_judge: bool = True | |
| self._last_llm_raw = "" | |
| self._last_llm_label = "" | |
| self._last_llm_cached = False | |
| async def run( | |
| self, | |
| client, | |
| game: str, | |
| max_steps: int, | |
| seed: int, | |
| verbose: bool = False, | |
| ) -> RunResult: | |
| """Run the agent for a game session.""" | |
| locations_visited = set() | |
| history = [] | |
| moves = 0 | |
| # Get list of available tools | |
| tools = await client.list_tools() | |
| tool_names = [t.name for t in tools] | |
| # Get initial observation | |
| result = await client.call_tool("play_action", {"action": "look"}) | |
| observation = self._extract_result(result) | |
| # Initialize inventory signature | |
| inv_res = await client.call_tool("inventory", {}) | |
| inv_text = self._extract_result(inv_res) | |
| self.inv_sig = self._inventory_signature(inv_text) | |
| self.prev_inv_sig = self.inv_sig | |
| # Track initial location | |
| location = observation.split("\n")[0] if observation else "Unknown" | |
| locations_visited.add(location) | |
| if verbose: | |
| print(f"\n{observation}") | |
| # Main ReAct loop | |
| for step in range(1, max_steps + 1): | |
| self.step = step | |
| # Refresh inventory periodically (cheap and very useful for gating) | |
| if step == 1 or step % 7 == 0: | |
| inv_res = await client.call_tool("inventory", {}) | |
| inv_text = self._extract_result(inv_res) | |
| self.inv_sig = self._inventory_signature(inv_text) | |
| # If inventory changed, we want to allow retry of gated actions everywhere | |
| if self.inv_sig != self.prev_inv_sig: | |
| # Clear ALL soft failures (gated actions may now be valid) | |
| self.failed_soft.clear() | |
| self.prev_inv_sig = self.inv_sig | |
| # Build prompt with context | |
| prompt = self._build_prompt(observation) | |
| # Call LLM for reasoning (use step-based seed for variety) | |
| response = call_llm(prompt, SYSTEM_PROMPT, seed + step) | |
| # Parse the response | |
| thought, tool_name, tool_args = self._parse_response(response, tool_names) | |
| if verbose: | |
| print(f"\n--- Step {step} ---") | |
| print(f"[THOUGHT] {thought}") | |
| print(f"[TOOL] {tool_name}({tool_args})") | |
| # Validate and fix common issues | |
| tool_name, tool_args = self._validate_tool_call(tool_name, tool_args, tool_names) | |
| tool_name, tool_args = self._apply_context_management(tool_name, tool_args, observation) | |
| # Loop detection | |
| if tool_name == "play_action": | |
| action = tool_args.get("action", "look") | |
| self.recent_actions.append(action) | |
| if len(self.recent_actions) > 5: | |
| self.recent_actions = self.recent_actions[-5:] | |
| # Detect loops - if same action 3 times, force "look" | |
| if len(self.recent_actions) >= 3 and len(set(self.recent_actions[-3:])) == 1: | |
| if verbose: | |
| print(f"[WARNING] Loop detected - forcing 'look'") | |
| tool_args = {"action": "look"} | |
| self.recent_actions.append("look") | |
| moves += 1 | |
| # Execute the tool | |
| try: | |
| prev_observation = observation # keep previous for failure detection | |
| result = await client.call_tool(tool_name, tool_args) | |
| observation = self._extract_result(result) | |
| # Update failure memory only for play_action | |
| if tool_name == "play_action": | |
| action = tool_args.get("action", "look") | |
| self._update_failure_memory(prev_observation, action, observation) | |
| self._log_context_state(prev_observation, action, new_observation=observation) | |
| if verbose: | |
| print(f"[RESULT] {observation[:200]}...") | |
| except Exception as e: | |
| observation = f"Error: {e}" | |
| if verbose: | |
| print(f"[ERROR] {e}") | |
| # Track location | |
| location = observation.split("\n")[0] if observation else "Unknown" | |
| locations_visited.add(location) | |
| # Update history | |
| self.history.append({ | |
| "step": step, | |
| "thought": thought, | |
| "tool": tool_name, | |
| "args": tool_args, | |
| "result": observation[:200] | |
| }) | |
| if len(self.history) > 10: | |
| self.history = self.history[-10:] | |
| # Track score from observation | |
| self._update_score(observation) | |
| # Record in result history | |
| history.append((thought, f"{tool_name}({tool_args})", observation[:100])) | |
| # Check for game over | |
| if self._is_game_over(observation): | |
| if verbose: | |
| print("\n*** GAME OVER ***") | |
| break | |
| return RunResult( | |
| final_score=self.score, | |
| max_score=350, | |
| moves=moves, | |
| locations_visited=locations_visited, | |
| game_completed=self._is_game_over(observation), | |
| history=history, | |
| ) | |
| def _build_prompt(self, observation: str) -> str: | |
| """Build the prompt for the LLM with context.""" | |
| parts = [] | |
| parts.append(f"Current Score: {self.score}") | |
| # Recent history | |
| if self.history: | |
| parts.append("\nRecent actions:") | |
| for entry in self.history[-8:]: | |
| action = entry.get("args", {}).get("action", entry["tool"]) | |
| result_short = entry["result"][:80] + "..." if len(entry["result"]) > 80 else entry["result"] | |
| parts.append(f" > {action} -> {result_short}") | |
| # Warn about repeated actions | |
| if self.recent_actions and len(set(self.recent_actions[-3:])) == 1: | |
| parts.append(f"\n[WARNING: You've been doing '{self.recent_actions[-1]}' repeatedly. TRY SOMETHING DIFFERENT!]") | |
| # Add context constraints to reduce repetition | |
| sid = self._state_id_from_observation(observation) | |
| key = (sid, self.inv_sig) | |
| forbidden = list(self.failed_strong[key])[:8] | |
| soft_forbidden = list(self.failed_soft[key].keys())[:8] | |
| if forbidden or soft_forbidden: | |
| parts.append("\nContext restrictions (DO NOT choose these actions here):") | |
| if forbidden: | |
| parts.append(" Strong banned: " + ", ".join(forbidden)) | |
| if soft_forbidden: | |
| parts.append(" Recently failed (wait before retry): " + ", ".join(soft_forbidden)) | |
| parts.append(f"\nCurrent situation:\n{observation}") | |
| parts.append("\nWhat do you do next?") | |
| return "\n".join(parts) | |
| def _parse_response(self, response: str, valid_tools: list[str]) -> tuple[str, str, dict]: | |
| """Parse the LLM response to extract thought, tool, and arguments.""" | |
| thought = "No reasoning provided" | |
| tool_name = "play_action" | |
| tool_args = {"action": "look"} | |
| lines = response.strip().split("\n") | |
| for line in lines: | |
| line_clean = line.strip() | |
| line_upper = line_clean.upper() | |
| if line_upper.startswith("THOUGHT:"): | |
| thought = line_clean.split(":", 1)[1].strip() | |
| elif line_upper.startswith("TOOL:"): | |
| raw_tool = line_clean.split(":", 1)[1].strip().lower() | |
| raw_tool = raw_tool.replace("**", "").replace("*", "").replace("`", "") | |
| raw_tool = raw_tool.split()[0] if raw_tool else "play_action" | |
| tool_name = raw_tool | |
| elif line_upper.startswith("ARGS:"): | |
| args_part = line_clean.split(":", 1)[1].strip() | |
| try: | |
| args_part = args_part.replace("'", '"') | |
| tool_args = json.loads(args_part) | |
| except json.JSONDecodeError: | |
| match = re.search(r'"action"\s*:\s*"([^"]+)"', args_part) | |
| if match: | |
| tool_args = {"action": match.group(1)} | |
| else: | |
| tool_args = {"action": "look"} | |
| return thought, tool_name, tool_args | |
| def _validate_tool_call(self, tool_name: str, tool_args: dict, valid_tools: list[str]) -> tuple[str, dict]: | |
| """Validate and fix common tool call issues.""" | |
| # Fix tool name | |
| if tool_name not in valid_tools: | |
| if tool_name in ["action", "do", "command"]: | |
| tool_name = "play_action" | |
| elif tool_name in ["map", "location"]: | |
| tool_name = "get_map" | |
| elif tool_name in ["mem", "state", "status"]: | |
| tool_name = "memory" | |
| elif tool_name in ["inv", "items"]: | |
| tool_name = "inventory" | |
| else: | |
| tool_name = "play_action" | |
| # Fix action verbs | |
| if tool_name == "play_action": | |
| action = tool_args.get("action", "look") | |
| invalid_verb_map = { | |
| "check": "examine", | |
| "inspect": "examine", | |
| "search": "look", | |
| "grab": "take", | |
| "pick": "take", | |
| "use": "examine", | |
| "investigate": "examine", | |
| } | |
| words = action.lower().split() | |
| if words and words[0] in invalid_verb_map: | |
| words[0] = invalid_verb_map[words[0]] | |
| action = " ".join(words) | |
| action = action.lower().strip() | |
| action = action.replace("**", "").replace("*", "").replace("`", "") | |
| action = " ".join(action.split()) | |
| tool_args["action"] = action | |
| return tool_name, tool_args | |
| def _extract_result(self, result) -> str: | |
| """Extract text from MCP tool result.""" | |
| if hasattr(result, 'content') and result.content: | |
| return result.content[0].text | |
| if isinstance(result, list) and result: | |
| return result[0].text if hasattr(result[0], 'text') else str(result[0]) | |
| return str(result) | |
| def _update_score(self, text: str) -> None: | |
| """Update score from game text.""" | |
| patterns = [ | |
| r'Score:\s*(\d+)', | |
| r'score[:\s]+(\d+)', | |
| r'\[Score:\s*(\d+)', | |
| ] | |
| for pattern in patterns: | |
| match = re.search(pattern, text, re.IGNORECASE) | |
| if match: | |
| self.score = max(self.score, int(match.group(1))) | |
| def _is_game_over(self, text: str) -> bool: | |
| """Check if the game is over.""" | |
| game_over_phrases = [ | |
| "game over", | |
| "you have died", | |
| "you are dead", | |
| "*** you have died ***", | |
| ] | |
| text_lower = text.lower() | |
| return any(phrase in text_lower for phrase in game_over_phrases) | |
| def _normalize_text(self, text: str) -> str: | |
| """Normalize text for comparison (remove numbers/punct, collapse spaces).""" | |
| s = (text or "").lower() | |
| s = re.sub(r"\d+", "0", s) | |
| s = re.sub(r"[^a-z0\s]", " ", s) | |
| s = re.sub(r"\s+", " ", s).strip() | |
| return s[:500] | |
| def _state_id_from_observation(self, observation: str) -> str: | |
| """Compute a stable-ish state id for the current room/context.""" | |
| lines = [l.strip() for l in (observation or "").splitlines() if l.strip()] | |
| head = lines[0].lower() if lines else "unknown" | |
| # hash helps reduce tiny variations | |
| norm = self._normalize_text(head) | |
| return hashlib.md5(norm.encode()).hexdigest() | |
| def _inventory_signature(self, inv_text: str) -> str: | |
| """Signature of inventory text; robust enough for gating retries.""" | |
| norm = self._normalize_text(inv_text) | |
| return hashlib.md5(norm.encode()).hexdigest() | |
| def _classify_failure(self, prev_obs: str, action: str, new_obs: str) -> str: | |
| """ | |
| Return: "none" | "strong" | "soft" | |
| strong = will never become valid (parser/no object) | |
| soft = gated (locked/dark/tool) -> retry when inventory/state changes | |
| """ | |
| # Optional fast rule: identical observation => likely failed (soft) | |
| if self._normalize_text(new_obs) == self._normalize_text(prev_obs): | |
| return "soft" | |
| if not getattr(self, "use_llm_judge", True): | |
| return "none" | |
| return self._llm_judge_failure(prev_obs, action, new_obs) | |
| def _llm_judge_failure(self, prev_obs: str, action: str, new_obs: str) -> str: | |
| """ | |
| LLM judge: returns "none" | "soft" | "strong" | |
| """ | |
| # Normalize for caching (prevents repeated LLM calls) | |
| prev_n = self._normalize_text(prev_obs) | |
| new_n = self._normalize_text(new_obs) | |
| act_n = " ".join((action or "").lower().split()) | |
| key = (prev_n, act_n, new_n) | |
| if key in self.judge_cache: | |
| self._last_llm_raw = "(cached)" | |
| self._last_llm_label = self.judge_cache[key] | |
| self._last_llm_cached = True | |
| return self.judge_cache[key] | |
| system = "You are a strict classifier for text-adventure command outcomes." | |
| prompt = f""" | |
| Classify whether the player's action FAILED, based on the before/after observations. | |
| Return EXACTLY one label: none | soft | strong | |
| Definitions: | |
| - strong: The command is invalid/unknown OR refers to something not present/visible OR impossible in principle. | |
| It will NOT become valid later just by having a different item. | |
| - soft: The command was understood but is blocked by a condition (locked, closed, too dark, need an item, must do something first, not possible yet). | |
| It COULD become valid later. | |
| - none: The action had an effect OR gave new useful information (state changed, moved, item changed, new description). | |
| BE STRICT: If the new observation is basically identical to the previous one AND no progress happened, prefer "soft". | |
| PREVIOUS_OBSERVATION: | |
| {prev_obs} | |
| ACTION: | |
| {action} | |
| NEW_OBSERVATION: | |
| {new_obs} | |
| """.strip() | |
| out = call_llm(prompt, system, seed=100000 + self.step, max_tokens=8) | |
| raw_out = out.strip().lower() | |
| label = "none" | |
| if "strong" in raw_out: | |
| label = "strong" | |
| elif "soft" in raw_out: | |
| label = "soft" | |
| elif raw_out in {"none", "soft", "strong"}: | |
| label = raw_out | |
| # Store for logging | |
| self._last_llm_raw = raw_out | |
| self._last_llm_label = label | |
| self._last_llm_cached = False | |
| self.judge_cache[key] = label | |
| return label | |
| def _apply_context_management(self,tool_name: str,tool_args: dict,observation: str,) -> tuple[str, dict]: | |
| """Prevent repeating failed actions in the same context (state + inv).""" | |
| if tool_name != "play_action": | |
| return tool_name, tool_args | |
| action = (tool_args.get("action") or "look").strip().lower() | |
| sid = self._state_id_from_observation(observation) | |
| key = (sid, self.inv_sig) | |
| # Strong blacklist | |
| if action in self.failed_strong[key]: | |
| return "play_action", {"action": self._fallback_action(observation)} | |
| # Soft blacklist with cooldown | |
| if action in self.failed_soft[key]: | |
| last = self.failed_soft[key][action] | |
| # cooldown: avoid retrying too soon | |
| if self.step - last < self.soft_cooldown_steps: | |
| return "play_action", {"action": self._fallback_action(observation)} | |
| # Prevent immediate repetition in same context | |
| if self.recent_actions and action == self.recent_actions[-1]: | |
| return "play_action", {"action": self._fallback_action(observation)} | |
| return tool_name, {"action": action} | |
| def _fallback_action(self, observation: str) -> str: | |
| """ | |
| Deterministic fallback when the chosen action is banned. | |
| Prefer exploration moves; otherwise look/inventory. | |
| """ | |
| # Prefer moves that haven't been tried recently | |
| move_candidates = ["north","south","east","west","up","down","n","s","e","w","u","d"] | |
| recent_set = set(self.recent_actions[-8:]) if self.recent_actions else set() | |
| for m in move_candidates: | |
| if m not in recent_set: | |
| return m | |
| # If stuck, refresh | |
| if "dark" in (observation or "").lower(): | |
| # lamp heuristic: often useful in Zork | |
| return "turn on lamp" | |
| return "look" | |
| def _update_failure_memory(self, prev_obs: str, action: str, new_obs: str) -> None: | |
| """Update strong/soft failed actions for this (state, inv) context.""" | |
| sid = self._state_id_from_observation(prev_obs) | |
| key = (sid, self.inv_sig) | |
| verdict = self._classify_failure(prev_obs, action, new_obs) | |
| if verdict == "strong": | |
| self.failed_strong[key].add(action) | |
| # also remove from soft if present | |
| if action in self.failed_soft[key]: | |
| del self.failed_soft[key][action] | |
| elif verdict == "soft": | |
| self.failed_soft[key][action] = self.step | |
| def _log_context_state(self, prev_observation: str, chosen_action: str = "", new_observation: str = ""): | |
| """Print debug info for context management (before-state bucket).""" | |
| if not self.debug_context: | |
| return | |
| sid_before = self._state_id_from_observation(prev_observation) | |
| sid_after = self._state_id_from_observation(new_observation) if new_observation else "" | |
| key_before = (sid_before, self.inv_sig) | |
| print("\n" + "=" * 60) | |
| print(f"[STEP] {self.step}") | |
| print(f"[STATE_ID_BEFORE] {sid_before}") | |
| if sid_after: | |
| print(f"[STATE_ID_AFTER] {sid_after}") | |
| print(f"[INV_SIG] {self.inv_sig[:8]}...") | |
| if chosen_action: | |
| print(f"[CHOSEN ACTION] {chosen_action}") | |
| # ---- LLM Judge Info ---- | |
| if getattr(self, "_last_llm_label", ""): | |
| print(f"[LLM_VERDICT] {self._last_llm_label}") | |
| print(f"[LLM_RAW_OUTPUT] {getattr(self, '_last_llm_raw', '')}") | |
| print(f"[LLM_FROM_CACHE] {getattr(self, '_last_llm_cached', False)}") | |
| # Strong failures (BEFORE bucket) | |
| strong = list(self.failed_strong[key_before]) | |
| print(f"[FAILED_STRONG_BEFORE] ({len(strong)})") | |
| for a in strong[:10]: | |
| print(f" - {a}") | |
| # Soft failures (BEFORE bucket) | |
| soft = self.failed_soft[key_before] | |
| print(f"[FAILED_SOFT_BEFORE] ({len(soft)})") | |
| for a, last_step in list(soft.items())[:10]: | |
| cooldown_left = max(0, self.soft_cooldown_steps - (self.step - last_step)) | |
| print(f" - {a} (retry in {cooldown_left} steps)") | |
| print(f"[RECENT_ACTIONS] {self.recent_actions[-10:]}") | |
| print("=" * 60 + "\n") | |
| # ============================================================================= | |
| # Local Testing | |
| # ============================================================================= | |
| async def test_agent(): | |
| """Test the agent locally.""" | |
| from fastmcp import Client | |
| agent = StudentAgent() | |
| async with Client("mcp_server.py") as client: | |
| result = await agent.run( | |
| client=client, | |
| game="zork1", | |
| max_steps=20, | |
| seed=42, | |
| verbose=True, | |
| ) | |
| print(f"\n{'=' * 50}") | |
| print(f"Final Score: {result.final_score}") | |
| print(f"Moves: {result.moves}") | |
| print(f"Locations: {len(result.locations_visited)}") | |
| if __name__ == "__main__": | |
| import asyncio | |
| asyncio.run(test_agent()) | |