bouhss commited on
Commit
0c8ceb9
·
verified ·
1 Parent(s): 615a63b

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +508 -227
agent.py CHANGED
@@ -1,33 +1,23 @@
1
  """
2
- Student Agent for Text Adventure Games
3
-
4
- This is your submission file. Implement the StudentAgent class to play
5
- text adventure games using the MCP server you also implement.
6
-
7
- Your agent should:
8
- 1. Connect to the MCP server via the provided client
9
- 2. Use the ReAct pattern (Thought -> Action -> Observation)
10
- 3. Call MCP tools to interact with the game
11
- 4. Maximize the game score within the step limit
12
-
13
- Required method:
14
- async def run(self, client, game, max_steps, seed, verbose) -> RunResult
15
-
16
- The 'client' is a FastMCP Client already connected to your MCP server.
17
- Use it to call tools like: await client.call_tool("play_action", {"action": "look"})
18
-
19
- Tips:
20
- - Start by looking around and understanding your environment
21
- - Keep track of visited locations to avoid loops
22
- - Pick up useful items (lamp, sword, etc.)
23
- - The seed parameter should be used to set your LLM's seed for reproducibility
24
  """
25
 
26
  import json
27
  import os
28
  import re
29
  from dataclasses import dataclass, field
30
- from typing import Optional
 
31
 
32
  from dotenv import load_dotenv
33
  from huggingface_hub import InferenceClient
@@ -35,80 +25,29 @@ from huggingface_hub import InferenceClient
35
  # Load environment variables
36
  load_dotenv()
37
 
38
- # Set USE_LOCAL_MODEL=1 in your .env to use a locally downloaded model
39
- USE_LOCAL_MODEL = os.getenv("USE_LOCAL_MODEL", "0").strip() in ("1", "true", "yes")
40
- LOCAL_MODEL_ID = os.getenv("LOCAL_MODEL_ID", "Qwen/Qwen2.5-3B-Instruct")
41
-
42
  # =============================================================================
43
  # LLM Configuration - DO NOT MODIFY
44
  # =============================================================================
45
-
46
- # Model to use (fixed for fair evaluation)
47
  LLM_MODEL = "Qwen/Qwen2.5-72B-Instruct"
48
 
49
- # Initialize the LLM client based on mode
50
- _local_pipeline = None
 
 
51
 
52
- if USE_LOCAL_MODEL:
53
- import torch
54
- from transformers import pipeline as _hf_pipeline
55
 
56
- _local_pipeline = _hf_pipeline(
57
- "text-generation",
58
- model=LOCAL_MODEL_ID,
59
- torch_dtype=torch.bfloat16,
60
- device_map="auto",
61
- )
62
- LLM_CLIENT = None
63
- else:
64
- _hf_token = os.getenv("HF_TOKEN")
65
- if not _hf_token:
66
- raise ValueError("HF_TOKEN not found. Set it in your .env file.")
67
- LLM_CLIENT = InferenceClient(token=_hf_token)
68
-
69
-
70
- def call_llm(prompt: str, system_prompt: str, seed: int, max_tokens: int = 300) -> str:
71
- """
72
- Call the LLM with the given prompt. Use this function in your agent.
73
-
74
- Args:
75
- prompt: The user prompt (current game state, history, etc.)
76
- system_prompt: The system prompt (instructions for the agent)
77
- seed: Random seed for reproducibility
78
- max_tokens: Maximum tokens in response (default: 300)
79
-
80
- Returns:
81
- The LLM's response text
82
-
83
- Example:
84
- response = call_llm(
85
- prompt="You are in a forest. What do you do?",
86
- system_prompt=SYSTEM_PROMPT,
87
- seed=42,
88
- )
89
- """
90
  messages = [
91
  {"role": "system", "content": system_prompt},
92
  {"role": "user", "content": prompt},
93
  ]
94
-
95
- if USE_LOCAL_MODEL and _local_pipeline is not None:
96
- outputs = _local_pipeline(
97
- messages,
98
- max_new_tokens=max_tokens,
99
- temperature=0.0001, # Near-deterministic (0.0 unsupported by some backends)
100
- do_sample=True,
101
- )
102
- return outputs[0]["generated_text"][-1]["content"]
103
-
104
  response = LLM_CLIENT.chat.completions.create(
105
  model=LLM_MODEL,
106
  messages=messages,
107
- temperature=0.0, # Deterministic for reproducibility
108
  max_tokens=max_tokens,
109
  seed=seed,
110
  )
111
-
112
  return response.choices[0].message.content
113
 
114
 
@@ -125,181 +64,523 @@ class RunResult:
125
 
126
 
127
  # =============================================================================
128
- # System Prompt - Customize this for your agent
129
  # =============================================================================
130
 
131
- SYSTEM_PROMPT = """You are playing a classic text adventure game.
132
 
133
- GOAL: Explore the world, solve puzzles, and maximize your score.
134
 
135
- AVAILABLE TOOLS (use via MCP):
136
- - play_action: Execute a game command (north, take lamp, open mailbox, etc.)
137
- - memory: Get current game state and history (if implemented)
138
- - inventory: Check what you're carrying (if implemented)
139
 
140
- VALID GAME COMMANDS for play_action:
141
- - Movement: north, south, east, west, up, down, enter, exit
142
- - Objects: take <item>, drop <item>, open <thing>, close <thing>, examine <thing>
143
- - Other: look, inventory, read <thing>, turn on lamp
 
 
144
 
145
- RESPOND IN THIS EXACT FORMAT (no markdown):
146
- THOUGHT: <your reasoning about what to do next>
147
- TOOL: <tool_name>
148
- ARGS: <JSON arguments, e.g., {"action": "look"}>
149
 
150
- Example:
151
- THOUGHT: I should look around to see where I am.
152
- TOOL: play_action
153
- ARGS: {"action": "look"}
154
- """
155
 
 
 
 
156
 
157
- # =============================================================================
158
- # Student Agent - IMPLEMENT THIS CLASS
159
- # =============================================================================
160
 
161
  class StudentAgent:
162
- """
163
- Your ReAct agent implementation.
164
-
165
- TODO:
166
- 1. Implement the run() method with the ReAct loop
167
- 2. Parse LLM responses to extract tool calls
168
- 3. Track state and avoid loops
169
-
170
- Use the provided call_llm() function to interact with the LLM.
171
- """
172
-
173
  def __init__(self):
174
- """Initialize your agent here."""
175
- # TODO: Initialize any state tracking you need
176
- # self.history = []
177
- # self.visited_locations = set()
178
- pass
179
-
180
- async def run(
181
- self,
182
- client, # FastMCP Client connected to your MCP server
183
- game: str,
184
- max_steps: int,
185
- seed: int,
186
- verbose: bool = False,
187
- ) -> RunResult:
188
- """
189
- Run the agent for a game session.
190
-
191
- Args:
192
- client: FastMCP Client connected to your MCP server
193
- game: Name of the game being played (e.g., "zork1")
194
- max_steps: Maximum number of steps to take
195
- seed: Random seed for reproducibility (use for LLM calls)
196
- verbose: Whether to print detailed output
197
-
198
- Returns:
199
- RunResult with final score and statistics
200
- """
201
- # TODO: Implement your ReAct loop here
202
- #
203
- # Basic structure:
204
- # 1. Get initial observation (call play_action with "look")
205
- # 2. Loop for max_steps:
206
- # a. Build prompt with current observation and history
207
- # b. Call LLM to get thought and action
208
- # c. Parse the response to extract tool and args
209
- # d. Call the tool via client.call_tool(tool_name, args)
210
- # e. Update history and state
211
- # f. Check for game over
212
- # 3. Return RunResult with final statistics
213
-
214
- # Example of calling a tool:
215
- # result = await client.call_tool("play_action", {"action": "look"})
216
- # observation = result[0].text if result else "No response"
217
-
218
- # Example of calling the LLM:
219
- # response = call_llm(
220
- # prompt="Current observation: " + observation,
221
- # system_prompt=SYSTEM_PROMPT,
222
- # seed=seed,
223
- # )
224
-
225
- # Placeholder implementation - replace with your code
226
- locations_visited = set()
227
- history = []
228
- final_score = 0
229
- moves = 0
230
-
231
- # TODO: Your implementation here
232
- # ...
233
-
234
- return RunResult(
235
- final_score=final_score,
236
- max_score=350, # Zork1 max score, adjust if needed
237
- moves=moves,
238
- locations_visited=locations_visited,
239
- game_completed=False,
240
- history=history,
241
- )
242
-
243
- def _build_prompt(self, observation: str, history: list) -> str:
244
- """
245
- Build the prompt for the LLM.
246
-
247
- TODO: Implement this to create effective prompts
248
- """
249
- # TODO: Combine system prompt, history, and current observation
250
- pass
251
-
252
- def _parse_response(self, response: str) -> tuple[str, str, dict]:
253
- """
254
- Parse LLM response to extract thought, tool name, and arguments.
255
-
256
- TODO: Implement robust parsing
257
-
258
- Returns:
259
- Tuple of (thought, tool_name, args_dict)
260
- """
261
- # TODO: Parse the response format:
262
- # THOUGHT: ...
263
- # TOOL: ...
264
- # ARGS: {...}
265
- pass
266
-
267
- def _call_llm(self, prompt: str, system_prompt: str, seed: int) -> str:
268
- """
269
- Call the LLM with the given prompt.
270
-
271
- This is a convenience wrapper - you can also use call_llm() directly.
272
- """
273
- return call_llm(prompt, system_prompt, seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
 
275
 
276
  # =============================================================================
277
- # For local testing
278
  # =============================================================================
279
-
280
  async def test_agent():
281
- """Test the agent locally."""
282
  from fastmcp import Client
283
-
284
- # Path to your MCP server
285
  server_path = "mcp_server.py"
286
-
287
  agent = StudentAgent()
288
-
289
  async with Client(server_path) as client:
290
  result = await agent.run(
291
  client=client,
292
- game="zork1",
293
- max_steps=10,
294
  seed=42,
295
  verbose=True,
296
  )
297
-
298
- print(f"\nFinal Score: {result.final_score}")
299
  print(f"Moves: {result.moves}")
300
- print(f"Locations: {result.locations_visited}")
301
 
302
 
303
  if __name__ == "__main__":
304
  import asyncio
305
- asyncio.run(test_agent())
 
1
  """
2
+ Student Agent for Text Adventure Games (Strong submission)
3
+
4
+ Key ideas:
5
+ - Deterministic & robust
6
+ - Uses MCP tools if available:
7
+ - get_valid_actions: reduce invalid commands
8
+ - peek_action: simulate actions without committing (safe look-ahead)
9
+ - inventory / memory / get_map: optional extra context
10
+ - Exploration + score oriented:
11
+ utility = score_gain * big_weight + new_location_bonus - loop_penalty - stuck_penalty - death_penalty
12
+ - LLM is used only as fallback, to choose among a candidate list.
 
 
 
 
 
 
 
 
 
 
 
13
  """
14
 
15
  import json
16
  import os
17
  import re
18
  from dataclasses import dataclass, field
19
+ from typing import Optional, Any
20
+ from collections import defaultdict, deque
21
 
22
  from dotenv import load_dotenv
23
  from huggingface_hub import InferenceClient
 
25
  # Load environment variables
26
  load_dotenv()
27
 
 
 
 
 
28
  # =============================================================================
29
  # LLM Configuration - DO NOT MODIFY
30
  # =============================================================================
 
 
31
  LLM_MODEL = "Qwen/Qwen2.5-72B-Instruct"
32
 
33
+ _hf_token = os.getenv("HF_TOKEN")
34
+ if not _hf_token:
35
+ raise ValueError("HF_TOKEN not found. Set it in your .env file.")
36
+ LLM_CLIENT = InferenceClient(token=_hf_token)
37
 
 
 
 
38
 
39
+ def call_llm(prompt: str, system_prompt: str, seed: int, max_tokens: int = 220) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  messages = [
41
  {"role": "system", "content": system_prompt},
42
  {"role": "user", "content": prompt},
43
  ]
 
 
 
 
 
 
 
 
 
 
44
  response = LLM_CLIENT.chat.completions.create(
45
  model=LLM_MODEL,
46
  messages=messages,
47
+ temperature=0.0,
48
  max_tokens=max_tokens,
49
  seed=seed,
50
  )
 
51
  return response.choices[0].message.content
52
 
53
 
 
64
 
65
 
66
  # =============================================================================
67
+ # LLM Prompt (fallback only)
68
  # =============================================================================
69
 
70
+ SYSTEM_PROMPT = """You are an expert text-adventure agent.
71
 
72
+ Goal: maximize score and explore new locations while avoiding loops.
73
 
74
+ You MUST output EXACTLY:
75
+ THOUGHT: ...
76
+ TOOL: play_action
77
+ ARGS: {"action": "<one candidate action>"}
78
 
79
+ Rules:
80
+ - Choose one action EXACTLY from the candidate list provided by the user.
81
+ - Avoid repeating the same action if it failed before.
82
+ - If darkness is mentioned, prioritize lamp actions if present in candidates.
83
+ - No markdown, no extra text.
84
+ """
85
 
 
 
 
 
86
 
87
+ MOVE_ACTIONS = ["north", "south", "east", "west", "up", "down", "enter", "exit"]
88
+ MOVE_ALIASES = {"n": "north", "s": "south", "e": "east", "w": "west", "u": "up", "d": "down"}
 
 
 
89
 
90
+ # avoid wasting steps on meta commands
91
+ BAD_PREFIXES = ("save", "restore", "quit", "restart", "help", "verbose", "script", "unscript", "version")
92
+ BAD_EXACT = {"wait", "z"}
93
 
 
 
 
94
 
95
  class StudentAgent:
 
 
 
 
 
 
 
 
 
 
 
96
  def __init__(self):
97
+ # parsed from banner
98
+ self.score = 0
99
+ self.max_score = 0
100
+ self.moves = 0
101
+
102
+ # exploration tracking
103
+ self.locations_visited: set[str] = set()
104
+ self.last_location = "Unknown"
105
+ self.edges = defaultdict(dict) # edges[loc][move] = new_loc
106
+
107
+ # loop avoidance
108
+ self.tried = defaultdict(int) # tried[(loc, action)] += 1
109
+ self.recent_actions = deque(maxlen=10)
110
+ self.recent_obs = deque(maxlen=6)
111
+
112
+ # cached valid actions by location
113
+ self.valid_actions_cache = {} # loc -> list[str]
114
+
115
+ # ---------------------------------------------------------------------
116
+ # Main run loop
117
+ # ---------------------------------------------------------------------
118
+ async def run(self, client, game: str, max_steps: int, seed: int, verbose: bool = False) -> RunResult:
119
+ history: list[tuple[str, str, str]] = []
120
+
121
+ try:
122
+ tools = await client.list_tools()
123
+ tool_names = {t.name for t in tools}
124
+
125
+ def has(tname: str) -> bool:
126
+ return tname in tool_names
127
+
128
+ # initial observation
129
+ obs = await self._call_tool_text(client, "play_action", {"action": "look"})
130
+ self._update_from_text(obs)
131
+ self.last_location = self._extract_location(obs)
132
+ self.locations_visited.add(self.last_location)
133
+
134
+ if verbose:
135
+ print(obs)
136
+
137
+ for step in range(1, max_steps + 1):
138
+ loc = self._extract_location(obs)
139
+ self.last_location = loc
140
+ self.locations_visited.add(loc)
141
+
142
+ stuck = self._is_stuck(obs)
143
+
144
+ # refresh valid actions periodically or when stuck/new location
145
+ valid_actions = self.valid_actions_cache.get(loc, [])
146
+ if has("get_valid_actions") and (stuck or not valid_actions or step % 6 == 0):
147
+ va_txt = await self._call_tool_text(client, "get_valid_actions", {"limit": 60})
148
+ valid_actions = self._parse_valid_actions(va_txt)
149
+ if valid_actions:
150
+ self.valid_actions_cache[loc] = valid_actions
151
+
152
+ # optional inventory
153
+ inv_txt = ""
154
+ if has("inventory") and (stuck or step % 8 == 0 or step == 1):
155
+ inv_txt = await self._call_tool_text(client, "inventory", {})
156
+
157
+ # build candidates
158
+ candidates = self._make_candidates(obs, inv_txt, valid_actions, loc)
159
+
160
+ # decide action
161
+ action = None
162
+ thought = ""
163
+
164
+ if has("peek_action") and candidates:
165
+ action, thought = await self._choose_by_lookahead(
166
+ client=client,
167
+ loc=loc,
168
+ obs=obs,
169
+ candidates=candidates,
170
+ seed=seed,
171
+ step=step,
172
+ verbose=verbose,
173
+ )
174
+
175
+ if not action:
176
+ action, thought = await self._choose_without_peek(
177
+ obs=obs,
178
+ inv_txt=inv_txt,
179
+ candidates=candidates,
180
+ seed=seed,
181
+ step=step,
182
+ )
183
+
184
+ action = self._normalize_action(action or "look")
185
+
186
+ # commit the action
187
+ obs2 = await self._call_tool_text(client, "play_action", {"action": action})
188
+
189
+ # update map edges if movement changed location
190
+ new_loc = self._extract_location(obs2)
191
+ if action.lower() in MOVE_ACTIONS and new_loc and new_loc != loc:
192
+ self.edges[loc][action.lower()] = new_loc
193
+
194
+ # bookkeeping
195
+ self.tried[(loc, action.lower())] += 1
196
+ self.recent_actions.append(action.lower())
197
+ self.recent_obs.append((obs2 or "")[:220])
198
+ self._update_from_text(obs2)
199
+
200
+ history.append((thought, f"play_action({action})", (obs2 or "")[:250]))
201
+
202
+ if verbose:
203
+ print(f"\n--- step {step} ---")
204
+ print(f"THOUGHT: {thought}")
205
+ print(f"ACTION: {action}")
206
+ print(obs2)
207
+
208
+ obs = obs2
209
+
210
+ if self._is_game_over(obs):
211
+ break
212
+
213
+ return RunResult(
214
+ final_score=self.score,
215
+ max_score=self.max_score,
216
+ moves=self.moves,
217
+ locations_visited=set(self.locations_visited),
218
+ game_completed=self._is_game_over(obs),
219
+ history=history,
220
+ )
221
+
222
+ except Exception as e:
223
+ return RunResult(
224
+ final_score=self.score,
225
+ max_score=self.max_score,
226
+ moves=self.moves,
227
+ locations_visited=set(self.locations_visited),
228
+ game_completed=False,
229
+ error=f"{type(e).__name__}: {e}",
230
+ history=history,
231
+ )
232
+
233
+ # ---------------------------------------------------------------------
234
+ # Tool / text helpers
235
+ # ---------------------------------------------------------------------
236
+ async def _call_tool_text(self, client, tool: str, args: dict) -> str:
237
+ result = await client.call_tool(tool, args)
238
+ return self._extract_text(result)
239
+
240
+ def _extract_text(self, result: Any) -> str:
241
+ if result is None:
242
+ return ""
243
+ if isinstance(result, list) and result:
244
+ part = result[0]
245
+ if hasattr(part, "text"):
246
+ return part.text or ""
247
+ if isinstance(part, dict) and "text" in part:
248
+ return part["text"] or ""
249
+ return str(part)
250
+ return str(result)
251
+
252
+ def _extract_location(self, text: str) -> str:
253
+ if not text:
254
+ return "Unknown"
255
+ for line in text.splitlines():
256
+ line = line.strip()
257
+ if not line:
258
+ continue
259
+ if line.startswith("[Score:"):
260
+ continue
261
+ return line
262
+ return "Unknown"
263
+
264
+ def _update_from_text(self, text: str) -> None:
265
+ # parse banner: [Score: x/y | Moves: n]
266
+ if not text:
267
+ return
268
+ m = re.search(r"\[Score:\s*(\d+)\s*/\s*(\d+)\s*\|\s*Moves:\s*(\d+)\s*\]", text)
269
+ if m:
270
+ self.score = int(m.group(1))
271
+ self.max_score = int(m.group(2))
272
+ self.moves = int(m.group(3))
273
+
274
+ def _parse_valid_actions(self, txt: str) -> list[str]:
275
+ if not txt:
276
+ return []
277
+ actions = []
278
+ for line in txt.splitlines():
279
+ line = line.strip()
280
+ if line.startswith("- "):
281
+ a = line[2:].strip()
282
+ a = self._normalize_action(a)
283
+ low = a.lower()
284
+ if not a:
285
+ continue
286
+ if low.startswith(BAD_PREFIXES) or low in BAD_EXACT:
287
+ continue
288
+ actions.append(a)
289
+ # dedup keep order
290
+ seen = set()
291
+ out = []
292
+ for a in actions:
293
+ if a.lower() not in seen:
294
+ seen.add(a.lower())
295
+ out.append(a)
296
+ return out
297
+
298
+ def _normalize_action(self, action: str) -> str:
299
+ a = (action or "").strip()
300
+ low = a.lower()
301
+ if low in MOVE_ALIASES:
302
+ return MOVE_ALIASES[low]
303
+ return a
304
+
305
+ def _is_game_over(self, text: str) -> bool:
306
+ t = (text or "").lower()
307
+ return ("game over" in t) or ("you have died" in t) or ("you are dead" in t)
308
+
309
+ def _is_stuck(self, text: str) -> bool:
310
+ t = (text or "").lower()
311
+ bad = [
312
+ "i don't understand",
313
+ "you can't go that way",
314
+ "that's not a verb",
315
+ "not a word i know",
316
+ "nothing happens",
317
+ "you can't",
318
+ "can't do that",
319
+ ]
320
+ rep = len(self.recent_obs) >= 3 and all(self.recent_obs[-1] == x for x in list(self.recent_obs)[-3:])
321
+ return any(b in t for b in bad) or rep
322
+
323
+ # ---------------------------------------------------------------------
324
+ # Candidate generation
325
+ # ---------------------------------------------------------------------
326
+ def _make_candidates(self, obs: str, inv_txt: str, valid_actions: list[str], loc: str) -> list[str]:
327
+ obs_l = (obs or "").lower()
328
+ inv_l = (inv_txt or "").lower()
329
+
330
+ candidates = []
331
+ seen = set()
332
+
333
+ def add(a: str):
334
+ a = self._normalize_action(a)
335
+ if not a:
336
+ return
337
+ low = a.lower()
338
+ if low.startswith(BAD_PREFIXES) or low in BAD_EXACT:
339
+ return
340
+ if low not in seen:
341
+ seen.add(low)
342
+ candidates.append(a)
343
+
344
+ # always safe
345
+ add("look")
346
+
347
+ # darkness handling
348
+ if "dark" in obs_l:
349
+ if "lamp" in obs_l or "lamp" in inv_l:
350
+ add("take lamp")
351
+ add("turn on lamp")
352
+
353
+ # split valid actions into move vs object
354
+ move_list = []
355
+ obj_list = []
356
+ for a in valid_actions or []:
357
+ low = a.lower()
358
+ if low in MOVE_ACTIONS:
359
+ move_list.append(a)
360
+ else:
361
+ obj_list.append(a)
362
+
363
+ # prioritize untried moves from this location
364
+ def move_key(m: str):
365
+ return (self.tried[(loc, m.lower())], 0 if m.lower() not in self.edges.get(loc, {}) else 1)
366
+
367
+ for m in sorted(set(move_list), key=move_key):
368
+ add(m)
369
+
370
+ # if no valid moves known, still try generic moves
371
+ if not move_list:
372
+ for m in MOVE_ACTIONS:
373
+ add(m)
374
+
375
+ # prioritize object actions that often give score
376
+ scorey_prefixes = ("take ", "get ", "open ", "read ", "examine ", "look at ", "turn on ", "unlock ", "insert ")
377
+ for a in obj_list:
378
+ if a.lower().startswith(scorey_prefixes):
379
+ add(a)
380
+
381
+ # then the rest (limited)
382
+ for a in obj_list:
383
+ add(a)
384
+ if len(candidates) >= 22:
385
+ break
386
+
387
+ # small generic probes (often good across games)
388
+ add("take all")
389
+ add("inventory")
390
+
391
+ # remove actions repeated too much recently
392
+ cleaned = []
393
+ for a in candidates:
394
+ if list(self.recent_actions).count(a.lower()) >= 3:
395
+ continue
396
+ cleaned.append(a)
397
+
398
+ return cleaned[:20]
399
+
400
+ # ---------------------------------------------------------------------
401
+ # Decision: look-ahead
402
+ # ---------------------------------------------------------------------
403
+ async def _choose_by_lookahead(self, client, loc: str, obs: str, candidates: list[str], seed: int, step: int, verbose: bool):
404
+ base_score = self.score
405
+ base_loc = loc
406
+
407
+ # prioritize a shortlist for speed
408
+ priority = []
409
+ for a in candidates:
410
+ low = a.lower()
411
+ is_move = low in MOVE_ACTIONS
412
+ is_obj = low.startswith(("take ", "get ", "open ", "read ", "examine ", "turn on ", "unlock "))
413
+ tried = self.tried[(loc, low)]
414
+ priority.append((tried, 0 if is_obj else 1, 0 if is_move else 1, low, a))
415
+ priority.sort()
416
+
417
+ shortlist = [x[-1] for x in priority][:10] # evaluate at most 10
418
+
419
+ best_a = None
420
+ best_u = -10**18
421
+ best_th = ""
422
+
423
+ for a in shortlist:
424
+ low = a.lower()
425
+ if self.tried[(loc, low)] >= 4:
426
+ continue
427
+
428
+ peek = await self._call_tool_text(client, "peek_action", {"action": a})
429
+ peek_l = (peek or "").lower()
430
+
431
+ if self._is_game_over(peek) or "you have died" in peek_l:
432
+ u = -1_000_000_000
433
+ else:
434
+ s_after, mx_after, mv_after = self._parse_banner(peek, fallback_score=base_score)
435
+ delta = max(0, s_after - base_score)
436
+
437
+ new_loc = self._extract_location(peek)
438
+ changed = (new_loc and new_loc != base_loc)
439
+ new_loc_bonus = 250 if (changed and new_loc not in self.locations_visited) else 0
440
+ changed_bonus = 40 if changed else 0
441
+
442
+ loop_pen = 80 * list(self.recent_actions).count(low)
443
+ stuck_pen = 160 if self._is_stuck(peek) else 0
444
+
445
+ # MAIN utility
446
+ u = delta * 900 + new_loc_bonus + changed_bonus - loop_pen - stuck_pen
447
+
448
+ # small preference: if darkness, lamp actions
449
+ if "dark" in (obs or "").lower() and ("lamp" in low):
450
+ u += 120
451
+
452
+ if u > best_u:
453
+ best_u = u
454
+ best_a = a
455
+ best_th = f"Look-ahead chose '{a}' (utility={u})."
456
+
457
+ if best_a is None or best_u < -10000:
458
+ return None, "Look-ahead found no good action; fallback."
459
+ return best_a, best_th
460
+
461
+ def _parse_banner(self, text: str, fallback_score: int):
462
+ score = fallback_score
463
+ mx = self.max_score
464
+ mv = self.moves
465
+ if not text:
466
+ return score, mx, mv
467
+ m = re.search(r"\[Score:\s*(\d+)\s*/\s*(\d+)\s*\|\s*Moves:\s*(\d+)\s*\]", text)
468
+ if m:
469
+ return int(m.group(1)), int(m.group(2)), int(m.group(3))
470
+ return score, mx, mv
471
+
472
+ # ---------------------------------------------------------------------
473
+ # Decision: no peek => heuristic then LLM fallback among candidates
474
+ # ---------------------------------------------------------------------
475
+ async def _choose_without_peek(self, obs: str, inv_txt: str, candidates: list[str], seed: int, step: int):
476
+ loc = self._extract_location(obs)
477
+
478
+ # heuristic: try an untried move
479
+ for m in MOVE_ACTIONS:
480
+ if m in [c.lower() for c in candidates] and self.tried[(loc, m)] == 0:
481
+ return m, "Heuristic: try an untried move to explore."
482
+
483
+ # heuristic: try untried "take/get/open/read/examine"
484
+ for a in candidates:
485
+ low = a.lower()
486
+ if low.startswith(("take ", "get ", "open ", "read ", "examine ", "turn on ")):
487
+ if self.tried[(loc, low)] == 0:
488
+ return a, "Heuristic: try a promising object interaction."
489
+
490
+ # LLM fallback: choose from candidate list exactly
491
+ if not candidates:
492
+ return "look", "No candidates; fallback to look."
493
+
494
+ cand = candidates[:10]
495
+ prompt = self._build_llm_prompt(obs, inv_txt, cand)
496
+ resp = call_llm(prompt, SYSTEM_PROMPT, seed + step, max_tokens=180)
497
+
498
+ thought, tool, args = self._parse_response(resp)
499
+ a = self._normalize_action(str(args.get("action", "")).strip())
500
+
501
+ # force action to be in candidate list
502
+ canon = {x.lower(): x for x in cand}
503
+ if a.lower() in canon:
504
+ return canon[a.lower()], thought or "LLM chose a candidate."
505
+ return cand[0], "LLM invalid; fallback to first candidate."
506
+
507
+ def _build_llm_prompt(self, obs: str, inv_txt: str, candidates: list[str]) -> str:
508
+ obs = (obs or "").strip()[:1100]
509
+ inv_txt = (inv_txt or "").strip()[:350]
510
+
511
+ lines = [
512
+ f"Score: {self.score}/{self.max_score} | Moves: {self.moves}",
513
+ f"Location guess: {self.last_location}",
514
+ ]
515
+ if inv_txt:
516
+ lines.append(f"Inventory:\n{inv_txt}")
517
+ if self.recent_actions:
518
+ lines.append("Recent actions: " + ", ".join(list(self.recent_actions)[-6:]))
519
+
520
+ lines.append("\nCurrent observation:\n" + obs)
521
+ lines.append("\nCandidate actions (choose exactly one):")
522
+ for a in candidates:
523
+ lines.append(f"- {a}")
524
+ lines.append("\nOutput TOOL=play_action and ARGS with one candidate action.")
525
+ return "\n".join(lines)
526
+
527
+ def _parse_response(self, response: str):
528
+ thought = ""
529
+ tool = "play_action"
530
+ args = {"action": "look"}
531
+
532
+ if not response:
533
+ return thought, tool, args
534
+
535
+ m = re.search(r"(?im)^\s*THOUGHT\s*:\s*(.+)$", response)
536
+ if m:
537
+ thought = m.group(1).strip()
538
+
539
+ m = re.search(r"(?im)^\s*TOOL\s*:\s*([a-zA-Z0-9_]+)\s*$", response)
540
+ if m:
541
+ tool = m.group(1).strip()
542
+
543
+ m = re.search(r"(?is)^\s*ARGS\s*:\s*(\{.*\})\s*$", response)
544
+ if m:
545
+ raw = m.group(1).strip()
546
+ try:
547
+ args = json.loads(raw)
548
+ except Exception:
549
+ raw2 = raw.replace("'", '"')
550
+ raw2 = re.sub(r",\s*}", "}", raw2)
551
+ try:
552
+ args = json.loads(raw2)
553
+ except Exception:
554
+ args = {"action": "look"}
555
+
556
+ if not isinstance(args, dict):
557
+ args = {"action": "look"}
558
+
559
+ return thought, tool, args
560
 
561
 
562
  # =============================================================================
563
+ # Local testing
564
  # =============================================================================
 
565
  async def test_agent():
 
566
  from fastmcp import Client
567
+
 
568
  server_path = "mcp_server.py"
 
569
  agent = StudentAgent()
570
+
571
  async with Client(server_path) as client:
572
  result = await agent.run(
573
  client=client,
574
+ game="lostpig",
575
+ max_steps=20,
576
  seed=42,
577
  verbose=True,
578
  )
579
+ print(f"\nFinal Score: {result.final_score}/{result.max_score}")
 
580
  print(f"Moves: {result.moves}")
581
+ print(f"Locations visited: {len(result.locations_visited)}")
582
 
583
 
584
  if __name__ == "__main__":
585
  import asyncio
586
+ asyncio.run(test_agent())