Willy Vo
Implement my agent
0b17df7
"""
Example: MCP ReAct Agent
A complete ReAct agent that uses MCP tools to play text adventure games.
This is a working example students can learn from.
"""
import json
import os
import re
from dataclasses import dataclass, field
from typing import Optional
from dotenv import load_dotenv
from huggingface_hub import InferenceClient
import hashlib
from collections import defaultdict
load_dotenv()
# =============================================================================
# LLM Configuration - DO NOT MODIFY
# =============================================================================
LLM_MODEL = "Qwen/Qwen2.5-72B-Instruct"
_hf_token = os.getenv("HF_TOKEN")
if not _hf_token:
raise ValueError("HF_TOKEN not found. Set it in your .env file.")
LLM_CLIENT = InferenceClient(token=_hf_token)
def call_llm(prompt: str, system_prompt: str, seed: int, max_tokens: int = 300) -> str:
"""Call the LLM with the given prompt."""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
]
response = LLM_CLIENT.chat.completions.create(
model=LLM_MODEL,
messages=messages,
temperature=0.0,
max_tokens=max_tokens,
seed=seed,
)
return response.choices[0].message.content
@dataclass
class RunResult:
"""Result of running the agent. Do not modify this class."""
final_score: int
max_score: int
moves: int
locations_visited: set[str]
game_completed: bool
error: Optional[str] = None
history: list[tuple[str, str, str]] = field(default_factory=list)
# =============================================================================
# System Prompt
# =============================================================================
SYSTEM_PROMPT = """You are an expert text adventure game player. Your goal is to explore, collect treasures, and maximize your score.
AVAILABLE TOOLS (use these via MCP):
1. play_action - Execute game commands (north, take lamp, open mailbox, etc.)
2. memory - Get current game state, score, and recent history
3. get_map - See explored locations and connections
4. inventory - Check what you're carrying
VALID GAME COMMANDS for play_action:
- Movement: north, south, east, west, up, down, enter, exit
- Objects: take <item>, drop <item>, open <thing>, close <thing>, examine <thing>
- Light: turn on lamp, turn off lamp
- Combat: attack <enemy> with <weapon>
- Other: inventory, look, read <thing>, wait
FORBIDDEN (will NOT work): check, inspect, search, grab, use, help
RESPOND IN THIS EXACT FORMAT (no markdown):
THOUGHT: <brief reasoning about what to do next>
TOOL: <tool_name>
ARGS: <JSON arguments>
Examples:
THOUGHT: I need to see what's around me.
TOOL: play_action
ARGS: {"action": "look"}
THOUGHT: Let me check my current state and score.
TOOL: memory
ARGS: {}
THOUGHT: The mailbox might contain something useful.
TOOL: play_action
ARGS: {"action": "open mailbox"}
STRATEGY:
1. Start by looking around and checking memory
2. Explore systematically - try all directions
3. Pick up useful items (lamp, sword, etc.)
4. Open containers (mailbox, window, etc.)
5. Use get_map to avoid getting lost
6. Turn on lamp before dark areas!
DO NOT repeat the same action multiple times in a row."""
# =============================================================================
# Student Agent Implementation
# =============================================================================
class StudentAgent:
"""
MCP ReAct Agent - A complete working example.
This agent demonstrates:
- ReAct loop (Thought -> Tool -> Observation)
- Loop detection
- Action validation
- Score tracking via memory tool
"""
def __init__(self):
"""Initialize the agent state."""
self.history: list[dict] = []
self.recent_actions: list[str] = []
self.score: int = 0
# --- Context management memory ---
# Keyed by (state_id, inv_sig)
self.failed_strong = defaultdict(set) # actions that are nonsense here
self.failed_soft = defaultdict(dict) # action -> last_step tried (cooldown)
self.state_last_obs = {} # (state_id, inv_sig) -> normalized obs
self.inv_sig: str = "" # current inventory signature
self.prev_inv_sig: str = "" # previous signature to detect changes
self.step: int = 0 # current step counter
self.debug_context: bool = True # Whether to include context management info in the prompt (for transparency)
self.soft_cooldown_steps = 30
# -- LLM judge
self.judge_cache: dict[tuple[str, str, str], str] = {}
self.use_llm_judge: bool = True
self._last_llm_raw = ""
self._last_llm_label = ""
self._last_llm_cached = False
async def run(
self,
client,
game: str,
max_steps: int,
seed: int,
verbose: bool = False,
) -> RunResult:
"""Run the agent for a game session."""
locations_visited = set()
history = []
moves = 0
# Get list of available tools
tools = await client.list_tools()
tool_names = [t.name for t in tools]
# Get initial observation
result = await client.call_tool("play_action", {"action": "look"})
observation = self._extract_result(result)
# Initialize inventory signature
inv_res = await client.call_tool("inventory", {})
inv_text = self._extract_result(inv_res)
self.inv_sig = self._inventory_signature(inv_text)
self.prev_inv_sig = self.inv_sig
# Track initial location
location = observation.split("\n")[0] if observation else "Unknown"
locations_visited.add(location)
if verbose:
print(f"\n{observation}")
# Main ReAct loop
for step in range(1, max_steps + 1):
self.step = step
# Refresh inventory periodically (cheap and very useful for gating)
if step == 1 or step % 7 == 0:
inv_res = await client.call_tool("inventory", {})
inv_text = self._extract_result(inv_res)
self.inv_sig = self._inventory_signature(inv_text)
# If inventory changed, we want to allow retry of gated actions everywhere
if self.inv_sig != self.prev_inv_sig:
# Clear ALL soft failures (gated actions may now be valid)
self.failed_soft.clear()
self.prev_inv_sig = self.inv_sig
# Build prompt with context
prompt = self._build_prompt(observation)
# Call LLM for reasoning (use step-based seed for variety)
response = call_llm(prompt, SYSTEM_PROMPT, seed + step)
# Parse the response
thought, tool_name, tool_args = self._parse_response(response, tool_names)
if verbose:
print(f"\n--- Step {step} ---")
print(f"[THOUGHT] {thought}")
print(f"[TOOL] {tool_name}({tool_args})")
# Validate and fix common issues
tool_name, tool_args = self._validate_tool_call(tool_name, tool_args, tool_names)
tool_name, tool_args = self._apply_context_management(tool_name, tool_args, observation)
# Loop detection
if tool_name == "play_action":
action = tool_args.get("action", "look")
self.recent_actions.append(action)
if len(self.recent_actions) > 5:
self.recent_actions = self.recent_actions[-5:]
# Detect loops - if same action 3 times, force "look"
if len(self.recent_actions) >= 3 and len(set(self.recent_actions[-3:])) == 1:
if verbose:
print(f"[WARNING] Loop detected - forcing 'look'")
tool_args = {"action": "look"}
self.recent_actions.append("look")
moves += 1
# Execute the tool
try:
prev_observation = observation # keep previous for failure detection
result = await client.call_tool(tool_name, tool_args)
observation = self._extract_result(result)
# Update failure memory only for play_action
if tool_name == "play_action":
action = tool_args.get("action", "look")
self._update_failure_memory(prev_observation, action, observation)
self._log_context_state(prev_observation, action, new_observation=observation)
if verbose:
print(f"[RESULT] {observation[:200]}...")
except Exception as e:
observation = f"Error: {e}"
if verbose:
print(f"[ERROR] {e}")
# Track location
location = observation.split("\n")[0] if observation else "Unknown"
locations_visited.add(location)
# Update history
self.history.append({
"step": step,
"thought": thought,
"tool": tool_name,
"args": tool_args,
"result": observation[:200]
})
if len(self.history) > 10:
self.history = self.history[-10:]
# Track score from observation
self._update_score(observation)
# Record in result history
history.append((thought, f"{tool_name}({tool_args})", observation[:100]))
# Check for game over
if self._is_game_over(observation):
if verbose:
print("\n*** GAME OVER ***")
break
return RunResult(
final_score=self.score,
max_score=350,
moves=moves,
locations_visited=locations_visited,
game_completed=self._is_game_over(observation),
history=history,
)
def _build_prompt(self, observation: str) -> str:
"""Build the prompt for the LLM with context."""
parts = []
parts.append(f"Current Score: {self.score}")
# Recent history
if self.history:
parts.append("\nRecent actions:")
for entry in self.history[-8:]:
action = entry.get("args", {}).get("action", entry["tool"])
result_short = entry["result"][:80] + "..." if len(entry["result"]) > 80 else entry["result"]
parts.append(f" > {action} -> {result_short}")
# Warn about repeated actions
if self.recent_actions and len(set(self.recent_actions[-3:])) == 1:
parts.append(f"\n[WARNING: You've been doing '{self.recent_actions[-1]}' repeatedly. TRY SOMETHING DIFFERENT!]")
# Add context constraints to reduce repetition
sid = self._state_id_from_observation(observation)
key = (sid, self.inv_sig)
forbidden = list(self.failed_strong[key])[:8]
soft_forbidden = list(self.failed_soft[key].keys())[:8]
if forbidden or soft_forbidden:
parts.append("\nContext restrictions (DO NOT choose these actions here):")
if forbidden:
parts.append(" Strong banned: " + ", ".join(forbidden))
if soft_forbidden:
parts.append(" Recently failed (wait before retry): " + ", ".join(soft_forbidden))
parts.append(f"\nCurrent situation:\n{observation}")
parts.append("\nWhat do you do next?")
return "\n".join(parts)
def _parse_response(self, response: str, valid_tools: list[str]) -> tuple[str, str, dict]:
"""Parse the LLM response to extract thought, tool, and arguments."""
thought = "No reasoning provided"
tool_name = "play_action"
tool_args = {"action": "look"}
lines = response.strip().split("\n")
for line in lines:
line_clean = line.strip()
line_upper = line_clean.upper()
if line_upper.startswith("THOUGHT:"):
thought = line_clean.split(":", 1)[1].strip()
elif line_upper.startswith("TOOL:"):
raw_tool = line_clean.split(":", 1)[1].strip().lower()
raw_tool = raw_tool.replace("**", "").replace("*", "").replace("`", "")
raw_tool = raw_tool.split()[0] if raw_tool else "play_action"
tool_name = raw_tool
elif line_upper.startswith("ARGS:"):
args_part = line_clean.split(":", 1)[1].strip()
try:
args_part = args_part.replace("'", '"')
tool_args = json.loads(args_part)
except json.JSONDecodeError:
match = re.search(r'"action"\s*:\s*"([^"]+)"', args_part)
if match:
tool_args = {"action": match.group(1)}
else:
tool_args = {"action": "look"}
return thought, tool_name, tool_args
def _validate_tool_call(self, tool_name: str, tool_args: dict, valid_tools: list[str]) -> tuple[str, dict]:
"""Validate and fix common tool call issues."""
# Fix tool name
if tool_name not in valid_tools:
if tool_name in ["action", "do", "command"]:
tool_name = "play_action"
elif tool_name in ["map", "location"]:
tool_name = "get_map"
elif tool_name in ["mem", "state", "status"]:
tool_name = "memory"
elif tool_name in ["inv", "items"]:
tool_name = "inventory"
else:
tool_name = "play_action"
# Fix action verbs
if tool_name == "play_action":
action = tool_args.get("action", "look")
invalid_verb_map = {
"check": "examine",
"inspect": "examine",
"search": "look",
"grab": "take",
"pick": "take",
"use": "examine",
"investigate": "examine",
}
words = action.lower().split()
if words and words[0] in invalid_verb_map:
words[0] = invalid_verb_map[words[0]]
action = " ".join(words)
action = action.lower().strip()
action = action.replace("**", "").replace("*", "").replace("`", "")
action = " ".join(action.split())
tool_args["action"] = action
return tool_name, tool_args
def _extract_result(self, result) -> str:
"""Extract text from MCP tool result."""
if hasattr(result, 'content') and result.content:
return result.content[0].text
if isinstance(result, list) and result:
return result[0].text if hasattr(result[0], 'text') else str(result[0])
return str(result)
def _update_score(self, text: str) -> None:
"""Update score from game text."""
patterns = [
r'Score:\s*(\d+)',
r'score[:\s]+(\d+)',
r'\[Score:\s*(\d+)',
]
for pattern in patterns:
match = re.search(pattern, text, re.IGNORECASE)
if match:
self.score = max(self.score, int(match.group(1)))
def _is_game_over(self, text: str) -> bool:
"""Check if the game is over."""
game_over_phrases = [
"game over",
"you have died",
"you are dead",
"*** you have died ***",
]
text_lower = text.lower()
return any(phrase in text_lower for phrase in game_over_phrases)
def _normalize_text(self, text: str) -> str:
"""Normalize text for comparison (remove numbers/punct, collapse spaces)."""
s = (text or "").lower()
s = re.sub(r"\d+", "0", s)
s = re.sub(r"[^a-z0\s]", " ", s)
s = re.sub(r"\s+", " ", s).strip()
return s[:500]
def _state_id_from_observation(self, observation: str) -> str:
"""Compute a stable-ish state id for the current room/context."""
lines = [l.strip() for l in (observation or "").splitlines() if l.strip()]
head = lines[0].lower() if lines else "unknown"
# hash helps reduce tiny variations
norm = self._normalize_text(head)
return hashlib.md5(norm.encode()).hexdigest()
def _inventory_signature(self, inv_text: str) -> str:
"""Signature of inventory text; robust enough for gating retries."""
norm = self._normalize_text(inv_text)
return hashlib.md5(norm.encode()).hexdigest()
def _classify_failure(self, prev_obs: str, action: str, new_obs: str) -> str:
"""
Return: "none" | "strong" | "soft"
strong = will never become valid (parser/no object)
soft = gated (locked/dark/tool) -> retry when inventory/state changes
"""
# Optional fast rule: identical observation => likely failed (soft)
if self._normalize_text(new_obs) == self._normalize_text(prev_obs):
return "soft"
if not getattr(self, "use_llm_judge", True):
return "none"
return self._llm_judge_failure(prev_obs, action, new_obs)
def _llm_judge_failure(self, prev_obs: str, action: str, new_obs: str) -> str:
"""
LLM judge: returns "none" | "soft" | "strong"
"""
# Normalize for caching (prevents repeated LLM calls)
prev_n = self._normalize_text(prev_obs)
new_n = self._normalize_text(new_obs)
act_n = " ".join((action or "").lower().split())
key = (prev_n, act_n, new_n)
if key in self.judge_cache:
self._last_llm_raw = "(cached)"
self._last_llm_label = self.judge_cache[key]
self._last_llm_cached = True
return self.judge_cache[key]
system = "You are a strict classifier for text-adventure command outcomes."
prompt = f"""
Classify whether the player's action FAILED, based on the before/after observations.
Return EXACTLY one label: none | soft | strong
Definitions:
- strong: The command is invalid/unknown OR refers to something not present/visible OR impossible in principle.
It will NOT become valid later just by having a different item.
- soft: The command was understood but is blocked by a condition (locked, closed, too dark, need an item, must do something first, not possible yet).
It COULD become valid later.
- none: The action had an effect OR gave new useful information (state changed, moved, item changed, new description).
BE STRICT: If the new observation is basically identical to the previous one AND no progress happened, prefer "soft".
PREVIOUS_OBSERVATION:
{prev_obs}
ACTION:
{action}
NEW_OBSERVATION:
{new_obs}
""".strip()
out = call_llm(prompt, system, seed=100000 + self.step, max_tokens=8)
raw_out = out.strip().lower()
label = "none"
if "strong" in raw_out:
label = "strong"
elif "soft" in raw_out:
label = "soft"
elif raw_out in {"none", "soft", "strong"}:
label = raw_out
# Store for logging
self._last_llm_raw = raw_out
self._last_llm_label = label
self._last_llm_cached = False
self.judge_cache[key] = label
return label
def _apply_context_management(self,tool_name: str,tool_args: dict,observation: str,) -> tuple[str, dict]:
"""Prevent repeating failed actions in the same context (state + inv)."""
if tool_name != "play_action":
return tool_name, tool_args
action = (tool_args.get("action") or "look").strip().lower()
sid = self._state_id_from_observation(observation)
key = (sid, self.inv_sig)
# Strong blacklist
if action in self.failed_strong[key]:
return "play_action", {"action": self._fallback_action(observation)}
# Soft blacklist with cooldown
if action in self.failed_soft[key]:
last = self.failed_soft[key][action]
# cooldown: avoid retrying too soon
if self.step - last < self.soft_cooldown_steps:
return "play_action", {"action": self._fallback_action(observation)}
# Prevent immediate repetition in same context
if self.recent_actions and action == self.recent_actions[-1]:
return "play_action", {"action": self._fallback_action(observation)}
return tool_name, {"action": action}
def _fallback_action(self, observation: str) -> str:
"""
Deterministic fallback when the chosen action is banned.
Prefer exploration moves; otherwise look/inventory.
"""
# Prefer moves that haven't been tried recently
move_candidates = ["north","south","east","west","up","down","n","s","e","w","u","d"]
recent_set = set(self.recent_actions[-8:]) if self.recent_actions else set()
for m in move_candidates:
if m not in recent_set:
return m
# If stuck, refresh
if "dark" in (observation or "").lower():
# lamp heuristic: often useful in Zork
return "turn on lamp"
return "look"
def _update_failure_memory(self, prev_obs: str, action: str, new_obs: str) -> None:
"""Update strong/soft failed actions for this (state, inv) context."""
sid = self._state_id_from_observation(prev_obs)
key = (sid, self.inv_sig)
verdict = self._classify_failure(prev_obs, action, new_obs)
if verdict == "strong":
self.failed_strong[key].add(action)
# also remove from soft if present
if action in self.failed_soft[key]:
del self.failed_soft[key][action]
elif verdict == "soft":
self.failed_soft[key][action] = self.step
def _log_context_state(self, prev_observation: str, chosen_action: str = "", new_observation: str = ""):
"""Print debug info for context management (before-state bucket)."""
if not self.debug_context:
return
sid_before = self._state_id_from_observation(prev_observation)
sid_after = self._state_id_from_observation(new_observation) if new_observation else ""
key_before = (sid_before, self.inv_sig)
print("\n" + "=" * 60)
print(f"[STEP] {self.step}")
print(f"[STATE_ID_BEFORE] {sid_before}")
if sid_after:
print(f"[STATE_ID_AFTER] {sid_after}")
print(f"[INV_SIG] {self.inv_sig[:8]}...")
if chosen_action:
print(f"[CHOSEN ACTION] {chosen_action}")
# ---- LLM Judge Info ----
if getattr(self, "_last_llm_label", ""):
print(f"[LLM_VERDICT] {self._last_llm_label}")
print(f"[LLM_RAW_OUTPUT] {getattr(self, '_last_llm_raw', '')}")
print(f"[LLM_FROM_CACHE] {getattr(self, '_last_llm_cached', False)}")
# Strong failures (BEFORE bucket)
strong = list(self.failed_strong[key_before])
print(f"[FAILED_STRONG_BEFORE] ({len(strong)})")
for a in strong[:10]:
print(f" - {a}")
# Soft failures (BEFORE bucket)
soft = self.failed_soft[key_before]
print(f"[FAILED_SOFT_BEFORE] ({len(soft)})")
for a, last_step in list(soft.items())[:10]:
cooldown_left = max(0, self.soft_cooldown_steps - (self.step - last_step))
print(f" - {a} (retry in {cooldown_left} steps)")
print(f"[RECENT_ACTIONS] {self.recent_actions[-10:]}")
print("=" * 60 + "\n")
# =============================================================================
# Local Testing
# =============================================================================
async def test_agent():
"""Test the agent locally."""
from fastmcp import Client
agent = StudentAgent()
async with Client("mcp_server.py") as client:
result = await agent.run(
client=client,
game="zork1",
max_steps=20,
seed=42,
verbose=True,
)
print(f"\n{'=' * 50}")
print(f"Final Score: {result.final_score}")
print(f"Moves: {result.moves}")
print(f"Locations: {len(result.locations_visited)}")
if __name__ == "__main__":
import asyncio
asyncio.run(test_agent())