dspy-zmachine / agent.py
janisaiad's picture
Submission: refs, run proof (no binary PNG for HF), abstract, short blog in README
59f6ee5
"""
ZorkGPT-Lite: Full orchestrator with Agent, Critic, Extractor, StrategyGen.
Uses Z-machine data (memory, inventory, get_valid_actions) + LLM for reasoning.
"""
import asyncio
import json
import os
import sys
import re
from dataclasses import dataclass, field
from typing import Optional, Tuple, Any
from pathlib import Path
from dotenv import load_dotenv
# we cache learned rules from refs/learned/ for _learned_heuristic
_LEARNED_CACHE: dict[str, Any] = {}
load_dotenv()
try:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
_LOCAL_INFERENCE_AVAILABLE = True
except ImportError:
_LOCAL_INFERENCE_AVAILABLE = False
from huggingface_hub import InferenceClient
# =============================================================================
# LLM Configuration
# =============================================================================
LLM_MODEL = "Qwen/Qwen2.5-72B-Instruct"
_USE_LOCAL = os.getenv("USE_LOCAL_MODEL", "false").lower() in ("true", "1", "yes")
_HF_MODEL_LOCAL = os.getenv("HF_MODEL", "Qwen/Qwen2.5-7B-Instruct")
_hf_token = os.getenv("HF_TOKEN")
if not _USE_LOCAL or not _LOCAL_INFERENCE_AVAILABLE:
if not _hf_token:
raise ValueError("HF_TOKEN not found. Set it in your .env file (or use USE_LOCAL_MODEL=true with transformers).")
LLM_CLIENT: Optional[InferenceClient] = InferenceClient(token=_hf_token)
else:
LLM_CLIENT = None
_local_tokenizer = None
_local_model = None
def _ensure_local_model() -> None:
global _local_tokenizer, _local_model
if _local_model is not None:
return
if not _LOCAL_INFERENCE_AVAILABLE or not _USE_LOCAL:
return
device = "cuda" if torch.cuda.is_available() else "cpu"
token_kw = {"token": _hf_token} if _hf_token else {}
if not _hf_token:
print("[INFO] No HF_TOKEN; gated models may fail. Set HF_TOKEN in .env for e.g. Gemma.")
_local_tokenizer = AutoTokenizer.from_pretrained(_HF_MODEL_LOCAL, **token_kw)
_local_model = AutoModelForCausalLM.from_pretrained(
_HF_MODEL_LOCAL,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
device_map="auto" if device == "cuda" else None,
**token_kw,
)
if device == "cpu":
_local_model = _local_model.to(device)
print(f"[INFO] Local model loaded: {_HF_MODEL_LOCAL} on {device}")
def call_llm(
prompt: str,
system_prompt: str,
seed: int,
max_tokens: int = 400,
simple_output: bool = False,
temperature: float = 0.0,
) -> str:
"""Call the LLM (API or local). simple_output=True skips THOUGHT priming. temperature>0 enables sampling."""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
]
if _USE_LOCAL and _LOCAL_INFERENCE_AVAILABLE:
_ensure_local_model()
if _local_tokenizer is None or _local_model is None:
raise RuntimeError("Local model failed to load.")
if hasattr(_local_tokenizer, "apply_chat_template"):
formatted = _local_tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
else:
formatted = f"{system_prompt}\n\nUser: {prompt}\n\nAssistant:"
if not simple_output:
formatted = formatted.rstrip() + "\nTHOUGHT:"
inputs = _local_tokenizer(formatted, return_tensors="pt")
model_device = next(_local_model.parameters()).device
inputs = {k: (v.to(model_device) if isinstance(v, torch.Tensor) else v) for k, v in inputs.items()}
do_sample = temperature > 0
gen_kwargs = dict(
max_new_tokens=max_tokens,
pad_token_id=_local_tokenizer.eos_token_id,
do_sample=do_sample,
)
if do_sample:
gen_kwargs["temperature"] = temperature
gen_kwargs["top_p"] = 0.95
with torch.no_grad():
gen_out = _local_model.generate(**inputs, **gen_kwargs)
out_slice = gen_out[0][inputs["input_ids"].shape[1]:]
if out_slice.is_cuda:
out_slice = out_slice.cpu()
raw = _local_tokenizer.decode(out_slice, skip_special_tokens=True).strip()
if not simple_output and formatted.rstrip().endswith("THOUGHT:") and raw and not raw.upper().startswith("THOUGHT:"):
raw = "THOUGHT: " + raw
return raw
response = LLM_CLIENT.chat.completions.create(
model=LLM_MODEL,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
seed=seed,
)
return response.choices[0].message.content
@dataclass
class RunResult:
"""Result of running the agent. Do not modify this class."""
final_score: int
max_score: int
moves: int
locations_visited: set[str]
game_completed: bool
error: Optional[str] = None
history: list[tuple[str, str, str]] = field(default_factory=list)
# =============================================================================
# Prompts: Agent, Critic, StrategyGen
# =============================================================================
AGENT_PROMPT = """You are an expert text adventure player. MAXIMIZE YOUR SCORE and EXPLORE NEW LOCATIONS.
Evaluation (lexicographic): (score, locations). Both matter. Passing: score >= 2 and >= 40 locations in 100 steps. Strong: score 3-4 and many locations.
CONTEXT/MEMORY MATTERS MOST – The main differentiator is avoiding loops and deadlocks:
- NEVER repeat an action that failed at the current location. If "Recently failed", "Avoid", or "DO NOT repeat at current location" is shown, do NOT propose those actions.
- Track visited locations: prioritize exploring NEW rooms. More locations = stronger result.
- Loop detector: if the same or very similar observation appears again within a few steps, you are in a loop – try a DIFFERENT direction or action immediately.
- Dead ends: if an action led to no progress (no score, no movement), do not retry it at this location.
AVAILABLE MCP TOOLS:
- play_action: Execute game commands (north, take lamp, open mailbox, get up, etc.)
- memory: Get current state from Z-machine
- inventory: Get items from Z-machine
- get_map: Explored locations
- get_context_for_agent: Call this! Returns "Locations visited: N" and "DO NOT repeat at current location: [list]". Use it to avoid loops.
When the prompt shows:
- "Locations visited: N" – aim to increase N; explore unmapped directions.
- "Avoid (recently failed): [list]" or "NEVER repeat at current location: [list]" – do NOT propose any action from that list.
- Same room/observation as a recent step – you are looping; try a different direction.
CRITICAL: You MUST respond in this exact format (no markdown, no extra text):
THOUGHT: <one sentence about what to do next>
TOOL: play_action
ARGS: {"action": "<command>"}
Universal rules (apply to any text adventure):
- If game says "get out of bed first" or "have to get up": try get up, stand
- If "too dark" or "can't see": light lamp, take lamp
- If "can't go that way": try different direction
- If "don't understand": try simpler verb (look, examine, take X)
- Explore directions (north, south, east, west). Take items. Do NOT repeat same action in a loop."""
AGENT_COMBAT_PROMPT_APPEND = """
COMBAT PRIORITY: If sword/lamp is GLOWING or enemy is attacking, use ONLY combat actions:
attack <enemy> with sword, kill <enemy>, fight <enemy>. Do NOT explore or take items until combat ends."""
AGENT_EMERGENCY_PROMPT_APPEND = """
EMERGENCY: Score stagnation - prioritize known score sources (trophy case, window entry, paintings).
Try unmapped directions. Drop heavy items to take valuable ones. Avoid wasting turns on non-existent enemies."""
AGENT_MULTI_CANDIDATES_PROMPT = """Propose 2 different actions as ALTERNATIVE1 and ALTERNATIVE2.
Format:
THOUGHT: <brief reasoning>
ALTERNATIVE1: <first action>
ALTERNATIVE2: <second action>
Use TOOL/ARGS for the first one only."""
CRITIC_PROMPT = """You evaluate whether a proposed game action is good.
Given: current observation, valid actions from Z-machine, proposed action.
Score -1.0 to 1.0: -1=bad (invalid, repeated, no progress, wrong object), 1=good (valid, progresses).
Use negative scores for actions that repeat failed commands or reference non-existent objects.
Respond in one line: SCORE: <-1.0 to 1.0> REASON: <brief reason>
If action is in valid_actions or is a common command (look, north, take X), score >= 0.6.
If object in action is NOT in valid_actions, score <= -0.5."""
STRATEGY_PROMPT = """Analyze this gameplay history and extract 3-5 strategic insights.
Format each as a short rule. Example: "In dark games, get lamp before exploring."
Output only the insights, one per line."""
# =============================================================================
# StudentAgent: Full ZorkGPT-Lite Orchestrator
# =============================================================================
class StudentAgent:
"""
Full orchestrator: Extractor (Z-machine) -> Agent -> Critic (Z-machine + LLM) -> Execute.
StrategyGen updates knowledge_base every 12 turns.
Override policy, score stagnation emergency, combat heuristic, object-tree validation.
"""
CRITIC_THRESHOLD = 0.5
CRITIC_OVERRIDE_MIN = -0.5 # we override rejection if score >= this and action is exploration
MAX_CRITIC_RETRIES = 3
STRATEGY_UPDATE_INTERVAL = 12
VALID_ACTIONS_TIMEOUT = 0.8
STAGNATION_DEATH_THRESHOLD = 35 # Zork-like: ~30-40 turns without score -> death
EMERGENCY_TURNS_LEFT = 10 # enter emergency when turns_until_death <= this
def __init__(self):
self.history: list[dict] = []
self.recent_actions: list[str] = []
self.failed_actions: set[str] = set() # we avoid repeating actions that failed
self.score: int = 0
self.max_score: int = 350
self.steps_without_score: int = 0
self.knowledge_base: str = "General: Explore, take items, use lamp before dark. Try get up if stuck. Try east/north when south fails."
self.seen_state_hashes: dict[str, int] = {} # state_hash -> last step seen (for loop detection)
self.unmapped_directions: set[str] = set()
self.failed_actions_at_state: dict[str, set[str]] = {} # state_hash -> failed actions (never repeat)
async def run(
self,
client,
game: str,
max_steps: int,
seed: int,
verbose: bool = False,
trace_path: Optional[str] = None,
) -> RunResult:
"""Run the full orchestrator loop. If trace_path is set, write step-level JSON for analysis."""
self._current_game = game
locations_visited = set()
history = []
moves = 0
trace_steps: list[dict] = []
tool_names = [t.name for t in await client.list_tools()]
self.failed_actions = set()
self.steps_without_score = 0
self.seen_state_hashes = {}
self.failed_actions_at_state = {}
# we get initial observation
result = await client.call_tool("play_action", {"action": "look"})
observation = self._extract_result(result)
loc = observation.split("\n")[0] if observation else "Unknown"
locations_visited.add(loc)
if verbose:
print(f"\n{observation}")
context = {}
for step in range(1, max_steps + 1):
# we print progress to stderr so batch runs show activity (every step)
print(f" step {step}/{max_steps} score={self.score}", file=sys.stderr, flush=True)
# we extract context from Z-machine (no LLM)
context = await self._extract_context(client)
context["game"] = game
state_hash = context.get("state_hash", "")
context["state_hash"] = state_hash
self.seen_state_hashes[state_hash[:64]] = step
# we build agent prompt (with combat/emergency appendices)
prompt = self._build_agent_prompt(observation, context)
thought, tool_name, tool_args = "No reasoning", "play_action", {"action": "look"}
action = "look"
# we get action from Agent LM (max_tokens 250 for small models)
print(f" [LLM] step {step}/{max_steps} action...", file=sys.stderr, flush=True)
response = call_llm(prompt, AGENT_PROMPT, seed + step, max_tokens=250)
if not response.strip():
response = self._heuristic_action(observation)
if verbose:
print(f"[DEBUG] LLM empty, heuristic: {response[:80]}")
thought, tool_name, tool_args = self._parse_response(response, tool_names)
tool_name, tool_args = self._validate_tool_call(tool_name, tool_args, tool_names)
if tool_name == "play_action":
action = tool_args.get("action", "look")
print(f" [LLM] -> {action[:40]}", file=sys.stderr, flush=True)
# we apply combat heuristic: if sword glowing, bias toward attack
if self._is_combat_situation(observation) and not any(
a.lower().startswith(("attack", "kill", "fight")) for a in self.recent_actions[-2:]
):
combat_action = self._combat_heuristic_action(observation)
if combat_action:
action = combat_action
tool_args = {"action": action}
if verbose:
print(f"[COMBAT] Biased toward attack: {action}")
# we run Critic: object-tree check first, then fast check, then LLM with override policy
accepted = False
override_reason: Optional[str] = None
obj_valid, obj_msg = self._object_tree_validation(action, context.get("valid_actions", ""))
if not obj_valid and obj_msg:
if verbose:
print(f"[CRITIC] Object tree: {obj_msg}")
fast_ok = self._critic_fast_check(action, context.get("valid_actions", ""))
if fast_ok and obj_valid:
accepted = True
critic_score = 0.5
for attempt in range(self.MAX_CRITIC_RETRIES):
if accepted:
break
critic_prompt = f"""Observation: {observation[:300]}
Valid actions: {context.get('valid_actions', 'unknown')}
Proposed: {action}
Score and reason?"""
print(f" [LLM] critic attempt {attempt+1}...", file=sys.stderr, flush=True)
critic_resp = call_llm(critic_prompt, CRITIC_PROMPT, seed + step + attempt, max_tokens=80)
critic_score = self._parse_critic_score(critic_resp)
if critic_score >= self.CRITIC_THRESHOLD:
accepted = True
break
# we override policy: when critic uncertain but action is exploration-related, override
should_override, override_reason = self._should_override_critic_rejection(
action, critic_score, context, observation
)
if should_override:
accepted = True
if verbose:
print(f"[OVERRIDE] {override_reason} (score {critic_score:.2f})")
break
if attempt < self.MAX_CRITIC_RETRIES - 1:
# we get multiple candidates when rejected (alternative actions)
feedback = f"Action '{action}' rejected (score {critic_score:.1f}). Propose 2 different actions."
prompt = self._build_agent_prompt(observation, context, feedback)
response = call_llm(prompt, AGENT_PROMPT, seed + step + attempt)
candidates = self._parse_multiple_actions(response)
if len(candidates) >= 2:
best_action, best_score = action, critic_score
for cand in candidates[:3]:
if cand == action:
continue
cp = f"Observation: {observation[:250]}\nValid: {context.get('valid_actions','')[:150]}\nProposed: {cand}\nScore?"
cr = call_llm(cp, CRITIC_PROMPT, seed + step + attempt + 1, max_tokens=60)
cs = self._parse_critic_score(cr)
if cs > best_score:
best_action, best_score = cand, cs
action = best_action
tool_args = {"action": action}
else:
thought, tool_name, tool_args = self._parse_response(response, tool_names)
tool_name, tool_args = self._validate_tool_call(tool_name, tool_args, tool_names)
if tool_name == "play_action":
action = tool_args.get("action", "look")
if self._is_combat_situation(observation):
combat_action = self._combat_heuristic_action(observation)
if combat_action:
action = combat_action
tool_args = {"action": action}
else:
accepted = True
# we loop detection: try result-based heuristic first, then generic verb cycle
if len(self.recent_actions) >= 3 and len(set(self.recent_actions[-3:])) == 1:
res = self._result_based_heuristic(observation)
if res is not None:
action = res
else:
action = self._generic_verb_cycle()
tool_args = {"action": action}
if verbose:
print(f"[WARNING] Loop detected - trying '{action}' instead")
# we skip actions that failed at current location (never repeat)
key = state_hash[:64] if state_hash else ""
failed_at_loc = self.failed_actions_at_state.get(key, set())
if action.lower() in failed_at_loc or action.lower() in self.failed_actions:
action = self._generic_verb_cycle(extra_skip=failed_at_loc)
tool_args = {"action": action}
if verbose:
print(f"[CONTEXT] Skipping failed action, trying {action}")
# we loop detection via state hash: if we've seen this state before, force different action
if key and key in self.seen_state_hashes and self.seen_state_hashes[key] < step - 2:
if action.lower() in [a.lower() for a in self.recent_actions[-4:]]:
action = self._generic_verb_cycle(extra_skip=failed_at_loc)
tool_args = {"action": action}
if verbose:
print(f"[LOOP] State revisited, trying {action}")
# we prefer valid_actions when stuck (no score for many steps)
if self.steps_without_score >= 5 and context.get("valid_actions"):
va = context["valid_actions"].lower()
for cand in ["take all", "take lamp", "take keys", "open", "examine", "north", "east"]:
if cand in va and cand not in self.failed_actions:
if cand not in [a.lower() for a in self.recent_actions[-3:]]:
action = cand
tool_args = {"action": action}
break
# we emergency mode: when turns_until_death is low, prioritize score sources and unmapped exits
# skip stagnation logic for scoreless games (max_score 0)
game_max = self._parse_max_score_from_context(context)
turns_until_death = self.STAGNATION_DEATH_THRESHOLD - self.steps_without_score
if game_max == 0:
turns_until_death = 999 # disable emergency for scoreless games
if turns_until_death <= self.EMERGENCY_TURNS_LEFT and turns_until_death > 0:
em_action = self._emergency_heuristic(observation, context)
if em_action:
action = em_action
tool_args = {"action": action}
if verbose:
print(f"[EMERGENCY] {turns_until_death} turns left - trying {action}")
self.recent_actions.append(action)
if len(self.recent_actions) > 10:
self.recent_actions = self.recent_actions[-10:]
# we track failed actions (rejection, no movement, no score)
if self._is_failure_result(observation, action):
self.failed_actions.add(action.lower())
if state_hash:
key = state_hash[:64]
if key not in self.failed_actions_at_state:
self.failed_actions_at_state[key] = set()
self.failed_actions_at_state[key].add(action.lower())
try:
await client.call_tool("record_failed_action", {"state_hash": state_hash, "action": action})
except Exception:
pass
else:
self.failed_actions.discard(action.lower())
# we track score progress and reinforce what worked
old_score = self.score
self._update_score(observation)
if self.score > old_score:
self.steps_without_score = 0
if len(self.knowledge_base) < 800:
self.knowledge_base = self.knowledge_base + f"\nScore: {action} worked."
else:
self.steps_without_score += 1
if verbose:
print(f"\n--- Step {step} ---")
print(f"[THOUGHT] {thought}")
print(f"[TOOL] {tool_name}({tool_args})")
score_before = self.score
obs_pre = observation[:400] # we save before overwriting
print(f" [exec] step {step} -> {action[:30]}...", file=sys.stderr, flush=True)
try:
result = await client.call_tool(tool_name, tool_args)
observation = self._extract_result(result)
moves += 1
except Exception as e:
observation = f"Error: {e}"
if verbose:
print(f"[ERROR] {e}")
print(f" [done] step {step} score={self.score}", file=sys.stderr, flush=True)
loc = observation.split("\n")[0] if observation else "Unknown"
locations_visited.add(loc)
self._update_score(observation)
score_after = self.score
if trace_path:
trace_steps.append({
"step": step,
"observation_pre": obs_pre,
"action": action,
"observation_post": observation[:400],
"score_before": score_before,
"score_after": score_after,
"reward": score_after - score_before,
"location": loc,
"valid_actions": (context.get("valid_actions") or "")[:200],
"critic_score": critic_score,
"override": override_reason or False,
})
history.append((thought, f"{tool_name}({tool_args})", observation[:100]))
self.history.append({"step": step, "thought": thought, "action": action, "result": observation[:200]})
if len(self.history) > 20:
self.history = self.history[-20:]
if verbose:
print(f"[RESULT] {observation[:200]}...")
# we update knowledge_base every N turns (StrategyGen)
if step % self.STRATEGY_UPDATE_INTERVAL == 0 and self.history:
strategy_hist = "\n".join([f"Step {h['step']}: {h['action']} -> {h['result'][:80]}" for h in self.history[-15:]])
strat_prompt = f"History:\n{strategy_hist}\n\nCurrent score: {self.score}\nExtract insights:"
try:
insights = call_llm(strat_prompt, STRATEGY_PROMPT, seed + step, max_tokens=150)
if insights.strip():
self.knowledge_base = self.knowledge_base + "\n" + insights.strip()[:300]
except Exception:
pass
if self._is_game_over(observation):
if verbose:
print("\n*** GAME OVER ***")
break
if trace_path:
trace_obj = {
"game": game,
"seed": seed,
"max_steps": max_steps,
"final_score": self.score,
"final_moves": moves,
"game_over": self._is_game_over(observation),
"steps": trace_steps,
}
Path(trace_path).parent.mkdir(parents=True, exist_ok=True)
with open(trace_path, "w") as f:
json.dump(trace_obj, f, indent=None)
return RunResult(
final_score=self.score,
max_score=self.max_score,
moves=moves,
locations_visited=locations_visited,
game_completed=self._is_game_over(observation),
history=history,
)
async def _extract_context(self, client) -> dict:
"""Extractor: Z-machine data via MCP tools (no LLM)."""
ctx = {}
tools_to_try = [
("memory", "memory"),
("inventory", "inventory"),
("get_map", "map"),
("get_state_hash", "state_hash_raw"),
("get_context_for_agent", "context_summary"),
]
for tool_name, key in tools_to_try:
try:
r = await client.call_tool(tool_name, {})
ctx[key] = self._extract_result(r)
except Exception:
ctx[key] = ""
# we enable get_valid_actions by default for critic; set USE_VALID_ACTIONS=false to disable (spacy can block)
if os.getenv("USE_VALID_ACTIONS", "true").lower() not in ("false", "0", "no"):
try:
r = await asyncio.wait_for(
client.call_tool("get_valid_actions", {}),
timeout=self.VALID_ACTIONS_TIMEOUT,
)
ctx["valid_actions"] = self._extract_result(r)
except (asyncio.TimeoutError, Exception):
ctx["valid_actions"] = ""
else:
ctx["valid_actions"] = ""
# we parse raw state hash for loop detection (strip "State hash: " prefix)
raw = ctx.get("state_hash_raw", "")
if raw and ":" in raw:
ctx["state_hash"] = raw.split(":", 1)[1].strip().split("...")[0].strip()
else:
ctx["state_hash"] = raw[:64] if raw else ""
return ctx
def _build_agent_prompt(self, observation: str, context: dict, feedback: str = "") -> str:
"""Build agent prompt with context, combat and emergency appendices."""
parts = [f"Knowledge base:\n{self.knowledge_base[:500]}\n"]
parts.append(f"Current score: {self.score}")
game_max = self._parse_max_score_from_context(context)
turns_until_death = self.STAGNATION_DEATH_THRESHOLD - self.steps_without_score
if game_max == 0:
turns_until_death = 999
if turns_until_death <= self.EMERGENCY_TURNS_LEFT and turns_until_death > 0:
parts.append(f"\n[EMERGENCY] {turns_until_death} turns until score-stagnation death! Prioritize known score sources.")
if self._is_combat_situation(observation):
parts.append(AGENT_COMBAT_PROMPT_APPEND)
if context.get("valid_actions"):
parts.append(f"\nValid actions (prefer these): {context['valid_actions'][:200]}")
# we add learned hints for current room when available
room = (observation or "").strip().split("\n")[0][:80]
learned_hint = self._learned_heuristic(room, context)
if learned_hint:
parts.append(f"\n[Learned] In this room try: {learned_hint}")
if context.get("memory"):
parts.append(f"\nZ-machine state:\n{context['memory'][:350]}")
if context.get("map"):
parts.append(f"\nMap:\n{context['map'][:250]}")
if context.get("inventory"):
parts.append(f"\n{context['inventory']}")
if context.get("context_summary"):
parts.append(f"\n{context['context_summary']}")
if self.failed_actions:
parts.append(f"\nAvoid (recently failed): {', '.join(list(self.failed_actions)[:8])}")
key = (context.get("state_hash") or "")[:64]
failed_at_loc = self.failed_actions_at_state.get(key, set())
if failed_at_loc:
parts.append(f"\nNEVER repeat at current location: {', '.join(list(failed_at_loc)[:8])}")
if self.history:
parts.append("\nRecent:")
for h in self.history[-4:]:
parts.append(f" > {h.get('action','?')} -> {h.get('result','')[:55]}...")
if feedback:
parts.append(f"\n[FEEDBACK] {feedback}")
parts.append(f"\nCurrent observation:\n{observation}")
parts.append("\nWhat do you do next?")
return "\n".join(parts)
def _critic_fast_check(self, action: str, valid_actions_str: str) -> bool:
"""Fast validation: is action likely valid?"""
action_lower = action.lower().strip()
if valid_actions_str and "valid actions:" in valid_actions_str.lower():
va = valid_actions_str.lower()
if action_lower in va or any(action_lower.startswith(a.strip()) for a in va.split(",")[:20] if a.strip()):
return True
verb = action_lower.split()[0] if action_lower.split() else ""
if verb in ["look", "inventory", "north", "south", "east", "west", "take", "open", "examine"]:
return True
common = ["look", "inventory", "north", "south", "east", "west", "up", "down", "take", "drop", "open", "examine", "read", "get"]
if any(action_lower.startswith(c) for c in common):
return True
return True
def _parse_critic_score(self, resp: str) -> float:
"""Parse critic score from response; supports negative scores (-1 to 1)."""
m = re.search(r"SCORE:\s*([-]?[\d.]+)", resp, re.IGNORECASE)
if m:
try:
return max(-1.0, min(1.0, float(m.group(1))))
except ValueError:
pass
return 0.5
def _is_combat_situation(self, observation: str) -> bool:
"""Check if sword/weapon is glowing or combat is imminent (bias toward attack)."""
r = observation.lower()
if "sword" in r and ("glow" in r or "glowing" in r or "brightly" in r):
return True
if "attack" in r and ("troll" in r or "enemy" in r or "block" in r):
return True
if "brandishing" in r or "blocks all passages" in r:
return True
return False
def _combat_heuristic_action(self, observation: str) -> Optional[str]:
"""Return attack action when in combat situation (sword glowing)."""
r = observation.lower()
if "troll" in r:
return "attack troll with sword"
for enemy in ["enemy", "creature", "monster", "grue", "thief"]:
if enemy in r:
return f"attack {enemy} with sword"
return "attack with sword"
def _parse_max_score_from_context(self, context: dict) -> int:
"""Parse max_score from memory string (Score: X / Y points); return 350 if not found."""
mem = context.get("memory", "") or ""
m = re.search(r"Score:\s*\d+\s*/\s*(\d+)\s*points", mem, re.IGNORECASE)
if m:
return int(m.group(1))
return self.max_score
def _object_tree_validation(self, action: str, valid_actions_str: str) -> Tuple[bool, str]:
"""Check if action's object is in valid_actions; return (valid, msg)."""
if not valid_actions_str or "valid actions:" not in valid_actions_str.lower():
return True, ""
va = valid_actions_str.lower()
action_lower = action.lower().strip()
words = action_lower.split()
if len(words) < 2:
return True, ""
verb, obj = words[0], " ".join(words[1:])
if verb in ["look", "inventory", "north", "south", "east", "west", "up", "down", "i"]:
return True, ""
if obj in va or action_lower in va:
return True, ""
for part in va.split(","):
part = part.strip()
if part and obj in part:
return True, ""
return False, f"[Object Tree Validation] Object '{obj}' is not present"
def _should_override_critic_rejection(
self, action: str, score: float, context: dict, observation: str
) -> Tuple[bool, str]:
"""Override when critic uncertain but action is exploration-related."""
if score < self.CRITIC_OVERRIDE_MIN:
return False, ""
action_lower = action.lower().strip()
exploration_verbs = ["north", "south", "east", "west", "up", "down", "look", "examine", "take", "open", "enter", "go"]
verb = action_lower.split()[0] if action_lower.split() else ""
is_exploration = any(action_lower.startswith(v) for v in exploration_verbs)
if is_exploration:
return True, "low_critic_confidence"
if "unmapped" in str(context.get("map", "")).lower() or "exploring" in observation.lower()[:100]:
return True, "exploring_new_locations"
return False, ""
def _load_learned_rules(self) -> dict[str, Any]:
"""Lazily load learned_*.json from refs/learned/ (or env LEARNED_DIR)."""
global _LEARNED_CACHE
if _LEARNED_CACHE:
return _LEARNED_CACHE
learned_dir = Path(os.getenv("LEARNED_DIR", "")) or (Path(__file__).resolve().parent.parent / "refs" / "learned")
if not learned_dir.exists():
return _LEARNED_CACHE
for name in ["learned_location_actions", "learned_score_sources"]:
p = learned_dir / f"{name}.json"
if p.exists():
try:
with open(p) as f:
_LEARNED_CACHE[name] = json.load(f)
except Exception:
_LEARNED_CACHE[name] = []
return _LEARNED_CACHE
def _learned_heuristic(self, room: str, context: dict) -> Optional[str]:
"""Suggest action from learned (game, room, action) data if available."""
game = context.get("game", getattr(self, "_current_game", ""))
if not game:
return None
rules = self._load_learned_rules()
va = (context.get("valid_actions") or "").lower()
room_lower = (room or "").lower()[:80]
for key in ["learned_score_sources", "learned_location_actions"]:
items = rules.get(key, [])
for item in items:
if item.get("game") != game:
continue
ctx = (item.get("context") or "").lower()[:80]
if ctx and room_lower and ctx not in room_lower and room_lower not in ctx:
continue
action = (item.get("action") or "").strip().lower()
if not action or action in self.failed_actions:
continue
if action in [a.lower() for a in self.recent_actions[-5:]]:
continue
if va and action not in va and not any(action in a for a in va.split(",")[:25]):
continue
return action
return None
def _emergency_heuristic(self, observation: str, context: dict) -> Optional[str]:
"""When score stagnation critical, try known score sources and unmapped directions."""
room = (observation or "").strip().split("\n")[0][:80]
learned = self._learned_heuristic(room, context)
if learned:
return learned
r = observation.lower()
if "your load is too heavy" in r and "painting" in r:
if "drop leaflet" in [a.lower() for a in self.recent_actions[-3:]]:
return "drop all except lantern"
return "drop leaflet"
if "painting" in r and "unparalleled" in r:
if "take painting" not in self.failed_actions:
return "take painting"
if "trap door" in r and "revealing" in r:
return "open trap door"
if "dusty cover" in r and "trap door" in r:
return "open trap door"
if "rug" in r and "center" in r and "move rug" not in [a.lower() for a in self.recent_actions[-3:]]:
return "move rug"
if "troll" in r and "attack" not in [a.lower()[:6] for a in self.recent_actions[-2:]]:
return "attack troll with sword"
if "can't see any troll" in r:
return None
for d in ["north", "south", "east", "west", "up", "down"]:
if d not in [a.lower() for a in self.recent_actions[-5:]] and d not in self.failed_actions:
return d
return None
# =========================================================================
# Universal verb vocabulary (game-agnostic) per common_structure.md
# =========================================================================
# we cycle through these when no result-based pattern matches
UNIVERSAL_VERB_CYCLE = [
"look", "examine", "inventory",
"north", "south", "east", "west", "up", "down", "in", "out",
"take all", "take lamp", "take keys", "take wallet", "take phone", "take sword",
"open mailbox", "open door", "open", "open chest",
"get up", "stand", "rise", "wake",
"light lamp", "turn on lamp", "wear", "use", "read",
]
def _result_based_heuristic(self, result_text: str) -> Optional[str]:
"""Game-agnostic heuristic from result text per common_structure.md."""
r = result_text.lower()
# we prioritize taking visible objects when room lists them (905, etc)
if "telephone" in r or ("phone" in r and "take phone" not in [a.lower() for a in self.recent_actions[-3:]]):
if "take phone" not in self.failed_actions:
return "take phone"
if "wallet" in r and "take wallet" not in self.failed_actions and "take wallet" not in [a.lower() for a in self.recent_actions[-3:]]:
return "take wallet"
if "keys" in r and "take keys" not in self.failed_actions and "take keys" not in [a.lower() for a in self.recent_actions[-3:]]:
return "take keys"
# prerequisite: get out of bed, have to get up
if "get out of bed" in r or "out of bed" in r or "have to get up" in r:
return "get up"
if "get up" in r and "have to" in r:
return "get up"
if "stand" in r and ("have to" in r or "must" in r):
return "stand"
# light: too dark, can't see
if "too dark" in r or "can't see" in r or "too dark to" in r:
for cmd in ["light lamp", "turn on lamp", "take lamp"]:
if cmd not in [a.lower() for a in self.recent_actions[-3:]]:
return cmd
return "light lamp"
# movement block: wall, can't go that way
if "can't go" in r or "wall" in r or "can't go that way" in r or "too narrow" in r:
return None # we let generic cycle pick next direction
# parser rejection: don't understand, can't
if "don't understand" in r or "i don't understand" in r:
return "look"
if "you can't" in r or "can't do that" in r:
return "examine"
# object: take X when objects mentioned (keys, wallet, lamp, etc)
common_objects = ["telephone", "phone", "keys", "wallet", "lamp", "sword", "treasure", "book", "rope", "knife", "chest", "dresser"]
for word in common_objects:
if word in r:
action_try = f"take {word}"
if word == "telephone":
action_try = "take phone"
if action_try in self.failed_actions:
continue
recent_lower = [a.lower() for a in self.recent_actions[-5:]]
if action_try not in recent_lower:
return action_try
if "dresser" in r:
if "open dresser" not in [a.lower() for a in self.recent_actions[-3:]]:
return "open dresser"
for obj in self._extract_objects_from_room(result_text):
action_try = f"take {obj}"
if action_try in self.failed_actions:
continue
recent_lower = [a.lower() for a in self.recent_actions[-5:]]
if action_try not in recent_lower:
return action_try
if "mailbox" in r:
recent_lower = [a.lower() for a in self.recent_actions[-3:]]
if "open mailbox" not in recent_lower:
return "open mailbox"
if "open" in r and "closed" in r:
for word in ["door", "mailbox", "chest", "box"]:
if word in r:
return f"open {word}"
if "open" in r and "door" in r:
return "open door"
# no such thing, I don't see
if "don't see" in r or "no such" in r or "can't see any" in r:
return "look"
# only go X (extract direction)
for d in ["north", "south", "east", "west"]:
if f"only go {d}" in r or f"only {d}" in r or f"can only go {d}" in r:
return d
# lostpig / general: south fails with "trouble" -> try east (forest)
if "get in big trouble" in r or "big trouble" in r:
south_count = sum(1 for a in self.recent_actions[-5:] if a.lower() == "south")
if south_count >= 2:
return "east"
return "north"
# forest dark / pig somewhere: try forest first, then try west/south when stuck
if "forest" in r and "dark" in r:
east_count = sum(1 for a in self.recent_actions[-6:] if a.lower() == "east")
north_count = sum(1 for a in self.recent_actions[-6:] if a.lower() == "north")
if east_count + north_count >= 4:
return "west"
if east_count < 2:
return "east"
return "north"
return None
def _extract_objects_from_room(self, text: str) -> list[str]:
"""Extract object names from room description for take/examine."""
r = text.lower()
objects = []
# patterns: "there is a X", "you see X", "X and Y", "on the X are Y", "X, Y and Z"
for m in re.finditer(r"\b(there is|you see|are|on the \w+ are)\s+[a ]+(\w+)", r):
objects.append(m.group(2))
for m in re.finditer(r"\b(telephone|phone|wallet|keys|lamp|sword|book|rope|knife|chest|mailbox)\b", r):
objects.append(m.group(1))
return list(dict.fromkeys(objects))[:5]
def _generic_verb_cycle(self, extra_skip: set[str] | None = None) -> str:
"""Return next action from universal cycle, skipping failed actions."""
skip = self.failed_actions | (extra_skip or set())
cycle = self.UNIVERSAL_VERB_CYCLE
start = 0
if self.recent_actions:
last = self.recent_actions[-1].lower()
idx = next((i for i, a in enumerate(cycle) if a == last), -1)
start = (idx + 1) % len(cycle)
for i in range(len(cycle)):
cand = cycle[(start + i) % len(cycle)]
if cand not in skip:
return cand
return "look"
def _heuristic_action(self, observation: str) -> str:
"""Heuristic when LLM empty: result-based first, then generic verb cycle."""
action = self._result_based_heuristic(observation)
if action is None:
action = self._generic_verb_cycle()
return f"THOUGHT: Try {action}.\nTOOL: play_action\nARGS: {{\"action\": \"{action}\"}}"
def _parse_multiple_actions(self, response: str) -> list[str]:
"""Extract multiple action candidates from response (ALTERNATIVE1, ALTERNATIVE2, or ARGS blocks)."""
actions: list[str] = []
for line in response.strip().split("\n"):
lc = line.strip()
for prefix in ["ALTERNATIVE1:", "ALTERNATIVE2:", "ALTERNATIVE3:", "ACTION1:", "ACTION2:"]:
if lc.upper().startswith(prefix.upper()):
rest = lc.split(":", 1)[1].strip().strip('"\'')
if rest and len(rest) < 80:
actions.append(rest)
break
m = re.findall(r'"action"\s*:\s*"([^"]+)"', response)
for a in m:
if a not in actions:
actions.append(a)
return actions[:5]
def _parse_response(self, response: str, valid_tools: list[str]) -> tuple[str, str, dict]:
"""Parse LLM response; fallback to extracting action from raw text."""
thought = "No reasoning provided"
tool_name = "play_action"
tool_args = {"action": "look"}
for line in response.strip().split("\n"):
lc = line.strip()
lu = lc.upper()
if lu.startswith("THOUGHT:"):
thought = lc.split(":", 1)[1].strip() or thought
elif lu.startswith("TOOL:"):
raw = lc.split(":", 1)[1].strip().lower().replace("**", "").replace("*", "")
raw = raw.split()[0] if raw else "play_action"
tool_name = raw
elif lu.startswith("ARGS:"):
s = lc.split(":", 1)[1].strip().replace("'", '"')
try:
tool_args = json.loads(s)
except json.JSONDecodeError:
m = re.search(r'"action"\s*:\s*"([^"]+)"', s)
if m:
tool_args = {"action": m.group(1)}
# we fallback: if still "look", try to extract action from raw response
if tool_args.get("action", "look") == "look" and response.strip():
r = response.lower()
for cmd in ["east", "north", "south", "west", "inventory", "take all", "take lamp"]:
if cmd in r:
tool_args = {"action": cmd}
break
return thought, tool_name, tool_args
def _validate_tool_call(self, tool_name: str, tool_args: dict, valid_tools: list[str]) -> tuple[str, dict]:
"""Validate and fix tool call."""
if tool_name not in valid_tools:
tool_name = "play_action"
if tool_name == "play_action":
action = tool_args.get("action", "look")
invalid = {"check": "examine", "inspect": "examine", "search": "look", "grab": "take", "pick": "take"}
words = action.lower().split()
if words and words[0] in invalid:
words[0] = invalid[words[0]]
action = " ".join(words)
action = action.lower().strip().replace("**", "").replace("*", "")
action = " ".join(action.split())
tool_args["action"] = action
return tool_name, tool_args
def _extract_result(self, result) -> str:
"""Extract text from MCP result."""
if hasattr(result, "content") and result.content:
return result.content[0].text
if isinstance(result, list) and result:
return result[0].text if hasattr(result[0], "text") else str(result[0])
return str(result)
def _update_score(self, text: str) -> None:
"""Update score from text."""
for pat in [r"Score:\s*(\d+)", r"\[Score:\s*(\d+)", r"Total:\s*(\d+)"]:
m = re.search(pat, text, re.IGNORECASE)
if m:
self.score = max(self.score, int(m.group(1)))
break
def _is_game_over(self, text: str) -> bool:
"""Check game over."""
t = text.lower()
return any(p in t for p in ["game over", "you have died", "you are dead", "*** you have died ***"])
def _is_failure_result(self, result: str, action: str) -> bool:
"""Check if result indicates action failed (rejection, no progress)."""
r = result.lower()
failure_phrases = [
"don't understand", "you can't", "can't do that", "can't go that way",
"there is no", "no such", "you'll have to", "have to get", "get out of bed first",
"verb error", "not recognized", "i don't see", "can't see any",
]
if any(p in r for p in failure_phrases):
return True
if "get in big trouble" in r or "grunk get in big trouble" in r:
return True
return False
async def test_agent():
"""Test the agent locally."""
from fastmcp import Client
from fastmcp.client.transports import StdioTransport
import sys
from pathlib import Path
server_path = Path(__file__).parent / "mcp_server.py"
env = os.environ.copy()
env["GAME"] = "lostpig"
transport = StdioTransport(command=sys.executable, args=[str(server_path)], env=env)
agent = StudentAgent()
async with Client(transport) as client:
result = await agent.run(client=client, game="lostpig", max_steps=10, seed=42, verbose=True)
print(f"\nFinal: score={result.final_score}, moves={result.moves}")
if __name__ == "__main__":
import asyncio
asyncio.run(test_agent())