Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| import re | |
| import difflib | |
| import random | |
| from collections import defaultdict, deque | |
| from dataclasses import dataclass, field | |
| from typing import Optional | |
| from dotenv import load_dotenv | |
| from huggingface_hub import InferenceClient | |
| # Load environment variables | |
| load_dotenv() | |
| # ============================================================================= | |
| # LLM Configuration - DO NOT MODIFY | |
| # ============================================================================= | |
| LLM_MODEL = "Qwen/Qwen2.5-72B-Instruct" | |
| _hf_token = os.getenv("HF_TOKEN") | |
| if not _hf_token: | |
| raise ValueError("HF_TOKEN not found. Set it in your .env file.") | |
| LLM_CLIENT = InferenceClient(token=_hf_token) | |
| def call_llm(prompt: str, system_prompt: str, seed: int, max_tokens: int = 300) -> str: | |
| """Standard wrapper for LLM calls with fixed temperature for reproducibility.""" | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": prompt}, | |
| ] | |
| response = LLM_CLIENT.chat.completions.create( | |
| model=LLM_MODEL, | |
| messages=messages, | |
| temperature=0.0, | |
| max_tokens=max_tokens, | |
| seed=seed, | |
| ) | |
| return response.choices[0].message.content | |
| class RunResult: | |
| """Structure to hold game execution results.""" | |
| final_score: int | |
| max_score: int | |
| moves: int | |
| locations_visited: set[str] | |
| game_completed: bool | |
| error: Optional[str] = None | |
| history: list[tuple[str, str, str]] = field(default_factory=list) | |
| # ============================================================================= | |
| # System Prompt | |
| # ============================================================================= | |
| SYSTEM_PROMPT = """You are playing a classic text adventure game. | |
| GOAL: Explore the world, solve puzzles, and maximize your score. | |
| RESPOND IN THIS EXACT FORMAT (no markdown): | |
| THOUGHT: <your reasoning about what to do next> | |
| TOOL: play_action | |
| ARGS: {"action": "<verb noun>"} | |
| Available MCP Tools: play_action, memory, get_map, get_valid_actions | |
| """ | |
| # ============================================================================= | |
| # Student Agent Implementation | |
| # ============================================================================= | |
| class StudentAgent: | |
| def __init__(self): | |
| """Initialize state tracking and item priority for decision making.""" | |
| self.visited_locations = set() | |
| self.inventory = set() | |
| self.pending_path = [] | |
| self.pending_containers = defaultdict(set) | |
| self.current_location = "START" | |
| self.recent_actions = deque(maxlen=20) | |
| self.world_map = defaultdict(dict) | |
| self.bad_actions_by_loc = defaultdict(lambda: defaultdict(int)) | |
| self.last_obs = "" | |
| self.goal_stack = deque() | |
| # Item priority: Lower values are dropped first if overweight | |
| self.item_priority = { | |
| "leaves": 0, "pile of leaves": 0, | |
| "leaflet": 1, "garlic": 2, "map": 5, | |
| "lantern": 10, "lamp": 10, "sword": 10, "key": 10 | |
| } | |
| async def run(self, client, game: str, max_steps: int, seed: int, verbose: bool = False) -> RunResult: | |
| # Initial room check | |
| init_res = await client.call_tool("play_action", {"action": "look"}) | |
| observation = init_res.content[0].text if init_res and init_res.content else "" | |
| self._extract_location(observation) | |
| if verbose: print(f"\n[INITIAL OBSERVATION]\n{observation}\n") | |
| history = [] | |
| final_score = 0 | |
| last_score = 0 | |
| for i in range(max_steps): | |
| old_loc = self.current_location | |
| self.visited_locations.add(old_loc) | |
| # Sync world state and available actions | |
| map_data = await client.call_tool("get_map", {}) | |
| try: self.world_map = json.loads(map_data.content[0].text) | |
| except: pass | |
| valid_data = await client.call_tool("get_valid_actions", {}) | |
| try: valid_actions = json.loads(valid_data.content[0].text) | |
| except: valid_actions = [] | |
| self._update_containers(observation) | |
| # --- DECISION PHASE --- | |
| if not self.pending_path: | |
| prompt = self._build_prompt(observation, valid_actions) | |
| raw_response = self._call_llm(prompt, SYSTEM_PROMPT, seed) | |
| thought, tool, args = self._parse_response(raw_response) | |
| action = args.get("action", "look") | |
| # Filter redundant 'take' actions if already in inventory | |
| if action.startswith(("take ", "get ")): | |
| item = action.replace("take ","").replace("get ","").lower() | |
| if any(item in inv_item.lower() for inv_item in self.inventory): | |
| action = "look" | |
| else: | |
| self.goal_stack.append(action) | |
| # BFS Pathfinding for 'go to' commands | |
| m = re.match(r"go to (.+)", action, re.I) | |
| if m: | |
| target = m.group(1).strip().upper() | |
| path = self._bfs_path(self.current_location, target) | |
| if path: | |
| self.pending_path = path[1:] | |
| action = path[0] | |
| else: | |
| action = self.pending_path.pop(0) | |
| thought = f"Following planned path. Target: {action}" | |
| if verbose: | |
| print(f"\n{'-'*10} STEP {i+1} {'-'*10}") | |
| print(f"THOUGHT: {thought}") | |
| print(f"ACTION: {action}") | |
| # --- EXECUTION PHASE --- | |
| result = await client.call_tool("play_action", {"action": action}) | |
| new_obs = result.content[0].text if result and result.content else "" | |
| if verbose: print(f"OBSERVATION: {new_obs.strip()}") | |
| # --- REACTIVE STATE UPDATES --- | |
| observation = new_obs | |
| self._extract_location(observation) | |
| # 1. Handle Overweight Feedback | |
| heavy_msg = ["too heavy", "can't carry any more", "heavy enough", "full"] | |
| if any(p in observation.lower() for p in heavy_msg) and self.inventory: | |
| to_drop = min(list(self.inventory), key=lambda x: self.item_priority.get(x.lower(), 5)) | |
| if verbose: print(f"⚖️ [REACTIVE] Overweight detected. Dropping: {to_drop}") | |
| await client.call_tool("play_action", {"action": f"drop {to_drop}"}) | |
| self.inventory.discard(to_drop) | |
| self.pending_path = [] # Reset plan to reassess after drop | |
| # 2. Update Inventory and Precise Goal Clearing | |
| if any(p in observation for p in ["Taken", "You take", "You now have"]): | |
| item_match = re.search(r"(?:Taken|take|have) (?:the )?([\w\s-]+)\.?", observation, re.I) | |
| if item_match: | |
| item_name = item_match.group(1).strip().lower() | |
| self.inventory.add(item_name) | |
| # Only clear 'take' goals, keep 'use' or 'unlock' goals | |
| self.goal_stack = deque([ | |
| g for g in self.goal_stack | |
| if not (g.lower().startswith(("take ", "get ")) and item_name in g.lower()) | |
| ]) | |
| # Clear path only if it was intended to get this specific item | |
| if self.pending_path and item_name in self.pending_path[-1].lower(): | |
| self.pending_path = [] | |
| # 3. Junk Filter: If we accidentally took leaves, drop them immediately | |
| if "leaves" in observation.lower() and ("Taken" in observation or "take" in action): | |
| await client.call_tool("play_action", {"action": "drop leaves"}) | |
| self.inventory.discard("leaves") | |
| # 4. Error Correction: Reset on "already have" hallucination | |
| if "already have" in observation.lower(): | |
| self.goal_stack.clear() | |
| self.pending_path = [] | |
| # 5. Goal Maintenance | |
| if not self.pending_path and self.goal_stack: | |
| if self._check_goal_complete(self.goal_stack[-1]): | |
| self.goal_stack.pop() | |
| # 6. Score Tracking | |
| mem_res = await client.call_tool("memory", {}) | |
| mem_text = mem_res.content[0].text if mem_res and mem_res.content else "" | |
| score_match = re.search(r"SCORE: (\d+)", mem_text) | |
| if score_match: | |
| current_score = int(score_match.group(1)) | |
| if current_score > last_score: | |
| print(f"\n[SCORE UPDATED] {last_score} -> {current_score}") | |
| last_score = current_score | |
| final_score = current_score | |
| history.append((thought, action, observation)) | |
| if "game over" in observation.lower() or "you have died" in observation.lower(): | |
| break | |
| return RunResult(final_score=final_score, max_score=350, moves=i+1, | |
| locations_visited=self.visited_locations, game_completed=False, history=history) | |
| # --- HELPER METHODS --- | |
| def _extract_location(self, obs: str): | |
| match = re.search(r"\[([^\]]+)\]", obs) | |
| if match: self.current_location = match.group(1).upper() | |
| return self.current_location | |
| def _check_goal_complete(self, goal: str) -> bool: | |
| goal = goal.lower() | |
| if goal.startswith("go to "): | |
| return self.current_location == goal[6:].strip().upper() | |
| if goal.startswith(("take ", "get ")): | |
| items = re.findall(r"(?:take|get)\s+([\w-]+)", goal) | |
| return items[0] in self.inventory if items else False | |
| return False | |
| def _update_containers(self, obs: str): | |
| loc = self.current_location | |
| containers = re.findall(r"(?:a|the)\s+([\w-]+)\s+(?:case|cupboard|chest|drawer|box)", obs.lower()) | |
| for c in containers: | |
| if c not in self.pending_containers[loc]: | |
| self.pending_path.insert(0, f"look inside {c}") | |
| self.pending_containers[loc].add(c) | |
| def _bfs_path(self, start: str, target: str) -> list: | |
| candidates = self.world_map.keys() | |
| match = difflib.get_close_matches(target.upper(), candidates, n=1, cutoff=0.6) | |
| target = match[0] if match else target | |
| if target not in self.world_map: return [] | |
| queue = deque([(start, [])]) | |
| visited = set() | |
| while queue: | |
| node, path = queue.popleft() | |
| if node == target: return path | |
| visited.add(node) | |
| for move, dest in self.world_map.get(node, {}).items(): | |
| if dest and dest not in visited: | |
| queue.append((dest, path + [move])) | |
| return [] | |
| def _build_prompt(self, observation: str, valid_actions: list) -> str: | |
| inv_str = ", ".join(self.inventory) if self.inventory else "Empty" | |
| return f""" | |
| [STATUS] | |
| Location: {self.current_location} | |
| Inventory: {inv_str} | |
| [RULES] | |
| - NEVER take useless junk like 'leaves'. | |
| - If you 'take' or 'open' something, DO NOT try to 'take' or 'open' it again. | |
| - Move to new areas if you are stuck in a loop. | |
| [OBSERVATION] | |
| {observation} | |
| [VALID ACTIONS] | |
| {valid_actions} | |
| """ | |
| def _parse_response(self, response: str) -> tuple[str, str, dict]: | |
| thought, tool, args = "Thinking...", "play_action", {"action": "look"} | |
| t_match = re.search(r"THOUGHT:\s*(.*)", response, re.I) | |
| if t_match: thought = t_match.group(1).split("TOOL:")[0].strip() | |
| tool_match = re.search(r"TOOL:\s*(\w+)", response, re.I) | |
| if tool_match: tool = tool_match.group(1).strip() | |
| args_match = re.search(r"ARGS:\s*({.*})", response, re.DOTALL) | |
| if args_match: | |
| try: args = json.loads(args_match.group(1)) | |
| except: pass | |
| return thought, tool, args | |
| def _call_llm(self, prompt: str, system_prompt: str, seed: int) -> str: | |
| return call_llm(prompt, system_prompt, seed) | |