VZ22's picture
cleaned and added README
748ada7
"""
Example: MCP ReAct Agent
A complete ReAct agent that uses MCP tools to play text adventure games.
This is a working example students can learn from.
"""
import json
import os
import re
from dataclasses import dataclass, field
from typing import Optional
import numpy as np
from dotenv import load_dotenv
from huggingface_hub import InferenceClient
load_dotenv()
# Set USE_LOCAL_MODEL=1 in your .env to use a locally downloaded model
USE_LOCAL_MODEL = os.getenv("USE_LOCAL_MODEL", "0").strip() in ("1", "true", "yes")
LOCAL_MODEL_ID = os.getenv("LOCAL_MODEL_ID", "Qwen/Qwen2.5-3B-Instruct")
# =============================================================================
# LLM Configuration - DO NOT MODIFY
# =============================================================================
LLM_MODEL ="Qwen/Qwen2.5-72B-Instruct"
# Initialize the LLM client based on mode
_local_pipeline = None
if USE_LOCAL_MODEL:
import torch
from transformers import pipeline as _hf_pipeline
_local_pipeline = _hf_pipeline(
"text-generation",
model=LOCAL_MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
)
LLM_CLIENT = None
else:
_hf_token = os.getenv("HF_TOKEN")
if not _hf_token:
raise ValueError("HF_TOKEN not found. Set it in your .env file.")
LLM_CLIENT = InferenceClient(token=_hf_token)
llm_call_count = 0 # For tracking number of LLM calls (optional)
def call_llm(prompt: str, system_prompt: str, seed: int, max_tokens: int = 300) -> str:
"""
Call the LLM with the given prompt. Use this function in your agent.
Args:
prompt: The user prompt (current game state, history, etc.)
system_prompt: The system prompt (instructions for the agent)
seed: Random seed for reproducibility
max_tokens: Maximum tokens in response (default: 300)
Returns:
The LLM's response text
Example:
response = call_llm(
prompt="You are in a forest. What do you do?",
system_prompt=SYSTEM_PROMPT,
seed=42,
)
"""
global llm_call_count
llm_call_count += 1
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
]
if USE_LOCAL_MODEL and _local_pipeline is not None:
outputs = _local_pipeline(
messages,
max_new_tokens=max_tokens,
temperature=0.0001, # Near-deterministic (0.0 unsupported by some backends)
do_sample=True,
max_length=None,
)
return outputs[0]["generated_text"][-1]["content"]
response = LLM_CLIENT.chat.completions.create(
model=LLM_MODEL,
messages=messages,
temperature=0.0, # Deterministic for reproducibility
max_tokens=max_tokens,
seed=seed,
)
return response.choices[0].message.content
def levenshtein(a,b,ratio=False,print_matrix=False,lowercase=False) :
# code copied from https://github.com/jamfromouterspace/levenshtein/blob/master/levenshtein.py
if type(a) != type('') :
raise TypeError('First argument is not a string!')
if type(b) != type('') :
raise TypeError('Second argument is not a string!')
if a == '' :
return len(b)
if b == '' :
return len(a)
if lowercase :
a = a.lower()
b = b.lower()
n = len(a)
m = len(b)
lev = np.zeros((n+1,m+1))
for i in range(0,n+1) :
lev[i,0] = i
for i in range(0,m+1) :
lev[0,i] = i
for i in range(1,n+1) :
for j in range(1,m+1) :
insertion = lev[i-1,j] + 1
deletion = lev[i,j-1] + 1
substitution = lev[i-1,j-1] + (1 if a[i-1]!= b[j-1] else 0)
lev[i,j] = min(insertion,deletion,substitution)
if print_matrix :
print(lev)
if ratio :
return (n+m-lev[n,m])/(n+m)
else :
return lev[n,m]
@dataclass
class RunResult:
"""Result of running the agent. Do not modify this class."""
final_score: int
max_score: int
moves: int
locations_visited: set[str]
game_completed: bool
error: Optional[str] = None
history: list[tuple[str, str, str]] = field(default_factory=list)
# =============================================================================
# System Prompt
# =============================================================================
SYSTEM_PROMPT = """You are an expert text adventure game player. Your goal is to explore, collect treasures, and maximize your score.
AVAILABLE TOOLS (use these via MCP):
1. play_action - Execute game commands (north, take lamp, open mailbox, etc.)
2. memory - Get current game state, score, and recent history
3. get_map - See explored locations and connections
4. inventory - Check what you're carrying
VALID GAME COMMANDS for play_action:
- Movement: north, south, east, west, up, down, enter, exit
- Objects: take <item>, drop <item>, open <thing>, close <thing>, examine <thing>
- Light: turn on lamp, turn off lamp
- Combat: attack <enemy> with <weapon>
- Other: inventory, look, read <thing>, wait, listen, look inside <container>, blow <object>, follow <creature>, climb <object>, drink <liquid>, eat <food>
FORBIDDEN (will NOT work): check, inspect, search, grab, use, help
RESPOND IN THIS EXACT FORMAT (no markdown):
THOUGHT: <brief reasoning about what to do next>
TOOL: <tool_name>
ARGS: <JSON arguments>
Examples:
THOUGHT: I need to see what's around me.
TOOL: play_action
ARGS: {"action": "look"}
THOUGHT: Let me check my current state and score.
TOOL: memory
ARGS: {}
THOUGHT: The mailbox might contain something useful.
TOOL: play_action
ARGS: {"action": "open mailbox"}
STRATEGY:
1. Start by looking around and checking memory
2. Examine everything - look at items, containers, and surroundings
3. Explore systematically - try all directions
4. Pick up useful items (lamp, sword, etc.)
5. Open containers (mailbox, window, etc.)
6. Use get_map to avoid getting lost
"""
# =============================================================================
# Student Agent Implementation
# =============================================================================
class StudentAgent:
"""
MCP ReAct Agent - A complete working example.
This agent demonstrates:
- ReAct loop (Thought -> Tool -> Observation)
- Loop detection
- Action validation
- Score tracking via memory tool
"""
def __init__(self):
"""Initialize the agent state."""
self.history: list[dict] = []
self.score: int = 0
self.history_state_tried_action = {}
self.location_state = {} # to each location, we have a set of every observation made here
self.idle_actions = ["listen", "wait", "diagnose", "yell", "pray", "launch", "take all"] # Actions that don't change location
self.map_size = 20
self.internal_map = [[["Unknown" for i in range(5)] for j in range(self.map_size)] for k in range(self.map_size)] # Internal map representation
self.position = (self.map_size//2, self.map_size//2, 2)
# Start at the middle in the internal map, we suppose the map is in 3D (taking into account up and down movements)
self.directions = {"north": (0, -1, 0), "south": (0, 1, 0), "east": (1, 0, 0), "west": (-1, 0, 0), "up": (0, 0, 1), "down": (0, 0, -1),
"northeast": (1, -1, 0), "northwest": (-1, -1, 0), "southeast": (1, 1, 0), "southwest": (-1, 1, 0)}
async def run(
self,
client,
game: str,
max_steps: int,
seed: int,
verbose: bool = False,
) -> RunResult:
"""Run the agent for a game session."""
locations_visited = set()
history = []
moves = 0
# Get list of available tools
tools = await client.list_tools()
tool_names = [t.name for t in tools]
# Get initial observation
result = await client.call_tool("play_action", {"action": "look"})
observation = self._extract_result(result)
observation = observation.strip() if observation else "No observation"
# Track initial location
location = await client.call_tool("current_location", {})
location = self._extract_result(location)
locations_visited.add(location)
if verbose:
print(f"Starting game: {game}")
print(f"\n{observation}")
print(f"\nAvailable tools: {tool_names}")
last_location = location
current_location = last_location
self.internal_map[self.position[0]][self.position[1]][self.position[2]] = current_location
old_state = await client.call_tool("last_observation", {})
old_state = self._extract_result(old_state)
current_state = old_state
tried_action_in_same_state = [("play_action", {"action": "look"})]
self.location_state[current_location] = set()
self.location_state[current_location].add(current_state)
look_observation = observation.lower().strip()
import pdb
# Main ReAct loop
for step in range(1, max_steps + 1):
# Build prompt with context
# pdb.set_trace()
global llm_call_count
if llm_call_count > 1.5*max_steps:
if verbose:
print(f"[WARNING] You've made {llm_call_count} LLM calls, which is quite high for {step} steps.")
break
old_state = current_state
if current_location != last_location:
print(f"[DEBUG] Moved to new location: {current_location}. Resetting tried actions for this state.")
observation += f"\n[INFO] You have moved from {last_location} to a new location: {current_location}."
if current_location in locations_visited and current_state in self.location_state.get(current_location, set()):
observation += " You've been here before, read the observation carefully, is it new? If not return where you came."
else:
observation += " Be thourough, examine everything around you and try to find all treasures and points of interest! Also remember your objective"
locations_visited.add(current_location)
prompt = self._build_prompt(observation)
prompt += self._look_for_neighboring_locations(prompt)
prompt = self._add_useless_actions_to_prompt(prompt, tried_action_in_same_state)
# Call LLM for reasoning (use step-based seed for variety)
response = call_llm(prompt, SYSTEM_PROMPT, seed + step)
# Parse the response
thought, tool_name, tool_args = self._parse_response(response, tool_names, verbose)
# Validate and fix common issues
tool_name, tool_args = self._validate_tool_call(tool_name, tool_args, tool_names)
loop_count = 0
while (tool_name, tool_args) in tried_action_in_same_state or (tool_name == "memory" and step < 5) and loop_count < 5:
loop_count += 1
if (tool_name, tool_args) in tried_action_in_same_state:
if verbose:
print(f"[WARNING] You've been trying the same action {tool_name} with args {tool_args} in the same state without success.")
new_prompt = prompt + response + "\n[WARNING: You've been trying the same action without success. Try a different approach!]"
response = call_llm(new_prompt, SYSTEM_PROMPT, seed + step + 100)
elif tool_name == "memory" and step < 5:
if verbose:
print("[INFO] Early in the game, it's better to explore than to check memory. Forcing an idle action to encourage exploration.")
new_prompt = prompt + response + "\n[INFO: Early in the game, it's better to explore. Try something else!]"
response = call_llm(new_prompt, SYSTEM_PROMPT, seed + step + 100)
# Parse the response
thought, tool_name, tool_args = self._parse_response(response, tool_names, verbose)
# Validate and fix common issues
tool_name, tool_args = self._validate_tool_call(tool_name, tool_args, tool_names)
if verbose:
print(f"[FINAL DECISION] {tool_name}({tool_args}) after {loop_count} loops to find a new action.")
# Loop detection
if tool_name == "play_action":
action = tool_args.get("action", "look")
# Detect loops
if (len(tried_action_in_same_state) >= 5) and step % 3 == 0:
actions_to_cycle = [a for a in self.idle_actions if ("play_action", {"action": a}) not in tried_action_in_same_state]
actions_to_cycle = actions_to_cycle + [direction for direction in self.directions.keys() if ("play_action", {"action": direction}) not in tried_action_in_same_state]
idx_random = (seed + 571*step) % len(actions_to_cycle)
action_forced = actions_to_cycle[idx_random]
if verbose:
print(f"[WARNING] Loop detected - forcing an random action to break the cycle.")
tool_args = {"action": action_forced} # Force an idle action
moves += 1
if verbose:
print(f"\n--- Step {step} ---")
print(f"[THOUGHT] {thought}")
print(f"[TOOL] {tool_name}({tool_args})")
not_new_state = False
# Execute the tool
try:
result = await client.call_tool(tool_name, tool_args)
observation = self._extract_result(result)
# Look if we got the same observation as for a "look"
current_obs = await client.call_tool("last_observation", {}) # observation also has the score
current_obs = self._extract_result(current_obs)
tried_action_in_same_state.append((tool_name, tool_args))
if verbose:
print(f"[RESULT] {observation}...")
except Exception as e:
observation = f"Error: {e}"
if verbose:
print(f"[ERROR] {e}")
if tool_args.get("action", "").lower() == "look":
look_observation = current_obs.lower()
elif levenshtein(look_observation, current_obs, ratio=True) > 0.8:
not_new_state = True
# Track location
location = await client.call_tool("current_location", {})
location = self._extract_result(location)
print(f"[DEBUG] Current location: {location}")
last_location = current_location
current_location = location
if current_location != last_location:
tried_action_in_same_state.pop() # If we moved, the action is not useless
# Otherwise we might get stuck
tried_action_in_same_state, current_state = self._update_history_state(tried_action_in_same_state, current_state, current_obs, verbose)
# Update position
action = tool_args.get("action", "").lower()
direction_curr = ""
directions_abreviations = {"n": "north", "s": "south", "e": "east", "w": "west", "u": "up", "d": "down",
"ne": "northeast", "nw": "northwest", "se": "southeast", "sw": "southwest"}
dx, dy, dz = 0, 0, 0
if action in self.directions:
dx, dy, dz = self.directions[action]
direction_curr = action
elif action in directions_abreviations:
direction_curr = directions_abreviations[action]
dx, dy, dz = self.directions[action]
if direction_curr != "down" and direction_curr != "" and "fall down" in observation.lower():
dz -= 1
new_position = (self.position[0] + dx, self.position[1] + dy, self.position[2] + dz)
if 0 <= new_position[0] < self.map_size and 0 <= new_position[1] < self.map_size and 0 <= new_position[2] < 5:
if current_location != last_location:
if verbose:
print(f"[DEBUG] Moving {direction_curr} to new location on new position ({new_position}): {current_location}. Updating internal map.")
self.internal_map[new_position[0]][new_position[1]][new_position[2]] = current_location
elif new_position != self.position:
self.internal_map[new_position[0]][new_position[1]][new_position[2]] = "Inaccessible"
self.position = new_position
else:
print(f"[DEBUG] New position {new_position} is out of bounds. Not updating position.")
# Update history
self.history.append({
"step": step,
"thought": thought,
"tool": tool_name,
"args": tool_args,
"result": observation[:200]
})
if len(self.history) > 10:
self.history = self.history[-10:]
# Track score from observation
self._update_score(observation)
# Record in result history
history.append((thought, f"{tool_name}({tool_args})", observation))
if "!" in observation.lower() and current_obs not in self.location_state.get(current_location, set()) and not not_new_state:
# first time seeing this observation in this location and it has an exclamation mark, it might be important
if verbose:
print(f"[EXCLAMATION] The observation contains an exclamation mark, which might indicate an important event!")
observation += " Something important just happened! Pay attention to this! If you are unsure of the action just do an idle action (look, listen, wait). "
tried_action_in_same_state, current_state = self._update_history_state(tried_action_in_same_state, current_state, current_obs, verbose)
if len(tried_action_in_same_state) > 5:
observation += f"\n[INFO] You've tried {len(tried_action_in_same_state)} different actions in this state. Consider finding new locations to explore!"
# Check for game over
if self._is_game_over(observation):
if verbose:
print("\n*** GAME OVER ***")
break
if current_location in self.location_state:
self.location_state[current_location].add(current_obs)
else:
self.location_state[current_location] = set([current_obs])
print(f"\n[FINAL SCORE] {self.score} after {moves} moves and visiting {len(locations_visited)} locations.")
print(f"The locations are: {', '.join(locations_visited)}")
print(f"Have visited states: {len(self.history_state_tried_action)}")
print(f"The sates are: \n {'\nState:\n'.join(list(self.history_state_tried_action.keys())[-5:])}")
return RunResult(
final_score=self.score,
max_score=350,
moves=moves,
locations_visited=locations_visited,
game_completed=self._is_game_over(observation),
history=history,
)
def _build_prompt(self, observation: str) -> str:
"""Build the prompt for the LLM with context."""
parts = []
parts.append(f"Current Score: {self.score}")
# Recent history
if self.history:
parts.append("\nRecent actions:")
for entry in self.history[-3:]:
action = entry.get("args", {}).get("action", entry["tool"])
result_short = entry["result"][:80] + "..." if len(entry["result"]) > 80 else entry["result"]
parts.append(f" > {action} -> {result_short}")
parts.append(f"\nCurrent situation:\n{observation}")
parts.append("\nWhat do you do next?")
return "\n".join(parts)
def _update_history_state(self, current_action_state: list, current_state: str, new_state: str, verbose: bool) -> list:
if verbose:
print(f"[DEBUG] Updating history state.")
self.history_state_tried_action[current_state] = current_action_state.copy()
current_state = new_state
if current_state not in self.history_state_tried_action:
current_action_state = []
neigh_coord = [(1,0,0), (-1,0,0), (0,1,0), (0,-1,0), (0,0,1), (0,0,-1), (1,1,0), (1,-1,0), (-1,1,0), (-1,-1,0)]
for dx, dy, dz in neigh_coord:
neighbor_pos = (self.position[0] + dx, self.position[1] + dy, self.position[2] + dz)
if 0 <= neighbor_pos[0] < self.map_size and 0 <= neighbor_pos[1] < self.map_size and 0 <= neighbor_pos[2] < self.map_size:
neighbor_location = self.internal_map[neighbor_pos[0]][neighbor_pos[1]][neighbor_pos[2]]
if neighbor_location != "Inaccessible":
# After a new state, the map might be updated with new information
self.internal_map[neighbor_pos[0]][neighbor_pos[1]][neighbor_pos[2]] = "Unknown"
else:
current_action_state = self.history_state_tried_action[current_state].copy()
return current_action_state, current_state
def _find_location(self, observation: str, default: str) -> str:
"""Extract location from observation."""
paragraphs = observation.split("\n")
for para in paragraphs:
if not ("." in para or "!" in para or "?" in para or "[" in para) and para.strip() != "":
return para.strip()
return default
def _add_useless_actions_to_prompt(self, prompt: str, useless_actions: list) -> str:
s = "You have tried these actions in the same state, DO NOT REPEAT THESE ACTIONS:"
for t, a in useless_actions:
s += f"> {t}({a}) "
new_prompt = prompt + f"\n[INFO: Recent tried actions in this state:{s}] \n You've tried these actions multiple times. BE CREATIVE and consider trying something different!]"
return new_prompt
def _look_for_neighboring_locations(self, prompt:str) -> list[str]:
s = "[INFO] Our neighbors are: "
for dir, (dx, dy, dz) in self.directions.items():
neighbor_pos = (self.position[0] + dx, self.position[1] + dy, self.position[2] + dz)
if 0 <= neighbor_pos[0] < self.map_size and 0 <= neighbor_pos[1] < self.map_size and 0 <= neighbor_pos[2] < self.map_size:
if self.internal_map[neighbor_pos[0]][neighbor_pos[1]][neighbor_pos[2]] != "Unknown":
s += f"<{dir}> ({self.internal_map[neighbor_pos[0]][neighbor_pos[1]][neighbor_pos[2]]}), "
else:
s += f"<{dir}> (Unknown), "
return s
def _parse_response(self, response: str, valid_tools: list[str], verbose:bool) -> tuple[str, str, dict]:
"""Parse the LLM response to extract thought, tool, and arguments."""
thought = "No reasoning provided"
tool_name = "play_action"
tool_args = {"action": "look"}
lines = response.strip().split("\n")
for line in lines:
line_clean = line.strip()
line_upper = line_clean.upper()
if line_upper.startswith("THOUGHT:"):
thought = line_clean.split(":", 1)[1].strip()
elif line_upper.startswith("TOOL:"):
raw_tool = line_clean.split(":", 1)[1].strip().lower()
raw_tool = raw_tool.replace("**", "").replace("*", "").replace("`", "")
raw_tool = raw_tool.split()[0] if raw_tool else "play_action"
tool_name = raw_tool
elif line_upper.startswith("ARGS:"):
args_part = line_clean.split(":", 1)[1].strip()
try:
args_part = args_part.replace("'", '"')
tool_args = json.loads(args_part)
except json.JSONDecodeError:
match = re.search(r'"action"\s*:\s*"([^"]+)"', args_part)
if match:
tool_args = {"action": match.group(1)}
else:
tool_args = {"action": "look"}
return thought, tool_name, tool_args
def _validate_tool_call(self, tool_name: str, tool_args: dict, valid_tools: list[str]) -> tuple[str, dict]:
"""Validate and fix common tool call issues."""
# Fix tool name
if tool_name not in valid_tools:
if tool_name in ["action", "do", "command"]:
tool_name = "play_action"
elif tool_name in ["map", "location"]:
tool_name = "get_map"
elif tool_name in ["mem", "state", "status"]:
tool_name = "memory"
elif tool_name in ["inv", "items"]:
tool_name = "inventory"
else:
tool_name = "play_action"
# Fix action verbs
if tool_name == "play_action":
action = tool_args.get("action", "look")
invalid_verb_map = {
"check": "examine",
"inspect": "examine",
"search": "look",
"grab": "take",
"pick": "take",
"use": "examine",
"investigate": "examine",
"look around": "look",
}
words = action.lower().split()
if words and words[0] in invalid_verb_map:
words[0] = invalid_verb_map[words[0]]
action = " ".join(words)
if "go" in action:
action = action.split(" ", 1)[-1] # Take the direction after "go"
action = action.lower().strip()
action = action.replace("**", "").replace("*", "").replace("`", "")
action = " ".join(action.split())
tool_args["action"] = action
return tool_name, tool_args
def _extract_result(self, result) -> str:
"""Extract text from MCP tool result."""
if hasattr(result, 'content') and result.content:
return result.content[0].text
if isinstance(result, list) and result:
return result[0].text if hasattr(result[0], 'text') else str(result[0])
return str(result)
def _update_score(self, text: str) -> None:
"""Update score from game text."""
patterns = [
r'Score:\s*(\d+)',
r'score[:\s]+(\d+)',
r'\[Score:\s*(\d+)',
]
for pattern in patterns:
match = re.search(pattern, text, re.IGNORECASE)
if match:
self.score = max(self.score, int(match.group(1)))
def _is_game_over(self, text: str) -> bool:
"""Check if the game is over."""
game_over_phrases = [
"game over",
"you have died",
"you are dead",
"*** you have died ***",
]
text_lower = text.lower()
return any(phrase in text_lower for phrase in game_over_phrases)
# =============================================================================
# Local Testing
# =============================================================================
async def test_agent():
"""Test the agent locally."""
from fastmcp import Client
agent = StudentAgent()
async with Client("mcp_server.py") as client:
result = await agent.run(
client=client,
game="zork1",
max_steps=20,
seed=42,
verbose=True,
)
print(f"\n{'=' * 50}")
print(f"Final Score: {result.final_score}")
print(f"Moves: {result.moves}")
print(f"Locations: {len(result.locations_visited)}")
if __name__ == "__main__":
import asyncio
asyncio.run(test_agent())