Spaces:
Sleeping
Sleeping
Update agent.py
Browse files
agent.py
CHANGED
|
@@ -10,6 +10,11 @@ Key upgrades:
|
|
| 10 |
- Oscillation detection to avoid A<->B loops
|
| 11 |
- More robust fail/stall detection
|
| 12 |
- Fully tool-safe (never trusts tool field)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
"""
|
| 14 |
|
| 15 |
import os
|
|
@@ -49,17 +54,37 @@ def call_llm(prompt: str, system: str, seed: int) -> str:
|
|
| 49 |
|
| 50 |
|
| 51 |
# ==========================================================
|
| 52 |
-
# RESULT
|
| 53 |
# ==========================================================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
@dataclass
|
| 55 |
class RunResult:
|
| 56 |
final_score: int
|
| 57 |
max_score: int
|
| 58 |
-
moves: int
|
|
|
|
| 59 |
locations_visited: set[str]
|
| 60 |
game_completed: bool
|
| 61 |
error: Optional[str] = None
|
| 62 |
-
history: list = field(default_factory=list)
|
| 63 |
|
| 64 |
|
| 65 |
# ==========================================================
|
|
@@ -141,8 +166,8 @@ class StudentAgent:
|
|
| 141 |
def __init__(self):
|
| 142 |
self.tried: Dict[str, Set[str]] = {}
|
| 143 |
self.failed: Dict[str, Set[str]] = {}
|
| 144 |
-
self.score = 0
|
| 145 |
|
|
|
|
| 146 |
self.global_failed: Set[str] = set()
|
| 147 |
self.last_locations: List[str] = []
|
| 148 |
self.last_actions: List[str] = []
|
|
@@ -151,13 +176,21 @@ class StudentAgent:
|
|
| 151 |
# MAIN LOOP
|
| 152 |
# ------------------------------------------------------
|
| 153 |
async def run(self, client, game, max_steps, seed, verbose=False):
|
| 154 |
-
history: List[
|
| 155 |
locations_visited: Set[str] = set()
|
| 156 |
-
moves = 0
|
| 157 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
obs = self._tool_to_text(await client.call_tool("play_action", {"action": "look"}))
|
|
|
|
|
|
|
|
|
|
| 159 |
|
| 160 |
for step in range(max_steps):
|
|
|
|
|
|
|
| 161 |
# Stable location from MCP if available
|
| 162 |
loc = await self._safe_tool(client, "location", {}, fallback=None)
|
| 163 |
if not loc:
|
|
@@ -181,26 +214,26 @@ class StudentAgent:
|
|
| 181 |
action = forced
|
| 182 |
else:
|
| 183 |
prompt = f"""Observation:
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
llm_out = call_llm(prompt, SYSTEM_PROMPT, seed + step)
|
| 205 |
action = self._extract_action_only(llm_out)
|
| 206 |
|
|
@@ -216,47 +249,84 @@ class StudentAgent:
|
|
| 216 |
# Record attempt
|
| 217 |
self.tried.setdefault(loc, set()).add(action)
|
| 218 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
# Execute
|
| 220 |
new_obs = self._tool_to_text(await client.call_tool("play_action", {"action": action}))
|
| 221 |
-
|
|
|
|
|
|
|
| 222 |
|
| 223 |
failed = self._is_failure(new_obs)
|
| 224 |
-
stalled = self._is_stalled(
|
| 225 |
|
| 226 |
-
|
|
|
|
| 227 |
self.failed.setdefault(loc, set()).add(action)
|
| 228 |
if self._should_global_fail(action):
|
| 229 |
self.global_failed.add(action)
|
| 230 |
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
if action in self.global_failed:
|
| 235 |
self.global_failed.discard(action)
|
| 236 |
|
| 237 |
-
|
| 238 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
self.last_actions.append(action)
|
| 240 |
self.last_actions = self.last_actions[-20:]
|
| 241 |
-
history.append((action, obs[:240]))
|
| 242 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
if verbose:
|
| 244 |
print("\n" + "=" * 70)
|
| 245 |
print(f"STEP {step + 1}/{max_steps}")
|
| 246 |
print(f"Location: {loc}")
|
| 247 |
print(f"Action: {action}")
|
| 248 |
-
print(f"Score: {
|
|
|
|
|
|
|
| 249 |
print("-" * 70)
|
| 250 |
-
print(
|
| 251 |
print("=" * 70)
|
| 252 |
|
| 253 |
if "GAME OVER" in obs:
|
| 254 |
break
|
| 255 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
return RunResult(
|
| 257 |
-
final_score=
|
| 258 |
max_score=350,
|
| 259 |
-
moves=
|
|
|
|
| 260 |
locations_visited=locations_visited,
|
| 261 |
game_completed=("GAME OVER" in obs),
|
| 262 |
history=history,
|
|
@@ -322,7 +392,7 @@ class StudentAgent:
|
|
| 322 |
|
| 323 |
def _parse_valid_actions(self, valid_txt: str) -> List[str]:
|
| 324 |
acts: List[str] = []
|
| 325 |
-
for line in valid_txt.splitlines():
|
| 326 |
line = line.strip()
|
| 327 |
if not line:
|
| 328 |
continue
|
|
@@ -375,7 +445,7 @@ class StudentAgent:
|
|
| 375 |
low = (obs or "").lower()
|
| 376 |
if "pitch black" in low or "grue" in low or "too dark" in low:
|
| 377 |
# try to light if possible
|
| 378 |
-
for cand in ("light lamp", "turn on lamp", "light lantern", "turn on lantern"):
|
| 379 |
mapped = self._map_to_valid(cand, valid_actions)
|
| 380 |
if mapped and mapped in valid_actions:
|
| 381 |
return mapped
|
|
@@ -452,7 +522,8 @@ class StudentAgent:
|
|
| 452 |
# ------------------------------------------------------
|
| 453 |
def _normalize_obs(self, text: str) -> str:
|
| 454 |
t = (text or "").lower().strip()
|
| 455 |
-
|
|
|
|
| 456 |
t = re.sub(r"\s+", " ", t).strip()
|
| 457 |
return t
|
| 458 |
|
|
@@ -482,13 +553,24 @@ class StudentAgent:
|
|
| 482 |
return i
|
| 483 |
|
| 484 |
# ------------------------------------------------------
|
| 485 |
-
# SCORE / LOCATION HELPERS
|
| 486 |
# ------------------------------------------------------
|
| 487 |
-
def
|
| 488 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 489 |
if m:
|
| 490 |
-
return int(m.group(1))
|
| 491 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 492 |
|
| 493 |
def _best_effort_location(self, text: str) -> str:
|
| 494 |
lines = [ln.strip() for ln in (text or "").splitlines() if ln.strip()]
|
|
|
|
| 10 |
- Oscillation detection to avoid A<->B loops
|
| 11 |
- More robust fail/stall detection
|
| 12 |
- Fully tool-safe (never trusts tool field)
|
| 13 |
+
|
| 14 |
+
FIXES / EXTENSIONS:
|
| 15 |
+
- Stores MUCH richer history (StepRecord) including before/after obs + score/moves deltas
|
| 16 |
+
- Prints the FULL result (new_obs) of each step when verbose=True
|
| 17 |
+
- RunResult.moves now reflects TRUE game moves from server footer when available
|
| 18 |
"""
|
| 19 |
|
| 20 |
import os
|
|
|
|
| 54 |
|
| 55 |
|
| 56 |
# ==========================================================
|
| 57 |
+
# STEP + RUN RESULT STRUCTURES (RICH HISTORY)
|
| 58 |
# ==========================================================
|
| 59 |
+
@dataclass
|
| 60 |
+
class StepRecord:
|
| 61 |
+
step: int
|
| 62 |
+
location: str
|
| 63 |
+
action: str
|
| 64 |
+
|
| 65 |
+
observation_before: str
|
| 66 |
+
observation_after: str
|
| 67 |
+
|
| 68 |
+
score_before: int
|
| 69 |
+
score_after: int
|
| 70 |
+
|
| 71 |
+
moves_before: int
|
| 72 |
+
moves_after: int
|
| 73 |
+
|
| 74 |
+
failed: bool
|
| 75 |
+
stalled: bool
|
| 76 |
+
|
| 77 |
+
|
| 78 |
@dataclass
|
| 79 |
class RunResult:
|
| 80 |
final_score: int
|
| 81 |
max_score: int
|
| 82 |
+
moves: int # true in-game moves when available
|
| 83 |
+
agent_steps: int # number of loop iterations (max_steps cap)
|
| 84 |
locations_visited: set[str]
|
| 85 |
game_completed: bool
|
| 86 |
error: Optional[str] = None
|
| 87 |
+
history: list[StepRecord] = field(default_factory=list)
|
| 88 |
|
| 89 |
|
| 90 |
# ==========================================================
|
|
|
|
| 166 |
def __init__(self):
|
| 167 |
self.tried: Dict[str, Set[str]] = {}
|
| 168 |
self.failed: Dict[str, Set[str]] = {}
|
|
|
|
| 169 |
|
| 170 |
+
self.score = 0
|
| 171 |
self.global_failed: Set[str] = set()
|
| 172 |
self.last_locations: List[str] = []
|
| 173 |
self.last_actions: List[str] = []
|
|
|
|
| 176 |
# MAIN LOOP
|
| 177 |
# ------------------------------------------------------
|
| 178 |
async def run(self, client, game, max_steps, seed, verbose=False):
|
| 179 |
+
history: List[StepRecord] = []
|
| 180 |
locations_visited: Set[str] = set()
|
|
|
|
| 181 |
|
| 182 |
+
# Agent loop iterations (can differ from in-game moves)
|
| 183 |
+
agent_steps = 0
|
| 184 |
+
|
| 185 |
+
# Prime observation
|
| 186 |
obs = self._tool_to_text(await client.call_tool("play_action", {"action": "look"}))
|
| 187 |
+
# Seed initial score from footer (if present)
|
| 188 |
+
s0, m0 = self._extract_score_moves(obs)
|
| 189 |
+
self.score = s0
|
| 190 |
|
| 191 |
for step in range(max_steps):
|
| 192 |
+
agent_steps += 1
|
| 193 |
+
|
| 194 |
# Stable location from MCP if available
|
| 195 |
loc = await self._safe_tool(client, "location", {}, fallback=None)
|
| 196 |
if not loc:
|
|
|
|
| 214 |
action = forced
|
| 215 |
else:
|
| 216 |
prompt = f"""Observation:
|
| 217 |
+
{obs}
|
| 218 |
+
|
| 219 |
+
Location: {loc}
|
| 220 |
+
|
| 221 |
+
Inventory: {inv_txt}
|
| 222 |
+
Objects here: {objs_here_txt}
|
| 223 |
+
Explored locations (map):
|
| 224 |
+
{map_txt}
|
| 225 |
+
|
| 226 |
+
Memory:
|
| 227 |
+
{mem}
|
| 228 |
+
|
| 229 |
+
Valid actions:
|
| 230 |
+
{valid_txt}
|
| 231 |
+
|
| 232 |
+
Already tried here: {sorted(self.tried.get(loc, set()))}
|
| 233 |
+
Failed here: {sorted(self.failed.get(loc, set()))}
|
| 234 |
+
Recent locations: {self.last_locations}
|
| 235 |
+
Recent actions: {self.last_actions[-10:]}
|
| 236 |
+
"""
|
| 237 |
llm_out = call_llm(prompt, SYSTEM_PROMPT, seed + step)
|
| 238 |
action = self._extract_action_only(llm_out)
|
| 239 |
|
|
|
|
| 249 |
# Record attempt
|
| 250 |
self.tried.setdefault(loc, set()).add(action)
|
| 251 |
|
| 252 |
+
# BEFORE snapshot (for richer history)
|
| 253 |
+
obs_before = obs
|
| 254 |
+
score_before, moves_before = self._extract_score_moves(obs_before)
|
| 255 |
+
|
| 256 |
# Execute
|
| 257 |
new_obs = self._tool_to_text(await client.call_tool("play_action", {"action": action}))
|
| 258 |
+
|
| 259 |
+
# AFTER snapshot
|
| 260 |
+
score_after, moves_after = self._extract_score_moves(new_obs)
|
| 261 |
|
| 262 |
failed = self._is_failure(new_obs)
|
| 263 |
+
stalled = self._is_stalled(obs_before, new_obs)
|
| 264 |
|
| 265 |
+
# Fail/stall bookkeeping
|
| 266 |
+
if failed or (stalled and score_after <= self.score):
|
| 267 |
self.failed.setdefault(loc, set()).add(action)
|
| 268 |
if self._should_global_fail(action):
|
| 269 |
self.global_failed.add(action)
|
| 270 |
|
| 271 |
+
# Update score tracker
|
| 272 |
+
if score_after > self.score:
|
| 273 |
+
self.score = score_after
|
| 274 |
if action in self.global_failed:
|
| 275 |
self.global_failed.discard(action)
|
| 276 |
|
| 277 |
+
# Save step record (FULL before/after result)
|
| 278 |
+
history.append(
|
| 279 |
+
StepRecord(
|
| 280 |
+
step=step + 1,
|
| 281 |
+
location=loc,
|
| 282 |
+
action=action,
|
| 283 |
+
observation_before=obs_before,
|
| 284 |
+
observation_after=new_obs,
|
| 285 |
+
score_before=score_before,
|
| 286 |
+
score_after=score_after,
|
| 287 |
+
moves_before=moves_before,
|
| 288 |
+
moves_after=moves_after,
|
| 289 |
+
failed=failed,
|
| 290 |
+
stalled=stalled,
|
| 291 |
+
)
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
# Update rolling recent lists
|
| 295 |
self.last_actions.append(action)
|
| 296 |
self.last_actions = self.last_actions[-20:]
|
|
|
|
| 297 |
|
| 298 |
+
# IMPORTANT: set obs for next loop
|
| 299 |
+
obs = new_obs
|
| 300 |
+
|
| 301 |
+
# Print the full result of this step
|
| 302 |
if verbose:
|
| 303 |
print("\n" + "=" * 70)
|
| 304 |
print(f"STEP {step + 1}/{max_steps}")
|
| 305 |
print(f"Location: {loc}")
|
| 306 |
print(f"Action: {action}")
|
| 307 |
+
print(f"Score: {score_after} (was {score_before})")
|
| 308 |
+
print(f"Moves: {moves_after} (was {moves_before})")
|
| 309 |
+
print(f"failed={failed} stalled={stalled}")
|
| 310 |
print("-" * 70)
|
| 311 |
+
print(new_obs.strip())
|
| 312 |
print("=" * 70)
|
| 313 |
|
| 314 |
if "GAME OVER" in obs:
|
| 315 |
break
|
| 316 |
|
| 317 |
+
# Final true score/moves from last obs footer if possible
|
| 318 |
+
final_score, final_moves = self._extract_score_moves(obs)
|
| 319 |
+
if final_score < 0:
|
| 320 |
+
final_score = self.score
|
| 321 |
+
if final_moves < 0:
|
| 322 |
+
# Fallback: if footer missing, estimate from last record
|
| 323 |
+
final_moves = history[-1].moves_after if history and history[-1].moves_after >= 0 else agent_steps
|
| 324 |
+
|
| 325 |
return RunResult(
|
| 326 |
+
final_score=final_score,
|
| 327 |
max_score=350,
|
| 328 |
+
moves=final_moves,
|
| 329 |
+
agent_steps=agent_steps,
|
| 330 |
locations_visited=locations_visited,
|
| 331 |
game_completed=("GAME OVER" in obs),
|
| 332 |
history=history,
|
|
|
|
| 392 |
|
| 393 |
def _parse_valid_actions(self, valid_txt: str) -> List[str]:
|
| 394 |
acts: List[str] = []
|
| 395 |
+
for line in (valid_txt or "").splitlines():
|
| 396 |
line = line.strip()
|
| 397 |
if not line:
|
| 398 |
continue
|
|
|
|
| 445 |
low = (obs or "").lower()
|
| 446 |
if "pitch black" in low or "grue" in low or "too dark" in low:
|
| 447 |
# try to light if possible
|
| 448 |
+
for cand in ("light lamp", "turn on lamp", "light lantern", "turn on lantern", "light torch"):
|
| 449 |
mapped = self._map_to_valid(cand, valid_actions)
|
| 450 |
if mapped and mapped in valid_actions:
|
| 451 |
return mapped
|
|
|
|
| 522 |
# ------------------------------------------------------
|
| 523 |
def _normalize_obs(self, text: str) -> str:
|
| 524 |
t = (text or "").lower().strip()
|
| 525 |
+
# remove server footer if present
|
| 526 |
+
t = re.sub(r"\[score:\s*\d+\s*\|\s*moves:\s*\d+\]\s*$", "", t, flags=re.I).strip()
|
| 527 |
t = re.sub(r"\s+", " ", t).strip()
|
| 528 |
return t
|
| 529 |
|
|
|
|
| 553 |
return i
|
| 554 |
|
| 555 |
# ------------------------------------------------------
|
| 556 |
+
# SCORE / MOVES / LOCATION HELPERS
|
| 557 |
# ------------------------------------------------------
|
| 558 |
+
def _extract_score_moves(self, text: str) -> tuple[int, int]:
|
| 559 |
+
"""
|
| 560 |
+
Prefer the MCP server footer: [Score: X | Moves: Y]
|
| 561 |
+
Returns (score, moves). Moves=-1 if unavailable.
|
| 562 |
+
"""
|
| 563 |
+
if not text:
|
| 564 |
+
return (self.score, -1)
|
| 565 |
+
|
| 566 |
+
m = re.search(r"\[Score:\s*(\d+)\s*\|\s*Moves:\s*(\d+)\]", text)
|
| 567 |
if m:
|
| 568 |
+
return (int(m.group(1)), int(m.group(2)))
|
| 569 |
+
|
| 570 |
+
# fallback
|
| 571 |
+
m2 = re.search(r"Score:\s*(\d+)", text)
|
| 572 |
+
score = int(m2.group(1)) if m2 else self.score
|
| 573 |
+
return (score, -1)
|
| 574 |
|
| 575 |
def _best_effort_location(self, text: str) -> str:
|
| 576 |
lines = [ln.strip() for ln in (text or "").splitlines() if ln.strip()]
|