Spaces:
Runtime error
Runtime error
| """ | |
| ZorkGPT-Lite: Full orchestrator with Agent, Critic, Extractor, StrategyGen. | |
| Uses Z-machine data (memory, inventory, get_valid_actions) + LLM for reasoning. | |
| """ | |
| import asyncio | |
| import json | |
| import os | |
| import sys | |
| import re | |
| from dataclasses import dataclass, field | |
| from typing import Optional, Tuple, Any | |
| from pathlib import Path | |
| from dotenv import load_dotenv | |
| # we cache learned rules from refs/learned/ for _learned_heuristic | |
| _LEARNED_CACHE: dict[str, Any] = {} | |
| load_dotenv() | |
| try: | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| _LOCAL_INFERENCE_AVAILABLE = True | |
| except ImportError: | |
| _LOCAL_INFERENCE_AVAILABLE = False | |
| from huggingface_hub import InferenceClient | |
| # ============================================================================= | |
| # LLM Configuration | |
| # ============================================================================= | |
| LLM_MODEL = "Qwen/Qwen2.5-72B-Instruct" | |
| _USE_LOCAL = os.getenv("USE_LOCAL_MODEL", "false").lower() in ("true", "1", "yes") | |
| _HF_MODEL_LOCAL = os.getenv("HF_MODEL", "Qwen/Qwen2.5-7B-Instruct") | |
| _hf_token = os.getenv("HF_TOKEN") | |
| if not _USE_LOCAL or not _LOCAL_INFERENCE_AVAILABLE: | |
| if not _hf_token: | |
| raise ValueError("HF_TOKEN not found. Set it in your .env file (or use USE_LOCAL_MODEL=true with transformers).") | |
| LLM_CLIENT: Optional[InferenceClient] = InferenceClient(token=_hf_token) | |
| else: | |
| LLM_CLIENT = None | |
| _local_tokenizer = None | |
| _local_model = None | |
| def _ensure_local_model() -> None: | |
| global _local_tokenizer, _local_model | |
| if _local_model is not None: | |
| return | |
| if not _LOCAL_INFERENCE_AVAILABLE or not _USE_LOCAL: | |
| return | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| token_kw = {"token": _hf_token} if _hf_token else {} | |
| if not _hf_token: | |
| print("[INFO] No HF_TOKEN; gated models may fail. Set HF_TOKEN in .env for e.g. Gemma.") | |
| _local_tokenizer = AutoTokenizer.from_pretrained(_HF_MODEL_LOCAL, **token_kw) | |
| _local_model = AutoModelForCausalLM.from_pretrained( | |
| _HF_MODEL_LOCAL, | |
| torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
| device_map="auto" if device == "cuda" else None, | |
| **token_kw, | |
| ) | |
| if device == "cpu": | |
| _local_model = _local_model.to(device) | |
| print(f"[INFO] Local model loaded: {_HF_MODEL_LOCAL} on {device}") | |
| def call_llm( | |
| prompt: str, | |
| system_prompt: str, | |
| seed: int, | |
| max_tokens: int = 400, | |
| simple_output: bool = False, | |
| temperature: float = 0.0, | |
| ) -> str: | |
| """Call the LLM (API or local). simple_output=True skips THOUGHT priming. temperature>0 enables sampling.""" | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": prompt}, | |
| ] | |
| if _USE_LOCAL and _LOCAL_INFERENCE_AVAILABLE: | |
| _ensure_local_model() | |
| if _local_tokenizer is None or _local_model is None: | |
| raise RuntimeError("Local model failed to load.") | |
| if hasattr(_local_tokenizer, "apply_chat_template"): | |
| formatted = _local_tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| else: | |
| formatted = f"{system_prompt}\n\nUser: {prompt}\n\nAssistant:" | |
| if not simple_output: | |
| formatted = formatted.rstrip() + "\nTHOUGHT:" | |
| inputs = _local_tokenizer(formatted, return_tensors="pt") | |
| model_device = next(_local_model.parameters()).device | |
| inputs = {k: (v.to(model_device) if isinstance(v, torch.Tensor) else v) for k, v in inputs.items()} | |
| do_sample = temperature > 0 | |
| gen_kwargs = dict( | |
| max_new_tokens=max_tokens, | |
| pad_token_id=_local_tokenizer.eos_token_id, | |
| do_sample=do_sample, | |
| ) | |
| if do_sample: | |
| gen_kwargs["temperature"] = temperature | |
| gen_kwargs["top_p"] = 0.95 | |
| with torch.no_grad(): | |
| gen_out = _local_model.generate(**inputs, **gen_kwargs) | |
| out_slice = gen_out[0][inputs["input_ids"].shape[1]:] | |
| if out_slice.is_cuda: | |
| out_slice = out_slice.cpu() | |
| raw = _local_tokenizer.decode(out_slice, skip_special_tokens=True).strip() | |
| if not simple_output and formatted.rstrip().endswith("THOUGHT:") and raw and not raw.upper().startswith("THOUGHT:"): | |
| raw = "THOUGHT: " + raw | |
| return raw | |
| response = LLM_CLIENT.chat.completions.create( | |
| model=LLM_MODEL, | |
| messages=messages, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| seed=seed, | |
| ) | |
| return response.choices[0].message.content | |
| class RunResult: | |
| """Result of running the agent. Do not modify this class.""" | |
| final_score: int | |
| max_score: int | |
| moves: int | |
| locations_visited: set[str] | |
| game_completed: bool | |
| error: Optional[str] = None | |
| history: list[tuple[str, str, str]] = field(default_factory=list) | |
| # ============================================================================= | |
| # Prompts: Agent, Critic, StrategyGen | |
| # ============================================================================= | |
| AGENT_PROMPT = """You are an expert text adventure player. MAXIMIZE YOUR SCORE and EXPLORE NEW LOCATIONS. | |
| Evaluation (lexicographic): (score, locations). Both matter. Passing: score >= 2 and >= 40 locations in 100 steps. Strong: score 3-4 and many locations. | |
| CONTEXT/MEMORY MATTERS MOST – The main differentiator is avoiding loops and deadlocks: | |
| - NEVER repeat an action that failed at the current location. If "Recently failed", "Avoid", or "DO NOT repeat at current location" is shown, do NOT propose those actions. | |
| - Track visited locations: prioritize exploring NEW rooms. More locations = stronger result. | |
| - Loop detector: if the same or very similar observation appears again within a few steps, you are in a loop – try a DIFFERENT direction or action immediately. | |
| - Dead ends: if an action led to no progress (no score, no movement), do not retry it at this location. | |
| AVAILABLE MCP TOOLS: | |
| - play_action: Execute game commands (north, take lamp, open mailbox, get up, etc.) | |
| - memory: Get current state from Z-machine | |
| - inventory: Get items from Z-machine | |
| - get_map: Explored locations | |
| - get_context_for_agent: Call this! Returns "Locations visited: N" and "DO NOT repeat at current location: [list]". Use it to avoid loops. | |
| When the prompt shows: | |
| - "Locations visited: N" – aim to increase N; explore unmapped directions. | |
| - "Avoid (recently failed): [list]" or "NEVER repeat at current location: [list]" – do NOT propose any action from that list. | |
| - Same room/observation as a recent step – you are looping; try a different direction. | |
| CRITICAL: You MUST respond in this exact format (no markdown, no extra text): | |
| THOUGHT: <one sentence about what to do next> | |
| TOOL: play_action | |
| ARGS: {"action": "<command>"} | |
| Universal rules (apply to any text adventure): | |
| - If game says "get out of bed first" or "have to get up": try get up, stand | |
| - If "too dark" or "can't see": light lamp, take lamp | |
| - If "can't go that way": try different direction | |
| - If "don't understand": try simpler verb (look, examine, take X) | |
| - Explore directions (north, south, east, west). Take items. Do NOT repeat same action in a loop.""" | |
| AGENT_COMBAT_PROMPT_APPEND = """ | |
| COMBAT PRIORITY: If sword/lamp is GLOWING or enemy is attacking, use ONLY combat actions: | |
| attack <enemy> with sword, kill <enemy>, fight <enemy>. Do NOT explore or take items until combat ends.""" | |
| AGENT_EMERGENCY_PROMPT_APPEND = """ | |
| EMERGENCY: Score stagnation - prioritize known score sources (trophy case, window entry, paintings). | |
| Try unmapped directions. Drop heavy items to take valuable ones. Avoid wasting turns on non-existent enemies.""" | |
| AGENT_MULTI_CANDIDATES_PROMPT = """Propose 2 different actions as ALTERNATIVE1 and ALTERNATIVE2. | |
| Format: | |
| THOUGHT: <brief reasoning> | |
| ALTERNATIVE1: <first action> | |
| ALTERNATIVE2: <second action> | |
| Use TOOL/ARGS for the first one only.""" | |
| CRITIC_PROMPT = """You evaluate whether a proposed game action is good. | |
| Given: current observation, valid actions from Z-machine, proposed action. | |
| Score -1.0 to 1.0: -1=bad (invalid, repeated, no progress, wrong object), 1=good (valid, progresses). | |
| Use negative scores for actions that repeat failed commands or reference non-existent objects. | |
| Respond in one line: SCORE: <-1.0 to 1.0> REASON: <brief reason> | |
| If action is in valid_actions or is a common command (look, north, take X), score >= 0.6. | |
| If object in action is NOT in valid_actions, score <= -0.5.""" | |
| STRATEGY_PROMPT = """Analyze this gameplay history and extract 3-5 strategic insights. | |
| Format each as a short rule. Example: "In dark games, get lamp before exploring." | |
| Output only the insights, one per line.""" | |
| # ============================================================================= | |
| # StudentAgent: Full ZorkGPT-Lite Orchestrator | |
| # ============================================================================= | |
| class StudentAgent: | |
| """ | |
| Full orchestrator: Extractor (Z-machine) -> Agent -> Critic (Z-machine + LLM) -> Execute. | |
| StrategyGen updates knowledge_base every 12 turns. | |
| Override policy, score stagnation emergency, combat heuristic, object-tree validation. | |
| """ | |
| CRITIC_THRESHOLD = 0.5 | |
| CRITIC_OVERRIDE_MIN = -0.5 # we override rejection if score >= this and action is exploration | |
| MAX_CRITIC_RETRIES = 3 | |
| STRATEGY_UPDATE_INTERVAL = 12 | |
| VALID_ACTIONS_TIMEOUT = 0.8 | |
| STAGNATION_DEATH_THRESHOLD = 35 # Zork-like: ~30-40 turns without score -> death | |
| EMERGENCY_TURNS_LEFT = 10 # enter emergency when turns_until_death <= this | |
| def __init__(self): | |
| self.history: list[dict] = [] | |
| self.recent_actions: list[str] = [] | |
| self.failed_actions: set[str] = set() # we avoid repeating actions that failed | |
| self.score: int = 0 | |
| self.max_score: int = 350 | |
| self.steps_without_score: int = 0 | |
| self.knowledge_base: str = "General: Explore, take items, use lamp before dark. Try get up if stuck. Try east/north when south fails." | |
| self.seen_state_hashes: dict[str, int] = {} # state_hash -> last step seen (for loop detection) | |
| self.unmapped_directions: set[str] = set() | |
| self.failed_actions_at_state: dict[str, set[str]] = {} # state_hash -> failed actions (never repeat) | |
| async def run( | |
| self, | |
| client, | |
| game: str, | |
| max_steps: int, | |
| seed: int, | |
| verbose: bool = False, | |
| trace_path: Optional[str] = None, | |
| ) -> RunResult: | |
| """Run the full orchestrator loop. If trace_path is set, write step-level JSON for analysis.""" | |
| self._current_game = game | |
| locations_visited = set() | |
| history = [] | |
| moves = 0 | |
| trace_steps: list[dict] = [] | |
| tool_names = [t.name for t in await client.list_tools()] | |
| self.failed_actions = set() | |
| self.steps_without_score = 0 | |
| self.seen_state_hashes = {} | |
| self.failed_actions_at_state = {} | |
| # we get initial observation | |
| result = await client.call_tool("play_action", {"action": "look"}) | |
| observation = self._extract_result(result) | |
| loc = observation.split("\n")[0] if observation else "Unknown" | |
| locations_visited.add(loc) | |
| if verbose: | |
| print(f"\n{observation}") | |
| context = {} | |
| for step in range(1, max_steps + 1): | |
| # we print progress to stderr so batch runs show activity (every step) | |
| print(f" step {step}/{max_steps} score={self.score}", file=sys.stderr, flush=True) | |
| # we extract context from Z-machine (no LLM) | |
| context = await self._extract_context(client) | |
| context["game"] = game | |
| state_hash = context.get("state_hash", "") | |
| context["state_hash"] = state_hash | |
| self.seen_state_hashes[state_hash[:64]] = step | |
| # we build agent prompt (with combat/emergency appendices) | |
| prompt = self._build_agent_prompt(observation, context) | |
| thought, tool_name, tool_args = "No reasoning", "play_action", {"action": "look"} | |
| action = "look" | |
| # we get action from Agent LM (max_tokens 250 for small models) | |
| print(f" [LLM] step {step}/{max_steps} action...", file=sys.stderr, flush=True) | |
| response = call_llm(prompt, AGENT_PROMPT, seed + step, max_tokens=250) | |
| if not response.strip(): | |
| response = self._heuristic_action(observation) | |
| if verbose: | |
| print(f"[DEBUG] LLM empty, heuristic: {response[:80]}") | |
| thought, tool_name, tool_args = self._parse_response(response, tool_names) | |
| tool_name, tool_args = self._validate_tool_call(tool_name, tool_args, tool_names) | |
| if tool_name == "play_action": | |
| action = tool_args.get("action", "look") | |
| print(f" [LLM] -> {action[:40]}", file=sys.stderr, flush=True) | |
| # we apply combat heuristic: if sword glowing, bias toward attack | |
| if self._is_combat_situation(observation) and not any( | |
| a.lower().startswith(("attack", "kill", "fight")) for a in self.recent_actions[-2:] | |
| ): | |
| combat_action = self._combat_heuristic_action(observation) | |
| if combat_action: | |
| action = combat_action | |
| tool_args = {"action": action} | |
| if verbose: | |
| print(f"[COMBAT] Biased toward attack: {action}") | |
| # we run Critic: object-tree check first, then fast check, then LLM with override policy | |
| accepted = False | |
| override_reason: Optional[str] = None | |
| obj_valid, obj_msg = self._object_tree_validation(action, context.get("valid_actions", "")) | |
| if not obj_valid and obj_msg: | |
| if verbose: | |
| print(f"[CRITIC] Object tree: {obj_msg}") | |
| fast_ok = self._critic_fast_check(action, context.get("valid_actions", "")) | |
| if fast_ok and obj_valid: | |
| accepted = True | |
| critic_score = 0.5 | |
| for attempt in range(self.MAX_CRITIC_RETRIES): | |
| if accepted: | |
| break | |
| critic_prompt = f"""Observation: {observation[:300]} | |
| Valid actions: {context.get('valid_actions', 'unknown')} | |
| Proposed: {action} | |
| Score and reason?""" | |
| print(f" [LLM] critic attempt {attempt+1}...", file=sys.stderr, flush=True) | |
| critic_resp = call_llm(critic_prompt, CRITIC_PROMPT, seed + step + attempt, max_tokens=80) | |
| critic_score = self._parse_critic_score(critic_resp) | |
| if critic_score >= self.CRITIC_THRESHOLD: | |
| accepted = True | |
| break | |
| # we override policy: when critic uncertain but action is exploration-related, override | |
| should_override, override_reason = self._should_override_critic_rejection( | |
| action, critic_score, context, observation | |
| ) | |
| if should_override: | |
| accepted = True | |
| if verbose: | |
| print(f"[OVERRIDE] {override_reason} (score {critic_score:.2f})") | |
| break | |
| if attempt < self.MAX_CRITIC_RETRIES - 1: | |
| # we get multiple candidates when rejected (alternative actions) | |
| feedback = f"Action '{action}' rejected (score {critic_score:.1f}). Propose 2 different actions." | |
| prompt = self._build_agent_prompt(observation, context, feedback) | |
| response = call_llm(prompt, AGENT_PROMPT, seed + step + attempt) | |
| candidates = self._parse_multiple_actions(response) | |
| if len(candidates) >= 2: | |
| best_action, best_score = action, critic_score | |
| for cand in candidates[:3]: | |
| if cand == action: | |
| continue | |
| cp = f"Observation: {observation[:250]}\nValid: {context.get('valid_actions','')[:150]}\nProposed: {cand}\nScore?" | |
| cr = call_llm(cp, CRITIC_PROMPT, seed + step + attempt + 1, max_tokens=60) | |
| cs = self._parse_critic_score(cr) | |
| if cs > best_score: | |
| best_action, best_score = cand, cs | |
| action = best_action | |
| tool_args = {"action": action} | |
| else: | |
| thought, tool_name, tool_args = self._parse_response(response, tool_names) | |
| tool_name, tool_args = self._validate_tool_call(tool_name, tool_args, tool_names) | |
| if tool_name == "play_action": | |
| action = tool_args.get("action", "look") | |
| if self._is_combat_situation(observation): | |
| combat_action = self._combat_heuristic_action(observation) | |
| if combat_action: | |
| action = combat_action | |
| tool_args = {"action": action} | |
| else: | |
| accepted = True | |
| # we loop detection: try result-based heuristic first, then generic verb cycle | |
| if len(self.recent_actions) >= 3 and len(set(self.recent_actions[-3:])) == 1: | |
| res = self._result_based_heuristic(observation) | |
| if res is not None: | |
| action = res | |
| else: | |
| action = self._generic_verb_cycle() | |
| tool_args = {"action": action} | |
| if verbose: | |
| print(f"[WARNING] Loop detected - trying '{action}' instead") | |
| # we skip actions that failed at current location (never repeat) | |
| key = state_hash[:64] if state_hash else "" | |
| failed_at_loc = self.failed_actions_at_state.get(key, set()) | |
| if action.lower() in failed_at_loc or action.lower() in self.failed_actions: | |
| action = self._generic_verb_cycle(extra_skip=failed_at_loc) | |
| tool_args = {"action": action} | |
| if verbose: | |
| print(f"[CONTEXT] Skipping failed action, trying {action}") | |
| # we loop detection via state hash: if we've seen this state before, force different action | |
| if key and key in self.seen_state_hashes and self.seen_state_hashes[key] < step - 2: | |
| if action.lower() in [a.lower() for a in self.recent_actions[-4:]]: | |
| action = self._generic_verb_cycle(extra_skip=failed_at_loc) | |
| tool_args = {"action": action} | |
| if verbose: | |
| print(f"[LOOP] State revisited, trying {action}") | |
| # we prefer valid_actions when stuck (no score for many steps) | |
| if self.steps_without_score >= 5 and context.get("valid_actions"): | |
| va = context["valid_actions"].lower() | |
| for cand in ["take all", "take lamp", "take keys", "open", "examine", "north", "east"]: | |
| if cand in va and cand not in self.failed_actions: | |
| if cand not in [a.lower() for a in self.recent_actions[-3:]]: | |
| action = cand | |
| tool_args = {"action": action} | |
| break | |
| # we emergency mode: when turns_until_death is low, prioritize score sources and unmapped exits | |
| # skip stagnation logic for scoreless games (max_score 0) | |
| game_max = self._parse_max_score_from_context(context) | |
| turns_until_death = self.STAGNATION_DEATH_THRESHOLD - self.steps_without_score | |
| if game_max == 0: | |
| turns_until_death = 999 # disable emergency for scoreless games | |
| if turns_until_death <= self.EMERGENCY_TURNS_LEFT and turns_until_death > 0: | |
| em_action = self._emergency_heuristic(observation, context) | |
| if em_action: | |
| action = em_action | |
| tool_args = {"action": action} | |
| if verbose: | |
| print(f"[EMERGENCY] {turns_until_death} turns left - trying {action}") | |
| self.recent_actions.append(action) | |
| if len(self.recent_actions) > 10: | |
| self.recent_actions = self.recent_actions[-10:] | |
| # we track failed actions (rejection, no movement, no score) | |
| if self._is_failure_result(observation, action): | |
| self.failed_actions.add(action.lower()) | |
| if state_hash: | |
| key = state_hash[:64] | |
| if key not in self.failed_actions_at_state: | |
| self.failed_actions_at_state[key] = set() | |
| self.failed_actions_at_state[key].add(action.lower()) | |
| try: | |
| await client.call_tool("record_failed_action", {"state_hash": state_hash, "action": action}) | |
| except Exception: | |
| pass | |
| else: | |
| self.failed_actions.discard(action.lower()) | |
| # we track score progress and reinforce what worked | |
| old_score = self.score | |
| self._update_score(observation) | |
| if self.score > old_score: | |
| self.steps_without_score = 0 | |
| if len(self.knowledge_base) < 800: | |
| self.knowledge_base = self.knowledge_base + f"\nScore: {action} worked." | |
| else: | |
| self.steps_without_score += 1 | |
| if verbose: | |
| print(f"\n--- Step {step} ---") | |
| print(f"[THOUGHT] {thought}") | |
| print(f"[TOOL] {tool_name}({tool_args})") | |
| score_before = self.score | |
| obs_pre = observation[:400] # we save before overwriting | |
| print(f" [exec] step {step} -> {action[:30]}...", file=sys.stderr, flush=True) | |
| try: | |
| result = await client.call_tool(tool_name, tool_args) | |
| observation = self._extract_result(result) | |
| moves += 1 | |
| except Exception as e: | |
| observation = f"Error: {e}" | |
| if verbose: | |
| print(f"[ERROR] {e}") | |
| print(f" [done] step {step} score={self.score}", file=sys.stderr, flush=True) | |
| loc = observation.split("\n")[0] if observation else "Unknown" | |
| locations_visited.add(loc) | |
| self._update_score(observation) | |
| score_after = self.score | |
| if trace_path: | |
| trace_steps.append({ | |
| "step": step, | |
| "observation_pre": obs_pre, | |
| "action": action, | |
| "observation_post": observation[:400], | |
| "score_before": score_before, | |
| "score_after": score_after, | |
| "reward": score_after - score_before, | |
| "location": loc, | |
| "valid_actions": (context.get("valid_actions") or "")[:200], | |
| "critic_score": critic_score, | |
| "override": override_reason or False, | |
| }) | |
| history.append((thought, f"{tool_name}({tool_args})", observation[:100])) | |
| self.history.append({"step": step, "thought": thought, "action": action, "result": observation[:200]}) | |
| if len(self.history) > 20: | |
| self.history = self.history[-20:] | |
| if verbose: | |
| print(f"[RESULT] {observation[:200]}...") | |
| # we update knowledge_base every N turns (StrategyGen) | |
| if step % self.STRATEGY_UPDATE_INTERVAL == 0 and self.history: | |
| strategy_hist = "\n".join([f"Step {h['step']}: {h['action']} -> {h['result'][:80]}" for h in self.history[-15:]]) | |
| strat_prompt = f"History:\n{strategy_hist}\n\nCurrent score: {self.score}\nExtract insights:" | |
| try: | |
| insights = call_llm(strat_prompt, STRATEGY_PROMPT, seed + step, max_tokens=150) | |
| if insights.strip(): | |
| self.knowledge_base = self.knowledge_base + "\n" + insights.strip()[:300] | |
| except Exception: | |
| pass | |
| if self._is_game_over(observation): | |
| if verbose: | |
| print("\n*** GAME OVER ***") | |
| break | |
| if trace_path: | |
| trace_obj = { | |
| "game": game, | |
| "seed": seed, | |
| "max_steps": max_steps, | |
| "final_score": self.score, | |
| "final_moves": moves, | |
| "game_over": self._is_game_over(observation), | |
| "steps": trace_steps, | |
| } | |
| Path(trace_path).parent.mkdir(parents=True, exist_ok=True) | |
| with open(trace_path, "w") as f: | |
| json.dump(trace_obj, f, indent=None) | |
| return RunResult( | |
| final_score=self.score, | |
| max_score=self.max_score, | |
| moves=moves, | |
| locations_visited=locations_visited, | |
| game_completed=self._is_game_over(observation), | |
| history=history, | |
| ) | |
| async def _extract_context(self, client) -> dict: | |
| """Extractor: Z-machine data via MCP tools (no LLM).""" | |
| ctx = {} | |
| tools_to_try = [ | |
| ("memory", "memory"), | |
| ("inventory", "inventory"), | |
| ("get_map", "map"), | |
| ("get_state_hash", "state_hash_raw"), | |
| ("get_context_for_agent", "context_summary"), | |
| ] | |
| for tool_name, key in tools_to_try: | |
| try: | |
| r = await client.call_tool(tool_name, {}) | |
| ctx[key] = self._extract_result(r) | |
| except Exception: | |
| ctx[key] = "" | |
| # we enable get_valid_actions by default for critic; set USE_VALID_ACTIONS=false to disable (spacy can block) | |
| if os.getenv("USE_VALID_ACTIONS", "true").lower() not in ("false", "0", "no"): | |
| try: | |
| r = await asyncio.wait_for( | |
| client.call_tool("get_valid_actions", {}), | |
| timeout=self.VALID_ACTIONS_TIMEOUT, | |
| ) | |
| ctx["valid_actions"] = self._extract_result(r) | |
| except (asyncio.TimeoutError, Exception): | |
| ctx["valid_actions"] = "" | |
| else: | |
| ctx["valid_actions"] = "" | |
| # we parse raw state hash for loop detection (strip "State hash: " prefix) | |
| raw = ctx.get("state_hash_raw", "") | |
| if raw and ":" in raw: | |
| ctx["state_hash"] = raw.split(":", 1)[1].strip().split("...")[0].strip() | |
| else: | |
| ctx["state_hash"] = raw[:64] if raw else "" | |
| return ctx | |
| def _build_agent_prompt(self, observation: str, context: dict, feedback: str = "") -> str: | |
| """Build agent prompt with context, combat and emergency appendices.""" | |
| parts = [f"Knowledge base:\n{self.knowledge_base[:500]}\n"] | |
| parts.append(f"Current score: {self.score}") | |
| game_max = self._parse_max_score_from_context(context) | |
| turns_until_death = self.STAGNATION_DEATH_THRESHOLD - self.steps_without_score | |
| if game_max == 0: | |
| turns_until_death = 999 | |
| if turns_until_death <= self.EMERGENCY_TURNS_LEFT and turns_until_death > 0: | |
| parts.append(f"\n[EMERGENCY] {turns_until_death} turns until score-stagnation death! Prioritize known score sources.") | |
| if self._is_combat_situation(observation): | |
| parts.append(AGENT_COMBAT_PROMPT_APPEND) | |
| if context.get("valid_actions"): | |
| parts.append(f"\nValid actions (prefer these): {context['valid_actions'][:200]}") | |
| # we add learned hints for current room when available | |
| room = (observation or "").strip().split("\n")[0][:80] | |
| learned_hint = self._learned_heuristic(room, context) | |
| if learned_hint: | |
| parts.append(f"\n[Learned] In this room try: {learned_hint}") | |
| if context.get("memory"): | |
| parts.append(f"\nZ-machine state:\n{context['memory'][:350]}") | |
| if context.get("map"): | |
| parts.append(f"\nMap:\n{context['map'][:250]}") | |
| if context.get("inventory"): | |
| parts.append(f"\n{context['inventory']}") | |
| if context.get("context_summary"): | |
| parts.append(f"\n{context['context_summary']}") | |
| if self.failed_actions: | |
| parts.append(f"\nAvoid (recently failed): {', '.join(list(self.failed_actions)[:8])}") | |
| key = (context.get("state_hash") or "")[:64] | |
| failed_at_loc = self.failed_actions_at_state.get(key, set()) | |
| if failed_at_loc: | |
| parts.append(f"\nNEVER repeat at current location: {', '.join(list(failed_at_loc)[:8])}") | |
| if self.history: | |
| parts.append("\nRecent:") | |
| for h in self.history[-4:]: | |
| parts.append(f" > {h.get('action','?')} -> {h.get('result','')[:55]}...") | |
| if feedback: | |
| parts.append(f"\n[FEEDBACK] {feedback}") | |
| parts.append(f"\nCurrent observation:\n{observation}") | |
| parts.append("\nWhat do you do next?") | |
| return "\n".join(parts) | |
| def _critic_fast_check(self, action: str, valid_actions_str: str) -> bool: | |
| """Fast validation: is action likely valid?""" | |
| action_lower = action.lower().strip() | |
| if valid_actions_str and "valid actions:" in valid_actions_str.lower(): | |
| va = valid_actions_str.lower() | |
| if action_lower in va or any(action_lower.startswith(a.strip()) for a in va.split(",")[:20] if a.strip()): | |
| return True | |
| verb = action_lower.split()[0] if action_lower.split() else "" | |
| if verb in ["look", "inventory", "north", "south", "east", "west", "take", "open", "examine"]: | |
| return True | |
| common = ["look", "inventory", "north", "south", "east", "west", "up", "down", "take", "drop", "open", "examine", "read", "get"] | |
| if any(action_lower.startswith(c) for c in common): | |
| return True | |
| return True | |
| def _parse_critic_score(self, resp: str) -> float: | |
| """Parse critic score from response; supports negative scores (-1 to 1).""" | |
| m = re.search(r"SCORE:\s*([-]?[\d.]+)", resp, re.IGNORECASE) | |
| if m: | |
| try: | |
| return max(-1.0, min(1.0, float(m.group(1)))) | |
| except ValueError: | |
| pass | |
| return 0.5 | |
| def _is_combat_situation(self, observation: str) -> bool: | |
| """Check if sword/weapon is glowing or combat is imminent (bias toward attack).""" | |
| r = observation.lower() | |
| if "sword" in r and ("glow" in r or "glowing" in r or "brightly" in r): | |
| return True | |
| if "attack" in r and ("troll" in r or "enemy" in r or "block" in r): | |
| return True | |
| if "brandishing" in r or "blocks all passages" in r: | |
| return True | |
| return False | |
| def _combat_heuristic_action(self, observation: str) -> Optional[str]: | |
| """Return attack action when in combat situation (sword glowing).""" | |
| r = observation.lower() | |
| if "troll" in r: | |
| return "attack troll with sword" | |
| for enemy in ["enemy", "creature", "monster", "grue", "thief"]: | |
| if enemy in r: | |
| return f"attack {enemy} with sword" | |
| return "attack with sword" | |
| def _parse_max_score_from_context(self, context: dict) -> int: | |
| """Parse max_score from memory string (Score: X / Y points); return 350 if not found.""" | |
| mem = context.get("memory", "") or "" | |
| m = re.search(r"Score:\s*\d+\s*/\s*(\d+)\s*points", mem, re.IGNORECASE) | |
| if m: | |
| return int(m.group(1)) | |
| return self.max_score | |
| def _object_tree_validation(self, action: str, valid_actions_str: str) -> Tuple[bool, str]: | |
| """Check if action's object is in valid_actions; return (valid, msg).""" | |
| if not valid_actions_str or "valid actions:" not in valid_actions_str.lower(): | |
| return True, "" | |
| va = valid_actions_str.lower() | |
| action_lower = action.lower().strip() | |
| words = action_lower.split() | |
| if len(words) < 2: | |
| return True, "" | |
| verb, obj = words[0], " ".join(words[1:]) | |
| if verb in ["look", "inventory", "north", "south", "east", "west", "up", "down", "i"]: | |
| return True, "" | |
| if obj in va or action_lower in va: | |
| return True, "" | |
| for part in va.split(","): | |
| part = part.strip() | |
| if part and obj in part: | |
| return True, "" | |
| return False, f"[Object Tree Validation] Object '{obj}' is not present" | |
| def _should_override_critic_rejection( | |
| self, action: str, score: float, context: dict, observation: str | |
| ) -> Tuple[bool, str]: | |
| """Override when critic uncertain but action is exploration-related.""" | |
| if score < self.CRITIC_OVERRIDE_MIN: | |
| return False, "" | |
| action_lower = action.lower().strip() | |
| exploration_verbs = ["north", "south", "east", "west", "up", "down", "look", "examine", "take", "open", "enter", "go"] | |
| verb = action_lower.split()[0] if action_lower.split() else "" | |
| is_exploration = any(action_lower.startswith(v) for v in exploration_verbs) | |
| if is_exploration: | |
| return True, "low_critic_confidence" | |
| if "unmapped" in str(context.get("map", "")).lower() or "exploring" in observation.lower()[:100]: | |
| return True, "exploring_new_locations" | |
| return False, "" | |
| def _load_learned_rules(self) -> dict[str, Any]: | |
| """Lazily load learned_*.json from refs/learned/ (or env LEARNED_DIR).""" | |
| global _LEARNED_CACHE | |
| if _LEARNED_CACHE: | |
| return _LEARNED_CACHE | |
| learned_dir = Path(os.getenv("LEARNED_DIR", "")) or (Path(__file__).resolve().parent.parent / "refs" / "learned") | |
| if not learned_dir.exists(): | |
| return _LEARNED_CACHE | |
| for name in ["learned_location_actions", "learned_score_sources"]: | |
| p = learned_dir / f"{name}.json" | |
| if p.exists(): | |
| try: | |
| with open(p) as f: | |
| _LEARNED_CACHE[name] = json.load(f) | |
| except Exception: | |
| _LEARNED_CACHE[name] = [] | |
| return _LEARNED_CACHE | |
| def _learned_heuristic(self, room: str, context: dict) -> Optional[str]: | |
| """Suggest action from learned (game, room, action) data if available.""" | |
| game = context.get("game", getattr(self, "_current_game", "")) | |
| if not game: | |
| return None | |
| rules = self._load_learned_rules() | |
| va = (context.get("valid_actions") or "").lower() | |
| room_lower = (room or "").lower()[:80] | |
| for key in ["learned_score_sources", "learned_location_actions"]: | |
| items = rules.get(key, []) | |
| for item in items: | |
| if item.get("game") != game: | |
| continue | |
| ctx = (item.get("context") or "").lower()[:80] | |
| if ctx and room_lower and ctx not in room_lower and room_lower not in ctx: | |
| continue | |
| action = (item.get("action") or "").strip().lower() | |
| if not action or action in self.failed_actions: | |
| continue | |
| if action in [a.lower() for a in self.recent_actions[-5:]]: | |
| continue | |
| if va and action not in va and not any(action in a for a in va.split(",")[:25]): | |
| continue | |
| return action | |
| return None | |
| def _emergency_heuristic(self, observation: str, context: dict) -> Optional[str]: | |
| """When score stagnation critical, try known score sources and unmapped directions.""" | |
| room = (observation or "").strip().split("\n")[0][:80] | |
| learned = self._learned_heuristic(room, context) | |
| if learned: | |
| return learned | |
| r = observation.lower() | |
| if "your load is too heavy" in r and "painting" in r: | |
| if "drop leaflet" in [a.lower() for a in self.recent_actions[-3:]]: | |
| return "drop all except lantern" | |
| return "drop leaflet" | |
| if "painting" in r and "unparalleled" in r: | |
| if "take painting" not in self.failed_actions: | |
| return "take painting" | |
| if "trap door" in r and "revealing" in r: | |
| return "open trap door" | |
| if "dusty cover" in r and "trap door" in r: | |
| return "open trap door" | |
| if "rug" in r and "center" in r and "move rug" not in [a.lower() for a in self.recent_actions[-3:]]: | |
| return "move rug" | |
| if "troll" in r and "attack" not in [a.lower()[:6] for a in self.recent_actions[-2:]]: | |
| return "attack troll with sword" | |
| if "can't see any troll" in r: | |
| return None | |
| for d in ["north", "south", "east", "west", "up", "down"]: | |
| if d not in [a.lower() for a in self.recent_actions[-5:]] and d not in self.failed_actions: | |
| return d | |
| return None | |
| # ========================================================================= | |
| # Universal verb vocabulary (game-agnostic) per common_structure.md | |
| # ========================================================================= | |
| # we cycle through these when no result-based pattern matches | |
| UNIVERSAL_VERB_CYCLE = [ | |
| "look", "examine", "inventory", | |
| "north", "south", "east", "west", "up", "down", "in", "out", | |
| "take all", "take lamp", "take keys", "take wallet", "take phone", "take sword", | |
| "open mailbox", "open door", "open", "open chest", | |
| "get up", "stand", "rise", "wake", | |
| "light lamp", "turn on lamp", "wear", "use", "read", | |
| ] | |
| def _result_based_heuristic(self, result_text: str) -> Optional[str]: | |
| """Game-agnostic heuristic from result text per common_structure.md.""" | |
| r = result_text.lower() | |
| # we prioritize taking visible objects when room lists them (905, etc) | |
| if "telephone" in r or ("phone" in r and "take phone" not in [a.lower() for a in self.recent_actions[-3:]]): | |
| if "take phone" not in self.failed_actions: | |
| return "take phone" | |
| if "wallet" in r and "take wallet" not in self.failed_actions and "take wallet" not in [a.lower() for a in self.recent_actions[-3:]]: | |
| return "take wallet" | |
| if "keys" in r and "take keys" not in self.failed_actions and "take keys" not in [a.lower() for a in self.recent_actions[-3:]]: | |
| return "take keys" | |
| # prerequisite: get out of bed, have to get up | |
| if "get out of bed" in r or "out of bed" in r or "have to get up" in r: | |
| return "get up" | |
| if "get up" in r and "have to" in r: | |
| return "get up" | |
| if "stand" in r and ("have to" in r or "must" in r): | |
| return "stand" | |
| # light: too dark, can't see | |
| if "too dark" in r or "can't see" in r or "too dark to" in r: | |
| for cmd in ["light lamp", "turn on lamp", "take lamp"]: | |
| if cmd not in [a.lower() for a in self.recent_actions[-3:]]: | |
| return cmd | |
| return "light lamp" | |
| # movement block: wall, can't go that way | |
| if "can't go" in r or "wall" in r or "can't go that way" in r or "too narrow" in r: | |
| return None # we let generic cycle pick next direction | |
| # parser rejection: don't understand, can't | |
| if "don't understand" in r or "i don't understand" in r: | |
| return "look" | |
| if "you can't" in r or "can't do that" in r: | |
| return "examine" | |
| # object: take X when objects mentioned (keys, wallet, lamp, etc) | |
| common_objects = ["telephone", "phone", "keys", "wallet", "lamp", "sword", "treasure", "book", "rope", "knife", "chest", "dresser"] | |
| for word in common_objects: | |
| if word in r: | |
| action_try = f"take {word}" | |
| if word == "telephone": | |
| action_try = "take phone" | |
| if action_try in self.failed_actions: | |
| continue | |
| recent_lower = [a.lower() for a in self.recent_actions[-5:]] | |
| if action_try not in recent_lower: | |
| return action_try | |
| if "dresser" in r: | |
| if "open dresser" not in [a.lower() for a in self.recent_actions[-3:]]: | |
| return "open dresser" | |
| for obj in self._extract_objects_from_room(result_text): | |
| action_try = f"take {obj}" | |
| if action_try in self.failed_actions: | |
| continue | |
| recent_lower = [a.lower() for a in self.recent_actions[-5:]] | |
| if action_try not in recent_lower: | |
| return action_try | |
| if "mailbox" in r: | |
| recent_lower = [a.lower() for a in self.recent_actions[-3:]] | |
| if "open mailbox" not in recent_lower: | |
| return "open mailbox" | |
| if "open" in r and "closed" in r: | |
| for word in ["door", "mailbox", "chest", "box"]: | |
| if word in r: | |
| return f"open {word}" | |
| if "open" in r and "door" in r: | |
| return "open door" | |
| # no such thing, I don't see | |
| if "don't see" in r or "no such" in r or "can't see any" in r: | |
| return "look" | |
| # only go X (extract direction) | |
| for d in ["north", "south", "east", "west"]: | |
| if f"only go {d}" in r or f"only {d}" in r or f"can only go {d}" in r: | |
| return d | |
| # lostpig / general: south fails with "trouble" -> try east (forest) | |
| if "get in big trouble" in r or "big trouble" in r: | |
| south_count = sum(1 for a in self.recent_actions[-5:] if a.lower() == "south") | |
| if south_count >= 2: | |
| return "east" | |
| return "north" | |
| # forest dark / pig somewhere: try forest first, then try west/south when stuck | |
| if "forest" in r and "dark" in r: | |
| east_count = sum(1 for a in self.recent_actions[-6:] if a.lower() == "east") | |
| north_count = sum(1 for a in self.recent_actions[-6:] if a.lower() == "north") | |
| if east_count + north_count >= 4: | |
| return "west" | |
| if east_count < 2: | |
| return "east" | |
| return "north" | |
| return None | |
| def _extract_objects_from_room(self, text: str) -> list[str]: | |
| """Extract object names from room description for take/examine.""" | |
| r = text.lower() | |
| objects = [] | |
| # patterns: "there is a X", "you see X", "X and Y", "on the X are Y", "X, Y and Z" | |
| for m in re.finditer(r"\b(there is|you see|are|on the \w+ are)\s+[a ]+(\w+)", r): | |
| objects.append(m.group(2)) | |
| for m in re.finditer(r"\b(telephone|phone|wallet|keys|lamp|sword|book|rope|knife|chest|mailbox)\b", r): | |
| objects.append(m.group(1)) | |
| return list(dict.fromkeys(objects))[:5] | |
| def _generic_verb_cycle(self, extra_skip: set[str] | None = None) -> str: | |
| """Return next action from universal cycle, skipping failed actions.""" | |
| skip = self.failed_actions | (extra_skip or set()) | |
| cycle = self.UNIVERSAL_VERB_CYCLE | |
| start = 0 | |
| if self.recent_actions: | |
| last = self.recent_actions[-1].lower() | |
| idx = next((i for i, a in enumerate(cycle) if a == last), -1) | |
| start = (idx + 1) % len(cycle) | |
| for i in range(len(cycle)): | |
| cand = cycle[(start + i) % len(cycle)] | |
| if cand not in skip: | |
| return cand | |
| return "look" | |
| def _heuristic_action(self, observation: str) -> str: | |
| """Heuristic when LLM empty: result-based first, then generic verb cycle.""" | |
| action = self._result_based_heuristic(observation) | |
| if action is None: | |
| action = self._generic_verb_cycle() | |
| return f"THOUGHT: Try {action}.\nTOOL: play_action\nARGS: {{\"action\": \"{action}\"}}" | |
| def _parse_multiple_actions(self, response: str) -> list[str]: | |
| """Extract multiple action candidates from response (ALTERNATIVE1, ALTERNATIVE2, or ARGS blocks).""" | |
| actions: list[str] = [] | |
| for line in response.strip().split("\n"): | |
| lc = line.strip() | |
| for prefix in ["ALTERNATIVE1:", "ALTERNATIVE2:", "ALTERNATIVE3:", "ACTION1:", "ACTION2:"]: | |
| if lc.upper().startswith(prefix.upper()): | |
| rest = lc.split(":", 1)[1].strip().strip('"\'') | |
| if rest and len(rest) < 80: | |
| actions.append(rest) | |
| break | |
| m = re.findall(r'"action"\s*:\s*"([^"]+)"', response) | |
| for a in m: | |
| if a not in actions: | |
| actions.append(a) | |
| return actions[:5] | |
| def _parse_response(self, response: str, valid_tools: list[str]) -> tuple[str, str, dict]: | |
| """Parse LLM response; fallback to extracting action from raw text.""" | |
| thought = "No reasoning provided" | |
| tool_name = "play_action" | |
| tool_args = {"action": "look"} | |
| for line in response.strip().split("\n"): | |
| lc = line.strip() | |
| lu = lc.upper() | |
| if lu.startswith("THOUGHT:"): | |
| thought = lc.split(":", 1)[1].strip() or thought | |
| elif lu.startswith("TOOL:"): | |
| raw = lc.split(":", 1)[1].strip().lower().replace("**", "").replace("*", "") | |
| raw = raw.split()[0] if raw else "play_action" | |
| tool_name = raw | |
| elif lu.startswith("ARGS:"): | |
| s = lc.split(":", 1)[1].strip().replace("'", '"') | |
| try: | |
| tool_args = json.loads(s) | |
| except json.JSONDecodeError: | |
| m = re.search(r'"action"\s*:\s*"([^"]+)"', s) | |
| if m: | |
| tool_args = {"action": m.group(1)} | |
| # we fallback: if still "look", try to extract action from raw response | |
| if tool_args.get("action", "look") == "look" and response.strip(): | |
| r = response.lower() | |
| for cmd in ["east", "north", "south", "west", "inventory", "take all", "take lamp"]: | |
| if cmd in r: | |
| tool_args = {"action": cmd} | |
| break | |
| return thought, tool_name, tool_args | |
| def _validate_tool_call(self, tool_name: str, tool_args: dict, valid_tools: list[str]) -> tuple[str, dict]: | |
| """Validate and fix tool call.""" | |
| if tool_name not in valid_tools: | |
| tool_name = "play_action" | |
| if tool_name == "play_action": | |
| action = tool_args.get("action", "look") | |
| invalid = {"check": "examine", "inspect": "examine", "search": "look", "grab": "take", "pick": "take"} | |
| words = action.lower().split() | |
| if words and words[0] in invalid: | |
| words[0] = invalid[words[0]] | |
| action = " ".join(words) | |
| action = action.lower().strip().replace("**", "").replace("*", "") | |
| action = " ".join(action.split()) | |
| tool_args["action"] = action | |
| return tool_name, tool_args | |
| def _extract_result(self, result) -> str: | |
| """Extract text from MCP result.""" | |
| if hasattr(result, "content") and result.content: | |
| return result.content[0].text | |
| if isinstance(result, list) and result: | |
| return result[0].text if hasattr(result[0], "text") else str(result[0]) | |
| return str(result) | |
| def _update_score(self, text: str) -> None: | |
| """Update score from text.""" | |
| for pat in [r"Score:\s*(\d+)", r"\[Score:\s*(\d+)", r"Total:\s*(\d+)"]: | |
| m = re.search(pat, text, re.IGNORECASE) | |
| if m: | |
| self.score = max(self.score, int(m.group(1))) | |
| break | |
| def _is_game_over(self, text: str) -> bool: | |
| """Check game over.""" | |
| t = text.lower() | |
| return any(p in t for p in ["game over", "you have died", "you are dead", "*** you have died ***"]) | |
| def _is_failure_result(self, result: str, action: str) -> bool: | |
| """Check if result indicates action failed (rejection, no progress).""" | |
| r = result.lower() | |
| failure_phrases = [ | |
| "don't understand", "you can't", "can't do that", "can't go that way", | |
| "there is no", "no such", "you'll have to", "have to get", "get out of bed first", | |
| "verb error", "not recognized", "i don't see", "can't see any", | |
| ] | |
| if any(p in r for p in failure_phrases): | |
| return True | |
| if "get in big trouble" in r or "grunk get in big trouble" in r: | |
| return True | |
| return False | |
| async def test_agent(): | |
| """Test the agent locally.""" | |
| from fastmcp import Client | |
| from fastmcp.client.transports import StdioTransport | |
| import sys | |
| from pathlib import Path | |
| server_path = Path(__file__).parent / "mcp_server.py" | |
| env = os.environ.copy() | |
| env["GAME"] = "lostpig" | |
| transport = StdioTransport(command=sys.executable, args=[str(server_path)], env=env) | |
| agent = StudentAgent() | |
| async with Client(transport) as client: | |
| result = await agent.run(client=client, game="lostpig", max_steps=10, seed=42, verbose=True) | |
| print(f"\nFinal: score={result.final_score}, moves={result.moves}") | |
| if __name__ == "__main__": | |
| import asyncio | |
| asyncio.run(test_agent()) | |