Sunxt25's picture
Update agent.py
43d3f35 verified
import json
import os
import re
import difflib
import random
from collections import defaultdict, deque
from dataclasses import dataclass, field
from typing import Optional
from dotenv import load_dotenv
from huggingface_hub import InferenceClient
# Load environment variables
load_dotenv()
# =============================================================================
# LLM Configuration - DO NOT MODIFY
# =============================================================================
LLM_MODEL = "Qwen/Qwen2.5-72B-Instruct"
_hf_token = os.getenv("HF_TOKEN")
if not _hf_token:
raise ValueError("HF_TOKEN not found. Set it in your .env file.")
LLM_CLIENT = InferenceClient(token=_hf_token)
def call_llm(prompt: str, system_prompt: str, seed: int, max_tokens: int = 300) -> str:
"""Standard wrapper for LLM calls with fixed temperature for reproducibility."""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
]
response = LLM_CLIENT.chat.completions.create(
model=LLM_MODEL,
messages=messages,
temperature=0.0,
max_tokens=max_tokens,
seed=seed,
)
return response.choices[0].message.content
@dataclass
class RunResult:
"""Structure to hold game execution results."""
final_score: int
max_score: int
moves: int
locations_visited: set[str]
game_completed: bool
error: Optional[str] = None
history: list[tuple[str, str, str]] = field(default_factory=list)
# =============================================================================
# System Prompt
# =============================================================================
SYSTEM_PROMPT = """You are playing a classic text adventure game.
GOAL: Explore the world, solve puzzles, and maximize your score.
RESPOND IN THIS EXACT FORMAT (no markdown):
THOUGHT: <your reasoning about what to do next>
TOOL: play_action
ARGS: {"action": "<verb noun>"}
Available MCP Tools: play_action, memory, get_map, get_valid_actions
"""
# =============================================================================
# Student Agent Implementation
# =============================================================================
class StudentAgent:
def __init__(self):
"""Initialize state tracking and item priority for decision making."""
self.visited_locations = set()
self.inventory = set()
self.pending_path = []
self.pending_containers = defaultdict(set)
self.current_location = "START"
self.recent_actions = deque(maxlen=20)
self.world_map = defaultdict(dict)
self.bad_actions_by_loc = defaultdict(lambda: defaultdict(int))
self.last_obs = ""
self.goal_stack = deque()
# Item priority: Lower values are dropped first if overweight
self.item_priority = {
"leaves": 0, "pile of leaves": 0,
"leaflet": 1, "garlic": 2, "map": 5,
"lantern": 10, "lamp": 10, "sword": 10, "key": 10
}
async def run(self, client, game: str, max_steps: int, seed: int, verbose: bool = False) -> RunResult:
# Initial room check
init_res = await client.call_tool("play_action", {"action": "look"})
observation = init_res.content[0].text if init_res and init_res.content else ""
self._extract_location(observation)
if verbose: print(f"\n[INITIAL OBSERVATION]\n{observation}\n")
history = []
final_score = 0
last_score = 0
for i in range(max_steps):
old_loc = self.current_location
self.visited_locations.add(old_loc)
# Sync world state and available actions
map_data = await client.call_tool("get_map", {})
try: self.world_map = json.loads(map_data.content[0].text)
except: pass
valid_data = await client.call_tool("get_valid_actions", {})
try: valid_actions = json.loads(valid_data.content[0].text)
except: valid_actions = []
self._update_containers(observation)
# --- DECISION PHASE ---
if not self.pending_path:
prompt = self._build_prompt(observation, valid_actions)
raw_response = self._call_llm(prompt, SYSTEM_PROMPT, seed)
thought, tool, args = self._parse_response(raw_response)
action = args.get("action", "look")
# Filter redundant 'take' actions if already in inventory
if action.startswith(("take ", "get ")):
item = action.replace("take ","").replace("get ","").lower()
if any(item in inv_item.lower() for inv_item in self.inventory):
action = "look"
else:
self.goal_stack.append(action)
# BFS Pathfinding for 'go to' commands
m = re.match(r"go to (.+)", action, re.I)
if m:
target = m.group(1).strip().upper()
path = self._bfs_path(self.current_location, target)
if path:
self.pending_path = path[1:]
action = path[0]
else:
action = self.pending_path.pop(0)
thought = f"Following planned path. Target: {action}"
if verbose:
print(f"\n{'-'*10} STEP {i+1} {'-'*10}")
print(f"THOUGHT: {thought}")
print(f"ACTION: {action}")
# --- EXECUTION PHASE ---
result = await client.call_tool("play_action", {"action": action})
new_obs = result.content[0].text if result and result.content else ""
if verbose: print(f"OBSERVATION: {new_obs.strip()}")
# --- REACTIVE STATE UPDATES ---
observation = new_obs
self._extract_location(observation)
# 1. Handle Overweight Feedback
heavy_msg = ["too heavy", "can't carry any more", "heavy enough", "full"]
if any(p in observation.lower() for p in heavy_msg) and self.inventory:
to_drop = min(list(self.inventory), key=lambda x: self.item_priority.get(x.lower(), 5))
if verbose: print(f"⚖️ [REACTIVE] Overweight detected. Dropping: {to_drop}")
await client.call_tool("play_action", {"action": f"drop {to_drop}"})
self.inventory.discard(to_drop)
self.pending_path = [] # Reset plan to reassess after drop
# 2. Update Inventory and Precise Goal Clearing
if any(p in observation for p in ["Taken", "You take", "You now have"]):
item_match = re.search(r"(?:Taken|take|have) (?:the )?([\w\s-]+)\.?", observation, re.I)
if item_match:
item_name = item_match.group(1).strip().lower()
self.inventory.add(item_name)
# Only clear 'take' goals, keep 'use' or 'unlock' goals
self.goal_stack = deque([
g for g in self.goal_stack
if not (g.lower().startswith(("take ", "get ")) and item_name in g.lower())
])
# Clear path only if it was intended to get this specific item
if self.pending_path and item_name in self.pending_path[-1].lower():
self.pending_path = []
# 3. Junk Filter: If we accidentally took leaves, drop them immediately
if "leaves" in observation.lower() and ("Taken" in observation or "take" in action):
await client.call_tool("play_action", {"action": "drop leaves"})
self.inventory.discard("leaves")
# 4. Error Correction: Reset on "already have" hallucination
if "already have" in observation.lower():
self.goal_stack.clear()
self.pending_path = []
# 5. Goal Maintenance
if not self.pending_path and self.goal_stack:
if self._check_goal_complete(self.goal_stack[-1]):
self.goal_stack.pop()
# 6. Score Tracking
mem_res = await client.call_tool("memory", {})
mem_text = mem_res.content[0].text if mem_res and mem_res.content else ""
score_match = re.search(r"SCORE: (\d+)", mem_text)
if score_match:
current_score = int(score_match.group(1))
if current_score > last_score:
print(f"\n[SCORE UPDATED] {last_score} -> {current_score}")
last_score = current_score
final_score = current_score
history.append((thought, action, observation))
if "game over" in observation.lower() or "you have died" in observation.lower():
break
return RunResult(final_score=final_score, max_score=350, moves=i+1,
locations_visited=self.visited_locations, game_completed=False, history=history)
# --- HELPER METHODS ---
def _extract_location(self, obs: str):
match = re.search(r"\[([^\]]+)\]", obs)
if match: self.current_location = match.group(1).upper()
return self.current_location
def _check_goal_complete(self, goal: str) -> bool:
goal = goal.lower()
if goal.startswith("go to "):
return self.current_location == goal[6:].strip().upper()
if goal.startswith(("take ", "get ")):
items = re.findall(r"(?:take|get)\s+([\w-]+)", goal)
return items[0] in self.inventory if items else False
return False
def _update_containers(self, obs: str):
loc = self.current_location
containers = re.findall(r"(?:a|the)\s+([\w-]+)\s+(?:case|cupboard|chest|drawer|box)", obs.lower())
for c in containers:
if c not in self.pending_containers[loc]:
self.pending_path.insert(0, f"look inside {c}")
self.pending_containers[loc].add(c)
def _bfs_path(self, start: str, target: str) -> list:
candidates = self.world_map.keys()
match = difflib.get_close_matches(target.upper(), candidates, n=1, cutoff=0.6)
target = match[0] if match else target
if target not in self.world_map: return []
queue = deque([(start, [])])
visited = set()
while queue:
node, path = queue.popleft()
if node == target: return path
visited.add(node)
for move, dest in self.world_map.get(node, {}).items():
if dest and dest not in visited:
queue.append((dest, path + [move]))
return []
def _build_prompt(self, observation: str, valid_actions: list) -> str:
inv_str = ", ".join(self.inventory) if self.inventory else "Empty"
return f"""
[STATUS]
Location: {self.current_location}
Inventory: {inv_str}
[RULES]
- NEVER take useless junk like 'leaves'.
- If you 'take' or 'open' something, DO NOT try to 'take' or 'open' it again.
- Move to new areas if you are stuck in a loop.
[OBSERVATION]
{observation}
[VALID ACTIONS]
{valid_actions}
"""
def _parse_response(self, response: str) -> tuple[str, str, dict]:
thought, tool, args = "Thinking...", "play_action", {"action": "look"}
t_match = re.search(r"THOUGHT:\s*(.*)", response, re.I)
if t_match: thought = t_match.group(1).split("TOOL:")[0].strip()
tool_match = re.search(r"TOOL:\s*(\w+)", response, re.I)
if tool_match: tool = tool_match.group(1).strip()
args_match = re.search(r"ARGS:\s*({.*})", response, re.DOTALL)
if args_match:
try: args = json.loads(args_match.group(1))
except: pass
return thought, tool, args
def _call_llm(self, prompt: str, system_prompt: str, seed: int) -> str:
return call_llm(prompt, system_prompt, seed)