InesManelB commited on
Commit
b59a6e3
·
1 Parent(s): 615a63b

My submission agent using just in time RL for action choosing

Browse files
Files changed (4) hide show
  1. agent.py +729 -135
  2. cross_episode_memory.py +442 -0
  3. mcp_server.py +81 -31
  4. requirements.txt +6 -0
agent.py CHANGED
@@ -26,12 +26,16 @@ Tips:
26
  import json
27
  import os
28
  import re
 
 
29
  from dataclasses import dataclass, field
30
  from typing import Optional
31
 
32
  from dotenv import load_dotenv
33
  from huggingface_hub import InferenceClient
34
 
 
 
35
  # Load environment variables
36
  load_dotenv()
37
 
@@ -67,7 +71,7 @@ else:
67
  LLM_CLIENT = InferenceClient(token=_hf_token)
68
 
69
 
70
- def call_llm(prompt: str, system_prompt: str, seed: int, max_tokens: int = 300) -> str:
71
  """
72
  Call the LLM with the given prompt. Use this function in your agent.
73
 
@@ -79,35 +83,32 @@ def call_llm(prompt: str, system_prompt: str, seed: int, max_tokens: int = 300)
79
 
80
  Returns:
81
  The LLM's response text
82
-
83
- Example:
84
- response = call_llm(
85
- prompt="You are in a forest. What do you do?",
86
- system_prompt=SYSTEM_PROMPT,
87
- seed=42,
88
- )
89
  """
90
  messages = [
91
  {"role": "system", "content": system_prompt},
92
  {"role": "user", "content": prompt},
93
  ]
94
-
95
- if USE_LOCAL_MODEL and _local_pipeline is not None:
96
- outputs = _local_pipeline(
97
- messages,
98
- max_new_tokens=max_tokens,
99
- temperature=0.0001, # Near-deterministic (0.0 unsupported by some backends)
100
- do_sample=True,
 
 
 
 
 
 
 
 
 
101
  )
102
- return outputs[0]["generated_text"][-1]["content"]
103
-
104
- response = LLM_CLIENT.chat.completions.create(
105
- model=LLM_MODEL,
106
- messages=messages,
107
- temperature=0.0, # Deterministic for reproducibility
108
- max_tokens=max_tokens,
109
- seed=seed,
110
- )
111
 
112
  return response.choices[0].message.content
113
 
@@ -124,153 +125,746 @@ class RunResult:
124
  history: list[tuple[str, str, str]] = field(default_factory=list)
125
 
126
 
127
- # =============================================================================
128
- # System Prompt - Customize this for your agent
129
- # =============================================================================
130
-
131
- SYSTEM_PROMPT = """You are playing a classic text adventure game.
132
-
133
- GOAL: Explore the world, solve puzzles, and maximize your score.
134
-
135
- AVAILABLE TOOLS (use via MCP):
136
- - play_action: Execute a game command (north, take lamp, open mailbox, etc.)
137
- - memory: Get current game state and history (if implemented)
138
- - inventory: Check what you're carrying (if implemented)
139
-
140
- VALID GAME COMMANDS for play_action:
141
- - Movement: north, south, east, west, up, down, enter, exit
142
- - Objects: take <item>, drop <item>, open <thing>, close <thing>, examine <thing>
143
- - Other: look, inventory, read <thing>, turn on lamp
144
-
145
- RESPOND IN THIS EXACT FORMAT (no markdown):
146
- THOUGHT: <your reasoning about what to do next>
147
- TOOL: <tool_name>
148
- ARGS: <JSON arguments, e.g., {"action": "look"}>
149
-
150
- Example:
151
- THOUGHT: I should look around to see where I am.
152
- TOOL: play_action
153
- ARGS: {"action": "look"}
154
- """
155
-
156
-
157
  # =============================================================================
158
  # Student Agent - IMPLEMENT THIS CLASS
159
  # =============================================================================
160
 
161
  class StudentAgent:
162
  """
163
- Your ReAct agent implementation.
164
-
165
- TODO:
166
- 1. Implement the run() method with the ReAct loop
167
- 2. Parse LLM responses to extract tool calls
168
- 3. Track state and avoid loops
169
 
170
- Use the provided call_llm() function to interact with the LLM.
 
 
 
 
171
  """
172
 
173
- def __init__(self):
174
- """Initialize your agent here."""
175
- # TODO: Initialize any state tracking you need
176
- # self.history = []
177
- # self.visited_locations = set()
178
- pass
 
 
 
 
 
 
179
 
180
  async def run(
181
  self,
182
- client, # FastMCP Client connected to your MCP server
183
  game: str,
184
  max_steps: int,
185
  seed: int,
186
  verbose: bool = False,
187
  ) -> RunResult:
188
- """
189
- Run the agent for a game session.
 
 
190
 
191
- Args:
192
- client: FastMCP Client connected to your MCP server
193
- game: Name of the game being played (e.g., "zork1")
194
- max_steps: Maximum number of steps to take
195
- seed: Random seed for reproducibility (use for LLM calls)
196
- verbose: Whether to print detailed output
197
-
198
- Returns:
199
- RunResult with final score and statistics
200
- """
201
- # TODO: Implement your ReAct loop here
202
- #
203
- # Basic structure:
204
- # 1. Get initial observation (call play_action with "look")
205
- # 2. Loop for max_steps:
206
- # a. Build prompt with current observation and history
207
- # b. Call LLM to get thought and action
208
- # c. Parse the response to extract tool and args
209
- # d. Call the tool via client.call_tool(tool_name, args)
210
- # e. Update history and state
211
- # f. Check for game over
212
- # 3. Return RunResult with final statistics
213
 
214
- # Example of calling a tool:
215
- # result = await client.call_tool("play_action", {"action": "look"})
216
- # observation = result[0].text if result else "No response"
 
 
 
 
217
 
218
- # Example of calling the LLM:
219
- # response = call_llm(
220
- # prompt="Current observation: " + observation,
221
- # system_prompt=SYSTEM_PROMPT,
222
- # seed=seed,
223
- # )
224
 
225
- # Placeholder implementation - replace with your code
226
- locations_visited = set()
227
- history = []
228
- final_score = 0
229
- moves = 0
230
 
231
- # TODO: Your implementation here
232
- # ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
  return RunResult(
235
- final_score=final_score,
236
- max_score=350, # Zork1 max score, adjust if needed
237
  moves=moves,
238
  locations_visited=locations_visited,
239
- game_completed=False,
240
  history=history,
241
  )
242
 
243
- def _build_prompt(self, observation: str, history: list) -> str:
244
- """
245
- Build the prompt for the LLM.
 
 
 
 
 
 
 
246
 
247
- TODO: Implement this to create effective prompts
248
- """
249
- # TODO: Combine system prompt, history, and current observation
250
- pass
251
-
252
- def _parse_response(self, response: str) -> tuple[str, str, dict]:
253
- """
254
- Parse LLM response to extract thought, tool name, and arguments.
 
 
 
 
 
 
255
 
256
- TODO: Implement robust parsing
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  Returns:
259
- Tuple of (thought, tool_name, args_dict)
 
 
260
  """
261
- # TODO: Parse the response format:
262
- # THOUGHT: ...
263
- # TOOL: ...
264
- # ARGS: {...}
265
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
 
267
- def _call_llm(self, prompt: str, system_prompt: str, seed: int) -> str:
268
- """
269
- Call the LLM with the given prompt.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
 
271
- This is a convenience wrapper - you can also use call_llm() directly.
272
- """
273
- return call_llm(prompt, system_prompt, seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
 
275
 
276
  # =============================================================================
 
26
  import json
27
  import os
28
  import re
29
+ import random
30
+ import time
31
  from dataclasses import dataclass, field
32
  from typing import Optional
33
 
34
  from dotenv import load_dotenv
35
  from huggingface_hub import InferenceClient
36
 
37
+ from cross_episode_memory import CrossEpisodeMemory
38
+
39
  # Load environment variables
40
  load_dotenv()
41
 
 
71
  LLM_CLIENT = InferenceClient(token=_hf_token)
72
 
73
 
74
+ def call_llm(prompt: str, system_prompt: str, seed: int = 42, temperature: int = 0.0001, max_tokens: int = 1000) -> str:
75
  """
76
  Call the LLM with the given prompt. Use this function in your agent.
77
 
 
83
 
84
  Returns:
85
  The LLM's response text
 
 
 
 
 
 
 
86
  """
87
  messages = [
88
  {"role": "system", "content": system_prompt},
89
  {"role": "user", "content": prompt},
90
  ]
91
+ try:
92
+ if USE_LOCAL_MODEL and _local_pipeline is not None:
93
+ outputs = _local_pipeline(
94
+ messages,
95
+ max_new_tokens=max_tokens,
96
+ temperature=temperature, # Near-deterministic (0.0 unsupported by some backends)
97
+ do_sample=True,
98
+ )
99
+ return outputs[0]["generated_text"][-1]["content"]
100
+
101
+ response = LLM_CLIENT.chat.completions.create(
102
+ model=LLM_MODEL,
103
+ messages=messages,
104
+ temperature=temperature, # Deterministic for reproducibility
105
+ max_tokens=max_tokens,
106
+ seed=seed,
107
  )
108
+
109
+ except Exception as e:
110
+ print(f"[LLM Error] {e}")
111
+ return "THOUGHT: LLM error, trying look."
 
 
 
 
 
112
 
113
  return response.choices[0].message.content
114
 
 
125
  history: list[tuple[str, str, str]] = field(default_factory=list)
126
 
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  # =============================================================================
129
  # Student Agent - IMPLEMENT THIS CLASS
130
  # =============================================================================
131
 
132
  class StudentAgent:
133
  """
134
+ MCP ReAct Agent - A complete working example.
 
 
 
 
 
135
 
136
+ This agent demonstrates:
137
+ - ReAct loop (Thought -> Tool -> Observation)
138
+ - Loop detection
139
+ - Action validation
140
+ - Score tracking via memory tool
141
  """
142
 
143
+ def __init__(self, game: str = "zork1", guiding_prompt: str = None, top_actions: int = 5, exploration_alpha: float = 0.4, gamma: float = 0.95):
144
+ """Initialize the agent state."""
145
+ self.history: list[dict] = []
146
+ self.recent_actions: list[str] = []
147
+ self.score: int = 0
148
+ self.guiding_prompt = guiding_prompt or """Explore the environment and try to maximize your score."""
149
+ self.top_actions = top_actions
150
+
151
+ # Create game-specific memory directory
152
+ game_memory_dir = os.path.join("memory", game)
153
+ self.cross_mem = CrossEpisodeMemory(base_dir=game_memory_dir, eval_llm_model=LLM_MODEL, gamma=gamma)
154
+ self.exploration_alpha = exploration_alpha
155
 
156
  async def run(
157
  self,
158
+ client,
159
  game: str,
160
  max_steps: int,
161
  seed: int,
162
  verbose: bool = False,
163
  ) -> RunResult:
164
+ """Run the agent for a game session."""
165
+ locations_visited = set()
166
+ history = []
167
+ moves = 0
168
 
169
+ # Get list of available tools
170
+ tools = await client.list_tools()
171
+ tool_names = [t.name for t in tools]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
+ # Get initial observation
174
+ result = await client.call_tool("reset_game", {})
175
+ raw = self._extract_result(result)
176
+
177
+ payload = json.loads(raw)
178
+ state = payload["state"] # <-- dict
179
+ observation = state["observation"]
180
 
181
+ # Track initial location
182
+ location = observation.split("\n")[0] if observation else "Unknown"
183
+ locations_visited.add(location)
 
 
 
184
 
185
+ if verbose:
186
+ print(f"\n{observation}")
 
 
 
187
 
188
+ # Parse initial score
189
+ self._update_score(state)
190
+
191
+ # Main ReAct loop
192
+ for step in range(1, max_steps + 1):
193
+ # Build prompt with context
194
+ user_prompt, sys_prompt, memory_text = self._build_prompt(state)
195
+
196
+ # Get response format
197
+ response_format = self._get_response_format()
198
+
199
+ # Call LLM for reasoning (use step-based seed for variety)
200
+ response = call_llm(user_prompt, sys_prompt, seed=seed + step, response_format=response_format)
201
+
202
+ # Parse the response
203
+ action, options_with_confidences = self._parse_response(response)
204
+
205
+ # Update action scores with memory
206
+ updated_options_with_logits = self.update_scores_with_memory(
207
+ observation,
208
+ options_with_confidences,
209
+ k=10,
210
+ r=0.9,
211
+ memory_text=memory_text)
212
+
213
+ if updated_options_with_logits:
214
+ # Select action with highest corrected confidence
215
+ valid_actions = state.get("valid_actions", []) or []
216
+ updated_options_valid_items = []
217
+ for option in updated_options_with_logits.items():
218
+ if option[1]['action'] in valid_actions:
219
+ updated_options_valid_items.append(option)
220
+
221
+ best_option = max(
222
+ updated_options_valid_items,
223
+ key=lambda x: x[1].get('corrected_confidence', 0)
224
+ )
225
+ action = best_option[1]['action']
226
+
227
+ if verbose:
228
+ print(f"\n--- Step {step} | Score: {self.score} ---")
229
+ print(f"[RAW_LLM_OUTPUT]\n{response}")
230
+ print(f"[ACTION] {action}")
231
+
232
+ # Loop detection
233
+ self.recent_actions.append(action)
234
+ if len(self.recent_actions) > 5:
235
+ self.recent_actions = self.recent_actions[-5:]
236
+
237
+ # Detect loops - if same action 3 times, force "look"
238
+ if len(self.recent_actions) >= 3 and len(set(self.recent_actions[-3:])) == 1:
239
+ if verbose:
240
+ print(f"[WARNING] Loop detected - forcing 'look'")
241
+ action = "look"
242
+ self.recent_actions.append(action)
243
+
244
+ moves += 1
245
+
246
+ # Update history
247
+ self.history.append({
248
+ "old_state": observation,
249
+ "action": action,
250
+ "full_response": response,
251
+ "reward": None,
252
+ "score": None,
253
+ })
254
+
255
+ # Execute the tool
256
+ try:
257
+ result = await client.call_tool("play_action", {"action": action})
258
+ raw = self._extract_result(result)
259
+ payload = json.loads(raw)
260
+ state = payload["state"]
261
+ observation = state["observation"]
262
+
263
+ if verbose:
264
+ print(f"[RESULT] {observation[:200]}...")
265
+ except Exception as e:
266
+ observation = f"Error: {e}"
267
+ if verbose:
268
+ print(f"[ERROR] {e}")
269
+
270
+ # Track location
271
+ location = observation.split("\n")[0] if observation else "Unknown"
272
+ locations_visited.add(location)
273
+
274
+ old_score = self.score
275
+
276
+ # Track score from observation
277
+ self._update_score(state)
278
+
279
+ self.history[-1]["reward"] = self.score - old_score
280
+ self.history[-1]["score"] = self.score
281
+ self.history[-1]["state"] = observation
282
+
283
+ # Record in result history
284
+ history.append((response, f"play_action(action: {action})", observation[:100]))
285
+
286
+ # Check for game over
287
+ if self._is_game_over(observation):
288
+ if verbose:
289
+ print("\n*** GAME OVER ***")
290
+ break
291
+
292
+ success = self._is_game_over(observation) and self.score > 0
293
+
294
+ # Adding the current execution to memory for future episodes
295
+ self.cross_mem.add_episode(
296
+ game_history=self.history,
297
+ final_score=self.score,
298
+ success=success,
299
+ state=observation
300
+ )
301
 
302
  return RunResult(
303
+ final_score=self.score,
304
+ max_score=350,
305
  moves=moves,
306
  locations_visited=locations_visited,
307
+ game_completed=self._is_game_over(observation),
308
  history=history,
309
  )
310
 
311
+ def update_scores_with_memory(self, current_state, options_with_confidences, k, r, memory_text, exploration_prob=0.4):
312
+ nearest_trajectories = self.cross_mem.retrieve_similar(
313
+ game_history=self.history,
314
+ current_state=current_state,
315
+ current_summary=memory_text,
316
+ k=k,
317
+ r=r
318
+ )
319
+ if not nearest_trajectories:
320
+ return {}
321
 
322
+ # Get existing actions
323
+ existing_actions = set()
324
+ for option_data in options_with_confidences.values():
325
+ if isinstance(option_data, dict) and 'action' in option_data:
326
+ existing_actions.add(option_data['action'])
327
+
328
+ # Aggregate action rewards from nearest_trajectories
329
+ action_rewards = {}
330
+ for result_dict in nearest_trajectories:
331
+ action = result_dict.get('action', '').strip()
332
+ discounted_reward = result_dict.get('discounted_reward', 0)
333
+ if action not in action_rewards:
334
+ action_rewards[action] = []
335
+ action_rewards[action].append(discounted_reward)
336
 
337
+ # Look for the action in the options suggested by the LLM
338
+ for action in action_rewards:
339
+ found = False
340
+ for option_data in options_with_confidences.values():
341
+ if isinstance(option_data, dict) and option_data.get('action', '') == action:
342
+ found = True
343
+ break
344
+ if not found:
345
+ rewards = action_rewards[action]
346
+ if len(rewards) >= 1 and sum(rewards) / len(rewards) > 0:
347
+ if options_with_confidences:
348
+ new_option_num = max(options_with_confidences.keys()) + 1
349
+ else:
350
+ new_option_num = 1
351
+ options_with_confidences[new_option_num] = {
352
+ 'action': action,
353
+ 'confidence': 0
354
+ }
355
 
356
+ # Calculate average discounted reward for each action Q(s, a)
357
+ action_avg_rewards = {}
358
+ for action, rewards in action_rewards.items():
359
+ if rewards: # Ensure not empty
360
+ action_avg_rewards[action] = sum(rewards) / len(rewards)
361
+ else:
362
+ action_avg_rewards[action] = 0
363
+
364
+ # Calculate overall average discounted reward (baseline) "state value"
365
+ count = 0
366
+ if action_rewards:
367
+ all_rewards = []
368
+ for rewards_list in action_rewards.values():
369
+ all_rewards.extend(rewards_list)
370
+ count = len(all_rewards)
371
+ overall_avg_reward = sum(all_rewards) / len(all_rewards) if all_rewards else 0
372
+ else:
373
+ overall_avg_reward = 0
374
+
375
+ # The second part of step 4 (The exploration bonus) : Completing the Q(s, a) for the actions that are not in memory
376
+ alpha = self.exploration_alpha
377
+ new_actions_added = False
378
+ for option_num, option_data in options_with_confidences.items():
379
+ action = option_data['action']
380
+ if action not in action_avg_rewards:
381
+ rad = random.random()
382
+ print(f"action: {action} rad:{rad}")
383
+ if rad < exploration_prob:
384
+ # UCB-style exploration bonus: decreases as sample count increases
385
+ if count > 0:
386
+ action_avg_rewards[action] = overall_avg_reward + alpha / count
387
+ else:
388
+ action_avg_rewards[action] = overall_avg_reward + alpha
389
+ else:
390
+ action_avg_rewards[action] = 0
391
+ new_actions_added = True
392
+
393
+
394
+ # If new actions were added, recalculate overall_avg_reward
395
+ if new_actions_added:
396
+ all_avg_rewards = list(action_avg_rewards.values())
397
+ overall_avg_reward = sum(all_avg_rewards) / len(all_avg_rewards) if all_avg_rewards else 0
398
+
399
+ # Calculate advantage value for each action
400
+ action_advantages = {}
401
+ for action, avg_reward in action_avg_rewards.items():
402
+ action_advantages[action] = avg_reward - overall_avg_reward
403
+
404
+ # Normalize advantage values
405
+ if action_advantages:
406
+ adv_values = list(action_advantages.values())
407
+
408
+ positive_advs = [adv for adv in adv_values if adv > 0]
409
+
410
+ if positive_advs:
411
+ max_positive = max(positive_advs)
412
+ normalized_advantages = {action: adv / max_positive for action, adv in action_advantages.items()}
413
+ else:
414
+ max_negative_abs = abs(min(adv_values))
415
+ if max_negative_abs > 0:
416
+ normalized_advantages = {action: adv / max_negative_abs for action, adv in action_advantages.items()}
417
+ else:
418
+ normalized_advantages = {action: 0 for action in action_advantages.keys()}
419
+ else:
420
+ normalized_advantages = {}
421
+
422
+ updated_options = {}
423
+ if options_with_confidences and normalized_advantages:
424
+ print(f"\n=== Action Advantage Analysis & Logit Correction ===")
425
+ print(f"Overall action average reward baseline: {overall_avg_reward:.4f}")
426
+
427
+ for option_num, option_data in options_with_confidences.items():
428
+ if isinstance(option_data, dict) and 'action' in option_data:
429
+ action = option_data['action']
430
+
431
+ # Get normalized advantage value
432
+ normalized_advantage = normalized_advantages.get(action, 0)
433
+ raw_advantage = action_advantages.get(action, 0)
434
+ avg_reward = action_avg_rewards.get(action, 0)
435
+
436
+ # Calculate episode-based weight for normalized advantage
437
+ # Weight increases from 1.0 (episode 1) to 1.5 (episode 50)
438
+ # Formula: weight = 1.0 + (current_episode / 50) * 0.5
439
+ # Clamped to max of 1.5 for episodes beyond 50
440
+
441
+ if self.cross_mem:
442
+ current_episode = self.cross_mem.current_episode_number
443
+ episode_weight = min(1.0 + (current_episode / 50.0) * 0.5, 1.5)
444
+ else:
445
+ episode_weight = 1.0
446
+
447
+ # Apply episode weight to normalized advantage
448
+ weighted_normalized_advantage = normalized_advantage * episode_weight
449
+
450
+ # Correct logit (add weighted normalized advantage value)
451
+ normalized_prob = option_data.get('confidence', 0)
452
+ corrected_logprob = normalized_prob + weighted_normalized_advantage
453
+
454
+ # Create corrected option data
455
+ updated_options[option_num] = {
456
+ 'action': action,
457
+ 'normalized_advantage': normalized_advantage,
458
+ 'corrected_confidence': corrected_logprob,
459
+ 'avg_reward': avg_reward,
460
+ 'raw_advantage': raw_advantage,
461
+ 'confidence': option_data.get('confidence', 0)
462
+ }
463
+
464
+ print(f" Action: {action}")
465
+ print(f" Average reward: {avg_reward:.4f} | Raw advantage: {raw_advantage:.4f}")
466
+ print(f" Normalized advantage: {normalized_advantage:.4f}")
467
+ print(f" Original logprob: {normalized_prob:.4f} -> Corrected: {corrected_logprob:.4f}")
468
+ else:
469
+ # If no matching advantage value, keep original value
470
+ updated_options[option_num] = option_data.copy()
471
+ updated_options[option_num]['corrected_confidence'] = option_data.get('confidence', 0)
472
+ updated_options[option_num]['normalized_advantage'] = 0
473
+ print("="*50)
474
+ return updated_options
475
+
476
+ else:
477
+ return {}
478
+
479
+ def _build_prompt(self, state: dict) -> tuple[str, str]:
480
+ """Build the prompt for the LLM with context."""
481
+
482
+ valid_actions = state.get("valid_actions", []) or []
483
+ current_inventory_list = state.get("inventory", []) or []
484
+ current_inventory = ", ".join(current_inventory_list) if current_inventory_list else "(empty)"
485
+
486
+ if self.history:
487
+ summary, memory_text = self.generate_history_summary(self.history, state.get("observation", ""), temperature=0.0, max_tokens=1000, current_inventory=current_inventory)
488
+ else:
489
+ summary = ""
490
+ memory_text = "No game history."
491
+ # Add inventory even if no game history
492
+ if current_inventory:
493
+ memory_text += f"\n\n[INVENTORY]\n{current_inventory}"
494
+ sys_prompt = """You are an expert player aiming to complete a text-based adventure game. Points are given for making progress in the game. Select promising actions based on the game state and memory of past interactions.
495
+
496
+ **EXPLORATION PRIORITY**: When you arrive at a NEW location you haven't fully explored before, you MUST thoroughly explore it FIRST before leaving."""
497
+ if self.guiding_prompt:
498
+ sys_prompt += f"\n\nFollow this guide: {self.guiding_prompt}"
499
+
500
+ # Add reminder about action list order only when valid_actions are provided
501
+ if valid_actions:
502
+ sys_prompt += """\n\n**CRITICAL CONSTRAINT**: When REFERENCE ACTIONS are provided, you MUST ONLY choose actions from that list. Any action not in the REFERENCE ACTIONS list is INVALID and will fail. Do NOT create custom actions. The list is unordered - position doesn't indicate quality."""
503
+
504
+ recent_history = ""
505
+ if self.history:
506
+ # Only keep the last 10 steps
507
+ recent_game_history = self.history[-5:]
508
+ start_index = max(0, len(self.history)-5)
509
+ for idx, entry in enumerate(recent_game_history):
510
+ actual_step = start_index + idx
511
+ recent_history += f"Step {actual_step}:\n"
512
+ recent_history += f"State: {entry.get('state', '')}\n"
513
+ recent_history += f"Action: {entry.get('action', '')}\n"
514
+ if entry.get('reward') is not None:
515
+ recent_history += f"Reward: {entry.get('reward', 0)}\n"
516
+ recent_history += "\n"
517
+
518
+ # Add valid actions section if available
519
+ valid_actions_text = ""
520
+ if valid_actions:
521
+ valid_actions = valid_actions[:] # make a copy to avoid side effects
522
+ rng = random.Random(time.time_ns())
523
+ rng.shuffle(valid_actions)
524
+ print(f"Valid actions (shuffled): {valid_actions}")
525
+ valid_actions_text = f"\nREFERENCE ACTIONS (ONLY VALID ACTIONS):\n{valid_actions}\n\n**STRICT REQUIREMENT**: These are the ONLY valid actions for the current state. You MUST select your actions EXCLUSIVELY from this list. Any action not in this list is INVALID and will be rejected by the game. Do NOT create, modify, or suggest any custom actions.**\n"
526
+
527
+ # Build option fields with confidence
528
+ option_fields = []
529
+ for i in range(1, self.top_actions + 1):
530
+ option_fields.append(f' "reasoning{i}": "Why this action makes sense",')
531
+ option_fields.append(f' "option{i}": "action command",')
532
+ option_fields.append(f' "confidence{i}": 80,')
533
+ option_format = '\n'.join(option_fields)
534
+
535
+ user_prompt = f"""
536
+ GAME HISTORY:
537
+ {summary}
538
+ {memory_text}
539
+
540
+ RECENT STEPS:
541
+ {recent_history}
542
+
543
+ CURRENT STATE: {state.get("observation", "")}
544
+
545
+ TASK:
546
+ 1. Analyze your progress: What have you achieved? What's your next objective?
547
+ 2. **MANDATORY: Check the REFERENCE ACTIONS list below - you MUST ONLY select from this list**
548
+ 3. Propose {self.top_actions} different actions with reasoning
549
+ 4. For EACH action, provide your confidence as an integer from 0 to 100 (e.g., 80 means 80% confidence that this action will help achieve the goal)
550
+
551
+ RESPONSE FORMAT (JSON):
552
+ {{
553
+ "progress_analysis": "What you've achieved and current challenges",
554
+ "next_objective": "Your next goal",
555
+ {option_format}
556
+ "best_action": 1
557
+ }}
558
+
559
+ IMPORTANT:
560
+ - **ABSOLUTE REQUIREMENT**: ALL actions (option1, option2, etc.) MUST be selected EXACTLY from the REFERENCE ACTIONS list below
561
+ - Actions NOT in the REFERENCE ACTIONS list are INVALID and will cause the game to fail
562
+ - DO NOT create custom actions, DO NOT modify actions from the list, DO NOT combine actions
563
+ - **CRITICAL**: The sum of all confidence values (confidence1 + confidence2 + ... + confidence{self.top_actions}) MUST equal 100
564
+ - The confidence values MUST have meaningful differences between them
565
+ - Higher confidence means you believe this action is more likely to succeed
566
+ - Lower confidence means you believe this action is less likely to succeed
567
+ - Pay attention to game hints and clues in state descriptions
568
+ - Don't repeat failed actions or create loops (e.g., north→south→north)
569
+
570
+ {valid_actions_text}
571
+ """
572
+ return user_prompt, sys_prompt, memory_text
573
+
574
+ def generate_history_summary(self, game_history, current_state=None, temperature=0.8, max_tokens=1000, current_inventory=None):
575
+ """
576
+ Generate structured LLM-summarized history context.
577
+ Summarizes the past states and optionally includes the current state.
578
+
579
+ Args:
580
+ game_history: List of game history entries containing state, action, score, reward, etc.
581
+ current_state: The current state to include in the summary (optional)
582
+ llm_model: The LLM model to use for generation
583
+ temperature: Temperature for LLM generation
584
+ max_tokens: Maximum tokens for the summary
585
+ current_inventory: The current inventory to include in the summary (optional)
586
+
587
  Returns:
588
+ tuple: (summary_text, structured_summary)
589
+ - summary_text: Natural language summary under [SUMMARY]
590
+ - structured_summary: Structured sections [PROGRESS], [LOCATION], [NEXT_OBJECTIVE], [INVENTORY]
591
  """
592
+ if not game_history and not current_state:
593
+ return "", ""
594
+
595
+ # Extract states and actions from game_history
596
+ states = [entry.get('state', '') for entry in game_history]
597
+ actions = [entry.get('action', '') for entry in game_history]
598
+ current_step = len(game_history) # Current step index
599
+
600
+ # Build trajectory text for all past steps
601
+ earlier_trajectory_text = ""
602
+ for i in range(current_step):
603
+ earlier_trajectory_text += f"Step {i}:\n"
604
+ if i < len(states):
605
+ earlier_trajectory_text += f"State: {states[i]}\n"
606
+ if i < len(actions) and actions[i]:
607
+ earlier_trajectory_text += f"Action: {actions[i]}\n"
608
+ earlier_trajectory_text += "\n"
609
+
610
+ # Add current state if provided
611
+ if current_state:
612
+ earlier_trajectory_text += f"Step {current_step} (Current):\n"
613
+ earlier_trajectory_text += f"State: {current_state}\n"
614
+
615
+ sys_prompt = """You are an expert at analyzing game trajectories and creating highly distinctive summaries.
616
+ Your milestones must be CONCISE and DISTINCTIVE - use specific keywords that clearly differentiate different game states.
617
+
618
+ Output Format Requirements:
619
+ [SUMMARY]: Provide a natural language summary that describes:
620
+ - The game's objective or goal (inferred from the trajectory)
621
+ - Current progress toward that goal
622
+ - Key accomplishments so far
623
+ - Current situation/status
624
+
625
+ [PROGRESS]: List milestones in format "✓ M#: <action>→<key object/result>"
626
+ - If NO steps have score increases, output ONLY "No Progress" (no milestones)
627
+
628
+ [LOCATION]: Track location changes in format "Location: A→B→C"
629
+ - Extract location names from state descriptions
630
+ - Only record when location actually changes
631
+ - CRITICAL: IGNORE unproductive loops - if the player goes back and forth between locations WITHOUT score increases, skip those redundant location movements
632
+ - Focus on the MEANINGFUL location trajectory that led to progress or new discoveries
633
+ - Example: If player went "room A→room B→room A→room B→room A→room C" with no score increase for the A↔B movements, record as "room A→room C"
634
+
635
+ Milestone Format Rules:
636
+ - Use format: <verb>→<critical object/state>
637
+ - Focus on STATE CHANGES, not descriptions
638
+ - Use specific nouns not generic terms
639
+ - Each milestone should capture ONE concrete action or discovery
640
+ - CRITICAL: Only record steps where the score increased compared to the previous step
641
+
642
+ Examples:
643
+ BAD SUMMARY: "The player has been exploring."
644
+ GOOD SUMMARY: "The game's objective is to escape the haunted mansion. So far, the player has found a key in the library and unlocked the basement door, earning 15 points. Currently in the dark basement, the player needs to find a light source to continue exploring."
645
+
646
+ BAD: "✓ Milestone 1: Entered the library and proceeded to the ground floor stacks"
647
+ GOOD: "✓ M1: enter→library ground floor"
648
+
649
+ Location Examples:
650
+ BAD: "Location: You are in a library→You moved to another area→You are somewhere else"
651
+ GOOD: "Location: entrance→library→north corridor→atrium"
652
+ BAD: "Location: room A→room B→room A→room B→room A" (unproductive loop)
653
+ GOOD: "Location: room A" (ignore the meaningless back-and-forth)
654
+ """
655
+
656
+ user_prompt = f"""Analyze the game trajectory and generate a natural language summary followed by structured sections.
657
+
658
+ Game History:
659
+ {earlier_trajectory_text}
660
+
661
+ Generate summary in this format:
662
+
663
+ [SUMMARY]
664
+ <Natural language paragraph (2-4 sentences) describing: game objective, current progress, key accomplishments, current situation>
665
+
666
+ [PROGRESS]
667
+ ✓ M1: <verb>→<result>
668
+ ✓ M2: <verb>→<result>
669
+ ...
670
+
671
+ [LOCATION]
672
+ Location: A→B→C→...
673
+
674
+ Requirements:
675
+ - SUMMARY: Write a natural language paragraph explaining the game's goal and current status
676
+ - Use specific object names
677
+ - Focus on concrete state changes and discoveries
678
+ - Avoid repetitive phrasing between milestones
679
+ - CRITICAL: ONLY record progress where the score increased from the previous step
680
+ - IMPORTANT: If there are NO steps with score increases, output ONLY "No Progress" under [PROGRESS] section (do not create any milestone items)
681
+ - For locations: Extract actual location names from state descriptions and track when they change
682
+ - Location format must be: "Location: " followed by location names separated by →
683
+ - CRITICAL FOR LOCATION: Detect and REMOVE unproductive loops from the location trajectory
684
+ * Look at the score/reward at each step
685
+ * If the player moves back and forth between locations WITHOUT any score increase, those movements are loops - SKIP them
686
+ * Only include location changes that either: (1) led to score increases, OR (2) represent meaningful forward exploration
687
+ * Example: If steps show "A→B (no reward)→A (no reward)→B (no reward)→A (no reward)→C (reward +5)", output "Location: A→C"
688
+
689
+ Generate the complete summary:"""
690
+
691
+ full_response = call_llm(user_prompt, sys_prompt, temperature=temperature, max_tokens=max_tokens).strip()
692
+
693
+ if full_response == "":
694
+ print("[Warning] No response from LLM for history summary generation")
695
+
696
+ # Extract [SUMMARY] section
697
+ summary_text = ""
698
+ if "[SUMMARY]" in full_response:
699
+ summary_start = full_response.find("[SUMMARY]") + len("[SUMMARY]")
700
+ # Find the next section marker
701
+ next_section = full_response.find("[PROGRESS]", summary_start)
702
+ if next_section == -1:
703
+ next_section = full_response.find("[LOCATION]", summary_start)
704
+
705
+ if next_section != -1:
706
+ summary_text = full_response[summary_start:next_section].strip()
707
+ else:
708
+ summary_text = full_response[summary_start:].strip()
709
+
710
+ # Extract structured sections (everything except [SUMMARY])
711
+ structured_summary = full_response
712
+ if "[SUMMARY]" in full_response:
713
+ # Remove [SUMMARY] section from structured summary
714
+ progress_start = full_response.find("[PROGRESS]")
715
+ if progress_start != -1:
716
+ structured_summary = full_response[progress_start:]
717
+
718
+ # Add inventory section to structured summary if available
719
+ if current_inventory:
720
+ structured_summary += f"\n\n[INVENTORY]\n{current_inventory}"
721
+
722
+ return summary_text, structured_summary
723
+
724
+ def _get_response_format(self) -> dict:
725
+ # Dynamically create JSON Schema to force final choice to be valid option numbers only
726
+ properties = {}
727
+ required_fields = []
728
+ valid_choices = []
729
+
730
+ # Add progress analysis fields
731
+ properties["progress_analysis"] = {
732
+ "type": "string",
733
+ "description": "Analysis of achievements and progress so far"
734
+ }
735
+ properties["next_objective"] = {
736
+ "type": "string",
737
+ "description": "Overall objective for the next steps"
738
+ }
739
+ required_fields.extend(["progress_analysis", "next_objective"])
740
+
741
+ # Dynamically generate reasoning and option fields
742
+ for i in range(1, self.top_actions + 1):
743
+ reasoning_key = f"reasoning{i}"
744
+ option_key = f"option{i}"
745
+
746
+ properties[reasoning_key] = {
747
+ "type": "string",
748
+ "description": f"Reasoning for option {i}"
749
+ }
750
+ properties[option_key] = {
751
+ "type": "string",
752
+ "description": f"Option {i} possible action"
753
+ }
754
+ required_fields.append(reasoning_key)
755
+ required_fields.append(option_key)
756
+ valid_choices.append(i)
757
+
758
+ confidence_key = f"confidence{i}"
759
+ properties[confidence_key] = {
760
+ "type": "integer",
761
+ "minimum": 0,
762
+ "maximum": 100,
763
+ "description": f"Confidence for option {i} (0-100)"
764
+ }
765
+ required_fields.append(confidence_key)
766
+
767
+ # Add best_action field
768
+ properties["best_action"] = {
769
+ "type": "number",
770
+ "minimum": 1,
771
+ "maximum": self.top_actions,
772
+ "description": f"The number of the best option (must be one of: {valid_choices})"
773
+ }
774
+ required_fields.extend([ "best_action"])
775
+
776
+ response_format = {
777
+ "type": "json_schema",
778
+ "json_schema": {
779
+ "name": "game_action_choice",
780
+ "strict": True,
781
+ "schema": {
782
+ "type": "object",
783
+ "properties": properties,
784
+ "required": required_fields,
785
+ "additionalProperties": False
786
+ }
787
+ }
788
+ }
789
+
790
+ return response_format
791
 
792
+ def _parse_response(self, response: str) -> tuple[str, dict]:
793
+ """Parse the LLM response to extract thought, tool, and arguments."""
794
+ # Parse JSON response
795
+ # Clean up markdown code blocks if present
796
+ cleaned_response = response.strip()
797
+ if cleaned_response.startswith('```'):
798
+ # Remove ```json or ``` at the start and ``` at the end
799
+ cleaned_response = re.sub(r'^```(?:json)?\s*\n?', '', cleaned_response)
800
+ cleaned_response = re.sub(r'\n?```\s*$', '', cleaned_response)
801
+
802
+ try:
803
+ json_response = json.loads(cleaned_response)
804
+ except json.JSONDecodeError as e:
805
+ print(f"JSON parsing error: {e}")
806
+ print(f"Raw response: {response}")
807
+ return "look", {}
808
+
809
+ # Extract and print progress analysis
810
+ progress_analysis = json_response.get("progress_analysis", "No analysis provided")
811
+ next_objective = json_response.get("next_objective", "No objective specified")
812
+
813
+ print("\n" + "="*50)
814
+ print("=== PROGRESS ANALYSIS ===")
815
+ print("="*50)
816
+ print(f"Achievements: {progress_analysis}")
817
+ print(f"\nNext Objective: {next_objective}")
818
+ print("="*50 + "\n")
819
+
820
+ # Extract all options with their reasoning and confidence
821
+ options_with_confidences = {}
822
 
823
+ for i in range(1, self.top_actions + 1):
824
+ reasoning_key = f"reasoning{i}"
825
+ option_key = f"option{i}"
826
+ confidence_key = f"confidence{i}"
827
+
828
+ reasoning = json_response.get(reasoning_key, "No reasoning provided")
829
+ option = json_response.get(option_key, "look")
830
+ confidence = json_response.get(confidence_key, 50)
831
+
832
+ options_with_confidences[i] = {
833
+ 'action': option.strip(),
834
+ 'reasoning': reasoning,
835
+ 'confidence': confidence / 100.0
836
+ }
837
+
838
+ # Get the best action
839
+ best_choice = json_response.get("best_action", 1)
840
+ action_text = options_with_confidences.get(best_choice, {}).get('action', "look")
841
+
842
+ return action_text, options_with_confidences
843
+
844
+ def _extract_result(self, result) -> str:
845
+ """Extract text from MCP tool result."""
846
+ if hasattr(result, 'content') and result.content:
847
+ return result.content[0].text
848
+ if isinstance(result, list) and result:
849
+ return result[0].text if hasattr(result[0], 'text') else str(result[0])
850
+ return str(result)
851
+
852
+ def _update_score(self, state: dict) -> None:
853
+ if not state:
854
+ return
855
+ if "score" in state:
856
+ self.score = int(state["score"])
857
+
858
+ def _is_game_over(self, text: str) -> bool:
859
+ """Check if the game is over."""
860
+ game_over_phrases = [
861
+ "game over",
862
+ "you have died",
863
+ "you are dead",
864
+ "*** you have died ***",
865
+ ]
866
+ text_lower = text.lower()
867
+ return any(phrase in text_lower for phrase in game_over_phrases)
868
 
869
 
870
  # =============================================================================
cross_episode_memory.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Cross-Episode Memory for Jericho Agent with Vector Similarity Search
3
+ Uses state embeddings to find similar past experiences and update action selection.
4
+
5
+ Installation:
6
+ pip install sentence-transformers faiss-cpu numpy --break-system-packages
7
+
8
+ Alternative models:
9
+ - 'all-MiniLM-L6-v2' (default, 384 dims, fast)
10
+ - 'all-mpnet-base-v2' (768 dims, better quality)
11
+ - 'paraphrase-multilingual-MiniLM-L12-v2' (multilingual)
12
+ """
13
+
14
+ import os
15
+ import json
16
+ import pickle
17
+ from typing import List, Dict, Any, Optional, Tuple
18
+ from collections import defaultdict, Counter
19
+ import numpy as np
20
+ from dotenv import load_dotenv
21
+ from sentence_transformers import SentenceTransformer
22
+ from huggingface_hub import InferenceClient
23
+ import faiss
24
+ import traceback
25
+ from openai import OpenAI
26
+
27
+ load_dotenv()
28
+
29
+ # Set USE_LOCAL_MODEL=1 in your .env to use a locally downloaded model
30
+ USE_LOCAL_MODEL = os.getenv("USE_LOCAL_MODEL", "0").strip() in ("1", "true", "yes")
31
+ LOCAL_MODEL_ID = os.getenv("LOCAL_MODEL_ID", "Qwen/Qwen2.5-3B-Instruct")
32
+
33
+ # Initialize the LLM client based on mode
34
+ _local_pipeline = None
35
+
36
+ if USE_LOCAL_MODEL:
37
+ import torch
38
+ from transformers import pipeline as _hf_pipeline
39
+
40
+ _local_pipeline = _hf_pipeline(
41
+ "text-generation",
42
+ model=LOCAL_MODEL_ID,
43
+ torch_dtype=torch.bfloat16,
44
+ device_map="auto",
45
+ )
46
+ LLM_CLIENT = None
47
+ else:
48
+ _hf_token = os.getenv("HF_TOKEN")
49
+ if not _hf_token:
50
+ raise ValueError("HF_TOKEN not found. Set it in your .env file.")
51
+ LLM_CLIENT = InferenceClient(token=_hf_token)
52
+
53
+ class CrossEpisodeMemory:
54
+ def __init__(
55
+ self,
56
+ base_dir: str,
57
+ eval_llm_model: str = "Qwen/Qwen2.5-72B-Instruct",
58
+ gamma: float = 0.95,
59
+ embedding_model: str = 'all-MiniLM-L6-v2',
60
+ vector_dim: int = 384
61
+ ):
62
+ self.base_dir = base_dir
63
+ self.gamma = gamma
64
+ self.vector_dim = vector_dim
65
+ os.makedirs(self.base_dir, exist_ok=True)
66
+
67
+ # File paths
68
+ self.episodes_path = os.path.join(self.base_dir, 'episodes.jsonl')
69
+ self.vector_index_path = os.path.join(self.base_dir, 'state_vectors.index')
70
+ self.step_metadata_path = os.path.join(self.base_dir, 'step_metadata.pkl')
71
+
72
+ self.eval_llm_model = eval_llm_model
73
+
74
+ self.query_vector_cache = []
75
+
76
+ # Episode counter
77
+ self.current_episode_number = self._get_episode_count()
78
+
79
+ self.encoder = SentenceTransformer(embedding_model)
80
+ print(f"✓ Loaded embedding model: {embedding_model} (dim: {vector_dim})")
81
+ # Initialize FAISS vector database
82
+ self._init_vector_database()
83
+
84
+ def _init_vector_database(self):
85
+ # Load or create FAISS index
86
+ if os.path.exists(self.vector_index_path):
87
+ self.vector_index = faiss.read_index(self.vector_index_path)
88
+ print(f"✓ Loaded vector index: {self.vector_index.ntotal} vectors")
89
+ else:
90
+ # Use Inner Product for cosine similarity (with normalized vectors)
91
+ self.vector_index = faiss.IndexFlatIP(self.vector_dim)
92
+ print(f"✓ Created new vector index")
93
+
94
+ # Load metadata
95
+ if os.path.exists(self.step_metadata_path):
96
+ with open(self.step_metadata_path, 'rb') as f:
97
+ self.step_metadata = pickle.load(f)
98
+ print(f"✓ Loaded metadata: {len(self.step_metadata)} entries")
99
+ else:
100
+ self.step_metadata = []
101
+
102
+ def _save_vector_database(self):
103
+ """Save FAISS index and metadata to disk."""
104
+ if self.vector_index is not None:
105
+ faiss.write_index(self.vector_index, self.vector_index_path)
106
+ if self.step_metadata:
107
+ with open(self.step_metadata_path, 'wb') as f:
108
+ pickle.dump(self.step_metadata, f)
109
+
110
+ def _get_episode_count(self) -> int:
111
+ """Count existing episodes."""
112
+ if not os.path.exists(self.episodes_path):
113
+ return 0
114
+ with open(self.episodes_path, 'r') as f:
115
+ return sum(1 for _ in f)
116
+
117
+ def _encode_state(self, state_text: str, history_summary: str = ""):
118
+ # Combine current state with history context
119
+ if history_summary:
120
+ combined_text = f"History: {history_summary}\n\nCurrent: {state_text}"
121
+ else:
122
+ combined_text = f"Current: {state_text}"
123
+
124
+ try:
125
+ # Encode and normalize for cosine similarity
126
+ embedding = self.encoder.encode([combined_text])[0]
127
+ embedding = embedding / (np.linalg.norm(embedding) + 1e-8)
128
+ return embedding.astype('float32')
129
+ except Exception as e:
130
+ print(f"Encoding error: {e}")
131
+ return None
132
+
133
+ # def _encode_state(self, state_text: str, location: str = ""):
134
+ # # Combine current state with history context
135
+ # if location:
136
+ # combined_text = f"location: {location}\n\nCurrent: {state_text}"
137
+ # else:
138
+ # combined_text = f"Current: {state_text}"
139
+
140
+ # try:
141
+ # # Encode and normalize for cosine similarity
142
+ # embedding = self.encoder.encode([combined_text])[0]
143
+ # embedding = embedding / (np.linalg.norm(embedding) + 1e-8)
144
+ # return embedding.astype('float32')
145
+ # except Exception as e:
146
+ # print(f"Encoding error: {e}")
147
+ # return None
148
+
149
+ def add_episode(
150
+ self,
151
+ game_history: List[Dict[str, Any]],
152
+ final_score: float,
153
+ state: str,
154
+ success: bool = False
155
+ ):
156
+ self.current_episode_number += 1
157
+
158
+ # Save episode metadata
159
+ episode_data = {
160
+ 'episode_number': self.current_episode_number,
161
+ 'final_score': final_score,
162
+ 'success': success,
163
+ 'steps': len(game_history),
164
+ 'history': game_history
165
+ }
166
+
167
+ with open(self.episodes_path, 'a') as f:
168
+ f.write(json.dumps(episode_data) + '\n')
169
+
170
+
171
+ llm_step_scores, step_reasonings = self.evaluate_step_scores_with_llm(
172
+ game_history=game_history,
173
+ state=state,
174
+ final_score=final_score,
175
+ success=success,
176
+ temperature=0.8
177
+ )
178
+ try:
179
+ assert len(llm_step_scores) == len(game_history)
180
+ except Exception:
181
+ breakpoint()
182
+
183
+ self._add_steps_to_vector_db(game_history, final_score, llm_step_scores)
184
+
185
+ # Save vector database
186
+ self._save_vector_database()
187
+
188
+ def _add_steps_to_vector_db(self, game_history: List[Dict], final_score: float, llm_step_scores):
189
+ n = len(llm_step_scores)
190
+
191
+ for i, step in enumerate(game_history):
192
+ # Get current state
193
+ state_text = step.get('state', '')
194
+ action = step.get('action', '')
195
+
196
+ # Encode state with history
197
+ vector = self.query_vector_cache[i]
198
+ if vector is None:
199
+ continue
200
+
201
+ # Calculate discounted future reward from this step
202
+ discounted_reward = 0.0
203
+ for j in range(i, n):
204
+ if step.get('reward', 0) > 0:
205
+ reward = max(llm_step_scores[j], step.get('reward', 0))
206
+ else:
207
+ reward = llm_step_scores[j]
208
+ discounted_reward += (self.gamma ** (j - i)) * reward
209
+
210
+ # Store metadata for this step
211
+ metadata = {
212
+ 'episode_number': self.current_episode_number,
213
+ 'step_index': i,
214
+ 'state': state_text,
215
+ 'action': action,
216
+ 'reward': step.get('reward', 0),
217
+ 'score': step.get('score', 0),
218
+ 'llm_step_score': llm_step_scores[i] if i < len(llm_step_scores) else 0,
219
+ 'discounted_reward': discounted_reward,
220
+ 'episode_final_score': final_score
221
+ }
222
+
223
+ # Add to vector index and metadata list
224
+ self.vector_index.add(np.array([vector]))
225
+ self.step_metadata.append(metadata)
226
+
227
+ def retrieve_similar(
228
+ self,
229
+ current_state: str,
230
+ game_history: List[Dict],
231
+ current_summary: str,
232
+ k: int = 5,
233
+ r: float = 0.9
234
+ ):
235
+ # Encode current state with history
236
+ query_vector = self._encode_state(current_state, current_summary)
237
+
238
+ if query_vector is None:
239
+ return []
240
+
241
+ self.query_vector_cache.append(query_vector)
242
+
243
+ if self.vector_index is None:
244
+ print("Dual vector database not available")
245
+ return []
246
+ if self.vector_index.ntotal == 0:
247
+ print("Dual vector database is empty")
248
+ return []
249
+
250
+ # Phase 1: Recall top candidates using vector similarity (fast)
251
+ # Use 10x more candidates than final k to ensure good coverage
252
+ recall_size = min(max(k * 10, 100), self.vector_index.ntotal)
253
+
254
+ scores, indices = self.vector_index.search(query_vector.reshape(1, -1), recall_size)
255
+
256
+ high_similarity_trajectories = [] # similarity > r
257
+ medium_similarity_trajectories = [] # similarity <= r
258
+
259
+ for score, idx in zip(scores[0], indices[0]):
260
+ if idx < 0 or idx >= len(self.step_metadata):
261
+ continue
262
+
263
+ metadata = self.step_metadata[idx]
264
+
265
+ if score > r:
266
+ high_similarity_trajectories.append((score, metadata["discounted_reward"], metadata))
267
+ else:
268
+ medium_similarity_trajectories.append((score, metadata["discounted_reward"], metadata))
269
+
270
+ # Sort both lists by similarity first, then by discounted reward
271
+ high_similarity_trajectories.sort(key=lambda x: (x[0], x[1]), reverse=True)
272
+ medium_similarity_trajectories.sort(key=lambda x: (x[0], x[1]), reverse=True)
273
+
274
+ # Combine results: first all high similarity (>0.98), then top k from medium similarity if needed
275
+ filtered_trajectories = high_similarity_trajectories[:]
276
+ if len(filtered_trajectories) < k:
277
+ remaining_needed = k - len(filtered_trajectories)
278
+ filtered_trajectories.extend(medium_similarity_trajectories[:remaining_needed])
279
+
280
+ return [x[2] for x in filtered_trajectories[:k]]
281
+
282
+ def evaluate_step_scores_with_llm(self, game_history, state, final_score, success, temperature=0.8):
283
+ """
284
+ Use LLM to assign scores for each step in the trajectory.
285
+
286
+ Args:
287
+ game_history: List of game history records
288
+ final_score: Final score
289
+ success: Whether successful
290
+ llm_model: LLM model to use
291
+ temperature: LLM temperature parameter
292
+
293
+ Returns:
294
+ list: Score list for each step
295
+ """
296
+ if not game_history:
297
+ return []
298
+
299
+ trajectory_text = ""
300
+ for i, entry in enumerate(game_history):
301
+ trajectory_text += f"Step {i}:\n"
302
+ trajectory_text += f"Current state: {entry.get('old_state', '')}\n"
303
+ if entry.get('action'):
304
+ trajectory_text += f"Action: {entry.get('action')}\n"
305
+ trajectory_text += f"Result state: {entry.get('state', '')}\n"
306
+ score = entry.get('score', 0)
307
+ reward = entry.get('reward', 0)
308
+ trajectory_text += f"Action reward: {reward}\n"
309
+ trajectory_text += "\n"
310
+ trajectory_text += f"Step: {len(game_history)}:\nState: {state}\n"
311
+ sys_prompt = """You are scoring game actions to build training data for future gameplay.
312
+
313
+ PURPOSE: Rate each action based on its overall impact - positive scores for actions worth repeating, negative for actions to avoid.
314
+
315
+ SCORING RULES:
316
+ - Positive: Action led to progress, a reward, useful discoveries or new locations discovered
317
+ - Negative: Action wasted time, caused loops, or had no benefit
318
+ - Magnitude: Match the game's typical reward scale (calibrate based on rewards shown in trajectory)
319
+ - Evaluation: Judge by full consequence chain, not just immediate result
320
+ """
321
+
322
+ user_prompt = f"""Score each action in this game session.
323
+
324
+ Final Result: {"SUCCESS" if success else "FAILURE"}, Final Score: {final_score}
325
+
326
+ Trajectory:
327
+ {trajectory_text}
328
+
329
+ JSON FORMAT:
330
+ {{
331
+ "step_analysis": [
332
+ {{
333
+ "step": 0,
334
+ "action": "exact action taken",
335
+ "detailed_reasoning": "What happened after this action and its consequences? Why should this be repeated or avoided?",
336
+ "score": 5,
337
+ }}
338
+ ],
339
+ }}
340
+
341
+ Provide complete JSON response:"""
342
+
343
+ try:
344
+ messages = [
345
+ {"role": "system", "content": sys_prompt},
346
+ {"role": "user", "content": user_prompt},
347
+ ]
348
+
349
+ if USE_LOCAL_MODEL and _local_pipeline is not None:
350
+ outputs = _local_pipeline(
351
+ messages,
352
+ max_new_tokens=max_tokens,
353
+ temperature=temperature,
354
+ do_sample=True,
355
+ )
356
+ scores_text = outputs[0]["generated_text"][-1]["content"]
357
+
358
+ response = LLM_CLIENT.chat.completions.create(
359
+ model=self.eval_llm_model,
360
+ messages=messages,
361
+ temperature=temperature,
362
+ max_tokens=8192,
363
+ )
364
+ scores_text = response.choices[0].message.content.strip()
365
+ except Exception as e:
366
+ print(getattr(e, 'response', None) and e.response.text)
367
+ print(traceback.format_exc())
368
+ print(f"[Error] LLM call failed: {e}")
369
+ print("[Warning] Falling back to reward-based scoring")
370
+ return [entry.get('reward', 0) for entry in game_history], ["No reasoning available" for _ in game_history]
371
+
372
+ if scores_text:
373
+ if scores_text.startswith('```json'):
374
+ scores_text = scores_text.replace('```json', '').replace('```', '').strip()
375
+ elif scores_text.startswith('```'):
376
+ scores_text = scores_text.replace('```', '').strip()
377
+
378
+ try:
379
+ response_data = json.loads(scores_text)
380
+ except json.JSONDecodeError as e:
381
+ print(f"[Warning] JSON decode error: {e}")
382
+ print(f"[Warning] Problematic JSON text (first 500 chars): {scores_text[:500]}")
383
+ print(f"[Warning] Falling back to reward-based scoring")
384
+ return [entry.get('reward', 0) for entry in game_history], ["No reasoning available" for _ in game_history]
385
+
386
+ if isinstance(response_data, dict):
387
+ if 'step_analysis' in response_data:
388
+ step_analyses = response_data['step_analysis']
389
+
390
+ analysis_dict = {}
391
+ for analysis in step_analyses:
392
+ step_num = analysis.get('step', -1)
393
+ if step_num >= 0:
394
+ analysis_dict[step_num] = analysis
395
+
396
+ scores = []
397
+ reasonings = []
398
+ missing_steps = []
399
+
400
+ print("\n=== LLM Step Analysis ===")
401
+ for i in range(len(game_history)):
402
+ if i in analysis_dict:
403
+ analysis = analysis_dict[i]
404
+ step_score = analysis.get('score', 0)
405
+ reasoning = analysis.get('detailed_reasoning', 'No reasoning')
406
+ scores.append(step_score)
407
+ reasonings.append(reasoning)
408
+ print(f"Step {i}: {analysis.get('action', 'Unknown')}")
409
+ print(f" Reasoning: {reasoning}")
410
+ print(f" Score: {step_score}")
411
+ print()
412
+ else:
413
+ scores.append(0)
414
+ reasonings.append('Missing from LLM analysis - filled with score 0')
415
+ missing_steps.append(i)
416
+ print(f"Step {i}: [MISSING FROM LLM ANALYSIS - FILLED WITH 0]")
417
+ print(f" Action: {game_history[i].get('action', 'Unknown')}")
418
+ print(f" Score: 0 (auto-filled)")
419
+ print()
420
+
421
+ if missing_steps:
422
+ print(f"[Warning] LLM missed analyzing steps: {missing_steps}")
423
+ print(f"[Info] Filled missing steps with score 0")
424
+
425
+ if 'overall_assessment' in response_data:
426
+ print(f"Overall Assessment: {response_data['overall_assessment']}")
427
+ print("=" * 50)
428
+ else:
429
+ print(f"[Warning] No step_analysis found in JSON format: {response_data}")
430
+ return [entry.get('reward', 0) for entry in game_history], ["No reasoning available" for _ in game_history]
431
+ elif isinstance(response_data, list):
432
+ scores = response_data
433
+ reasonings = ["Legacy format - no reasoning available" for _ in game_history]
434
+ else:
435
+ print(f"[Warning] Unexpected JSON format: {response_data}")
436
+ return [entry.get('reward', 0) for entry in game_history], ["No reasoning available" for _ in game_history]
437
+
438
+
439
+ return scores, reasonings
440
+ else:
441
+ print("[Warning] No valid response from LLM for step scoring")
442
+ return [entry.get('reward', 0) for entry in game_history], ["No reasoning available" for _ in game_history]
mcp_server.py CHANGED
@@ -26,6 +26,8 @@ Then open the MCP Inspector in your browser to test the tools interactively.
26
 
27
  import sys
28
  import os
 
 
29
 
30
  # Add parent directory to path to import games module
31
  sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
@@ -60,9 +62,9 @@ class GameManager:
60
  self.state = None
61
  self.game_name: str = ""
62
  # TODO: Add more state tracking
63
- # self.history: list[tuple[str, str]] = []
64
- # self.explored_locations: dict[str, set[str]] = {}
65
- # self.current_location: str = ""
66
 
67
  def initialize(self, game: str = "zork1"):
68
  """Initialize or reset the game."""
@@ -71,7 +73,21 @@ class GameManager:
71
  self.state = self.env.reset()
72
  # TODO: Reset your state tracking here
73
  return self.state.observation
74
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  def step(self, action: str) -> str:
76
  """Execute an action and return the result."""
77
  if self.env is None:
@@ -135,12 +151,15 @@ def play_action(action: str) -> str:
135
  # TODO: You might want to add action validation here
136
  # TODO: You might want to include score changes in the response
137
 
138
- result = game.step(action)
139
 
140
  # Optional: Append score info
141
  # result += f"\n[Score: {game.get_score()} | Moves: {game.get_moves()}]"
142
 
143
- return result
 
 
 
144
 
145
 
146
  # TODO: Implement additional tools to help your agent
@@ -158,17 +177,17 @@ def play_action(action: str) -> str:
158
  # pass
159
 
160
 
161
- # @mcp.tool()
162
- # def inventory() -> str:
163
- # """
164
- # Check what the player is carrying.
165
- #
166
- # Returns:
167
- # List of items in the player's inventory
168
- # """
169
- # game = get_game()
170
- # result = game.step("inventory")
171
- # return result
172
 
173
 
174
  # @mcp.tool()
@@ -184,21 +203,52 @@ def play_action(action: str) -> str:
184
  # pass
185
 
186
 
187
- # @mcp.tool()
188
- # def get_valid_actions() -> str:
189
- # """
190
- # Get a list of likely valid actions from the current location.
191
- #
192
- # Returns:
193
- # List of actions that might work here
194
- # """
195
- # # This is a hint: Jericho provides get_valid_actions()
196
- # game = get_game()
197
- # if game.env and game.env.env:
198
- # valid = game.env.env.get_valid_actions()
199
- # return "Valid actions: " + ", ".join(valid[:20])
200
- # return "Could not determine valid actions"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
 
202
 
203
  # =============================================================================
204
  # Run the server
 
26
 
27
  import sys
28
  import os
29
+ import json
30
+ from dataclasses import asdict
31
 
32
  # Add parent directory to path to import games module
33
  sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
 
62
  self.state = None
63
  self.game_name: str = ""
64
  # TODO: Add more state tracking
65
+ self.history: list[tuple[str, str]] = []
66
+ self.explored_locations: dict[str, set[str]] = {}
67
+ self.current_location: str = ""
68
 
69
  def initialize(self, game: str = "zork1"):
70
  """Initialize or reset the game."""
 
73
  self.state = self.env.reset()
74
  # TODO: Reset your state tracking here
75
  return self.state.observation
76
+
77
+ def get_full_state(self, include_valid_actions: bool = True) -> dict:
78
+ """
79
+ Convert current env GameState (dataclass from TextAdventureEnv) to a JSON-ready dict.
80
+ """
81
+ data = asdict(self.state)
82
+
83
+ if include_valid_actions:
84
+ try:
85
+ data["valid_actions"] = self.env.get_valid_actions()
86
+ except Exception:
87
+ data["valid_actions"] = []
88
+
89
+ return data
90
+
91
  def step(self, action: str) -> str:
92
  """Execute an action and return the result."""
93
  if self.env is None:
 
151
  # TODO: You might want to add action validation here
152
  # TODO: You might want to include score changes in the response
153
 
154
+ game.step(action)
155
 
156
  # Optional: Append score info
157
  # result += f"\n[Score: {game.get_score()} | Moves: {game.get_moves()}]"
158
 
159
+ payload = {
160
+ "state": game.get_full_state(include_valid_actions=True),
161
+ }
162
+ return json.dumps(payload)
163
 
164
 
165
  # TODO: Implement additional tools to help your agent
 
177
  # pass
178
 
179
 
180
+ @mcp.tool()
181
+ def inventory() -> str:
182
+ """
183
+ Check what the player is carrying.
184
+
185
+ Returns:
186
+ List of items in the player's inventory
187
+ """
188
+ game = get_game()
189
+ result = game.step("inventory")
190
+ return result
191
 
192
 
193
  # @mcp.tool()
 
203
  # pass
204
 
205
 
206
+ @mcp.tool()
207
+ def get_valid_actions() -> str:
208
+ """
209
+ Get a list of likely valid actions from the current location.
210
+
211
+ Returns:
212
+ List of actions that might work here
213
+ """
214
+ # This is a hint: Jericho provides get_valid_actions()
215
+ game = get_game()
216
+ if game.env and game.env.env:
217
+ valid = game.env.env.get_valid_actions()
218
+ return "Valid actions: " + ", ".join(valid[:20])
219
+ return "Could not determine valid actions"
220
+
221
+ @mcp.tool()
222
+ def reset_game() -> str:
223
+ """
224
+ Reset the game to the beginning or switch to a different game.
225
+
226
+ Use this to start over if you get stuck, die, or want to try a different game.
227
+
228
+ Args:
229
+ game: Game name (e.g., 'zork1', 'zork2', 'advent', 'enchanter')
230
+ Use list_games() to see available options.
231
+
232
+ Returns:
233
+ The initial game text
234
+ """
235
+ global _game
236
+ try:
237
+ if _game.env is None:
238
+ game = get_game()
239
+ else:
240
+ game_name = os.environ.get("GAME", "zork1")
241
+ _game.initialize(game_name)
242
+ game = _game
243
+
244
+ payload = {
245
+ "state": game.get_full_state(include_valid_actions=True),
246
+ }
247
+ return json.dumps(payload)
248
+
249
+ except ValueError as e:
250
 
251
+ return json.dumps({"error": str(e)})
252
 
253
  # =============================================================================
254
  # Run the server
requirements.txt CHANGED
@@ -7,3 +7,9 @@
7
  # Add any additional packages your agent needs below:
8
  # numpy
9
  # requests
 
 
 
 
 
 
 
7
  # Add any additional packages your agent needs below:
8
  # numpy
9
  # requests
10
+
11
+ python-dotenv
12
+ spacy
13
+
14
+ faiss
15
+ sentence-transformers