bouhss commited on
Commit
42e80f7
·
verified ·
1 Parent(s): 92b59a5

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +407 -321
agent.py CHANGED
@@ -1,49 +1,49 @@
1
  """
2
- Student Agent (Best practical submission)
3
-
4
- - Works even if HF_TOKEN is missing (no crash).
5
- - Uses peek_action + get_valid_actions + server meta tags to explore and gain score.
6
- - Uses LLM only as fallback when HF_TOKEN is available.
7
- - Always returns non-zero moves (internal counter).
 
 
 
 
 
 
 
8
  """
9
 
10
  import json
11
  import os
12
  import re
13
- import time
 
14
  from dataclasses import dataclass, field
15
- from typing import Optional, Any
16
- from collections import defaultdict, deque
17
 
18
  from dotenv import load_dotenv
19
  from huggingface_hub import InferenceClient
20
 
21
  load_dotenv()
22
 
23
- LLM_MODEL = "Qwen/Qwen2.5-72B-Instruct"
24
  _hf_token = os.getenv("HF_TOKEN")
25
  LLM_CLIENT = InferenceClient(token=_hf_token) if _hf_token else None
26
 
27
 
28
- def call_llm(prompt: str, system_prompt: str, seed: int, max_tokens: int = 180) -> str:
29
  if LLM_CLIENT is None:
30
- raise RuntimeError("LLM unavailable (HF_TOKEN missing).")
31
- messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}]
32
- for attempt in range(3):
33
- try:
34
- resp = LLM_CLIENT.chat.completions.create(
35
- model=LLM_MODEL,
36
- messages=messages,
37
- temperature=0.0,
38
- max_tokens=max_tokens,
39
- seed=seed,
40
- )
41
- return resp.choices[0].message.content
42
- except Exception:
43
- if attempt < 2:
44
- time.sleep(2 ** attempt)
45
- continue
46
- raise
47
 
48
 
49
  @dataclass
@@ -51,337 +51,423 @@ class RunResult:
51
  final_score: int
52
  max_score: int
53
  moves: int
54
- locations_visited: set[str]
55
  game_completed: bool
56
  error: Optional[str] = None
57
  history: list[tuple[str, str, str]] = field(default_factory=list)
58
 
59
 
60
- SYSTEM_PROMPT = """You are an expert text-adventure agent.
 
 
 
 
61
 
62
- Output EXACTLY:
63
- THOUGHT: ...
64
- TOOL: play_action
65
- ARGS: {"action": "<one candidate action>"}
 
 
66
 
67
- Rules:
68
- - Choose exactly one action from the candidate list.
69
- - Do not invent actions outside the list.
70
- - No extra text, no markdown.
71
- """
72
 
 
 
73
 
74
- MOVE_ALIASES = {"n":"north","s":"south","e":"east","w":"west","u":"up","d":"down","ne":"northeast","nw":"northwest","se":"southeast","sw":"southwest"}
75
- BAD_PREFIXES = ("save", "restore", "quit", "restart", "help", "verbose", "script", "unscript", "version")
76
- BAD_EXACT = {"wait", "z"}
 
 
77
 
78
 
79
  class StudentAgent:
80
- def __init__(self):
81
- self.score = 0
82
- self.max_score = 0
83
- self.moves = 0
84
- self._internal_moves = 0
 
 
 
 
85
 
86
- self.locations_visited: set[str] = set()
87
- self.last_location = "Unknown"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
- self.tried = defaultdict(int)
90
- self.recent_actions = deque(maxlen=10)
91
- self.recent_obs = deque(maxlen=6)
 
 
 
92
 
93
- self.valid_cache = {}
 
 
 
 
 
94
 
95
- async def run(self, client, game: str, max_steps: int, seed: int, verbose: bool = False) -> RunResult:
96
- history: list[tuple[str, str, str]] = []
 
97
 
98
- try:
99
- tools = await client.list_tools()
100
- tool_names = {t.name for t in tools}
101
-
102
- def has(name: str) -> bool:
103
- return name in tool_names
104
-
105
- # initial look
106
- obs = await self._call_tool_text(client, "play_action", {"action": "look"})
107
- self._internal_moves += 1
108
- self._update_from_text(obs)
109
- self.last_location = self._extract_location(obs)
110
- self.locations_visited.add(self.last_location)
111
-
112
- for step in range(1, max_steps + 1):
113
- loc = self._extract_location(obs)
114
- self.last_location = loc
115
- self.locations_visited.add(loc)
116
-
117
- stuck = self._is_stuck(obs)
118
-
119
- valid_actions = self.valid_cache.get(loc, [])
120
- if has("get_valid_actions") and (stuck or not valid_actions or step % 6 == 0):
121
- va_txt = await self._call_tool_text(client, "get_valid_actions", {"limit": 60})
122
- valid_actions = self._parse_valid_actions(va_txt)
123
- if valid_actions:
124
- self.valid_cache[loc] = valid_actions
125
-
126
- inv_txt = ""
127
- if has("inventory") and (step == 1 or stuck or step % 8 == 0):
128
- inv_txt = await self._call_tool_text(client, "inventory", {})
129
-
130
- candidates = self._make_candidates(obs, inv_txt, valid_actions, loc)
131
-
132
- action, thought = None, ""
133
- if has("peek_action") and candidates:
134
- action, thought = await self._choose_by_lookahead(client, loc, obs, candidates)
135
-
136
- if not action:
137
- action, thought = await self._choose_fallback(obs, inv_txt, candidates, seed, step)
138
-
139
- action = self._norm_action(action or "look")
140
-
141
- obs2 = await self._call_tool_text(client, "play_action", {"action": action})
142
- self._internal_moves += 1
143
-
144
- self.tried[(loc, action.lower())] += 1
145
- self.recent_actions.append(action.lower())
146
- self.recent_obs.append((obs2 or "")[:220])
147
-
148
- self._update_from_text(obs2)
149
- new_loc = self._extract_location(obs2)
150
- self.locations_visited.add(new_loc)
151
-
152
- history.append((thought, f"play_action({action})", (obs2 or "")[:260]))
153
- if verbose:
154
- print(f"\n--- step {step} ---\nTHOUGHT: {thought}\nACTION: {action}\n{obs2}")
155
-
156
- obs = obs2
157
- if self._is_game_over(obs):
158
- break
159
-
160
- return RunResult(
161
- final_score=self.score,
162
- max_score=self.max_score,
163
- moves=max(self.moves, self._internal_moves),
164
- locations_visited=set(self.locations_visited),
165
- game_completed=self._is_game_over(obs),
166
- history=history,
167
- )
168
 
169
- except Exception as e:
170
- return RunResult(
171
- final_score=self.score,
172
- max_score=self.max_score,
173
- moves=max(self.moves, self._internal_moves),
174
- locations_visited=set(self.locations_visited),
175
- game_completed=False,
176
- error=f"{type(e).__name__}: {e}",
177
- history=history,
178
- )
179
 
180
- async def _call_tool_text(self, client, tool: str, args: dict) -> str:
181
- r = await client.call_tool(tool, args)
182
- return self._extract_text(r)
183
-
184
- def _extract_text(self, result: Any) -> str:
185
- if result is None:
186
- return ""
187
- if isinstance(result, list) and result:
188
- part = result[0]
189
- if hasattr(part, "text"):
190
- return part.text or ""
191
- if isinstance(part, dict) and "text" in part:
192
- return part["text"] or ""
193
- return str(part)
194
- return str(result)
195
 
196
- def _norm_action(self, a: str) -> str:
197
- a = (a or "").strip()
198
- low = a.lower()
199
- return MOVE_ALIASES.get(low, a)
200
 
201
- def _update_from_text(self, text: str) -> None:
202
- m = re.search(r"\[Score:\s*(\d+)\s*/\s*(\d+)\s*\|\s*Moves:\s*(\d+)\s*\|\s*Location:\s*(.+?)\]", text or "")
203
- if m:
204
- self.score = int(m.group(1))
205
- self.max_score = int(m.group(2))
206
- self.moves = int(m.group(3))
207
- self.last_location = m.group(4).strip()
208
-
209
- def _extract_location(self, text: str) -> str:
210
- m = re.search(r"\|\s*Location:\s*(.+?)\]", text or "")
211
- if m and m.group(1).strip():
212
- return m.group(1).strip()
213
- for line in (text or "").splitlines():
214
- line = line.strip()
215
- if line and not line.startswith("[Score:"):
216
- return line
217
- return "Unknown"
218
-
219
- def _extract_untried(self, text: str) -> list[str]:
220
- m = re.search(r"\[Untried exits:\s*(.+?)\]", text or "")
221
- if not m:
222
- return []
223
- return [self._norm_action(x.strip()).lower() for x in m.group(1).split(",") if x.strip()]
224
-
225
- def _extract_interactions(self, text: str) -> list[str]:
226
- m = re.search(r"\[Interactions:\s*(.+?)\]", text or "")
227
- if not m:
228
- return []
229
- return [x.strip() for x in m.group(1).split(",") if x.strip()]
230
 
231
- def _is_game_over(self, text: str) -> bool:
232
- t = (text or "").lower()
233
- return ("game over" in t) or ("you have died" in t) or ("you are dead" in t)
234
 
235
- def _is_stuck(self, text: str) -> bool:
236
- t = (text or "").lower()
237
- bad = ["i don't understand", "you can't", "that's not", "not a verb", "nothing happens", "beg your pardon"]
238
- rep = len(self.recent_obs) >= 3 and all(self.recent_obs[-1] == x for x in list(self.recent_obs)[-3:])
239
- return any(b in t for b in bad) or rep
240
 
241
- def _parse_valid_actions(self, txt: str) -> list[str]:
242
- out = []
243
- for line in (txt or "").splitlines():
244
- line = line.strip()
245
- if line.startswith("- "):
246
- a = self._norm_action(line[2:].strip())
247
- low = a.lower()
248
- if not a:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  continue
250
- if low.startswith(BAD_PREFIXES) or low in BAD_EXACT:
251
  continue
252
- out.append(a)
253
- # dedup
254
- seen = set()
255
- uniq = []
256
- for a in out:
257
- low = a.lower()
258
- if low not in seen:
259
- seen.add(low)
260
- uniq.append(a)
261
- return uniq
262
-
263
- def _make_candidates(self, obs: str, inv_txt: str, valid_actions: list[str], loc: str) -> list[str]:
264
- candidates, seen = [], set()
265
-
266
- def add(a: str):
267
- a = self._norm_action(a)
268
- low = a.lower().strip()
269
- if not a:
270
- return
271
- if low.startswith(BAD_PREFIXES) or low in BAD_EXACT:
272
- return
273
- if low not in seen:
274
- seen.add(low)
275
- candidates.append(a)
 
 
 
 
 
 
 
 
 
 
 
276
 
277
- # from tags
278
- for d in self._extract_untried(obs):
279
- add(d)
280
- for a in self._extract_interactions(obs):
281
- add(a)
282
 
283
- # from valid actions
284
- for a in valid_actions[:25]:
285
- add(a)
 
 
 
 
 
 
 
 
286
 
287
- # basics
288
- add("look")
289
- add("inventory")
290
- add("take all")
291
 
292
- # avoid too repeated
293
- cleaned = []
294
  for a in candidates:
295
- if list(self.recent_actions).count(a.lower()) >= 3:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
  continue
297
- cleaned.append(a)
298
- return cleaned[:20]
299
 
300
- async def _choose_by_lookahead(self, client, loc: str, obs: str, candidates: list[str]) -> tuple[Optional[str], str]:
301
- base_score = self.score
302
- untried = set(self._extract_untried(obs))
303
 
304
- # shortlist
305
- pr = []
306
- for a in candidates:
307
- low = a.lower().strip()
308
- pr.append((0 if low in untried else 1, self.tried[(loc, low)], low, a))
309
- pr.sort()
310
- shortlist = [x[-1] for x in pr][:10]
311
-
312
- best_a, best_u, best_th = None, -10**18, ""
313
- for a in shortlist:
314
- low = a.lower().strip()
315
- if self.tried[(loc, low)] >= 4:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
  continue
317
- peek = await self._call_tool_text(client, "peek_action", {"action": a})
318
- if self._is_game_over(peek):
319
- u = -1_000_000_000
320
- else:
321
- s_after = base_score
322
- m = re.search(r"\[Score:\s*(\d+)\s*/", peek or "")
323
- if m:
324
- s_after = int(m.group(1))
325
- delta = max(0, s_after - base_score)
326
- loc_after = self._extract_location(peek)
327
- new_loc_bonus = 280 if (loc_after and loc_after not in self.locations_visited and loc_after != self._extract_location(obs)) else 0
328
- untried_bonus = 120 if low in untried else 0
329
- loop_pen = 90 * list(self.recent_actions).count(low)
330
- stuck_pen = 180 if self._is_stuck(peek) else 0
331
- u = delta * 900 + new_loc_bonus + untried_bonus - loop_pen - stuck_pen
332
-
333
- if u > best_u:
334
- best_u, best_a = u, a
335
- best_th = f"Look-ahead chose '{a}' (utility={u})."
336
-
337
- if best_a is None or best_u < -10000:
338
- return None, "Look-ahead no good action; fallback."
339
- return best_a, best_th
340
-
341
- async def _choose_fallback(self, obs: str, inv_txt: str, candidates: list[str], seed: int, step: int) -> tuple[str, str]:
342
- untried = self._extract_untried(obs)
343
- if untried:
344
- return untried[0], "Heuristic: try untried exit."
345
-
346
- if not candidates:
347
- return "look", "No candidates; fallback."
348
-
349
- # LLM only if available
350
  try:
351
- prompt = self._llm_prompt(obs, inv_txt, candidates[:10])
352
- resp = call_llm(prompt, SYSTEM_PROMPT, seed + step, max_tokens=160)
353
- thought, args = self._parse_llm(resp)
354
- act = self._norm_action(str(args.get("action", "")).strip())
355
- canon = {x.lower(): x for x in candidates[:10]}
356
- if act.lower() in canon:
357
- return canon[act.lower()], thought or "LLM chose candidate."
358
  except Exception:
359
  pass
 
360
 
361
- return candidates[0], "Fallback: first candidate."
362
 
363
- def _llm_prompt(self, obs: str, inv_txt: str, candidates: list[str]) -> str:
364
- parts = [
365
- f"Score: {self.score}/{self.max_score} | Moves: {max(self.moves, self._internal_moves)}",
366
- f"Location: {self.last_location}",
367
- "\nCurrent observation:\n" + (obs or "")[:1100],
368
- "\nCandidate actions (choose exactly one):",
369
- ]
370
- for a in candidates:
371
- parts.append(f"- {a}")
372
- return "\n".join(parts)
373
 
374
- def _parse_llm(self, resp: str) -> tuple[str, dict]:
375
- thought = ""
376
- args = {"action": "look"}
377
- m = re.search(r"(?im)^THOUGHT:\s*(.+)$", resp or "")
378
- if m:
379
- thought = m.group(1).strip()
380
- m = re.search(r"(?is)^ARGS:\s*(\{.*\})\s*$", resp or "")
381
- if m:
382
- raw = m.group(1)
383
- try:
384
- args = json.loads(raw)
385
- except Exception:
386
- pass
387
- return thought, args
 
1
  """
2
+ Exploration-first hybrid agent (score + locations) for text adventures.
3
+
4
+ Key points:
5
+ - Deterministic policy driven by server status() JSON.
6
+ - Priority:
7
+ A) Valid untried exits (Jericho-validated) + obs-boosted directions
8
+ B) Bounded suggested_interactions (game-validated)
9
+ C) BFS backtrack to nearest frontier (room with untried exits)
10
+ D) Stuck recovery (look/inventory/examine noun)
11
+ E) Optional single LLM fallback if HF_TOKEN is present (never required)
12
+
13
+ - Uses peek_action (if available) to score a small candidate set quickly.
14
+ - All verbose/debug output goes to stderr only.
15
  """
16
 
17
  import json
18
  import os
19
  import re
20
+ import sys
21
+ from collections import deque
22
  from dataclasses import dataclass, field
23
+ from typing import Optional
 
24
 
25
  from dotenv import load_dotenv
26
  from huggingface_hub import InferenceClient
27
 
28
  load_dotenv()
29
 
30
+ LLM_MODEL = os.getenv("HF_MODEL", "Qwen/Qwen2.5-72B-Instruct")
31
  _hf_token = os.getenv("HF_TOKEN")
32
  LLM_CLIENT = InferenceClient(token=_hf_token) if _hf_token else None
33
 
34
 
35
+ def call_llm(prompt: str, system_prompt: str, seed: int, max_tokens: int = 120) -> str:
36
  if LLM_CLIENT is None:
37
+ raise RuntimeError("HF_TOKEN missing => LLM unavailable")
38
+ r = LLM_CLIENT.chat.completions.create(
39
+ model=LLM_MODEL,
40
+ messages=[{"role": "system", "content": system_prompt},
41
+ {"role": "user", "content": prompt}],
42
+ temperature=0.0,
43
+ max_tokens=max_tokens,
44
+ seed=seed,
45
+ )
46
+ return r.choices[0].message.content or ""
 
 
 
 
 
 
 
47
 
48
 
49
  @dataclass
 
51
  final_score: int
52
  max_score: int
53
  moves: int
54
+ locations_visited: set
55
  game_completed: bool
56
  error: Optional[str] = None
57
  history: list[tuple[str, str, str]] = field(default_factory=list)
58
 
59
 
60
+ # Tunables
61
+ MAX_INTERACTIONS = 4
62
+ STUCK_THRESHOLD = 10
63
+ MEMORY_LEN = 20
64
+ PEEK_K = 6 # lower if too slow; higher can improve decisions but costs time
65
 
66
+ UNSAFE_STARTS = (
67
+ "burn ", "set fire", "ignite ",
68
+ "attack ", "kill ", "hit ", "stab ", "shoot ", "punch ", "fight ",
69
+ "destroy ", "break ", "smash ",
70
+ "eat ",
71
+ )
72
 
73
+ DIR_WORD_RE = re.compile(
74
+ r"\b(north(?:east|west)?|south(?:east|west)?|east|west|"
75
+ r"northeast|northwest|southeast|southwest|up|down|in|out)\b",
76
+ re.IGNORECASE,
77
+ )
78
 
79
+ DISAMBIG_RE = re.compile(r"which do you mean|do you mean|be more specific|what do you want", re.IGNORECASE)
80
+ OPTION_RE = re.compile(r"\bthe\s+([a-z]+(?:\s+[a-z]+)?)", re.IGNORECASE)
81
 
82
+ LLM_SYSTEM = (
83
+ "You play a text adventure game. Propose ONE action (<= 5 words) that helps "
84
+ "explore a new location or gain points. Reply with exactly one line:\n"
85
+ "ACTION: <command>"
86
+ )
87
 
88
 
89
  class StudentAgent:
90
+ def __init__(self) -> None:
91
+ self.visited: set[int] = set()
92
+ self.graph: dict[int, dict[str, int]] = {}
93
+ self.loc_untried: dict[int, list[str]] = {}
94
+ self.interactions_done: dict[int, int] = {}
95
+ self.recent_memory = deque(maxlen=MEMORY_LEN) # (action, loc_id, score, obs_snip)
96
+ self.no_progress_steps = 0
97
+ self.llm_calls = 0
98
+ self.last_action = ""
99
 
100
+ async def run(self, client, game: str, max_steps: int, seed: int, verbose: bool = False) -> RunResult:
101
+ history = []
102
+ moves_taken = 0
103
+ final_score = 0
104
+ max_score = 0
105
+ game_completed = False
106
+ last_status = {}
107
+
108
+ tools = await client.list_tools()
109
+ tool_names = {t.name for t in tools}
110
+ has_peek = "peek_action" in tool_names
111
+
112
+ # prime game
113
+ _ = await client.call_tool("play_action", {"action": "look"})
114
+ moves_taken += 1
115
+ self.last_action = "look"
116
+
117
+ prev_score = 0
118
+ prev_loc = -1
119
+
120
+ while moves_taken < max_steps:
121
+ # status (no move cost)
122
+ try:
123
+ raw = await client.call_tool("status", {})
124
+ status = json.loads(self._text(raw))
125
+ last_status = status
126
+ except Exception:
127
+ status = last_status
128
 
129
+ if not status:
130
+ # emergency
131
+ res = await client.call_tool("play_action", {"action": "look"})
132
+ moves_taken += 1
133
+ history.append(("No status; look", "look", self._text(res)[:140]))
134
+ continue
135
 
136
+ loc_id = int(status["loc_id"])
137
+ score = int(status.get("score", 0))
138
+ final_score = score
139
+ max_score = int(status.get("max_score", max_score) or max_score)
140
+ done = bool(status.get("done", False))
141
+ obs = status.get("last_observation", "") or ""
142
 
143
+ self.visited.add(loc_id)
144
+ self._merge_edges(loc_id, status.get("edges_here", {}) or {})
145
+ self.loc_untried[loc_id] = list(status.get("untried_directions", []) or [])
146
 
147
+ if score == prev_score and loc_id == prev_loc:
148
+ self.no_progress_steps += 1
149
+ else:
150
+ self.no_progress_steps = 0
151
+ prev_score, prev_loc = score, loc_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
+ if done:
154
+ game_completed = True
155
+ break
 
 
 
 
 
 
 
156
 
157
+ thought, action = self._decide(status, seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
+ if has_peek:
160
+ action = await self._peek_pick(client, status, action)
 
 
161
 
162
+ action = self._sanitize_action(action)
163
+ res = await client.call_tool("play_action", {"action": action})
164
+ moves_taken += 1
165
+ obs2 = self._text(res)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
+ self.recent_memory.append((action.lower().strip(), loc_id, score, obs2[:60]))
168
+ self.last_action = action
 
169
 
170
+ if verbose:
171
+ print(
172
+ f"[step] loc={loc_id} score={score} stuck={self.no_progress_steps} -> {action!r}",
173
+ file=sys.stderr,
174
+ )
175
 
176
+ history.append((thought, action, obs2[:160]))
177
+
178
+ if self._is_game_over(obs2):
179
+ game_completed = True
180
+ break
181
+
182
+ # final status (best effort)
183
+ try:
184
+ raw = await client.call_tool("status", {})
185
+ st2 = json.loads(self._text(raw))
186
+ final_score = max(final_score, int(st2.get("score", 0)))
187
+ max_score = max_score or int(st2.get("max_score", 0))
188
+ self.visited.add(int(st2["loc_id"]))
189
+ except Exception:
190
+ pass
191
+
192
+ return RunResult(
193
+ final_score=final_score,
194
+ max_score=max_score,
195
+ moves=moves_taken,
196
+ locations_visited=self.visited,
197
+ game_completed=game_completed,
198
+ history=history,
199
+ )
200
+
201
+ # -----------------
202
+ # decision logic
203
+ # -----------------
204
+ def _decide(self, status: dict, seed: int) -> tuple[str, str]:
205
+ loc_id = int(status["loc_id"])
206
+ obs = status.get("last_observation", "") or ""
207
+ outcomes = status.get("outcomes_here", {}) or {}
208
+
209
+ banned = {str(x).lower().strip() for x in (status.get("banned_actions_here", []) or [])}
210
+ untried = status.get("untried_directions", []) or []
211
+ valid_exits = status.get("valid_exits", []) or []
212
+ suggested = status.get("suggested_interactions", []) or []
213
+
214
+ # 0) disambiguation
215
+ if DISAMBIG_RE.search(obs):
216
+ opt = self._extract_option(obs)
217
+ if opt and not self._repeat_noop(opt, loc_id):
218
+ return "Disambiguate", opt
219
+
220
+ # A1) Jericho-validated untried exits
221
+ untried_set = set(untried)
222
+ obs_dirs = self._mentioned_dirs(obs)
223
+
224
+ for d in valid_exits:
225
+ dl = d.lower().strip()
226
+ if d in untried_set and dl not in banned and not self._repeat_noop(d, loc_id):
227
+ return f"Valid exit {d}", d
228
+
229
+ # A2) obs-boosted untried dirs
230
+ for d in obs_dirs:
231
+ if d in untried_set and d.lower() not in banned and not self._repeat_noop(d, loc_id):
232
+ return f"Obs-boost {d}", d
233
+
234
+ # A3) any untried
235
+ for d in untried:
236
+ if d.lower() not in banned and not self._repeat_noop(d, loc_id):
237
+ return f"Untried {d}", d
238
+
239
+ # B) bounded interactions (game-validated)
240
+ n = self.interactions_done.get(loc_id, 0)
241
+ if n < MAX_INTERACTIONS:
242
+ for a in suggested:
243
+ al = a.lower().strip()
244
+ if al in banned:
245
  continue
246
+ if any(al.startswith(x) for x in UNSAFE_STARTS):
247
  continue
248
+ if a in outcomes:
249
+ continue
250
+ if self._repeat_noop(a, loc_id):
251
+ continue
252
+ self.interactions_done[loc_id] = n + 1
253
+ return f"Interaction {n+1}", a
254
+
255
+ # C) BFS backtrack to frontier
256
+ avoid = self._oscillation_avoid()
257
+ step_dir = self._bfs_step(loc_id, avoid)
258
+ if step_dir:
259
+ return "BFS backtrack", step_dir
260
+
261
+ # D) stuck recovery
262
+ if self.no_progress_steps >= STUCK_THRESHOLD:
263
+ for a in ("look", "inventory"):
264
+ if not self._repeat_noop(a, loc_id):
265
+ return "Stuck recovery", a
266
+ noun = self._extract_noun(obs)
267
+ if noun and not self._repeat_noop(f"examine {noun}", loc_id):
268
+ return "Stuck examine", f"examine {noun}"
269
+
270
+ # E) optional LLM fallback
271
+ if LLM_CLIENT is not None:
272
+ try:
273
+ self.llm_calls += 1
274
+ prompt = self._llm_prompt(status)
275
+ resp = call_llm(prompt, LLM_SYSTEM, seed + self.llm_calls)
276
+ act = self._parse_llm(resp)
277
+ if act and act.lower().strip() not in banned and not self._repeat_noop(act, loc_id):
278
+ return "LLM fallback", act
279
+ except Exception:
280
+ pass
281
+
282
+ return "Fallback", "look"
283
 
284
+ async def _peek_pick(self, client, status: dict, current_action: str) -> str:
285
+ """Use peek_action to score a small candidate set and pick best."""
286
+ loc_id = int(status["loc_id"])
287
+ score = int(status.get("score", 0))
 
288
 
289
+ candidates = []
290
+ if current_action:
291
+ candidates.append(current_action)
292
+
293
+ # add a few candidates
294
+ for d in (status.get("untried_directions", []) or [])[:4]:
295
+ if d not in candidates:
296
+ candidates.append(d)
297
+ for a in (status.get("suggested_interactions", []) or [])[:4]:
298
+ if a not in candidates:
299
+ candidates.append(a)
300
 
301
+ candidates = candidates[:PEEK_K]
302
+ best = current_action
303
+ best_u = -10**18
 
304
 
 
 
305
  for a in candidates:
306
+ try:
307
+ raw = await client.call_tool("peek_action", {"action": a})
308
+ st = json.loads(self._text(raw))
309
+ new_score = int(st.get("score", score))
310
+ new_loc = int(st.get("loc_id", loc_id))
311
+ delta = max(0, new_score - score)
312
+
313
+ if new_loc != loc_id:
314
+ moved_bonus = 600 if (new_loc not in self.visited) else 80
315
+ else:
316
+ moved_bonus = 0
317
+
318
+ repeat_pen = 120 if self._repeat_noop(a, loc_id) else 0
319
+ u = delta * 900 + moved_bonus - repeat_pen
320
+
321
+ if u > best_u:
322
+ best_u = u
323
+ best = a
324
+ except Exception:
325
  continue
 
 
326
 
327
+ return best
 
 
328
 
329
+ # -----------------
330
+ # graph / BFS
331
+ # -----------------
332
+ def _merge_edges(self, loc_id: int, edges_here: dict) -> None:
333
+ if not edges_here:
334
+ return
335
+ node = self.graph.setdefault(loc_id, {})
336
+ for d, nid in edges_here.items():
337
+ try:
338
+ node[str(d)] = int(nid)
339
+ except Exception:
340
+ pass
341
+
342
+ def _oscillation_avoid(self) -> Optional[int]:
343
+ locs = [x[1] for x in self.recent_memory]
344
+ if len(locs) >= 4 and locs[-1] == locs[-3] and locs[-2] == locs[-4]:
345
+ return locs[-2]
346
+ return None
347
+
348
+ def _bfs_step(self, from_loc: int, avoid_loc: Optional[int]) -> Optional[str]:
349
+ frontier = {lid for lid, u in self.loc_untried.items() if u and lid != from_loc}
350
+ if not frontier:
351
+ return None
352
+
353
+ q = deque()
354
+ seen = {from_loc}
355
+
356
+ for d, nid in self.graph.get(from_loc, {}).items():
357
+ if nid not in seen and nid != avoid_loc:
358
+ q.append((nid, d))
359
+ seen.add(nid)
360
+
361
+ while q:
362
+ cur, first_dir = q.popleft()
363
+ if cur in frontier:
364
+ return first_dir
365
+ for d, nid in self.graph.get(cur, {}).items():
366
+ if nid not in seen:
367
+ seen.add(nid)
368
+ q.append((nid, first_dir))
369
+ return None
370
+
371
+ # -----------------
372
+ # loop / parsing helpers
373
+ # -----------------
374
+ def _repeat_noop(self, action: str, loc_id: int) -> bool:
375
+ a = (action or "").lower().strip()
376
+ return any(prev_a == a and prev_loc == loc_id for (prev_a, prev_loc, _sc, _o) in self.recent_memory)
377
+
378
+ def _mentioned_dirs(self, obs: str) -> list[str]:
379
+ out = []
380
+ for m in DIR_WORD_RE.finditer(obs or ""):
381
+ d = m.group(1).lower()
382
+ if d not in out:
383
+ out.append(d)
384
+ return out
385
+
386
+ def _extract_option(self, obs: str) -> Optional[str]:
387
+ m = OPTION_RE.search(obs or "")
388
+ if m:
389
+ return m.group(1).strip().lower()
390
+ return None
391
+
392
+ def _extract_noun(self, obs: str) -> Optional[str]:
393
+ m = re.search(r"\bthe\s+([a-z]{3,})\b", (obs or "").lower())
394
+ if m:
395
+ noun = m.group(1)
396
+ if noun not in CANONICAL_DIR_SET:
397
+ return noun
398
+ return None
399
+
400
+ def _sanitize_action(self, a: str) -> str:
401
+ a = (a or "").strip()
402
+ a = re.sub(r"[`\"']", "", a)
403
+ a = re.sub(r"\s+", " ", a).strip()
404
+ words = a.split()[:6]
405
+ return " ".join(words) if words else "look"
406
+
407
+ def _llm_prompt(self, status: dict) -> str:
408
+ inv = ", ".join(status.get("inventory", [])) or "empty"
409
+ tried = ", ".join(list((status.get("outcomes_here") or {}).keys())[:20]) or "none"
410
+ banned = ", ".join(status.get("banned_actions_here", [])) or "none"
411
+ return (
412
+ f"Location: {status.get('loc_name')} (id={status.get('loc_id')})\n"
413
+ f"Score: {status.get('score')}/{status.get('max_score')} Moves: {status.get('moves')}\n"
414
+ f"Inventory: {inv}\n"
415
+ f"Untried dirs: {', '.join((status.get('untried_directions') or [])[:12])}\n"
416
+ f"Tried here: {tried}\n"
417
+ f"BANNED: {banned}\n\n"
418
+ f"Observation:\n{(status.get('last_observation') or '')[:500]}\n"
419
+ )
420
+
421
+ def _parse_llm(self, resp: str) -> str:
422
+ for line in (resp or "").splitlines():
423
+ line = line.strip()
424
+ if not line:
425
  continue
426
+ if line.upper().startswith("ACTION:"):
427
+ line = line.split(":", 1)[1].strip()
428
+ line = line.lower()
429
+ m = re.match(
430
+ r"^(?:go\s+)?(north(?:east|west)?|south(?:east|west)?|east|west|up|down|in|out)\b",
431
+ line,
432
+ )
433
+ if m:
434
+ return m.group(1)
435
+ return " ".join(line.split()[:5])
436
+ return "look"
437
+
438
+ def _is_game_over(self, text: str) -> bool:
439
+ t = (text or "").lower()
440
+ return any(x in t for x in ("game over", "you have died", "you are dead", "you have won"))
441
+
442
+ def _text(self, result) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
443
  try:
444
+ if hasattr(result, "content") and result.content:
445
+ return result.content[0].text
446
+ if isinstance(result, list) and result:
447
+ return result[0].text
 
 
 
448
  except Exception:
449
  pass
450
+ return str(result)
451
 
 
452
 
453
+ # Optional smoke-test
454
+ async def _test() -> None:
455
+ from fastmcp import Client
456
+ from fastmcp.client.transports import StdioTransport
457
+ import sys as _sys
458
+ import os as _os
 
 
 
 
459
 
460
+ transport = StdioTransport(
461
+ command=_sys.executable,
462
+ args=[_os.path.join(_os.path.dirname(__file__), "mcp_server.py")],
463
+ env={**_os.environ, "GAME": "lostpig"},
464
+ )
465
+ agent = StudentAgent()
466
+ async with Client(transport) as client:
467
+ res = await agent.run(client, game="lostpig", max_steps=30, seed=42, verbose=True)
468
+ print(f"Score: {res.final_score}/{res.max_score} | Moves: {res.moves} | Locations: {len(res.locations_visited)}", file=sys.stderr)
469
+
470
+
471
+ if __name__ == "__main__":
472
+ import asyncio
473
+ asyncio.run(_test())