import json import os import re import difflib import random from collections import defaultdict, deque from dataclasses import dataclass, field from typing import Optional from dotenv import load_dotenv from huggingface_hub import InferenceClient # Load environment variables load_dotenv() # ============================================================================= # LLM Configuration - DO NOT MODIFY # ============================================================================= LLM_MODEL = "Qwen/Qwen2.5-72B-Instruct" _hf_token = os.getenv("HF_TOKEN") if not _hf_token: raise ValueError("HF_TOKEN not found. Set it in your .env file.") LLM_CLIENT = InferenceClient(token=_hf_token) def call_llm(prompt: str, system_prompt: str, seed: int, max_tokens: int = 300) -> str: """Standard wrapper for LLM calls with fixed temperature for reproducibility.""" messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}, ] response = LLM_CLIENT.chat.completions.create( model=LLM_MODEL, messages=messages, temperature=0.0, max_tokens=max_tokens, seed=seed, ) return response.choices[0].message.content @dataclass class RunResult: """Structure to hold game execution results.""" final_score: int max_score: int moves: int locations_visited: set[str] game_completed: bool error: Optional[str] = None history: list[tuple[str, str, str]] = field(default_factory=list) # ============================================================================= # System Prompt # ============================================================================= SYSTEM_PROMPT = """You are playing a classic text adventure game. GOAL: Explore the world, solve puzzles, and maximize your score. RESPOND IN THIS EXACT FORMAT (no markdown): THOUGHT: TOOL: play_action ARGS: {"action": ""} Available MCP Tools: play_action, memory, get_map, get_valid_actions """ # ============================================================================= # Student Agent Implementation # ============================================================================= class StudentAgent: def __init__(self): """Initialize state tracking and item priority for decision making.""" self.visited_locations = set() self.inventory = set() self.pending_path = [] self.pending_containers = defaultdict(set) self.current_location = "START" self.recent_actions = deque(maxlen=20) self.world_map = defaultdict(dict) self.bad_actions_by_loc = defaultdict(lambda: defaultdict(int)) self.last_obs = "" self.goal_stack = deque() # Item priority: Lower values are dropped first if overweight self.item_priority = { "leaves": 0, "pile of leaves": 0, "leaflet": 1, "garlic": 2, "map": 5, "lantern": 10, "lamp": 10, "sword": 10, "key": 10 } async def run(self, client, game: str, max_steps: int, seed: int, verbose: bool = False) -> RunResult: # Initial room check init_res = await client.call_tool("play_action", {"action": "look"}) observation = init_res.content[0].text if init_res and init_res.content else "" self._extract_location(observation) if verbose: print(f"\n[INITIAL OBSERVATION]\n{observation}\n") history = [] final_score = 0 last_score = 0 for i in range(max_steps): old_loc = self.current_location self.visited_locations.add(old_loc) # Sync world state and available actions map_data = await client.call_tool("get_map", {}) try: self.world_map = json.loads(map_data.content[0].text) except: pass valid_data = await client.call_tool("get_valid_actions", {}) try: valid_actions = json.loads(valid_data.content[0].text) except: valid_actions = [] self._update_containers(observation) # --- DECISION PHASE --- if not self.pending_path: prompt = self._build_prompt(observation, valid_actions) raw_response = self._call_llm(prompt, SYSTEM_PROMPT, seed) thought, tool, args = self._parse_response(raw_response) action = args.get("action", "look") # Filter redundant 'take' actions if already in inventory if action.startswith(("take ", "get ")): item = action.replace("take ","").replace("get ","").lower() if any(item in inv_item.lower() for inv_item in self.inventory): action = "look" else: self.goal_stack.append(action) # BFS Pathfinding for 'go to' commands m = re.match(r"go to (.+)", action, re.I) if m: target = m.group(1).strip().upper() path = self._bfs_path(self.current_location, target) if path: self.pending_path = path[1:] action = path[0] else: action = self.pending_path.pop(0) thought = f"Following planned path. Target: {action}" if verbose: print(f"\n{'-'*10} STEP {i+1} {'-'*10}") print(f"THOUGHT: {thought}") print(f"ACTION: {action}") # --- EXECUTION PHASE --- result = await client.call_tool("play_action", {"action": action}) new_obs = result.content[0].text if result and result.content else "" if verbose: print(f"OBSERVATION: {new_obs.strip()}") # --- REACTIVE STATE UPDATES --- observation = new_obs self._extract_location(observation) # 1. Handle Overweight Feedback heavy_msg = ["too heavy", "can't carry any more", "heavy enough", "full"] if any(p in observation.lower() for p in heavy_msg) and self.inventory: to_drop = min(list(self.inventory), key=lambda x: self.item_priority.get(x.lower(), 5)) if verbose: print(f"⚖️ [REACTIVE] Overweight detected. Dropping: {to_drop}") await client.call_tool("play_action", {"action": f"drop {to_drop}"}) self.inventory.discard(to_drop) self.pending_path = [] # Reset plan to reassess after drop # 2. Update Inventory and Precise Goal Clearing if any(p in observation for p in ["Taken", "You take", "You now have"]): item_match = re.search(r"(?:Taken|take|have) (?:the )?([\w\s-]+)\.?", observation, re.I) if item_match: item_name = item_match.group(1).strip().lower() self.inventory.add(item_name) # Only clear 'take' goals, keep 'use' or 'unlock' goals self.goal_stack = deque([ g for g in self.goal_stack if not (g.lower().startswith(("take ", "get ")) and item_name in g.lower()) ]) # Clear path only if it was intended to get this specific item if self.pending_path and item_name in self.pending_path[-1].lower(): self.pending_path = [] # 3. Junk Filter: If we accidentally took leaves, drop them immediately if "leaves" in observation.lower() and ("Taken" in observation or "take" in action): await client.call_tool("play_action", {"action": "drop leaves"}) self.inventory.discard("leaves") # 4. Error Correction: Reset on "already have" hallucination if "already have" in observation.lower(): self.goal_stack.clear() self.pending_path = [] # 5. Goal Maintenance if not self.pending_path and self.goal_stack: if self._check_goal_complete(self.goal_stack[-1]): self.goal_stack.pop() # 6. Score Tracking mem_res = await client.call_tool("memory", {}) mem_text = mem_res.content[0].text if mem_res and mem_res.content else "" score_match = re.search(r"SCORE: (\d+)", mem_text) if score_match: current_score = int(score_match.group(1)) if current_score > last_score: print(f"\n[SCORE UPDATED] {last_score} -> {current_score}") last_score = current_score final_score = current_score history.append((thought, action, observation)) if "game over" in observation.lower() or "you have died" in observation.lower(): break return RunResult(final_score=final_score, max_score=350, moves=i+1, locations_visited=self.visited_locations, game_completed=False, history=history) # --- HELPER METHODS --- def _extract_location(self, obs: str): match = re.search(r"\[([^\]]+)\]", obs) if match: self.current_location = match.group(1).upper() return self.current_location def _check_goal_complete(self, goal: str) -> bool: goal = goal.lower() if goal.startswith("go to "): return self.current_location == goal[6:].strip().upper() if goal.startswith(("take ", "get ")): items = re.findall(r"(?:take|get)\s+([\w-]+)", goal) return items[0] in self.inventory if items else False return False def _update_containers(self, obs: str): loc = self.current_location containers = re.findall(r"(?:a|the)\s+([\w-]+)\s+(?:case|cupboard|chest|drawer|box)", obs.lower()) for c in containers: if c not in self.pending_containers[loc]: self.pending_path.insert(0, f"look inside {c}") self.pending_containers[loc].add(c) def _bfs_path(self, start: str, target: str) -> list: candidates = self.world_map.keys() match = difflib.get_close_matches(target.upper(), candidates, n=1, cutoff=0.6) target = match[0] if match else target if target not in self.world_map: return [] queue = deque([(start, [])]) visited = set() while queue: node, path = queue.popleft() if node == target: return path visited.add(node) for move, dest in self.world_map.get(node, {}).items(): if dest and dest not in visited: queue.append((dest, path + [move])) return [] def _build_prompt(self, observation: str, valid_actions: list) -> str: inv_str = ", ".join(self.inventory) if self.inventory else "Empty" return f""" [STATUS] Location: {self.current_location} Inventory: {inv_str} [RULES] - NEVER take useless junk like 'leaves'. - If you 'take' or 'open' something, DO NOT try to 'take' or 'open' it again. - Move to new areas if you are stuck in a loop. [OBSERVATION] {observation} [VALID ACTIONS] {valid_actions} """ def _parse_response(self, response: str) -> tuple[str, str, dict]: thought, tool, args = "Thinking...", "play_action", {"action": "look"} t_match = re.search(r"THOUGHT:\s*(.*)", response, re.I) if t_match: thought = t_match.group(1).split("TOOL:")[0].strip() tool_match = re.search(r"TOOL:\s*(\w+)", response, re.I) if tool_match: tool = tool_match.group(1).strip() args_match = re.search(r"ARGS:\s*({.*})", response, re.DOTALL) if args_match: try: args = json.loads(args_match.group(1)) except: pass return thought, tool, args def _call_llm(self, prompt: str, system_prompt: str, seed: int) -> str: return call_llm(prompt, system_prompt, seed)