test1 / agent.py
bouhss's picture
Update agent.py
d9d90c6 verified
"""
Exploration-first hybrid ReAct agent (score + locations) for text adventures.
Key points:
- Deterministic policy driven by server status() JSON.
- ReAct loop explicit each step: THOUGHT -> TOOL(play_action) -> OBSERVATION
- Priority:
A) Valid untried exits (Jericho-validated) + obs-boosted directions
B) Bounded suggested_interactions (game-validated)
C) BFS backtrack to nearest frontier (room with untried exits)
D) Stuck recovery (look/inventory/examine noun)
E) Optional single LLM fallback if HF_TOKEN is present (never required)
- Uses peek_action (if available) to score a small candidate set quickly.
- All verbose/debug output goes to stderr only.
"""
import json
import os
import re
import sys
from collections import deque
from dataclasses import dataclass, field
from typing import Optional
from dotenv import load_dotenv
from huggingface_hub import InferenceClient
load_dotenv()
# =============================================================================
# LLM Configuration (fixed model for fairness)
# =============================================================================
LLM_MODEL = "Qwen/Qwen2.5-72B-Instruct"
_hf_token = os.getenv("HF_TOKEN")
LLM_CLIENT = InferenceClient(token=_hf_token) if _hf_token else None
def call_llm(prompt: str, system_prompt: str, seed: int, max_tokens: int = 120) -> str:
"""LLM call used only as last-resort fallback (optional)."""
if LLM_CLIENT is None:
raise RuntimeError("HF_TOKEN missing => LLM unavailable")
r = LLM_CLIENT.chat.completions.create(
model=LLM_MODEL,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
],
temperature=0.0,
max_tokens=max_tokens,
seed=seed,
)
return r.choices[0].message.content or ""
@dataclass
class RunResult:
final_score: int
max_score: int
moves: int
locations_visited: set
game_completed: bool
error: Optional[str] = None
history: list[tuple[str, str, str]] = field(default_factory=list)
# =============================================================================
# Tunables
# =============================================================================
MAX_INTERACTIONS = 4
STUCK_THRESHOLD = 10
MEMORY_LEN = 20
PEEK_K = 6 # lower if too slow; higher can improve decisions but costs time
UNSAFE_STARTS = (
"burn ", "set fire", "ignite ",
"attack ", "kill ", "hit ", "stab ", "shoot ", "punch ", "fight ",
"destroy ", "break ", "smash ",
"eat ",
)
DIR_WORD_RE = re.compile(
r"\b(north(?:east|west)?|south(?:east|west)?|east|west|"
r"northeast|northwest|southeast|southwest|up|down|in|out)\b",
re.IGNORECASE,
)
DISAMBIG_RE = re.compile(
r"which do you mean|do you mean|be more specific|what do you want",
re.IGNORECASE,
)
OPTION_RE = re.compile(r"\bthe\s+([a-z]+(?:\s+[a-z]+)?)", re.IGNORECASE)
LLM_SYSTEM = (
"You play a text adventure game. Propose ONE action (<= 5 words) that helps "
"explore a new location or gain points. Reply with exactly one line:\n"
"ACTION: <command>"
)
class StudentAgent:
def __init__(self) -> None:
self.visited: set[int] = set()
self.graph: dict[int, dict[str, int]] = {}
self.loc_untried: dict[int, list[str]] = {}
self.interactions_done: dict[int, int] = {}
# recent_memory: (action, loc_id, score_before, obs_snip_after)
self.recent_memory = deque(maxlen=MEMORY_LEN)
self.no_progress_steps = 0
self.llm_calls = 0
self.last_action = ""
# =============================================================================
# ReAct loop
# =============================================================================
async def run(self, client, game: str, max_steps: int, seed: int, verbose: bool = False) -> RunResult:
history: list[tuple[str, str, str]] = []
moves_taken = 0
final_score = 0
max_score = 0
game_completed = False
last_status = {}
tools = await client.list_tools()
tool_names = {t.name for t in tools}
has_peek = "peek_action" in tool_names
# Initial observation
init_obs = await client.call_tool("play_action", {"action": "look"})
moves_taken += 1
self.last_action = "look"
history.append((
"THOUGHT: Start by looking around to ground the state.",
"TOOL: play_action ARGS: {'action': 'look'}",
self._text(init_obs)[:160],
))
prev_score = 0
prev_loc = -1
while moves_taken < max_steps:
# Observation/telemetry (does not consume moves)
try:
raw = await client.call_tool("status", {})
status = json.loads(self._text(raw))
last_status = status
except Exception:
status = last_status
if not status:
# Emergency fallback
thought = "THOUGHT: Status unavailable; use a safe action to recover."
tool_call = "TOOL: play_action ARGS: {'action': 'look'}"
res = await client.call_tool("play_action", {"action": "look"})
moves_taken += 1
obs_txt = self._text(res)
history.append((thought, tool_call, obs_txt[:160]))
continue
loc_id = int(status["loc_id"])
score = int(status.get("score", 0))
final_score = score
max_score = int(status.get("max_score", max_score) or max_score)
done = bool(status.get("done", False))
self.visited.add(loc_id)
self._merge_edges(loc_id, status.get("edges_here", {}) or {})
self.loc_untried[loc_id] = list(status.get("untried_directions", []) or [])
if score == prev_score and loc_id == prev_loc:
self.no_progress_steps += 1
else:
self.no_progress_steps = 0
prev_score, prev_loc = score, loc_id
if done:
game_completed = True
break
# Decide next action (deterministic heuristics + optional LLM fallback)
thought_reason, action = self._decide(status, seed)
# Optional look-ahead improvement
if has_peek:
action = await self._peek_pick(client, status, action)
action = self._sanitize_action(action)
# ReAct record (explicit)
thought = f"THOUGHT: {thought_reason}"
tool_call = f"TOOL: play_action ARGS: {{'action': '{action}'}}"
# Execute action
res = await client.call_tool("play_action", {"action": action})
moves_taken += 1
obs2 = self._text(res)
# Update recent memory for loop avoidance
self.recent_memory.append((action.lower().strip(), loc_id, score, obs2[:60]))
self.last_action = action
if verbose:
print(
f"[step] loc={loc_id} score={score} stuck={self.no_progress_steps} -> {action!r}",
file=sys.stderr,
)
history.append((thought, tool_call, obs2[:160]))
if self._is_game_over(obs2):
game_completed = True
break
# final status (best effort)
try:
raw = await client.call_tool("status", {})
st2 = json.loads(self._text(raw))
final_score = max(final_score, int(st2.get("score", 0)))
max_score = max_score or int(st2.get("max_score", 0))
self.visited.add(int(st2["loc_id"]))
except Exception:
pass
return RunResult(
final_score=final_score,
max_score=max_score,
moves=moves_taken,
locations_visited=self.visited,
game_completed=game_completed,
history=history,
)
# =============================================================================
# Decision logic
# =============================================================================
def _decide(self, status: dict, seed: int) -> tuple[str, str]:
loc_id = int(status["loc_id"])
obs = status.get("last_observation", "") or ""
outcomes = status.get("outcomes_here", {}) or {}
banned = {str(x).lower().strip() for x in (status.get("banned_actions_here", []) or [])}
untried = status.get("untried_directions", []) or []
valid_exits = status.get("valid_exits", []) or []
suggested = status.get("suggested_interactions", []) or []
# 0) Disambiguation
if DISAMBIG_RE.search(obs):
opt = self._extract_option(obs)
if opt and not self._repeat_noop(opt, loc_id):
return "Disambiguation requested by the game; answer with the first plausible option.", opt
# A1) Jericho-validated untried exits
untried_set = set(untried)
obs_dirs = self._mentioned_dirs(obs)
for d in valid_exits:
dl = d.lower().strip()
if d in untried_set and dl not in banned and not self._repeat_noop(d, loc_id):
return f"Take a valid untried exit to explore: {d}.", d
# A2) Observation-boosted untried dirs
for d in obs_dirs:
if d in untried_set and d.lower() not in banned and not self._repeat_noop(d, loc_id):
return f"Direction mentioned in observation and untried; explore: {d}.", d
# A3) Any untried direction
for d in untried:
if d.lower() not in banned and not self._repeat_noop(d, loc_id):
return f"No strong cue; systematically try untried direction: {d}.", d
# B) Bounded safe interactions
n = self.interactions_done.get(loc_id, 0)
if n < MAX_INTERACTIONS:
for a in suggested:
al = a.lower().strip()
if al in banned:
continue
if any(al.startswith(x) for x in UNSAFE_STARTS):
continue
if a in outcomes:
continue
if self._repeat_noop(a, loc_id):
continue
self.interactions_done[loc_id] = n + 1
return f"Try a game-validated interaction in this room (#{n+1}): {a}.", a
# C) BFS backtrack to frontier
avoid = self._oscillation_avoid()
step_dir = self._bfs_step(loc_id, avoid)
if step_dir:
return "No local frontier; backtrack via BFS to nearest unexplored frontier.", step_dir
# D) Stuck recovery
if self.no_progress_steps >= STUCK_THRESHOLD:
for a in ("look", "inventory"):
if not self._repeat_noop(a, loc_id):
return "Stuck for many steps; run a safe recovery action.", a
noun = self._extract_noun(obs)
if noun and not self._repeat_noop(f"examine {noun}", loc_id):
return "Stuck; examine a likely noun from the observation.", f"examine {noun}"
# E) Optional LLM fallback
if LLM_CLIENT is not None:
try:
self.llm_calls += 1
prompt = self._llm_prompt(status)
resp = call_llm(prompt, LLM_SYSTEM, seed + self.llm_calls)
act = self._parse_llm(resp)
if act and act.lower().strip() not in banned and not self._repeat_noop(act, loc_id):
return "Heuristics exhausted; use one short LLM suggestion (optional fallback).", act
except Exception:
pass
return "Fallback to a safe neutral action.", "look"
async def _peek_pick(self, client, status: dict, current_action: str) -> str:
"""Use peek_action to score a small candidate set and pick best."""
loc_id = int(status["loc_id"])
score = int(status.get("score", 0))
candidates = []
if current_action:
candidates.append(current_action)
for d in (status.get("untried_directions", []) or [])[:4]:
if d not in candidates:
candidates.append(d)
for a in (status.get("suggested_interactions", []) or [])[:4]:
if a not in candidates:
candidates.append(a)
candidates = candidates[:PEEK_K]
best = current_action
best_u = -10**18
for a in candidates:
try:
raw = await client.call_tool("peek_action", {"action": a})
st = json.loads(self._text(raw))
new_score = int(st.get("score", score))
new_loc = int(st.get("loc_id", loc_id))
delta = max(0, new_score - score)
if new_loc != loc_id:
moved_bonus = 600 if (new_loc not in self.visited) else 80
else:
moved_bonus = 0
repeat_pen = 120 if self._repeat_noop(a, loc_id) else 0
u = delta * 900 + moved_bonus - repeat_pen
if u > best_u:
best_u = u
best = a
except Exception:
continue
return best
# =============================================================================
# Graph / BFS
# =============================================================================
def _merge_edges(self, loc_id: int, edges_here: dict) -> None:
if not edges_here:
return
node = self.graph.setdefault(loc_id, {})
for d, nid in edges_here.items():
try:
node[str(d)] = int(nid)
except Exception:
pass
def _oscillation_avoid(self) -> Optional[int]:
locs = [x[1] for x in self.recent_memory]
if len(locs) >= 4 and locs[-1] == locs[-3] and locs[-2] == locs[-4]:
return locs[-2]
return None
def _bfs_step(self, from_loc: int, avoid_loc: Optional[int]) -> Optional[str]:
frontier = {lid for lid, u in self.loc_untried.items() if u and lid != from_loc}
if not frontier:
return None
q = deque()
seen = {from_loc}
for d, nid in self.graph.get(from_loc, {}).items():
if nid not in seen and nid != avoid_loc:
q.append((nid, d))
seen.add(nid)
while q:
cur, first_dir = q.popleft()
if cur in frontier:
return first_dir
for d, nid in self.graph.get(cur, {}).items():
if nid not in seen:
seen.add(nid)
q.append((nid, first_dir))
return None
# =============================================================================
# Parsing / loop helpers
# =============================================================================
def _repeat_noop(self, action: str, loc_id: int) -> bool:
a = (action or "").lower().strip()
return any(prev_a == a and prev_loc == loc_id for (prev_a, prev_loc, _sc, _o) in self.recent_memory)
def _mentioned_dirs(self, obs: str) -> list[str]:
out = []
for m in DIR_WORD_RE.finditer(obs or ""):
d = m.group(1).lower()
if d not in out:
out.append(d)
return out
def _extract_option(self, obs: str) -> Optional[str]:
m = OPTION_RE.search(obs or "")
if m:
return m.group(1).strip().lower()
return None
def _extract_noun(self, obs: str) -> Optional[str]:
m = re.search(r"\bthe\s+([a-z]{3,})\b", (obs or "").lower())
if m:
noun = m.group(1)
# avoid directions being interpreted as nouns
if noun not in {"north", "south", "east", "west", "up", "down", "in", "out"}:
return noun
return None
def _sanitize_action(self, a: str) -> str:
a = (a or "").strip()
a = re.sub(r"[`\"']", "", a)
a = re.sub(r"\s+", " ", a).strip()
words = a.split()[:6]
return " ".join(words) if words else "look"
def _llm_prompt(self, status: dict) -> str:
inv = ", ".join(status.get("inventory", [])) or "empty"
tried = ", ".join(list((status.get("outcomes_here") or {}).keys())[:20]) or "none"
banned = ", ".join(status.get("banned_actions_here", [])) or "none"
return (
f"Location: {status.get('loc_name')} (id={status.get('loc_id')})\n"
f"Score: {status.get('score')}/{status.get('max_score')} Moves: {status.get('moves')}\n"
f"Inventory: {inv}\n"
f"Untried dirs: {', '.join((status.get('untried_directions') or [])[:12])}\n"
f"Tried here: {tried}\n"
f"BANNED: {banned}\n\n"
f"Observation:\n{(status.get('last_observation') or '')[:500]}\n"
)
def _parse_llm(self, resp: str) -> str:
for line in (resp or "").splitlines():
line = line.strip()
if not line:
continue
if line.upper().startswith("ACTION:"):
line = line.split(":", 1)[1].strip()
line = line.lower()
m = re.match(
r"^(?:go\s+)?(north(?:east|west)?|south(?:east|west)?|east|west|up|down|in|out)\b",
line,
)
if m:
return m.group(1)
return " ".join(line.split()[:5])
return "look"
def _is_game_over(self, text: str) -> bool:
t = (text or "").lower()
return any(x in t for x in ("game over", "you have died", "you are dead", "you have won"))
def _text(self, result) -> str:
try:
if hasattr(result, "content") and result.content:
return result.content[0].text
if isinstance(result, list) and result:
return result[0].text
except Exception:
pass
return str(result)
# Optional smoke-test (does not run during evaluation import)
async def _test() -> None:
from fastmcp import Client
from fastmcp.client.transports import StdioTransport
import sys as _sys
import os as _os
transport = StdioTransport(
command=_sys.executable,
args=[_os.path.join(_os.path.dirname(__file__), "mcp_server.py")],
env={**_os.environ, "GAME": "lostpig"},
)
agent = StudentAgent()
async with Client(transport) as client:
res = await agent.run(client, game="lostpig", max_steps=30, seed=42, verbose=True)
print(
f"Score: {res.final_score}/{res.max_score} | Moves: {res.moves} | Locations: {len(res.locations_visited)}",
file=sys.stderr,
)
if __name__ == "__main__":
import asyncio
asyncio.run(_test())