malekfeki14 commited on
Commit
2ede02f
·
verified ·
1 Parent(s): 0d5fd5e

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +129 -47
agent.py CHANGED
@@ -10,6 +10,11 @@ Key upgrades:
10
  - Oscillation detection to avoid A<->B loops
11
  - More robust fail/stall detection
12
  - Fully tool-safe (never trusts tool field)
 
 
 
 
 
13
  """
14
 
15
  import os
@@ -49,17 +54,37 @@ def call_llm(prompt: str, system: str, seed: int) -> str:
49
 
50
 
51
  # ==========================================================
52
- # RESULT STRUCTURE
53
  # ==========================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  @dataclass
55
  class RunResult:
56
  final_score: int
57
  max_score: int
58
- moves: int
 
59
  locations_visited: set[str]
60
  game_completed: bool
61
  error: Optional[str] = None
62
- history: list = field(default_factory=list)
63
 
64
 
65
  # ==========================================================
@@ -141,8 +166,8 @@ class StudentAgent:
141
  def __init__(self):
142
  self.tried: Dict[str, Set[str]] = {}
143
  self.failed: Dict[str, Set[str]] = {}
144
- self.score = 0
145
 
 
146
  self.global_failed: Set[str] = set()
147
  self.last_locations: List[str] = []
148
  self.last_actions: List[str] = []
@@ -151,13 +176,21 @@ class StudentAgent:
151
  # MAIN LOOP
152
  # ------------------------------------------------------
153
  async def run(self, client, game, max_steps, seed, verbose=False):
154
- history: List[Tuple[str, str]] = []
155
  locations_visited: Set[str] = set()
156
- moves = 0
157
 
 
 
 
 
158
  obs = self._tool_to_text(await client.call_tool("play_action", {"action": "look"}))
 
 
 
159
 
160
  for step in range(max_steps):
 
 
161
  # Stable location from MCP if available
162
  loc = await self._safe_tool(client, "location", {}, fallback=None)
163
  if not loc:
@@ -181,26 +214,26 @@ class StudentAgent:
181
  action = forced
182
  else:
183
  prompt = f"""Observation:
184
- {obs}
185
-
186
- Location: {loc}
187
-
188
- Inventory: {inv_txt}
189
- Objects here: {objs_here_txt}
190
- Explored locations (map):
191
- {map_txt}
192
-
193
- Memory:
194
- {mem}
195
-
196
- Valid actions:
197
- {valid_txt}
198
-
199
- Already tried here: {sorted(self.tried.get(loc, set()))}
200
- Failed here: {sorted(self.failed.get(loc, set()))}
201
- Recent locations: {self.last_locations}
202
- Recent actions: {self.last_actions[-10:]}
203
- """
204
  llm_out = call_llm(prompt, SYSTEM_PROMPT, seed + step)
205
  action = self._extract_action_only(llm_out)
206
 
@@ -216,47 +249,84 @@ class StudentAgent:
216
  # Record attempt
217
  self.tried.setdefault(loc, set()).add(action)
218
 
 
 
 
 
219
  # Execute
220
  new_obs = self._tool_to_text(await client.call_tool("play_action", {"action": action}))
221
- new_score = self._extract_score(new_obs)
 
 
222
 
223
  failed = self._is_failure(new_obs)
224
- stalled = self._is_stalled(obs, new_obs)
225
 
226
- if failed or (stalled and new_score <= self.score):
 
227
  self.failed.setdefault(loc, set()).add(action)
228
  if self._should_global_fail(action):
229
  self.global_failed.add(action)
230
 
231
- if new_score > self.score:
232
- self.score = new_score
233
- # if it just paid off, don't blacklist it globally by accident
234
  if action in self.global_failed:
235
  self.global_failed.discard(action)
236
 
237
- obs = new_obs
238
- moves += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  self.last_actions.append(action)
240
  self.last_actions = self.last_actions[-20:]
241
- history.append((action, obs[:240]))
242
 
 
 
 
 
243
  if verbose:
244
  print("\n" + "=" * 70)
245
  print(f"STEP {step + 1}/{max_steps}")
246
  print(f"Location: {loc}")
247
  print(f"Action: {action}")
248
- print(f"Score: {self.score}")
 
 
249
  print("-" * 70)
250
- print(obs.strip())
251
  print("=" * 70)
252
 
253
  if "GAME OVER" in obs:
254
  break
255
 
 
 
 
 
 
 
 
 
256
  return RunResult(
257
- final_score=self.score,
258
  max_score=350,
259
- moves=moves,
 
260
  locations_visited=locations_visited,
261
  game_completed=("GAME OVER" in obs),
262
  history=history,
@@ -322,7 +392,7 @@ class StudentAgent:
322
 
323
  def _parse_valid_actions(self, valid_txt: str) -> List[str]:
324
  acts: List[str] = []
325
- for line in valid_txt.splitlines():
326
  line = line.strip()
327
  if not line:
328
  continue
@@ -375,7 +445,7 @@ class StudentAgent:
375
  low = (obs or "").lower()
376
  if "pitch black" in low or "grue" in low or "too dark" in low:
377
  # try to light if possible
378
- for cand in ("light lamp", "turn on lamp", "light lantern", "turn on lantern"):
379
  mapped = self._map_to_valid(cand, valid_actions)
380
  if mapped and mapped in valid_actions:
381
  return mapped
@@ -452,7 +522,8 @@ class StudentAgent:
452
  # ------------------------------------------------------
453
  def _normalize_obs(self, text: str) -> str:
454
  t = (text or "").lower().strip()
455
- t = re.sub(r"\[score:\s*\d+\s*\|\s*moves:\s*\d+\]\s*$", "", t).strip()
 
456
  t = re.sub(r"\s+", " ", t).strip()
457
  return t
458
 
@@ -482,13 +553,24 @@ class StudentAgent:
482
  return i
483
 
484
  # ------------------------------------------------------
485
- # SCORE / LOCATION HELPERS
486
  # ------------------------------------------------------
487
- def _extract_score(self, text: str) -> int:
488
- m = re.search(r"Score:\s*(\d+)", text)
 
 
 
 
 
 
 
489
  if m:
490
- return int(m.group(1))
491
- return self.score
 
 
 
 
492
 
493
  def _best_effort_location(self, text: str) -> str:
494
  lines = [ln.strip() for ln in (text or "").splitlines() if ln.strip()]
 
10
  - Oscillation detection to avoid A<->B loops
11
  - More robust fail/stall detection
12
  - Fully tool-safe (never trusts tool field)
13
+
14
+ FIXES / EXTENSIONS:
15
+ - Stores MUCH richer history (StepRecord) including before/after obs + score/moves deltas
16
+ - Prints the FULL result (new_obs) of each step when verbose=True
17
+ - RunResult.moves now reflects TRUE game moves from server footer when available
18
  """
19
 
20
  import os
 
54
 
55
 
56
  # ==========================================================
57
+ # STEP + RUN RESULT STRUCTURES (RICH HISTORY)
58
  # ==========================================================
59
+ @dataclass
60
+ class StepRecord:
61
+ step: int
62
+ location: str
63
+ action: str
64
+
65
+ observation_before: str
66
+ observation_after: str
67
+
68
+ score_before: int
69
+ score_after: int
70
+
71
+ moves_before: int
72
+ moves_after: int
73
+
74
+ failed: bool
75
+ stalled: bool
76
+
77
+
78
  @dataclass
79
  class RunResult:
80
  final_score: int
81
  max_score: int
82
+ moves: int # true in-game moves when available
83
+ agent_steps: int # number of loop iterations (max_steps cap)
84
  locations_visited: set[str]
85
  game_completed: bool
86
  error: Optional[str] = None
87
+ history: list[StepRecord] = field(default_factory=list)
88
 
89
 
90
  # ==========================================================
 
166
  def __init__(self):
167
  self.tried: Dict[str, Set[str]] = {}
168
  self.failed: Dict[str, Set[str]] = {}
 
169
 
170
+ self.score = 0
171
  self.global_failed: Set[str] = set()
172
  self.last_locations: List[str] = []
173
  self.last_actions: List[str] = []
 
176
  # MAIN LOOP
177
  # ------------------------------------------------------
178
  async def run(self, client, game, max_steps, seed, verbose=False):
179
+ history: List[StepRecord] = []
180
  locations_visited: Set[str] = set()
 
181
 
182
+ # Agent loop iterations (can differ from in-game moves)
183
+ agent_steps = 0
184
+
185
+ # Prime observation
186
  obs = self._tool_to_text(await client.call_tool("play_action", {"action": "look"}))
187
+ # Seed initial score from footer (if present)
188
+ s0, m0 = self._extract_score_moves(obs)
189
+ self.score = s0
190
 
191
  for step in range(max_steps):
192
+ agent_steps += 1
193
+
194
  # Stable location from MCP if available
195
  loc = await self._safe_tool(client, "location", {}, fallback=None)
196
  if not loc:
 
214
  action = forced
215
  else:
216
  prompt = f"""Observation:
217
+ {obs}
218
+
219
+ Location: {loc}
220
+
221
+ Inventory: {inv_txt}
222
+ Objects here: {objs_here_txt}
223
+ Explored locations (map):
224
+ {map_txt}
225
+
226
+ Memory:
227
+ {mem}
228
+
229
+ Valid actions:
230
+ {valid_txt}
231
+
232
+ Already tried here: {sorted(self.tried.get(loc, set()))}
233
+ Failed here: {sorted(self.failed.get(loc, set()))}
234
+ Recent locations: {self.last_locations}
235
+ Recent actions: {self.last_actions[-10:]}
236
+ """
237
  llm_out = call_llm(prompt, SYSTEM_PROMPT, seed + step)
238
  action = self._extract_action_only(llm_out)
239
 
 
249
  # Record attempt
250
  self.tried.setdefault(loc, set()).add(action)
251
 
252
+ # BEFORE snapshot (for richer history)
253
+ obs_before = obs
254
+ score_before, moves_before = self._extract_score_moves(obs_before)
255
+
256
  # Execute
257
  new_obs = self._tool_to_text(await client.call_tool("play_action", {"action": action}))
258
+
259
+ # AFTER snapshot
260
+ score_after, moves_after = self._extract_score_moves(new_obs)
261
 
262
  failed = self._is_failure(new_obs)
263
+ stalled = self._is_stalled(obs_before, new_obs)
264
 
265
+ # Fail/stall bookkeeping
266
+ if failed or (stalled and score_after <= self.score):
267
  self.failed.setdefault(loc, set()).add(action)
268
  if self._should_global_fail(action):
269
  self.global_failed.add(action)
270
 
271
+ # Update score tracker
272
+ if score_after > self.score:
273
+ self.score = score_after
274
  if action in self.global_failed:
275
  self.global_failed.discard(action)
276
 
277
+ # Save step record (FULL before/after result)
278
+ history.append(
279
+ StepRecord(
280
+ step=step + 1,
281
+ location=loc,
282
+ action=action,
283
+ observation_before=obs_before,
284
+ observation_after=new_obs,
285
+ score_before=score_before,
286
+ score_after=score_after,
287
+ moves_before=moves_before,
288
+ moves_after=moves_after,
289
+ failed=failed,
290
+ stalled=stalled,
291
+ )
292
+ )
293
+
294
+ # Update rolling recent lists
295
  self.last_actions.append(action)
296
  self.last_actions = self.last_actions[-20:]
 
297
 
298
+ # IMPORTANT: set obs for next loop
299
+ obs = new_obs
300
+
301
+ # Print the full result of this step
302
  if verbose:
303
  print("\n" + "=" * 70)
304
  print(f"STEP {step + 1}/{max_steps}")
305
  print(f"Location: {loc}")
306
  print(f"Action: {action}")
307
+ print(f"Score: {score_after} (was {score_before})")
308
+ print(f"Moves: {moves_after} (was {moves_before})")
309
+ print(f"failed={failed} stalled={stalled}")
310
  print("-" * 70)
311
+ print(new_obs.strip())
312
  print("=" * 70)
313
 
314
  if "GAME OVER" in obs:
315
  break
316
 
317
+ # Final true score/moves from last obs footer if possible
318
+ final_score, final_moves = self._extract_score_moves(obs)
319
+ if final_score < 0:
320
+ final_score = self.score
321
+ if final_moves < 0:
322
+ # Fallback: if footer missing, estimate from last record
323
+ final_moves = history[-1].moves_after if history and history[-1].moves_after >= 0 else agent_steps
324
+
325
  return RunResult(
326
+ final_score=final_score,
327
  max_score=350,
328
+ moves=final_moves,
329
+ agent_steps=agent_steps,
330
  locations_visited=locations_visited,
331
  game_completed=("GAME OVER" in obs),
332
  history=history,
 
392
 
393
  def _parse_valid_actions(self, valid_txt: str) -> List[str]:
394
  acts: List[str] = []
395
+ for line in (valid_txt or "").splitlines():
396
  line = line.strip()
397
  if not line:
398
  continue
 
445
  low = (obs or "").lower()
446
  if "pitch black" in low or "grue" in low or "too dark" in low:
447
  # try to light if possible
448
+ for cand in ("light lamp", "turn on lamp", "light lantern", "turn on lantern", "light torch"):
449
  mapped = self._map_to_valid(cand, valid_actions)
450
  if mapped and mapped in valid_actions:
451
  return mapped
 
522
  # ------------------------------------------------------
523
  def _normalize_obs(self, text: str) -> str:
524
  t = (text or "").lower().strip()
525
+ # remove server footer if present
526
+ t = re.sub(r"\[score:\s*\d+\s*\|\s*moves:\s*\d+\]\s*$", "", t, flags=re.I).strip()
527
  t = re.sub(r"\s+", " ", t).strip()
528
  return t
529
 
 
553
  return i
554
 
555
  # ------------------------------------------------------
556
+ # SCORE / MOVES / LOCATION HELPERS
557
  # ------------------------------------------------------
558
+ def _extract_score_moves(self, text: str) -> tuple[int, int]:
559
+ """
560
+ Prefer the MCP server footer: [Score: X | Moves: Y]
561
+ Returns (score, moves). Moves=-1 if unavailable.
562
+ """
563
+ if not text:
564
+ return (self.score, -1)
565
+
566
+ m = re.search(r"\[Score:\s*(\d+)\s*\|\s*Moves:\s*(\d+)\]", text)
567
  if m:
568
+ return (int(m.group(1)), int(m.group(2)))
569
+
570
+ # fallback
571
+ m2 = re.search(r"Score:\s*(\d+)", text)
572
+ score = int(m2.group(1)) if m2 else self.score
573
+ return (score, -1)
574
 
575
  def _best_effort_location(self, text: str) -> str:
576
  lines = [ln.strip() for ln in (text or "").splitlines() if ln.strip()]