Spaces:
Sleeping
Sleeping
| """ | |
| Student Agent for Text Adventure Games | |
| """ | |
| import json | |
| import os | |
| import re | |
| from dataclasses import dataclass, field | |
| from typing import Optional | |
| from dotenv import load_dotenv | |
| from huggingface_hub import InferenceClient | |
| load_dotenv() | |
| # ============================================================================= | |
| # LLM Configuration - DO NOT MODIFY | |
| # ============================================================================= | |
| LLM_MODEL = "Qwen/Qwen2.5-72B-Instruct" | |
| _hf_token = os.getenv("HF_TOKEN") | |
| if not _hf_token: | |
| raise ValueError("HF_TOKEN not found. Set it in your .env file.") | |
| LLM_CLIENT = InferenceClient(token=_hf_token) | |
| def call_llm(prompt: str, system_prompt: str, seed: int, max_tokens: int = 300) -> str: | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": prompt}, | |
| ] | |
| response = LLM_CLIENT.chat.completions.create( | |
| model=LLM_MODEL, | |
| messages=messages, | |
| temperature=0.0, | |
| max_tokens=max_tokens, | |
| seed=seed, | |
| ) | |
| return response.choices[0].message.content | |
| class RunResult: | |
| """Result of running the agent. Do not modify this class.""" | |
| final_score: int | |
| max_score: int | |
| moves: int | |
| locations_visited: set[str] | |
| game_completed: bool | |
| error: Optional[str] = None | |
| history: list[tuple[str, str, str]] = field(default_factory=list) | |
| # ============================================================================= | |
| # System Prompts | |
| # ============================================================================= | |
| SYSTEM_PROMPT = """You are an expert text adventure game player. Your goal is to explore as many NEW locations as possible and maximize your score. | |
| AVAILABLE TOOLS: | |
| 1. play_action - Execute game commands | |
| 2. memory - Get current state, score, and FAILED ACTIONS to avoid | |
| 3. get_valid_actions - Get valid commands for current location | |
| 4. get_map - See explored locations | |
| 5. inventory - Check items | |
| VALID COMMANDS: north, south, east, west, up, down, enter, exit, n, s, e, w, u, d, take <item>, drop <item>, open <thing>, examine <thing>, look, read <thing> | |
| RESPOND IN THIS EXACT FORMAT: | |
| THOUGHT: <reasoning> | |
| TOOL: <tool_name> | |
| ARGS: <JSON arguments> | |
| CRITICAL RULES: | |
| 1. NEVER repeat a failed action | |
| 2. Always try promising actions before moving on | |
| 3. Prioritize moving to NEW unvisited locations | |
| 4. If you have been in the same location for 5+ steps, MOVE to a new location immediately | |
| 5. Pick up useful items (lamp, torch, sword, coin, key) | |
| 6. Do NOT damage useful items""" | |
| PROMISING_ACTIONS_PROMPT = """You are analyzing a text adventure game observation to find promising actions. | |
| Given this observation from a text adventure game, list the most promising actions to try. | |
| Focus on: items to pick up, containers to open, exits to explore, puzzles to solve. | |
| Return ONLY a JSON array of action strings, nothing else. | |
| Example: ["take lamp", "open mailbox", "go north", "examine sword"] | |
| Observation: | |
| {observation} | |
| Valid actions available: | |
| {valid_actions} | |
| Return only the JSON array:""" | |
| SUMMARIZE_PROMPT = """Summarize this text adventure game action result in one short sentence (max 10 words). | |
| Action: {action} | |
| Result: {result} | |
| Summary:""" | |
| # ============================================================================= | |
| # Student Agent | |
| # ============================================================================= | |
| class StudentAgent: | |
| def __init__(self): | |
| self.score: int = 0 | |
| # Per-location tracking | |
| self.location_actions: dict[str, list[tuple[str, str]]] = {} # location -> [(action, summary)] | |
| self.location_promising: dict[str, list[str]] = {} # location -> [promising actions] | |
| self.location_failed: dict[str, set[str]] = {} # location -> {failed actions} | |
| self.steps_in_location: int = 0 | |
| self.current_location: str = "" | |
| self.all_locations_visited: set[str] = set() | |
| def _is_new_location(self, location: str) -> bool: | |
| return location not in self.all_locations_visited | |
| def _on_enter_location(self, location: str, observation: str, valid_actions: list[str], seed: int): | |
| """Called when entering a new location - sets up tracking and extracts promising actions.""" | |
| self.all_locations_visited.add(location) | |
| self.location_actions[location] = [] | |
| self.location_failed[location] = set() | |
| self.steps_in_location = 0 | |
| # Use LLM to extract promising actions (hint 5) | |
| promising = self._extract_promising_actions(observation, valid_actions, seed) | |
| self.location_promising[location] = promising | |
| def _extract_promising_actions(self, observation: str, valid_actions: list[str], seed: int) -> list[str]: | |
| """Use LLM to identify promising actions from observation (hint 5).""" | |
| try: | |
| prompt = PROMISING_ACTIONS_PROMPT.format( | |
| observation=observation[:500], | |
| valid_actions=", ".join(valid_actions[:20]) | |
| ) | |
| response = call_llm(prompt, "You extract promising actions from text adventure observations. Return only valid JSON arrays.", seed) | |
| response = response.strip() | |
| # Extract JSON array | |
| match = re.search(r'\[.*?\]', response, re.DOTALL) | |
| if match: | |
| actions = json.loads(match.group()) | |
| return [a.lower().strip() for a in actions if isinstance(a, str)] | |
| except Exception: | |
| pass | |
| return [] | |
| def _summarize_outcome(self, action: str, result: str, seed: int) -> str: | |
| """Use LLM to summarize an action outcome in one sentence (hint 3).""" | |
| try: | |
| prompt = SUMMARIZE_PROMPT.format(action=action, result=result[:200]) | |
| summary = call_llm(prompt, "You summarize text adventure outcomes briefly.", seed, max_tokens=30) | |
| return summary.strip() | |
| except Exception: | |
| return result[:60] | |
| def _get_next_action(self, location: str, observation: str, seed: int) -> Optional[str]: | |
| """ | |
| Structured action selection (hints 3, 4, 6): | |
| 1. Try promising actions first | |
| 2. If stuck too long, force movement | |
| 3. Fall back to LLM | |
| """ | |
| failed = self.location_failed.get(location, set()) | |
| promising = self.location_promising.get(location, []) | |
| # Try promising actions that haven't failed yet (hint 4) | |
| for action in promising: | |
| if action not in failed: | |
| self.location_promising[location].remove(action) | |
| return action | |
| # Exploration bias - if stuck too long, force movement (hint 6) | |
| if self.steps_in_location >= 5: | |
| movement_dirs = ["north", "south", "east", "west", "up", "down", | |
| "n", "s", "e", "w", "u", "d", "enter", "exit"] | |
| for d in movement_dirs: | |
| if d not in failed: | |
| return d | |
| return None # Fall back to LLM | |
| def _build_prompt(self, observation: str, location: str) -> str: | |
| parts = [f"Current Score: {self.score}"] | |
| parts.append(f"Current Location: {location}") | |
| parts.append(f"Steps in this location: {self.steps_in_location}") | |
| # Show action history for current location (hint 3) | |
| loc_history = self.location_actions.get(location, []) | |
| if loc_history: | |
| parts.append("\nActions tried here:") | |
| for action, summary in loc_history[-5:]: | |
| parts.append(f" > {action} -> {summary}") | |
| # Show failed actions | |
| failed = self.location_failed.get(location, set()) | |
| if failed: | |
| parts.append(f"\nFailed actions here (DO NOT REPEAT): {', '.join(list(failed)[:15])}") | |
| # Show promising actions remaining | |
| promising = self.location_promising.get(location, []) | |
| if promising: | |
| parts.append(f"\nPromising actions to try: {', '.join(promising[:5])}") | |
| # Exploration bias warning (hint 6) | |
| if self.steps_in_location >= 5: | |
| parts.append(f"\n[WARNING: Stuck in {location} for {self.steps_in_location} steps. MOVE TO A NEW LOCATION NOW!]") | |
| parts.append(f"\nCurrent observation:\n{observation}") | |
| parts.append("\nWhat do you do next?") | |
| return "\n".join(parts) | |
| def _parse_response(self, response: str, valid_tools: list[str]) -> tuple[str, str, dict]: | |
| thought = "No reasoning provided" | |
| tool_name = "play_action" | |
| tool_args = {"action": "look"} | |
| for line in response.strip().split("\n"): | |
| line_clean = line.strip() | |
| line_upper = line_clean.upper() | |
| if line_upper.startswith("THOUGHT:"): | |
| thought = line_clean.split(":", 1)[1].strip() | |
| elif line_upper.startswith("TOOL:"): | |
| raw = line_clean.split(":", 1)[1].strip().lower() | |
| raw = raw.replace("**", "").replace("*", "").replace("`", "").split()[0] if raw else "play_action" | |
| tool_name = raw | |
| elif line_upper.startswith("ARGS:"): | |
| args_part = line_clean.split(":", 1)[1].strip() | |
| try: | |
| tool_args = json.loads(args_part.replace("'", '"')) | |
| except json.JSONDecodeError: | |
| match = re.search(r'"action"\s*:\s*"([^"]+)"', args_part) | |
| tool_args = {"action": match.group(1)} if match else {"action": "look"} | |
| return thought, tool_name, tool_args | |
| def _validate_tool_call(self, tool_name: str, tool_args: dict, valid_tools: list[str]) -> tuple[str, dict]: | |
| if tool_name not in valid_tools: | |
| aliases = { | |
| "action": "play_action", "do": "play_action", "command": "play_action", | |
| "map": "get_map", "location": "get_map", | |
| "mem": "memory", "state": "memory", "status": "memory", | |
| "inv": "inventory", "items": "inventory", | |
| "valid": "get_valid_actions", "actions": "get_valid_actions", | |
| } | |
| tool_name = aliases.get(tool_name, "play_action") | |
| if tool_name == "play_action": | |
| action = tool_args.get("action", "look") | |
| invalid_verb_map = { | |
| "check": "examine", "inspect": "examine", "search": "look", | |
| "grab": "take", "pick": "take", "use": "examine", "investigate": "examine", | |
| } | |
| words = action.lower().split() | |
| if words and words[0] in invalid_verb_map: | |
| words[0] = invalid_verb_map[words[0]] | |
| action = " ".join(words) | |
| tool_args["action"] = " ".join(action.lower().strip().split()) | |
| return tool_name, tool_args | |
| def _extract_result(self, result) -> str: | |
| if hasattr(result, 'content') and result.content: | |
| return result.content[0].text | |
| if isinstance(result, list) and result: | |
| return result[0].text if hasattr(result[0], 'text') else str(result[0]) | |
| return str(result) | |
| def _update_score(self, text: str) -> None: | |
| for pattern in [r'Score:\s*(\d+)', r'score[:\s]+(\d+)', r'\[Score:\s*(\d+)']: | |
| match = re.search(pattern, text, re.IGNORECASE) | |
| if match: | |
| self.score = max(self.score, int(match.group(1))) | |
| def _is_game_over(self, text: str) -> bool: | |
| return any(p in text.lower() for p in ["game over", "you have died", "you are dead", "*** you have died ***"]) | |
| async def run(self, client, game: str, max_steps: int, seed: int, verbose: bool = False) -> RunResult: | |
| locations_visited = set() | |
| history = [] | |
| moves = 0 | |
| tools = await client.list_tools() | |
| tool_names = [t.name for t in tools] | |
| # Get initial observation | |
| result = await client.call_tool("play_action", {"action": "look"}) | |
| observation = self._extract_result(result) | |
| # Get initial location and valid actions | |
| try: | |
| loc_result = await client.call_tool("memory", {}) | |
| loc_text = self._extract_result(loc_result) | |
| loc_match = re.search(r'Location:\s*(.+)', loc_text) | |
| current_location = loc_match.group(1).strip() if loc_match else observation.split("\n")[0] | |
| except Exception: | |
| current_location = observation.split("\n")[0] | |
| self.current_location = current_location | |
| locations_visited.add(current_location) | |
| # Get valid actions for initial location | |
| try: | |
| valid_result = await client.call_tool("get_valid_actions", {}) | |
| valid_text = self._extract_result(valid_result) | |
| valid_actions = [a.strip() for a in valid_text.replace("Valid actions:", "").split(",")] | |
| except Exception: | |
| valid_actions = [] | |
| # Enter initial location | |
| self._on_enter_location(current_location, observation, valid_actions, seed) | |
| if verbose: | |
| print(f"\n{observation}") | |
| print(f"Location: {current_location}") | |
| print(f"Promising: {self.location_promising.get(current_location, [])}") | |
| for step in range(1, max_steps + 1): | |
| self.steps_in_location += 1 | |
| # Try structured action selection first (hints 3, 4, 6) | |
| structured_action = self._get_next_action(current_location, observation, seed + step) | |
| if structured_action: | |
| thought = f"Structured: trying '{structured_action}'" | |
| tool_name = "play_action" | |
| tool_args = {"action": structured_action} | |
| if verbose: | |
| print(f"\n--- Step {step} ---") | |
| print(f"[STRUCTURED] {structured_action}") | |
| else: | |
| # Fall back to LLM | |
| prompt = self._build_prompt(observation, current_location) | |
| response = call_llm(prompt, SYSTEM_PROMPT, seed + step) | |
| thought, tool_name, tool_args = self._parse_response(response, tool_names) | |
| tool_name, tool_args = self._validate_tool_call(tool_name, tool_args, tool_names) | |
| if verbose: | |
| print(f"\n--- Step {step} ---") | |
| print(f"[THOUGHT] {thought}") | |
| print(f"[TOOL] {tool_name}({tool_args})") | |
| # Execute | |
| try: | |
| result = await client.call_tool(tool_name, tool_args) | |
| observation = self._extract_result(result) | |
| if verbose: | |
| print(f"[RESULT] {observation[:150]}") | |
| except Exception as e: | |
| observation = f"Error: {e}" | |
| if verbose: | |
| print(f"[ERROR] {e}") | |
| if tool_name == "play_action": | |
| action = tool_args.get("action", "look") | |
| moves += 1 | |
| # Detect new location (hint 1) | |
| try: | |
| loc_result = await client.call_tool("memory", {}) | |
| loc_text = self._extract_result(loc_result) | |
| loc_match = re.search(r'Location:\s*(.+)', loc_text) | |
| new_location = loc_match.group(1).strip() if loc_match else observation.split("\n")[0] | |
| except Exception: | |
| new_location = observation.split("\n")[0] | |
| locations_visited.add(new_location) | |
| if new_location != current_location: | |
| # Entered new location (hint 2) | |
| current_location = new_location | |
| self.current_location = current_location | |
| self.steps_in_location = 0 | |
| try: | |
| valid_result = await client.call_tool("get_valid_actions", {}) | |
| valid_text = self._extract_result(valid_result) | |
| valid_actions = [a.strip() for a in valid_text.replace("Valid actions:", "").split(",")] | |
| except Exception: | |
| valid_actions = [] | |
| if self._is_new_location(current_location): | |
| self._on_enter_location(current_location, observation, valid_actions, seed + step) | |
| if verbose: | |
| print(f"[NEW LOCATION] {current_location}") | |
| print(f"[PROMISING] {self.location_promising.get(current_location, [])}") | |
| else: | |
| # Same location - track action outcome | |
| prev_score = self.score | |
| self._update_score(observation) | |
| # Mark as failed if no score gain and still in same location | |
| if self.score == prev_score: | |
| self.location_failed.setdefault(current_location, set()).add(action) | |
| # Summarize and log outcome (hint 3) | |
| summary = self._summarize_outcome(action, observation, seed + step) | |
| self.location_actions.setdefault(current_location, []).append((action, summary)) | |
| self._update_score(observation) | |
| history.append((thought, f"{tool_name}({tool_args})", observation[:100])) | |
| if self._is_game_over(observation): | |
| if verbose: | |
| print("\n*** GAME OVER ***") | |
| break | |
| return RunResult( | |
| final_score=self.score, | |
| max_score=350, | |
| moves=moves, | |
| locations_visited=locations_visited, | |
| game_completed=self._is_game_over(observation), | |
| history=history, | |
| ) | |
| async def test_agent(): | |
| from fastmcp import Client | |
| agent = StudentAgent() | |
| async with Client("mcp_server.py") as client: | |
| result = await agent.run(client=client, game="lostpig", max_steps=20, seed=42, verbose=True) | |
| print(f"\nFinal Score: {result.final_score}") | |
| print(f"Moves: {result.moves}") | |
| print(f"Locations: {len(result.locations_visited)}") | |
| if __name__ == "__main__": | |
| import asyncio | |
| asyncio.run(test_agent()) |