Sammy972's picture
Update agent for less times runnig
68b71bb
"""
Improved MCP Agent for Text Adventures
This agent moves beyond a pure ReAct loop by keeping structured, location-aware
state:
- New-location detection (Jericho-aware with LLM fallback)
- Valid action retrieval per location
- Per-location action logs with summarized outcomes
- Exploration bias to avoid stalling in one place
"""
import json
import os
import re
from dataclasses import dataclass, field
from typing import Any, Optional
from dotenv import load_dotenv
from huggingface_hub import InferenceClient
load_dotenv()
# =============================================================================
# LLM Configuration - DO NOT MODIFY
# =============================================================================
LLM_MODEL = "Qwen/Qwen2.5-72B-Instruct"
_hf_token = os.getenv("HF_TOKEN")
if not _hf_token:
raise ValueError("HF_TOKEN not found. Set it in your .env file.")
LLM_CLIENT = InferenceClient(token=_hf_token)
def call_llm(prompt: str, system_prompt: str, seed: int, max_tokens: int = 220) -> str:
"""Call the LLM with the given prompt."""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
]
response = LLM_CLIENT.chat.completions.create(
model=LLM_MODEL,
messages=messages,
temperature=0.0,
max_tokens=max_tokens,
seed=seed,
)
return response.choices[0].message.content
@dataclass
class RunResult:
"""Result of running the agent. Do not modify this class."""
final_score: int
max_score: int
moves: int
locations_visited: set[str]
game_completed: bool
error: Optional[str] = None
history: list[tuple[str, str, str]] = field(default_factory=list)
MOVEMENT_ACTIONS = {
"north", "south", "east", "west", "up", "down", "enter", "exit",
"n", "s", "e", "w", "u", "d",
}
DEFAULT_ACTIONS = [
"look", "inventory", "north", "south", "east", "west", "up", "down",
]
DANGEROUS_ACTION_PREFIXES = (
"jump",
"leap",
"dive",
"suicide",
)
DANGEROUS_ACTIONS = {
"kill self",
"attack self",
"hit self",
"hurt self",
"stab self",
"cut self",
}
STUCK_THRESHOLD = 4
VALID_ACTION_REFRESH_STUCK_INTERVAL = 3
LOCATION_LOG_LIMIT = 20
PROMPT_ACTIONS_MIN = 20
PROMPT_ACTIONS_MAX = 40
PROMPT_ACTIONS_TARGET = 30
ACTION_SYSTEM_PROMPT = """You are a strong text-adventure planner.
You receive structured context (valid actions, location action history, and stagnation level).
Your task is to choose ONE next game command.
Rules:
1. Prefer untried actions supported by current observation and local history.
2. If STUCK_LEVEL >= 4, strongly prioritize exploration (movement or a new untried action).
3. Avoid repeating an action that already failed in the same location.
4. Avoid obviously dangerous commands such as jump, leap, dive, or self-harm actions.
5. Use concise text-adventure commands only.
6. No markdown.
Respond exactly as:
THOUGHT: <short reasoning>
ACTION: <single game command>
"""
LOCATION_CHANGE_SYSTEM_PROMPT = """Decide if the player entered a new location.
Return JSON only:
{"new_location": true/false, "location_label": "<best label>"}
"""
SUMMARY_SYSTEM_PROMPT = """Summarize an action outcome in one short line.
Focus on progress, obstacle, or reward.
No markdown. Max 16 words.
"""
class StudentAgent:
"""
Location-aware MCP agent.
Main ideas:
- Uses Jericho location when available to detect room transitions
- Pulls valid actions when entering a room
- Maintains per-location action logs
- Uses exploration bias to avoid stalling
"""
def __init__(self):
self._reset_state()
def _reset_state(self) -> None:
self.history: list[dict] = []
self.recent_actions: list[str] = []
self.score: int = 0
self.max_score: int = 0
self.current_location_key: str = "unknown"
self.current_location_label: str = "Unknown"
self.current_display_location: str = "Unknown"
self.current_jericho_location: str = "Unknown"
self.steps_in_current_location: int = 0
self.location_visit_counts: dict[str, int] = {}
self.location_action_log: dict[str, list[dict]] = {}
self.location_valid_actions: dict[str, list[str]] = {}
self.latest_context: dict[str, Any] = {}
self.latest_lookahead: dict[str, Any] = {}
self.refresh_valid_actions_next_step: bool = True
async def run(
self,
client,
game: str,
max_steps: int,
seed: int,
verbose: bool = False,
) -> RunResult:
"""Run the agent for a game session."""
self._reset_state()
locations_visited: set[str] = set()
result_history: list[tuple[str, str, str]] = []
moves = 0
tools = await client.list_tools()
tool_names = [t.name for t in tools]
# Initial observation
initial_result = await client.call_tool("play_action", {"action": "look"})
observation = self._extract_result(initial_result)
self._update_score(observation)
await self._refresh_context(client, tool_names, observation)
memory_text = ""
if "memory" in tool_names:
memory_text = await self._safe_tool_call(client, "memory", {})
self._update_score(memory_text)
self._update_max_score(memory_text)
location_info = await self._get_location_info(
client=client,
observation=observation,
memory_text=memory_text,
tool_names=tool_names,
)
self._set_current_location(location_info, entered_new_location=True)
locations_visited.add(self.current_location_key)
valid_actions = await self._fetch_valid_actions(
client,
self.current_location_key,
tool_names,
force_refresh=True,
)
self.refresh_valid_actions_next_step = False
map_snapshot = ""
if "get_map" in tool_names:
map_snapshot = await self._safe_tool_call(client, "get_map", {})
if verbose:
print(f"\n{observation}")
print(
f"[LOCATION] {self.current_location_label} "
f"(key={self.current_location_key}, display={self.current_display_location}, "
f"jericho={self.current_jericho_location})"
)
for step in range(1, max_steps + 1):
local_log = self.location_action_log.get(self.current_location_key, [])
refresh_valid_actions = self._should_refresh_valid_actions(
location_key=self.current_location_key,
tool_names=tool_names,
)
valid_actions = await self._fetch_valid_actions(
client,
self.current_location_key,
tool_names,
force_refresh=refresh_valid_actions,
)
self.refresh_valid_actions_next_step = False
lookahead = self._disabled_lookahead()
prompt = self._build_prompt(
observation=observation,
memory_text=memory_text,
map_snapshot=map_snapshot,
valid_actions=valid_actions,
local_log=local_log,
step=step,
max_steps=max_steps,
)
llm_response = call_llm(
prompt=prompt,
system_prompt=ACTION_SYSTEM_PROMPT,
seed=seed + step,
max_tokens=220,
)
thought, action = self._parse_action_response(llm_response)
action = self._validate_action(action, valid_actions)
action = self._apply_exploration_bias(
action=action,
location_key=self.current_location_key,
valid_actions=valid_actions,
)
# If the same action is repeated too much, force an untried option.
if len(self.recent_actions) >= 2 and all(a == action for a in self.recent_actions[-2:]):
alternatives = self._get_untried_actions(
self.current_location_key,
valid_actions,
)
if alternatives:
action = alternatives[0]
action = self._avoid_risky_action(
action=action,
valid_actions=valid_actions,
lookahead=lookahead,
)
prev_score = self.score
prev_observation = observation
prev_location_info = {
"location_key": self.current_location_key,
"display_location": self.current_display_location,
"jericho_location": self.current_jericho_location,
}
if verbose:
print(f"\n--- Step {step} ---")
print(f"[THOUGHT] {thought}")
print(f"[ACTION] {action}")
try:
step_result = await client.call_tool("play_action", {"action": action})
observation = self._extract_result(step_result)
except Exception as e:
observation = f"Error: {e}"
self.recent_actions.append(action)
if len(self.recent_actions) > 6:
self.recent_actions = self.recent_actions[-6:]
moves += 1
self._update_score(observation)
await self._refresh_context(client, tool_names, observation)
if "memory" in tool_names:
memory_text = await self._safe_tool_call(client, "memory", {})
self._update_score(memory_text)
self._update_max_score(memory_text)
location_info = await self._get_location_info(
client=client,
observation=observation,
memory_text=memory_text,
tool_names=tool_names,
)
entered_new_location, resolved_location = self._did_enter_new_location(
previous_location=prev_location_info["location_key"],
current_location=location_info["location_key"],
previous_observation=prev_observation,
current_observation=observation,
seed=seed + step,
)
location_info["location_key"] = resolved_location
self._set_current_location(location_info, entered_new_location=entered_new_location)
locations_visited.add(self.current_location_key)
if entered_new_location:
self.refresh_valid_actions_next_step = True
if entered_new_location:
if "get_map" in tool_names:
map_snapshot = await self._safe_tool_call(client, "get_map", {})
else:
if self.steps_in_current_location >= STUCK_THRESHOLD and "get_map" in tool_names:
map_snapshot = await self._safe_tool_call(client, "get_map", {})
outcome_summary = self._summarize_outcome(
action=action,
previous_observation=prev_observation,
current_observation=observation,
seed=seed + step,
)
score_delta = self.score - prev_score
self._append_location_log(
location_key=prev_location_info["location_key"],
action=action,
summary=outcome_summary,
score_delta=score_delta,
)
self.history.append(
{
"step": step,
"location": prev_location_info["location_key"],
"thought": thought,
"action": action,
"summary": outcome_summary,
}
)
if len(self.history) > 15:
self.history = self.history[-15:]
result_history.append((thought, action, observation[:100]))
if verbose:
print(f"[RESULT] {observation[:180]}...")
print(
f"[LOCATION] {self.current_location_label} "
f"(key={self.current_location_key}) | Stuck steps: {self.steps_in_current_location}"
)
print(f"[SCORE] {self.score}/{self.max_score}")
if self._is_game_over(observation):
if verbose:
print("\n*** GAME OVER ***")
break
return RunResult(
final_score=self.score,
max_score=self.max_score,
moves=moves,
locations_visited=locations_visited,
game_completed=self._is_game_over(observation),
history=result_history,
)
async def _safe_tool_call(self, client, tool_name: str, tool_args: dict) -> str:
"""Call an MCP tool and always return text."""
try:
result = await client.call_tool(tool_name, tool_args)
return self._extract_result(result)
except Exception as e:
return f"Error: {e}"
async def _refresh_context(self, client, tool_names: list[str], observation: str) -> None:
"""Refresh structured context from MCP if available."""
if "get_context" not in tool_names:
return
raw = await self._safe_tool_call(client, "get_context", {"limit": 80})
payload = self._parse_json_payload(raw)
if not isinstance(payload, dict):
return
self.latest_context = payload
score = payload.get("score")
max_score = payload.get("max_score")
if isinstance(score, int):
self.score = score
if isinstance(max_score, int) and max_score > 0:
self.max_score = max_score
async def _get_location_info(
self,
client,
observation: str,
memory_text: str,
tool_names: list[str],
) -> dict:
"""Get current location using get_location tool, memory, and observation fallback."""
display_location = "Unknown"
jericho_location = "Unknown"
# Prefer structured context when available.
if self.latest_context:
display_location = str(
self.latest_context.get("display_location")
or self.latest_context.get("location")
or display_location
)
jericho_location = str(self.latest_context.get("jericho_location", jericho_location))
# Direct location tool fallback (when get_context is unavailable/partial).
if "get_location" in tool_names:
raw = await self._safe_tool_call(client, "get_location", {})
payload = self._parse_json_payload(raw)
if isinstance(payload, dict):
display_location = str(payload.get("display_location", display_location) or display_location)
jericho_location = str(payload.get("jericho_location", jericho_location) or jericho_location)
parsed_memory = self._parse_memory(memory_text)
if display_location == "Unknown":
display_location = parsed_memory.get("display_location", "Unknown")
if jericho_location == "Unknown":
jericho_location = parsed_memory.get("jericho_location", "Unknown")
if display_location == "Unknown":
first_line = observation.strip().split("\n")[0] if observation else "Unknown"
display_location = first_line if first_line else "Unknown"
display_location = self._normalize_location(display_location)
jericho_location = self._normalize_location(jericho_location)
location_key = self._location_key(display_location, jericho_location)
location_label = self._humanize_location(
jericho_location if jericho_location != "unknown" else display_location
)
return {
"display_location": display_location,
"jericho_location": jericho_location,
"location_key": location_key,
"location_label": location_label,
}
def _set_current_location(self, location_info: dict, entered_new_location: bool) -> None:
"""Update current location state and stagnation counters."""
self.current_display_location = location_info.get("display_location", "Unknown")
self.current_jericho_location = location_info.get("jericho_location", "Unknown")
self.current_location_key = location_info.get("location_key", "unknown")
self.current_location_label = location_info.get(
"location_label",
self._humanize_location(self.current_location_key),
)
if entered_new_location:
self.steps_in_current_location = 0
self.location_visit_counts[self.current_location_key] = (
self.location_visit_counts.get(self.current_location_key, 0) + 1
)
else:
self.steps_in_current_location += 1
async def _fetch_valid_actions(
self,
client,
location_key: str,
tool_names: list[str],
force_refresh: bool = False,
) -> list[str]:
"""Fetch and cache valid actions for a location."""
if not force_refresh and location_key in self.location_valid_actions:
return self.location_valid_actions[location_key]
actions: list[str] = []
context_actions = self.latest_context.get("valid_actions")
if isinstance(context_actions, list):
actions = [str(v) for v in context_actions]
if "get_valid_actions" in tool_names:
# Keep this as fallback if context is unavailable/incomplete.
if not actions:
raw = await self._safe_tool_call(client, "get_valid_actions", {"limit": 80})
payload = self._parse_json_payload(raw)
if isinstance(payload, dict):
values = payload.get("valid_actions", [])
if isinstance(values, list):
actions = [str(v) for v in values]
if not actions:
actions = DEFAULT_ACTIONS.copy()
normalized: list[str] = []
for action in actions:
norm = self._normalize_action(action)
if norm and norm not in normalized:
normalized.append(norm)
if not normalized:
normalized = DEFAULT_ACTIONS.copy()
self.location_valid_actions[location_key] = normalized
return normalized
def _should_refresh_valid_actions(self, location_key: str, tool_names: list[str]) -> bool:
"""Refresh valid actions only when the cache is missing or likely stale."""
if "get_valid_actions" not in tool_names:
return False
if self.refresh_valid_actions_next_step:
return True
cached = self.location_valid_actions.get(location_key, [])
if not cached:
return True
if not self._get_untried_actions(location_key, cached):
return True
if self.steps_in_current_location >= STUCK_THRESHOLD:
stuck_offset = self.steps_in_current_location - STUCK_THRESHOLD
if stuck_offset % VALID_ACTION_REFRESH_STUCK_INTERVAL == 0:
return True
return False
def _disabled_lookahead(self) -> dict[str, Any]:
"""This variant disables lookahead entirely for Frotz stability."""
payload = {
"enabled": False,
"risky_actions": [],
"entries": [],
}
self.latest_lookahead = payload
return payload
def _did_enter_new_location(
self,
previous_location: str,
current_location: str,
previous_observation: str,
current_observation: str,
seed: int,
) -> tuple[bool, str]:
"""
Determine if the agent entered a new location.
Priority:
1) Jericho/location-key comparison (deterministic)
2) LLM fallback when locations are unknown/ambiguous
"""
prev_norm = self._normalize_location(previous_location)
curr_norm = self._normalize_location(current_location)
if prev_norm != "unknown" and curr_norm != "unknown":
return prev_norm != curr_norm, curr_norm
# LLM fallback for ambiguous cases
prompt = (
f"Previous location key: {previous_location}\n"
f"Current location key: {current_location}\n\n"
f"Previous observation:\n{previous_observation[:1000]}\n\n"
f"Current observation:\n{current_observation[:1000]}\n"
)
try:
response = call_llm(
prompt=prompt,
system_prompt=LOCATION_CHANGE_SYSTEM_PROMPT,
seed=seed,
max_tokens=120,
)
payload = self._parse_json_payload(response)
if isinstance(payload, dict):
is_new = bool(payload.get("new_location", False))
label = str(payload.get("location_label", current_location)).strip()
normalized_label = self._normalize_location(label)
if normalized_label != "unknown":
return is_new, normalized_label
except Exception:
pass
# Last fallback: compare first non-empty line.
prev_line = previous_observation.strip().split("\n")[0].strip().lower() if previous_observation else ""
curr_line = current_observation.strip().split("\n")[0].strip().lower() if current_observation else ""
if prev_line and curr_line and prev_line != curr_line:
fallback_key = curr_norm if curr_norm != "unknown" else self._normalize_location(curr_line)
return True, fallback_key
fallback_key = curr_norm if curr_norm != "unknown" else self._normalize_location(curr_line)
return False, fallback_key
def _summarize_outcome(
self,
action: str,
previous_observation: str,
current_observation: str,
seed: int,
) -> str:
"""Summarize action outcomes for local logs."""
prompt = (
f"Action: {action}\n\n"
f"Before:\n{previous_observation[:500]}\n\n"
f"After:\n{current_observation[:700]}\n"
)
try:
summary = call_llm(
prompt=prompt,
system_prompt=SUMMARY_SYSTEM_PROMPT,
seed=seed,
max_tokens=64,
)
summary = summary.strip().replace("\n", " ")
return summary[:180] if summary else "No clear effect."
except Exception:
cleaned = re.sub(r"\s+", " ", current_observation).strip()
if not cleaned:
return "No clear effect."
return cleaned[:120]
def _append_location_log(self, location_key: str, action: str, summary: str, score_delta: int) -> None:
"""Append an action summary to the per-location log."""
if location_key not in self.location_action_log:
self.location_action_log[location_key] = []
self.location_action_log[location_key].append(
{
"action": self._normalize_action(action),
"summary": summary,
"score_delta": score_delta,
}
)
if len(self.location_action_log[location_key]) > LOCATION_LOG_LIMIT:
self.location_action_log[location_key] = self.location_action_log[location_key][-LOCATION_LOG_LIMIT:]
def _build_prompt(
self,
observation: str,
memory_text: str,
map_snapshot: str,
valid_actions: list[str],
local_log: list[dict],
step: int,
max_steps: int,
) -> str:
"""Build a location-aware decision prompt."""
untried_actions = self._get_untried_actions(self.current_location_key, valid_actions)
prompt_actions = self._select_actions_for_llm(self.current_location_key, valid_actions)
log_lines = []
for entry in local_log[-6:]:
delta_prefix = "+" if entry.get("score_delta", 0) > 0 else ""
log_lines.append(
f"- {entry.get('action')} -> {entry.get('summary')} "
f"(score {delta_prefix}{entry.get('score_delta', 0)})"
)
exploration_hint = (
"HIGH_EXPLORATION_BIAS: yes"
if self.steps_in_current_location >= STUCK_THRESHOLD
else "HIGH_EXPLORATION_BIAS: no"
)
parts = [
f"Step: {step}/{max_steps}",
f"Score: {self.score}/{self.max_score}",
f"Location label: {self.current_location_label}",
f"Location key: {self.current_location_key}",
f"Display location: {self.current_display_location}",
f"Jericho location: {self.current_jericho_location}",
f"STUCK_LEVEL: {self.steps_in_current_location}",
exploration_hint,
"",
"Valid actions (selected):",
", ".join(prompt_actions) if prompt_actions else "(none)",
f"(showing {len(prompt_actions)} of {len(valid_actions)} valid actions)",
"",
"Untried candidates:",
", ".join(untried_actions[:12]) if untried_actions else "(none)",
"",
"Recent outcomes in this location:",
"\n".join(log_lines) if log_lines else "- none yet",
"",
"Current observation:",
observation[:1500],
"",
"Memory snapshot:",
memory_text[:900] if memory_text else "(memory unavailable)",
]
if map_snapshot:
parts.extend(["", "Map snapshot:", map_snapshot[:900]])
parts.extend(
[
"",
"Pick exactly one next command.",
"If HIGH_EXPLORATION_BIAS is yes, avoid staying in place.",
]
)
return "\n".join(parts)
def _select_actions_for_llm(self, location_key: str, valid_actions: list[str]) -> list[str]:
"""Select an informative subset of valid actions instead of first-N truncation."""
normalized: list[str] = []
seen: set[str] = set()
for raw in valid_actions:
action = self._normalize_action(raw)
if not action or action in seen:
continue
seen.add(action)
normalized.append(action)
if not normalized:
return DEFAULT_ACTIONS.copy()
if len(normalized) <= PROMPT_ACTIONS_MAX:
return normalized
target = min(PROMPT_ACTIONS_TARGET, len(normalized))
target = max(min(target, PROMPT_ACTIONS_MAX), PROMPT_ACTIONS_MIN)
local_log = self.location_action_log.get(location_key, [])
tried_counts: dict[str, int] = {}
failed_counts: dict[str, int] = {}
fail_markers = [
"can't",
"cannot",
"nothing happens",
"no effect",
"not possible",
"not here",
"unknown",
"blocked",
"empty",
]
for entry in local_log:
action = self._normalize_action(str(entry.get("action", "")))
if not action:
continue
tried_counts[action] = tried_counts.get(action, 0) + 1
summary = str(entry.get("summary", "")).lower()
delta = int(entry.get("score_delta", 0))
if delta <= 0 and any(marker in summary for marker in fail_markers):
failed_counts[action] = failed_counts.get(action, 0) + 1
interaction_verbs = {"examine", "open", "read", "take", "push", "pull", "put", "turn", "attack"}
generic_actions = {"look", "inventory"}
stuck = self.steps_in_current_location >= STUCK_THRESHOLD
specific_bucket: list[str] = []
movement_bucket: list[str] = []
remaining_bucket: list[str] = []
generic_bucket: list[str] = []
blocked_bucket: list[str] = []
for action in normalized:
verb = action.split()[0]
blocked = failed_counts.get(action, 0) >= 2
if blocked:
blocked_bucket.append(action)
continue
if tried_counts.get(action, 0) == 0 and len(action.split()) >= 2 and verb in interaction_verbs:
specific_bucket.append(action)
continue
if verb in MOVEMENT_ACTIONS:
movement_bucket.append(action)
continue
if action in generic_actions:
generic_bucket.append(action)
continue
remaining_bucket.append(action)
def ordered(items: list[str]) -> list[str]:
return sorted(items, key=lambda a: (tried_counts.get(a, 0), failed_counts.get(a, 0), a))
specific_bucket = ordered(specific_bucket)
movement_bucket = ordered(movement_bucket)
remaining_bucket = ordered(remaining_bucket)
generic_bucket = ordered(generic_bucket)
blocked_bucket = ordered(blocked_bucket)
if stuck:
ranked = movement_bucket + specific_bucket + remaining_bucket + generic_bucket + blocked_bucket
else:
ranked = specific_bucket + movement_bucket + remaining_bucket + generic_bucket + blocked_bucket
selected: list[str] = []
used_verbs: set[str] = set()
def add(action: str) -> None:
if action and action not in selected:
selected.append(action)
# Diversity pass: one action per verb first.
for action in ranked:
verb = action.split()[0]
if verb in used_verbs:
continue
add(action)
used_verbs.add(verb)
if len(selected) >= target:
return selected[:target]
# Fill remaining slots.
for action in ranked:
add(action)
if len(selected) >= target:
return selected[:target]
return selected[:target]
def _parse_action_response(self, response: str) -> tuple[str, str]:
"""Parse LLM response for THOUGHT and ACTION."""
thought = "No reasoning provided."
action = "look"
for line in response.strip().split("\n"):
clean = line.strip()
upper = clean.upper()
if upper.startswith("THOUGHT:"):
thought = clean.split(":", 1)[1].strip()
elif upper.startswith("ACTION:"):
action = clean.split(":", 1)[1].strip()
if not action or action == "look":
payload = self._parse_json_payload(response)
if isinstance(payload, dict):
if "thought" in payload:
thought = str(payload.get("thought", thought))
if "action" in payload:
action = str(payload.get("action", action))
return thought, action
def _validate_action(self, action: str, valid_actions: list[str]) -> str:
"""Normalize action text and repair common invalid verbs."""
action = self._normalize_action(action)
if not action:
return "look"
invalid_verb_map = {
"check": "examine",
"inspect": "examine",
"search": "look",
"grab": "take",
"pick": "take",
"use": "examine",
"investigate": "examine",
}
words = action.split()
if words and words[0] in invalid_verb_map:
words[0] = invalid_verb_map[words[0]]
action = " ".join(words)
if self._is_dangerous_action(action):
return self._fallback_safe_action(valid_actions)
if valid_actions:
normalized_valid = [self._normalize_action(v) for v in valid_actions]
if action in normalized_valid:
return action
# Keep creative commands if they have a common valid verb.
verb = action.split()[0]
if any(v.startswith(verb + " ") or v == verb for v in normalized_valid):
return action
# If nothing matches and the command is too noisy, fallback.
noisy = len(action.split()) > 8 or "?" in action
if noisy:
untried = self._get_untried_actions(self.current_location_key, normalized_valid)
return untried[0] if untried else normalized_valid[0]
return action
def _is_dangerous_action(self, action: str) -> bool:
"""Block a small set of obviously risky commands."""
action = self._normalize_action(action)
if not action:
return False
if action in DANGEROUS_ACTIONS:
return True
return any(action == prefix or action.startswith(prefix + " ") for prefix in DANGEROUS_ACTION_PREFIXES)
def _fallback_safe_action(self, valid_actions: list[str]) -> str:
"""Choose a conservative fallback when an action is blocked."""
normalized_valid = [
self._normalize_action(v)
for v in valid_actions
if self._normalize_action(v) and not self._is_dangerous_action(self._normalize_action(v))
]
untried = self._get_untried_actions(self.current_location_key, normalized_valid)
if untried:
return untried[0]
if normalized_valid:
return normalized_valid[0]
return "look"
def _apply_exploration_bias(
self,
action: str,
location_key: str,
valid_actions: list[str],
) -> str:
"""
Force exploration when stagnating in the same room.
"""
if self.steps_in_current_location < STUCK_THRESHOLD:
return action
tried_counts: dict[str, int] = {}
for entry in self.location_action_log.get(location_key, []):
tried_action = entry.get("action", "")
tried_counts[tried_action] = tried_counts.get(tried_action, 0) + 1
if tried_counts.get(action, 0) >= 2:
untried_movement = [
a for a in self._get_untried_actions(location_key, valid_actions)
if a.split()[0] in MOVEMENT_ACTIONS
]
if untried_movement:
return untried_movement[0]
untried_any = self._get_untried_actions(location_key, valid_actions)
if untried_any:
return untried_any[0]
return action
def _avoid_risky_action(
self,
action: str,
valid_actions: list[str],
lookahead: Optional[dict[str, Any]],
) -> str:
"""Replace only actions flagged as lethal/game-over by lookahead."""
if not lookahead or not lookahead.get("enabled"):
return action
risky_set = {
self._normalize_action(a) for a in lookahead.get("risky_actions", [])
if self._normalize_action(a)
}
action = self._normalize_action(action)
if not action or action not in risky_set:
return action
normalized_valid: list[str] = []
seen: set[str] = set()
for raw in valid_actions:
candidate = self._normalize_action(raw)
if candidate and candidate not in seen and candidate not in risky_set:
seen.add(candidate)
normalized_valid.append(candidate)
untried_valid = self._get_untried_actions(self.current_location_key, normalized_valid)
if untried_valid:
return untried_valid[0]
if normalized_valid:
return normalized_valid[0]
return "look"
def _get_untried_actions(self, location_key: str, candidates: list[str]) -> list[str]:
"""Return candidate actions not yet tried in this location."""
tried = {entry.get("action", "") for entry in self.location_action_log.get(location_key, [])}
untried: list[str] = []
for action in candidates:
norm = self._normalize_action(action)
if norm and norm not in tried and norm not in untried:
untried.append(norm)
return untried
def _parse_json_payload(self, text: str):
"""Parse JSON from plain text or fenced block."""
if not text:
return None
cleaned = text.strip()
if cleaned.startswith("```"):
cleaned = re.sub(r"^```(?:json)?\s*|\s*```$", "", cleaned, flags=re.DOTALL).strip()
try:
return json.loads(cleaned)
except json.JSONDecodeError:
pass
obj_match = re.search(r"\{.*\}", cleaned, flags=re.DOTALL)
if obj_match:
try:
return json.loads(obj_match.group(0))
except json.JSONDecodeError:
pass
arr_match = re.search(r"\[.*\]", cleaned, flags=re.DOTALL)
if arr_match:
try:
return json.loads(arr_match.group(0))
except json.JSONDecodeError:
pass
return None
def _parse_memory(self, memory_text: str) -> dict:
"""Parse key fields from memory() output."""
parsed = {
"display_location": "Unknown",
"jericho_location": "Unknown",
"score": None,
"max_score": None,
"moves": None,
}
if not memory_text:
return parsed
display_match = re.search(r"-\s*(?:Display\s+)?Location:\s*(.+)", memory_text)
jericho_match = re.search(r"-\s*Jericho\s+Location:\s*(.+)", memory_text)
score_match = re.search(r"-\s*Score:\s*(\d+)", memory_text)
max_score_match = re.search(r"-\s*Max\s+Score:\s*(\d+)", memory_text)
moves_match = re.search(r"-\s*Moves:\s*(\d+)", memory_text)
if display_match:
parsed["display_location"] = display_match.group(1).strip()
if jericho_match:
parsed["jericho_location"] = jericho_match.group(1).strip()
if score_match:
parsed["score"] = int(score_match.group(1))
if max_score_match:
parsed["max_score"] = int(max_score_match.group(1))
if moves_match:
parsed["moves"] = int(moves_match.group(1))
return parsed
def _extract_result(self, result) -> str:
"""Extract text from MCP tool result."""
if hasattr(result, "content") and result.content:
return result.content[0].text
if isinstance(result, list) and result:
return result[0].text if hasattr(result[0], "text") else str(result[0])
return str(result)
def _normalize_action(self, action: str) -> str:
"""Normalize action formatting."""
if not action:
return ""
action = action.lower().strip()
action = action.replace("**", "").replace("*", "").replace("`", "")
action = re.sub(r"\s+", " ", action)
return action
def _normalize_location(self, location: str) -> str:
"""Normalize location keys."""
if not location:
return "unknown"
value = location.strip().lower()
value = value.replace("`", "")
# Jericho can return verbose object dumps such as:
# "obj128: statue room parent0 sibling0 child90 attributes [...]"
if re.match(r"^obj\d+\s*:", value):
value = value.split(":", 1)[1].strip()
value = re.sub(r"^obj\d+\b[:\s-]*", "", value).strip()
value = re.split(r"\s+parent\d+\b", value)[0].strip()
value = re.split(r"\s+sibling\d+\b", value)[0].strip()
value = re.split(r"\s+child\d+\b", value)[0].strip()
value = re.split(r"\s+attributes?\b", value)[0].strip()
value = re.split(r"\s+properties?\b", value)[0].strip()
value = re.sub(r"[\[\]\(\)\{\}]", " ", value)
value = re.sub(r"\s+", " ", value).strip()
if value in {"unknown", "none", "null", "n/a", ""}:
return "unknown"
return value
def _location_key(self, display_location: str, jericho_location: str) -> str:
"""Prefer Jericho location as stable room identity."""
jericho = self._normalize_location(jericho_location)
if jericho != "unknown":
return jericho
return self._normalize_location(display_location)
def _humanize_location(self, location: str) -> str:
"""Convert normalized location keys to readable labels for logs."""
normalized = self._normalize_location(location)
if normalized == "unknown":
return "Unknown"
return " ".join(word.capitalize() for word in normalized.split())
def _extract_score_from_text(self, text: str) -> Optional[int]:
"""Extract score from various text formats."""
patterns = [
r"\[Score:\s*(\d+)\s*\|",
r"\(Total:\s*(\d+)\)",
r"(?:^|\n)\s*-\s*Score:\s*(\d+)\b",
r"(?:^|\n)\s*Score:\s*(\d+)\b",
r"\bscore\s+is\s+(\d+)\b",
]
for pattern in patterns:
match = re.search(pattern, text, flags=re.IGNORECASE)
if match:
return int(match.group(1))
return None
def _update_score(self, text: str) -> None:
"""Update current score from text when available."""
score = self._extract_score_from_text(text)
if score is not None:
self.score = score
def _update_max_score(self, text: str) -> None:
"""Update max score from text when available."""
match = re.search(r"-\s*Max\s+Score:\s*(\d+)", text, flags=re.IGNORECASE)
if match:
self.max_score = int(match.group(1))
def _is_game_over(self, text: str) -> bool:
"""Check if the game is over."""
game_over_phrases = [
"game over",
"you have died",
"you are dead",
"*** you have died ***",
"*** the end ***",
]
lowered = text.lower()
return any(phrase in lowered for phrase in game_over_phrases)
async def test_agent():
"""Test the agent locally."""
from fastmcp import Client
agent = StudentAgent()
async with Client("mcp_server.py") as client:
result = await agent.run(
client=client,
game="zork1",
max_steps=20,
seed=42,
verbose=True,
)
print(f"\n{'=' * 50}")
print(f"Final Score: {result.final_score}/{result.max_score}")
print(f"Moves: {result.moves}")
print(f"Locations: {len(result.locations_visited)}")
if __name__ == "__main__":
import asyncio
asyncio.run(test_agent())