Spaces:
Sleeping
Sleeping
| """ | |
| Student Agent for Text Adventure Games | |
| This is your submission file. Implement the StudentAgent class to play | |
| text adventure games using the MCP server you also implement. | |
| Your agent should: | |
| 1. Connect to the MCP server via the provided client | |
| 2. Use the ReAct pattern (Thought -> Action -> Observation) | |
| 3. Call MCP tools to interact with the game | |
| 4. Maximize the game score within the step limit | |
| Required method: | |
| async def run(self, client, game, max_steps, seed, verbose) -> RunResult | |
| The 'client' is a FastMCP Client already connected to your MCP server. | |
| Use it to call tools like: await client.call_tool("play_action", {"action": "look"}) | |
| Tips: | |
| - Start by looking around and understanding your environment | |
| - Keep track of visited locations to avoid loops | |
| - Pick up useful items (lamp, sword, etc.) | |
| - The seed parameter should be used to set your LLM's seed for reproducibility | |
| """ | |
| import json | |
| import os | |
| import re | |
| import random | |
| from dataclasses import dataclass, field | |
| from collections import defaultdict | |
| from typing import Optional | |
| from dotenv import load_dotenv | |
| from huggingface_hub import InferenceClient | |
| # Load environment variables | |
| load_dotenv() | |
| # Set USE_LOCAL_MODEL=1 in your .env to use a locally downloaded model | |
| USE_LOCAL_MODEL = os.getenv("USE_LOCAL_MODEL", "0").strip() in ("1", "true", "yes") | |
| LOCAL_MODEL_ID = os.getenv("LOCAL_MODEL_ID", "Qwen/Qwen2.5-3B-Instruct") | |
| # ============================================================================= | |
| # LLM Configuration - DO NOT MODIFY | |
| # ============================================================================= | |
| # Model to use (fixed for fair evaluation) | |
| LLM_MODEL = "Qwen/Qwen2.5-72B-Instruct" | |
| # Initialize the LLM client based on mode | |
| _local_pipeline = None | |
| if USE_LOCAL_MODEL: | |
| import torch | |
| from transformers import pipeline as _hf_pipeline | |
| _local_pipeline = _hf_pipeline( | |
| "text-generation", | |
| model=LOCAL_MODEL_ID, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| ) | |
| LLM_CLIENT = None | |
| else: | |
| _hf_token = os.getenv("HF_TOKEN") | |
| if not _hf_token: | |
| raise ValueError("HF_TOKEN not found. Set it in your .env file.") | |
| LLM_CLIENT = InferenceClient(token=_hf_token) | |
| def call_llm(prompt: str, system_prompt: str, seed: int, max_tokens: int = 300) -> str: | |
| """ | |
| Call the LLM with the given prompt. Use this function in your agent. | |
| Args: | |
| prompt: The user prompt (current game state, history, etc.) | |
| system_prompt: The system prompt (instructions for the agent) | |
| seed: Random seed for reproducibility | |
| max_tokens: Maximum tokens in response (default: 300) | |
| Returns: | |
| The LLM's response text | |
| Example: | |
| response = call_llm( | |
| prompt="You are in a forest. What do you do?", | |
| system_prompt=SYSTEM_PROMPT, | |
| seed=42, | |
| ) | |
| """ | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": prompt}, | |
| ] | |
| if USE_LOCAL_MODEL and _local_pipeline is not None: | |
| outputs = _local_pipeline( | |
| messages, | |
| max_new_tokens=max_tokens, | |
| temperature=0.0001, # Near-deterministic (0.0 unsupported by some backends) | |
| do_sample=True, | |
| ) | |
| return outputs[0]["generated_text"][-1]["content"] | |
| response = LLM_CLIENT.chat.completions.create( | |
| model=LLM_MODEL, | |
| messages=messages, | |
| temperature=0.0, # Deterministic for reproducibility | |
| max_tokens=max_tokens, | |
| seed=seed, | |
| ) | |
| return response.choices[0].message.content | |
| class RunResult: | |
| """Result of running the agent. Do not modify this class.""" | |
| final_score: int | |
| max_score: int | |
| moves: int | |
| locations_visited: set[str] | |
| game_completed: bool | |
| error: Optional[str] = None | |
| history: list[tuple[str, str, str]] = field(default_factory=list) | |
| # ============================================================================= | |
| # System Prompt - Customize this for your agent | |
| # ============================================================================= | |
| SYSTEM_PROMPT = """You are playing a classic text adventure game. | |
| GOAL: Explore the world, solve puzzles, and maximize your score. | |
| AVAILABLE TOOLS (use via MCP): | |
| - play_action: Execute a game command (north, take lamp, open mailbox, etc.) | |
| - memory: Get current game state and history (if implemented) | |
| - inventory: Check what you're carrying (if implemented) | |
| VALID GAME COMMANDS for play_action: | |
| - Movement: north, south, east, west, up, down, enter, exit | |
| - Objects: take <item>, drop <item>, open <thing>, close <thing>, examine <thing> | |
| - Other: look, inventory, read <thing>, turn on lamp | |
| RESPOND IN THIS EXACT FORMAT (no markdown): | |
| THOUGHT: <your reasoning about what to do next> | |
| TOOL: <tool_name> | |
| ARGS: <JSON arguments, e.g., {"action": "look"}> | |
| Example: | |
| THOUGHT: I should look around to see where I am. | |
| TOOL: play_action | |
| ARGS: {"action": "look"} | |
| """ | |
| # ============================================================================= | |
| # Student Agent - IMPLEMENT THIS CLASS | |
| # ============================================================================= | |
| class StudentAgent: | |
| """A lean ReAct agent with a dash of personal taste.""" | |
| def __init__(self): | |
| """Initialize run-local state.""" | |
| self.history: list[tuple[str, str, str]] = [] | |
| self.visited_locations: set[str] = set() | |
| self.actions_tried = defaultdict(lambda: defaultdict(int)) # location -> action -> count | |
| self.current_score = 0 | |
| self.max_score = 350 | |
| self.moves = 0 | |
| self.game = "" | |
| self.last_location = "Unknown" | |
| async def run( | |
| self, | |
| client, # FastMCP Client connected to your MCP server | |
| game: str, | |
| max_steps: int, | |
| seed: int, | |
| verbose: bool = False, | |
| ) -> RunResult: | |
| """Run the ReAct loop.""" | |
| random.seed(seed) | |
| self.history = [] | |
| self.visited_locations = set() | |
| self.actions_tried = defaultdict(lambda: defaultdict(int)) | |
| self.current_score = 0 | |
| self.max_score = 350 | |
| self.moves = 0 | |
| self.game = game | |
| self.last_location = "Unknown" | |
| observation = await self._safe_tool(client, "play_action", {"action": "look"}) | |
| prev_moves_mark = self.moves | |
| self._ingest_observation(observation) | |
| if self.moves == prev_moves_mark: | |
| self.moves += 1 | |
| mem_text = await self._safe_tool(client, "memory", {"limit": 3}) | |
| self.max_score = self._parse_max_score(mem_text) or self.max_score | |
| self.current_score, self.moves = self._parse_score_moves( | |
| mem_text, self.current_score, self.moves | |
| ) | |
| for step in range(max_steps): | |
| prompt = self._build_prompt(observation, self.history) | |
| llm_response = self._call_llm(prompt, SYSTEM_PROMPT, seed) | |
| thought, tool, args = self._parse_response(llm_response) | |
| allowed_tools = {"play_action", "memory", "inventory", "get_map", "get_valid_actions"} | |
| if tool not in allowed_tools: | |
| tool, args = "play_action", {"action": "look"} | |
| prev_moves = self.moves | |
| if tool == "play_action": | |
| action = (args.get("action") or "").strip() | |
| if not action: | |
| action = "look" | |
| location = self.last_location | |
| if self._should_switch(location, action): | |
| action = self._fallback_action(self.actions_tried[location]) | |
| self.actions_tried[location][action] += 1 | |
| observation = await self._safe_tool(client, "play_action", {"action": action}) | |
| else: | |
| observation = await self._safe_tool(client, tool, args) | |
| self._ingest_observation(observation) | |
| if tool == "play_action" and self.moves == prev_moves: | |
| self.moves += 1 | |
| self.history.append((thought, f"{tool} {json.dumps(args)}", observation)) | |
| if verbose: | |
| print(f"\n> {tool} {args}\n{observation}") | |
| if self._is_terminal(observation): | |
| break | |
| if self.moves >= max_steps: | |
| break | |
| clean_locations = {loc for loc in self.visited_locations if loc != "Unknown"} | |
| game_completed = self.current_score >= self.max_score or self._is_win(observation) | |
| return RunResult( | |
| final_score=self.current_score, | |
| max_score=self.max_score, | |
| moves=self.moves, | |
| locations_visited=clean_locations, | |
| game_completed=game_completed, | |
| history=self.history, | |
| ) | |
| def _build_prompt(self, observation: str, history: list) -> str: | |
| """ | |
| Build the prompt for the LLM. | |
| Mix a little personality with concise context so the model | |
| keeps commands short and avoids spinning in circles. | |
| """ | |
| recent = history[-5:] | |
| lines = [ | |
| f"Game: {self.game}", | |
| "You are me playing a parser game. Be decisive, keep commands under four words.", | |
| "If something failed twice in this room, try a different verb or direction.", | |
| "", | |
| "Current observation:", | |
| observation.strip(), | |
| "", | |
| "Recent steps:", | |
| ] | |
| if not recent: | |
| lines.append("- none yet") | |
| else: | |
| for thought, action, obs in recent: | |
| snippet = obs.replace("\n", " ") | |
| if len(snippet) > 120: | |
| snippet = snippet[:117] + "..." | |
| lines.append(f"- {action}: {snippet}") | |
| lines.append("\nNext command?") | |
| return "\n".join(lines) | |
| def _parse_response(self, response: str) -> tuple[str, str, dict]: | |
| """ | |
| Parse LLM response to extract thought, tool name, and arguments. | |
| Returns: | |
| Tuple of (thought, tool_name, args_dict) | |
| """ | |
| thought = "" | |
| tool = "play_action" | |
| args: dict = {"action": "look"} | |
| if not response: | |
| return thought, tool, args | |
| cleaned = response.strip().replace("```", "") | |
| thought_match = re.search(r"THOUGHT:\s*(.*)", cleaned, re.IGNORECASE) | |
| if thought_match: | |
| thought = thought_match.group(1).strip() | |
| tool_match = re.search(r"TOOL:\s*([A-Za-z0-9_]+)", cleaned, re.IGNORECASE) | |
| if tool_match: | |
| tool = tool_match.group(1).strip() | |
| args_match = re.search(r"ARGS:\s*(\{[\s\S]*\})", cleaned, re.IGNORECASE) | |
| if args_match: | |
| raw_args = args_match.group(1) | |
| raw_args = raw_args[: raw_args.rfind("}") + 1] if "}" in raw_args else raw_args | |
| try: | |
| args = json.loads(raw_args) | |
| except Exception: | |
| try: | |
| args = json.loads(raw_args.replace("'", "\"")) | |
| except Exception: | |
| args = {"action": raw_args.strip("{} ").strip()} | |
| if tool == "play_action" and "action" not in args: | |
| args["action"] = "look" | |
| return thought, tool, args | |
| async def _safe_tool(self, client, tool: str, args: dict) -> str: | |
| """Call a tool and always return a string.""" | |
| try: | |
| result = await client.call_tool(tool, args) | |
| except Exception as exc: | |
| return f"[tool-error:{tool}] {exc}" | |
| return self._extract_text(result) | |
| def _extract_text(self, result) -> str: | |
| """Normalize FastMCP tool responses into plain text.""" | |
| if result is None: | |
| return "" | |
| if isinstance(result, str): | |
| return result | |
| if isinstance(result, list): | |
| texts = [self._extract_text(r) for r in result] | |
| return "\n".join(t for t in texts if t) | |
| if hasattr(result, "text"): | |
| try: | |
| return result.text | |
| except Exception: | |
| pass | |
| if hasattr(result, "content"): | |
| content = getattr(result, "content") | |
| if isinstance(content, list): | |
| texts = [self._extract_text(c) for c in content] | |
| return "\n".join(t for t in texts if t) | |
| if isinstance(content, str): | |
| return content | |
| if isinstance(result, dict): | |
| for key in ("text", "content", "data", "result", "output"): | |
| if key in result: | |
| return self._extract_text(result[key]) | |
| return str(result) | |
| def _ingest_observation(self, observation: str): | |
| """Update cached score, move count, and location tracking.""" | |
| self.current_score, self.moves = self._parse_score_moves( | |
| observation, self.current_score, self.moves | |
| ) | |
| location = self._extract_location(observation) | |
| self.last_location = location | |
| if location and location != "Unknown": | |
| self.visited_locations.add(location) | |
| def _parse_score_moves( | |
| self, text: str, current_score: int, current_moves: int | |
| ) -> tuple[int, int]: | |
| if not text: | |
| return current_score, current_moves | |
| score_match = re.search(r"Score:\s*(\d+)", text) | |
| move_match = re.search(r"Moves?:\s*(\d+)", text) | |
| if score_match: | |
| current_score = int(score_match.group(1)) | |
| if move_match: | |
| current_moves = int(move_match.group(1)) | |
| return current_score, current_moves | |
| def _parse_max_score(self, text: str) -> Optional[int]: | |
| if not text: | |
| return None | |
| max_match = re.search(r"Score:\s*\d+\s*/\s*(\d+)", text) | |
| return int(max_match.group(1)) if max_match else None | |
| def _extract_location(self, observation: str) -> str: | |
| if not observation: | |
| return "Unknown" | |
| match = re.search(r"Location:\s*([^\]\n]+)", observation) | |
| if match: | |
| return match.group(1).strip() | |
| first_line = observation.strip().splitlines()[0].strip() | |
| if len(first_line) <= 80: | |
| return first_line or "Unknown" | |
| return "Unknown" | |
| def _should_switch(self, location: str, action: str) -> bool: | |
| tried_here = self.actions_tried[location] | |
| return tried_here.get(action, 0) >= 2 | |
| def _fallback_action(self, tried_actions: dict[str, int]) -> str: | |
| palette = [ | |
| "look", | |
| "inventory", | |
| "north", | |
| "south", | |
| "east", | |
| "west", | |
| "up", | |
| "down", | |
| "enter", | |
| "exit", | |
| "take all", | |
| "open door", | |
| "examine room", | |
| ] | |
| for candidate in palette: | |
| if tried_actions.get(candidate, 0) == 0: | |
| return candidate | |
| return "look" | |
| def _is_terminal(self, observation: str) -> bool: | |
| if not observation: | |
| return False | |
| lower = observation.lower() | |
| return any( | |
| phrase in lower | |
| for phrase in [ | |
| "you have died", | |
| "you are dead", | |
| "game over", | |
| "you have won", | |
| "congratulations", | |
| "*** the end", | |
| ] | |
| ) | |
| def _is_win(self, observation: str) -> bool: | |
| if not observation: | |
| return False | |
| lower = observation.lower() | |
| return "you have won" in lower or "congratulations" in lower | |
| def _call_llm(self, prompt: str, system_prompt: str, seed: int) -> str: | |
| """ | |
| Call the LLM with the given prompt. | |
| This is a convenience wrapper - you can also use call_llm() directly. | |
| """ | |
| return call_llm(prompt, system_prompt, seed) | |
| # ============================================================================= | |
| # For local testing | |
| # ============================================================================= | |
| async def test_agent(): | |
| """Test the agent locally.""" | |
| from fastmcp import Client | |
| # Path to your MCP server | |
| server_path = "mcp_server.py" | |
| agent = StudentAgent() | |
| async with Client(server_path) as client: | |
| result = await agent.run( | |
| client=client, | |
| game="zork1", | |
| max_steps=10, | |
| seed=42, | |
| verbose=True, | |
| ) | |
| print(f"\nFinal Score: {result.final_score}") | |
| print(f"Moves: {result.moves}") | |
| print(f"Locations: {result.locations_visited}") | |
| if __name__ == "__main__": | |
| import asyncio | |
| asyncio.run(test_agent()) | |