Spaces:
Sleeping
Sleeping
| """ | |
| : MCP ReAct Agent (adapted for your MCP server) | |
| Key upgrades: | |
| - Actually calls memory/get_map/inventory periodically (doesn't cost "moves") | |
| - Injects those outputs into the LLM prompt (LLM-friendly context) | |
| - Updates score from BOTH play_action output and memory output | |
| - Keeps loop detection + action normalization | |
| """ | |
| import json | |
| import os | |
| import re | |
| from dataclasses import dataclass, field | |
| from typing import Optional | |
| from dotenv import load_dotenv | |
| from huggingface_hub import InferenceClient | |
| load_dotenv() | |
| # ============================================================================= | |
| # LLM Configuration - DO NOT MODIFY | |
| # ============================================================================= | |
| LLM_MODEL = "Qwen/Qwen2.5-72B-Instruct" | |
| _hf_token = os.getenv("HF_TOKEN") | |
| if not _hf_token: | |
| raise ValueError("HF_TOKEN not found. Set it in your .env file.") | |
| LLM_CLIENT = InferenceClient(token=_hf_token) | |
| def call_llm(prompt: str, system_prompt: str, seed: int, max_tokens: int = 300) -> str: | |
| """Call the LLM with the given prompt.""" | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": prompt}, | |
| ] | |
| response = LLM_CLIENT.chat.completions.create( | |
| model=LLM_MODEL, | |
| messages=messages, | |
| temperature=0.0, | |
| max_tokens=max_tokens, | |
| seed=seed, | |
| ) | |
| return response.choices[0].message.content | |
| class RunResult: | |
| """Result of running the agent. Do not modify this class.""" | |
| final_score: int | |
| max_score: int | |
| moves: int | |
| locations_visited: set[str] | |
| game_completed: bool | |
| error: Optional[str] = None | |
| history: list[tuple[str, str, str]] = field(default_factory=list) | |
| # ============================================================================= | |
| # System Prompt | |
| # ============================================================================= | |
| SYSTEM_PROMPT = """You are an intelligent text adventure game agent. | |
| Your goal is to solve the main problem of the game efficiently and maximize score within 100 moves. | |
| This game is small and objective-focused. Avoid unnecessary wandering. | |
| AVAILABLE TOOLS (use via MCP): | |
| 1. play_action - Execute valid game commands. | |
| 2. memory - Get structured summary of current state and recent actions. | |
| 3. get_map - See explored locations. | |
| 4. inventory - Check carried items. | |
| VALID ACTION STYLE: | |
| Movement: | |
| - north, south, east, west, up, down | |
| - n, s, e, w, u, d | |
| Core actions: | |
| - look | |
| - examine <thing> | |
| - take <item>, drop <item> | |
| - open <thing>, close <thing> | |
| - talk to <character> | |
| - give <item> to <character> | |
| - use specific verbs mentioned in observation | |
| AVOID: | |
| - generic verbs like "use" | |
| - random movement without purpose | |
| - repeating failed actions | |
| -------------------------------------------------- | |
| CORE STRATEGY (IMPORTANT) | |
| -------------------------------------------------- | |
| 1) DOMINANT OBJECT RULE (VERY IMPORTANT): | |
| If a specific object or character is repeatedly mentioned in the observation, | |
| treat it as the main objective. | |
| Do NOT leave the area until you: | |
| - examine it | |
| - try multiple meaningful interactions | |
| - or confirm no new interaction is possible | |
| Stay focused before exploring elsewhere. | |
| 2) PROBLEM-SOLVING PRIORITY: | |
| If the game clearly revolves around one main goal, | |
| prioritize actions that directly affect that goal instead of exploring new rooms. | |
| 3) CONTROLLED MOVEMENT: | |
| Only move if: | |
| - you have exhausted interactions in the current room | |
| - or memory/map suggests a new unexplored path is necessary | |
| 4) LIMITED RETRIES: | |
| If an action fails once, try a different verb. | |
| Do NOT repeat the same failed action more than once. | |
| 5) OBJECT TRANSFORMATION FOCUS: | |
| If an object seems central, try actions that might change its state: | |
| - examine | |
| - open | |
| - give something | |
| - use appropriate verbs mentioned in text | |
| - interact from different angles | |
| -------------------------------------------------- | |
| TOOL USAGE RULES | |
| -------------------------------------------------- | |
| - Use memory() when uncertain or before repeating behavior. | |
| - Use get_map() only if navigation becomes necessary. | |
| - Use inventory() after obtaining items. | |
| -------------------------------------------------- | |
| OUTPUT FORMAT (STRICT) | |
| -------------------------------------------------- | |
| THOUGHT: <brief reasoning> | |
| TOOL: <tool_name> | |
| ARGS: <JSON arguments> | |
| Keep THOUGHT short (1-2 sentences). | |
| Do not repeat the same action multiple times. | |
| Prefer solving over wandering. | |
| """ | |
| # ============================================================================= | |
| # Student Agent Implementation | |
| # ============================================================================= | |
| class StudentAgent: | |
| """ | |
| MCP ReAct Agent adapted to your MCP server outputs: | |
| - memory() returns STATE / RECENT / OBSERVATION | |
| - get_map() returns MAP ... | |
| - inventory() returns INVENTORY ... | |
| """ | |
| def __init__(self): | |
| self.history: list[dict] = [] | |
| self.recent_actions: list[str] = [] | |
| self.score: int = 0 | |
| # Cached tool outputs | |
| self.last_memory: str = "" | |
| self.last_map: str = "" | |
| self.last_inventory: str = "" | |
| self.last_observation: str = "" | |
| # Exploration / anti-loop state | |
| self.visit_counts: dict[str, int] = {} | |
| self.loc_move_failures: dict[tuple[str, str], int] = {} | |
| self.pending_move: Optional[tuple[str, str]] = None | |
| # NEW: prevent repeating same thought+action at same location | |
| self.loc_action_thought_counts: dict[tuple[str, str, str], int] = {} | |
| # ------------------------------------------------------------ | |
| # Thought normalization helper | |
| # ------------------------------------------------------------ | |
| def _thought_sig(self, thought: str) -> str: | |
| t = (thought or "").lower() | |
| t = re.sub(r"[^a-z0-9\s]", " ", t) | |
| t = re.sub(r"\s+", " ", t).strip() | |
| return " ".join(t.split()[:12]) | |
| async def run( | |
| self, | |
| client, | |
| game: str, | |
| max_steps: int, | |
| seed: int, | |
| verbose: bool = False, | |
| ) -> RunResult: | |
| locations_visited = set() | |
| history = [] | |
| moves = 0 | |
| MOVE_CMDS = {"north","south","east","west","up","down","enter","exit","n","s","e","w","u","d"} | |
| # Available tools | |
| tools = await client.list_tools() | |
| tool_names = [t.name for t in tools] | |
| # Initial observation | |
| result = await client.call_tool("play_action", {"action": "look"}) | |
| observation = self._extract_result(result) | |
| self.last_observation = observation | |
| location = observation.split("\n")[0] if observation else "Unknown" | |
| locations_visited.add(location) | |
| self.visit_counts[location] = self.visit_counts.get(location, 0) + 1 | |
| # Prime context (no moves) | |
| if "memory" in tool_names: | |
| self.last_memory = self._extract_result(await client.call_tool("memory", {})) | |
| self._update_score(self.last_memory) | |
| if "inventory" in tool_names: | |
| self.last_inventory = self._extract_result(await client.call_tool("inventory", {})) | |
| if verbose: | |
| print(f"\n{observation}") | |
| for step in range(1, max_steps + 1): | |
| await self._refresh_context_tools(client, tool_names, step, verbose) | |
| prompt = self._build_prompt() | |
| response = call_llm(prompt, SYSTEM_PROMPT, seed + step) | |
| thought, tool_name, tool_args = self._parse_response(response, tool_names) | |
| if verbose: | |
| print(f"\n--- Step {step} ---") | |
| print(f"[THOUGHT] {thought}") | |
| print(f"[TOOL] {tool_name}({tool_args})") | |
| tool_name, tool_args = self._validate_tool_call(tool_name, tool_args, tool_names) | |
| # ------------------------------------------------------------ | |
| # Block SAME (location + action + thought) | |
| # ------------------------------------------------------------ | |
| if tool_name == "play_action": | |
| current_loc = ( | |
| self.last_observation.split("\n")[0].strip() | |
| if self.last_observation else "Unknown" | |
| ) | |
| action_norm = tool_args.get("action", "look").strip().lower() | |
| t_sig = self._thought_sig(thought) | |
| triple = (current_loc, action_norm, t_sig) | |
| self.loc_action_thought_counts[triple] = ( | |
| self.loc_action_thought_counts.get(triple, 0) + 1 | |
| ) | |
| if self.loc_action_thought_counts[triple] >= 2: | |
| if verbose: | |
| print(f"[ANTI-REPEAT] Blocking repeated thought+action at '{current_loc}'") | |
| if "get_map" in tool_names: | |
| tool_name, tool_args = "get_map", {} | |
| elif "memory" in tool_names: | |
| tool_name, tool_args = "memory", {} | |
| else: | |
| tool_name, tool_args = "play_action", {"action": "look"} | |
| # ------------------------------------------------------------ | |
| # Loop detection (same action spam) | |
| # ------------------------------------------------------------ | |
| if tool_name == "play_action": | |
| action = tool_args.get("action", "look") | |
| self.recent_actions.append(action) | |
| if len(self.recent_actions) > 5: | |
| self.recent_actions = self.recent_actions[-5:] | |
| if len(self.recent_actions) >= 3 and len(set(self.recent_actions[-3:])) == 1: | |
| if verbose: | |
| print("[WARNING] Loop detected - forcing 'look'") | |
| tool_args = {"action": "look"} | |
| # ------------------------------------------------------------ | |
| # Anti-backtracking: block only FAILED moves | |
| # ------------------------------------------------------------ | |
| self.pending_move = None | |
| if tool_name == "play_action": | |
| action_norm = tool_args.get("action", "look").strip().lower() | |
| if action_norm in MOVE_CMDS: | |
| current_loc = ( | |
| self.last_observation.split("\n")[0].strip() | |
| if self.last_observation else "Unknown" | |
| ) | |
| key = (current_loc, action_norm) | |
| if self.loc_move_failures.get(key, 0) >= 2: | |
| if verbose: | |
| print(f"[GUARD] Blocking failed move '{action_norm}' from '{current_loc}'") | |
| if "get_map" in tool_names: | |
| tool_name, tool_args = "get_map", {} | |
| elif "memory" in tool_names: | |
| tool_name, tool_args = "memory", {} | |
| else: | |
| tool_name, tool_args = "play_action", {"action": "look"} | |
| else: | |
| self.pending_move = (current_loc, action_norm) | |
| # ------------------------------------------------------------ | |
| # Count moves | |
| # ------------------------------------------------------------ | |
| if tool_name == "play_action": | |
| moves += 1 | |
| # ------------------------------------------------------------ | |
| # Execute tool | |
| # ------------------------------------------------------------ | |
| try: | |
| result = await client.call_tool(tool_name, tool_args) | |
| out_text = self._extract_result(result) | |
| if tool_name == "play_action": | |
| observation = out_text | |
| self.last_observation = observation | |
| elif tool_name == "memory": | |
| self.last_memory = out_text | |
| elif tool_name == "get_map": | |
| self.last_map = out_text | |
| elif tool_name == "inventory": | |
| self.last_inventory = out_text | |
| if verbose: | |
| print(f"[RESULT] {out_text[:200]}...") | |
| except Exception as e: | |
| out_text = f"Error: {e}" | |
| observation = out_text | |
| self.last_observation = observation | |
| if verbose: | |
| print(f"[ERROR] {e}") | |
| # ------------------------------------------------------------ | |
| # Post-move update | |
| # ------------------------------------------------------------ | |
| if tool_name == "play_action": | |
| new_location = observation.split("\n")[0] if observation else "Unknown" | |
| if self.pending_move is not None: | |
| prev_loc, prev_action = self.pending_move | |
| key = (prev_loc, prev_action) | |
| if new_location == prev_loc: | |
| self.loc_move_failures[key] = self.loc_move_failures.get(key, 0) + 1 | |
| else: | |
| self.loc_move_failures[key] = 0 | |
| self.pending_move = None | |
| location = new_location | |
| locations_visited.add(location) | |
| self.visit_counts[location] = self.visit_counts.get(location, 0) + 1 | |
| self._update_score(observation) | |
| if re.search(r"\bTaken\b|\byou are now carrying\b", observation, re.IGNORECASE): | |
| if "inventory" in tool_names: | |
| self.last_inventory = self._extract_result( | |
| await client.call_tool("inventory", {}) | |
| ) | |
| # ------------------------------------------------------------ | |
| # History | |
| # ------------------------------------------------------------ | |
| self.history.append({ | |
| "step": step, | |
| "thought": thought, | |
| "tool": tool_name, | |
| "args": tool_args, | |
| "result": out_text[:200] | |
| }) | |
| if len(self.history) > 10: | |
| self.history = self.history[-10:] | |
| history.append((thought, f"{tool_name}({tool_args})", out_text[:100])) | |
| if self._is_game_over(observation): | |
| if verbose: | |
| print("\n*** GAME OVER ***") | |
| break | |
| return RunResult( | |
| final_score=self.score, | |
| max_score=350, | |
| moves=moves, | |
| locations_visited=locations_visited, | |
| game_completed=self._is_game_over(self.last_observation), | |
| history=history, | |
| ) | |
| async def _refresh_context_tools(self, client, tool_names: list[str], step: int, verbose: bool) -> None: | |
| """ | |
| Pull structured context from MCP server without spending moves. | |
| Tuned to your server outputs: | |
| - memory() is the best single summary | |
| - get_map() helps navigation | |
| - inventory() helps object planning | |
| """ | |
| # Memory: often (every 4 steps) so LLM doesn't forget state | |
| if "memory" in tool_names and (step == 1 or step % 4 == 0): | |
| try: | |
| self.last_memory = self._extract_result(await client.call_tool("memory", {})) | |
| self._update_score(self.last_memory) | |
| except Exception: | |
| pass | |
| # Map: occasionally (every 6 steps), and also if we moved a lot recently | |
| if "get_map" in tool_names and (step % 6 == 0): | |
| try: | |
| self.last_map = self._extract_result(await client.call_tool("get_map", {})) | |
| except Exception: | |
| pass | |
| # Inventory: occasionally (every 10 steps) | |
| if "inventory" in tool_names and (step == 1 or step % 10 == 0): | |
| try: | |
| self.last_inventory = self._extract_result(await client.call_tool("inventory", {})) | |
| except Exception: | |
| pass | |
| def _build_prompt(self) -> str: | |
| """ | |
| Build prompt that is aligned with your MCP server: | |
| - memory() has STATE/RECENT/OBSERVATION | |
| - get_map() starts with MAP | |
| - inventory() starts with INVENTORY | |
| """ | |
| parts = [] | |
| parts.append(f"Current best-known score: {self.score}") | |
| # Give the model your server-side memory snapshot (truncate to keep prompt lean) | |
| if self.last_memory: | |
| mem = self._truncate(self.last_memory, 1200) | |
| parts.append("\n=== MEMORY (from MCP server) ===\n" + mem) | |
| if self.last_inventory: | |
| inv = self._truncate(self.last_inventory, 400) | |
| parts.append("\n=== INVENTORY (from MCP server) ===\n" + inv) | |
| if self.last_map: | |
| mp = self._truncate(self.last_map, 700) | |
| parts.append("\n=== MAP (from MCP server) ===\n" + mp) | |
| # Recent local history (anti-loop) | |
| if self.history: | |
| parts.append("\n=== RECENT LOCAL ACTIONS (agent) ===") | |
| for entry in self.history[-3:]: | |
| action = entry.get("args", {}).get("action", entry["tool"]) | |
| result_short = entry["result"][:100] + "..." if len(entry["result"]) > 100 else entry["result"] | |
| parts.append(f" > {action} -> {result_short}") | |
| if self.recent_actions and len(set(self.recent_actions[-3:])) == 1: | |
| parts.append(f"\n[WARNING: repeated '{self.recent_actions[-1]}'. Choose a different action.]") | |
| # Always include the most recent raw observation | |
| parts.append("\n=== LATEST OBSERVATION (play_action) ===\n" + self._truncate(self.last_observation, 900)) | |
| parts.append("\nWhat do you do next?") | |
| return "\n".join(parts) | |
| def _truncate(self, text: str, limit: int) -> str: | |
| text = text or "" | |
| if len(text) <= limit: | |
| return text | |
| return text[:limit] + "\n...[truncated]" | |
| def _parse_response(self, response: str, valid_tools: list[str]) -> tuple[str, str, dict]: | |
| thought = "No reasoning provided" | |
| tool_name = "play_action" | |
| tool_args = {"action": "look"} | |
| lines = response.strip().split("\n") | |
| for line in lines: | |
| line_clean = line.strip() | |
| line_upper = line_clean.upper() | |
| if line_upper.startswith("THOUGHT:"): | |
| thought = line_clean.split(":", 1)[1].strip() | |
| elif line_upper.startswith("TOOL:"): | |
| raw_tool = line_clean.split(":", 1)[1].strip().lower() | |
| raw_tool = raw_tool.replace("**", "").replace("*", "").replace("`", "") | |
| raw_tool = raw_tool.split()[0] if raw_tool else "play_action" | |
| tool_name = raw_tool | |
| elif line_upper.startswith("ARGS:"): | |
| args_part = line_clean.split(":", 1)[1].strip() | |
| if not args_part: | |
| tool_args = {} | |
| continue | |
| try: | |
| args_part = args_part.replace("'", '"') | |
| tool_args = json.loads(args_part) | |
| except json.JSONDecodeError: | |
| match = re.search(r'"action"\s*:\s*"([^"]+)"', args_part) | |
| if match: | |
| tool_args = {"action": match.group(1)} | |
| else: | |
| tool_args = {"action": "look"} | |
| return thought, tool_name, tool_args | |
| def _validate_tool_call(self, tool_name: str, tool_args: dict, valid_tools: list[str]) -> tuple[str, dict]: | |
| if tool_name not in valid_tools: | |
| if tool_name in ["action", "do", "command"]: | |
| tool_name = "play_action" | |
| elif tool_name in ["map", "location"]: | |
| tool_name = "get_map" | |
| elif tool_name in ["mem", "state", "status"]: | |
| tool_name = "memory" | |
| elif tool_name in ["inv", "items"]: | |
| tool_name = "inventory" | |
| else: | |
| tool_name = "play_action" | |
| if tool_name == "play_action": | |
| action = tool_args.get("action", "look") | |
| invalid_verb_map = { | |
| "check": "examine", | |
| "inspect": "examine", | |
| "search": "look", | |
| "grab": "take", | |
| "pick": "take", | |
| "use": "examine", | |
| "investigate": "examine", | |
| } | |
| words = action.lower().split() | |
| if words and words[0] in invalid_verb_map: | |
| words[0] = invalid_verb_map[words[0]] | |
| action = " ".join(words) | |
| action = action.lower().strip() | |
| action = action.replace("**", "").replace("*", "").replace("`", "") | |
| action = " ".join(action.split()) | |
| tool_args["action"] = action | |
| return tool_name, tool_args | |
| def _extract_result(self, result) -> str: | |
| if hasattr(result, 'content') and result.content: | |
| return result.content[0].text | |
| if isinstance(result, list) and result: | |
| return result[0].text if hasattr(result[0], 'text') else str(result[0]) | |
| return str(result) | |
| def _update_score(self, text: str) -> None: | |
| patterns = [ | |
| r'\[Score:\s*(\d+)', | |
| r'Score:\s*(\d+)\b', | |
| ] | |
| for pattern in patterns: | |
| match = re.search(pattern, text, re.IGNORECASE) | |
| if match: | |
| self.score = max(self.score, int(match.group(1))) | |
| def _is_game_over(self, text: str) -> bool: | |
| game_over_phrases = [ | |
| "game over", | |
| "you have died", | |
| "you are dead", | |
| "*** you have died ***", | |
| ] | |
| text_lower = (text or "").lower() | |
| return any(phrase in text_lower for phrase in game_over_phrases) | |
| # ============================================================================= | |
| # Local Testing | |
| # ============================================================================= | |
| async def test_agent(): | |
| from fastmcp import Client | |
| agent = StudentAgent() | |
| async with Client("mcp_server.py") as client: | |
| result = await agent.run( | |
| client=client, | |
| game="zork1", | |
| max_steps=20, | |
| seed=42, | |
| verbose=True, | |
| ) | |
| print(f"\n{'=' * 50}") | |
| print(f"Final Score: {result.final_score}") | |
| print(f"Moves: {result.moves}") | |
| print(f"Locations: {len(result.locations_visited)}") | |
| if __name__ == "__main__": | |
| import asyncio | |
| asyncio.run(test_agent()) | |