Spaces:
Sleeping
Sleeping
Commit ·
b59a6e3
1
Parent(s): 615a63b
My submission agent using just in time RL for action choosing
Browse files- agent.py +729 -135
- cross_episode_memory.py +442 -0
- mcp_server.py +81 -31
- requirements.txt +6 -0
agent.py
CHANGED
|
@@ -26,12 +26,16 @@ Tips:
|
|
| 26 |
import json
|
| 27 |
import os
|
| 28 |
import re
|
|
|
|
|
|
|
| 29 |
from dataclasses import dataclass, field
|
| 30 |
from typing import Optional
|
| 31 |
|
| 32 |
from dotenv import load_dotenv
|
| 33 |
from huggingface_hub import InferenceClient
|
| 34 |
|
|
|
|
|
|
|
| 35 |
# Load environment variables
|
| 36 |
load_dotenv()
|
| 37 |
|
|
@@ -67,7 +71,7 @@ else:
|
|
| 67 |
LLM_CLIENT = InferenceClient(token=_hf_token)
|
| 68 |
|
| 69 |
|
| 70 |
-
def call_llm(prompt: str, system_prompt: str, seed: int, max_tokens: int =
|
| 71 |
"""
|
| 72 |
Call the LLM with the given prompt. Use this function in your agent.
|
| 73 |
|
|
@@ -79,35 +83,32 @@ def call_llm(prompt: str, system_prompt: str, seed: int, max_tokens: int = 300)
|
|
| 79 |
|
| 80 |
Returns:
|
| 81 |
The LLM's response text
|
| 82 |
-
|
| 83 |
-
Example:
|
| 84 |
-
response = call_llm(
|
| 85 |
-
prompt="You are in a forest. What do you do?",
|
| 86 |
-
system_prompt=SYSTEM_PROMPT,
|
| 87 |
-
seed=42,
|
| 88 |
-
)
|
| 89 |
"""
|
| 90 |
messages = [
|
| 91 |
{"role": "system", "content": system_prompt},
|
| 92 |
{"role": "user", "content": prompt},
|
| 93 |
]
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
)
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
messages=messages,
|
| 107 |
-
temperature=0.0, # Deterministic for reproducibility
|
| 108 |
-
max_tokens=max_tokens,
|
| 109 |
-
seed=seed,
|
| 110 |
-
)
|
| 111 |
|
| 112 |
return response.choices[0].message.content
|
| 113 |
|
|
@@ -124,153 +125,746 @@ class RunResult:
|
|
| 124 |
history: list[tuple[str, str, str]] = field(default_factory=list)
|
| 125 |
|
| 126 |
|
| 127 |
-
# =============================================================================
|
| 128 |
-
# System Prompt - Customize this for your agent
|
| 129 |
-
# =============================================================================
|
| 130 |
-
|
| 131 |
-
SYSTEM_PROMPT = """You are playing a classic text adventure game.
|
| 132 |
-
|
| 133 |
-
GOAL: Explore the world, solve puzzles, and maximize your score.
|
| 134 |
-
|
| 135 |
-
AVAILABLE TOOLS (use via MCP):
|
| 136 |
-
- play_action: Execute a game command (north, take lamp, open mailbox, etc.)
|
| 137 |
-
- memory: Get current game state and history (if implemented)
|
| 138 |
-
- inventory: Check what you're carrying (if implemented)
|
| 139 |
-
|
| 140 |
-
VALID GAME COMMANDS for play_action:
|
| 141 |
-
- Movement: north, south, east, west, up, down, enter, exit
|
| 142 |
-
- Objects: take <item>, drop <item>, open <thing>, close <thing>, examine <thing>
|
| 143 |
-
- Other: look, inventory, read <thing>, turn on lamp
|
| 144 |
-
|
| 145 |
-
RESPOND IN THIS EXACT FORMAT (no markdown):
|
| 146 |
-
THOUGHT: <your reasoning about what to do next>
|
| 147 |
-
TOOL: <tool_name>
|
| 148 |
-
ARGS: <JSON arguments, e.g., {"action": "look"}>
|
| 149 |
-
|
| 150 |
-
Example:
|
| 151 |
-
THOUGHT: I should look around to see where I am.
|
| 152 |
-
TOOL: play_action
|
| 153 |
-
ARGS: {"action": "look"}
|
| 154 |
-
"""
|
| 155 |
-
|
| 156 |
-
|
| 157 |
# =============================================================================
|
| 158 |
# Student Agent - IMPLEMENT THIS CLASS
|
| 159 |
# =============================================================================
|
| 160 |
|
| 161 |
class StudentAgent:
|
| 162 |
"""
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
TODO:
|
| 166 |
-
1. Implement the run() method with the ReAct loop
|
| 167 |
-
2. Parse LLM responses to extract tool calls
|
| 168 |
-
3. Track state and avoid loops
|
| 169 |
|
| 170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
"""
|
| 172 |
|
| 173 |
-
def __init__(self):
|
| 174 |
-
"""Initialize
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
|
| 180 |
async def run(
|
| 181 |
self,
|
| 182 |
-
client,
|
| 183 |
game: str,
|
| 184 |
max_steps: int,
|
| 185 |
seed: int,
|
| 186 |
verbose: bool = False,
|
| 187 |
) -> RunResult:
|
| 188 |
-
"""
|
| 189 |
-
|
|
|
|
|
|
|
| 190 |
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
max_steps: Maximum number of steps to take
|
| 195 |
-
seed: Random seed for reproducibility (use for LLM calls)
|
| 196 |
-
verbose: Whether to print detailed output
|
| 197 |
-
|
| 198 |
-
Returns:
|
| 199 |
-
RunResult with final score and statistics
|
| 200 |
-
"""
|
| 201 |
-
# TODO: Implement your ReAct loop here
|
| 202 |
-
#
|
| 203 |
-
# Basic structure:
|
| 204 |
-
# 1. Get initial observation (call play_action with "look")
|
| 205 |
-
# 2. Loop for max_steps:
|
| 206 |
-
# a. Build prompt with current observation and history
|
| 207 |
-
# b. Call LLM to get thought and action
|
| 208 |
-
# c. Parse the response to extract tool and args
|
| 209 |
-
# d. Call the tool via client.call_tool(tool_name, args)
|
| 210 |
-
# e. Update history and state
|
| 211 |
-
# f. Check for game over
|
| 212 |
-
# 3. Return RunResult with final statistics
|
| 213 |
|
| 214 |
-
#
|
| 215 |
-
|
| 216 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
|
| 218 |
-
#
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
# system_prompt=SYSTEM_PROMPT,
|
| 222 |
-
# seed=seed,
|
| 223 |
-
# )
|
| 224 |
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
history = []
|
| 228 |
-
final_score = 0
|
| 229 |
-
moves = 0
|
| 230 |
|
| 231 |
-
#
|
| 232 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
|
| 234 |
return RunResult(
|
| 235 |
-
final_score=
|
| 236 |
-
max_score=350,
|
| 237 |
moves=moves,
|
| 238 |
locations_visited=locations_visited,
|
| 239 |
-
game_completed=
|
| 240 |
history=history,
|
| 241 |
)
|
| 242 |
|
| 243 |
-
def
|
| 244 |
-
|
| 245 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
|
| 256 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
Returns:
|
| 259 |
-
|
|
|
|
|
|
|
| 260 |
"""
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
#
|
| 265 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
|
| 267 |
-
def
|
| 268 |
-
"""
|
| 269 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
|
| 275 |
|
| 276 |
# =============================================================================
|
|
|
|
| 26 |
import json
|
| 27 |
import os
|
| 28 |
import re
|
| 29 |
+
import random
|
| 30 |
+
import time
|
| 31 |
from dataclasses import dataclass, field
|
| 32 |
from typing import Optional
|
| 33 |
|
| 34 |
from dotenv import load_dotenv
|
| 35 |
from huggingface_hub import InferenceClient
|
| 36 |
|
| 37 |
+
from cross_episode_memory import CrossEpisodeMemory
|
| 38 |
+
|
| 39 |
# Load environment variables
|
| 40 |
load_dotenv()
|
| 41 |
|
|
|
|
| 71 |
LLM_CLIENT = InferenceClient(token=_hf_token)
|
| 72 |
|
| 73 |
|
| 74 |
+
def call_llm(prompt: str, system_prompt: str, seed: int = 42, temperature: int = 0.0001, max_tokens: int = 1000) -> str:
|
| 75 |
"""
|
| 76 |
Call the LLM with the given prompt. Use this function in your agent.
|
| 77 |
|
|
|
|
| 83 |
|
| 84 |
Returns:
|
| 85 |
The LLM's response text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
"""
|
| 87 |
messages = [
|
| 88 |
{"role": "system", "content": system_prompt},
|
| 89 |
{"role": "user", "content": prompt},
|
| 90 |
]
|
| 91 |
+
try:
|
| 92 |
+
if USE_LOCAL_MODEL and _local_pipeline is not None:
|
| 93 |
+
outputs = _local_pipeline(
|
| 94 |
+
messages,
|
| 95 |
+
max_new_tokens=max_tokens,
|
| 96 |
+
temperature=temperature, # Near-deterministic (0.0 unsupported by some backends)
|
| 97 |
+
do_sample=True,
|
| 98 |
+
)
|
| 99 |
+
return outputs[0]["generated_text"][-1]["content"]
|
| 100 |
+
|
| 101 |
+
response = LLM_CLIENT.chat.completions.create(
|
| 102 |
+
model=LLM_MODEL,
|
| 103 |
+
messages=messages,
|
| 104 |
+
temperature=temperature, # Deterministic for reproducibility
|
| 105 |
+
max_tokens=max_tokens,
|
| 106 |
+
seed=seed,
|
| 107 |
)
|
| 108 |
+
|
| 109 |
+
except Exception as e:
|
| 110 |
+
print(f"[LLM Error] {e}")
|
| 111 |
+
return "THOUGHT: LLM error, trying look."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
return response.choices[0].message.content
|
| 114 |
|
|
|
|
| 125 |
history: list[tuple[str, str, str]] = field(default_factory=list)
|
| 126 |
|
| 127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
# =============================================================================
|
| 129 |
# Student Agent - IMPLEMENT THIS CLASS
|
| 130 |
# =============================================================================
|
| 131 |
|
| 132 |
class StudentAgent:
|
| 133 |
"""
|
| 134 |
+
MCP ReAct Agent - A complete working example.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
+
This agent demonstrates:
|
| 137 |
+
- ReAct loop (Thought -> Tool -> Observation)
|
| 138 |
+
- Loop detection
|
| 139 |
+
- Action validation
|
| 140 |
+
- Score tracking via memory tool
|
| 141 |
"""
|
| 142 |
|
| 143 |
+
def __init__(self, game: str = "zork1", guiding_prompt: str = None, top_actions: int = 5, exploration_alpha: float = 0.4, gamma: float = 0.95):
|
| 144 |
+
"""Initialize the agent state."""
|
| 145 |
+
self.history: list[dict] = []
|
| 146 |
+
self.recent_actions: list[str] = []
|
| 147 |
+
self.score: int = 0
|
| 148 |
+
self.guiding_prompt = guiding_prompt or """Explore the environment and try to maximize your score."""
|
| 149 |
+
self.top_actions = top_actions
|
| 150 |
+
|
| 151 |
+
# Create game-specific memory directory
|
| 152 |
+
game_memory_dir = os.path.join("memory", game)
|
| 153 |
+
self.cross_mem = CrossEpisodeMemory(base_dir=game_memory_dir, eval_llm_model=LLM_MODEL, gamma=gamma)
|
| 154 |
+
self.exploration_alpha = exploration_alpha
|
| 155 |
|
| 156 |
async def run(
|
| 157 |
self,
|
| 158 |
+
client,
|
| 159 |
game: str,
|
| 160 |
max_steps: int,
|
| 161 |
seed: int,
|
| 162 |
verbose: bool = False,
|
| 163 |
) -> RunResult:
|
| 164 |
+
"""Run the agent for a game session."""
|
| 165 |
+
locations_visited = set()
|
| 166 |
+
history = []
|
| 167 |
+
moves = 0
|
| 168 |
|
| 169 |
+
# Get list of available tools
|
| 170 |
+
tools = await client.list_tools()
|
| 171 |
+
tool_names = [t.name for t in tools]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
|
| 173 |
+
# Get initial observation
|
| 174 |
+
result = await client.call_tool("reset_game", {})
|
| 175 |
+
raw = self._extract_result(result)
|
| 176 |
+
|
| 177 |
+
payload = json.loads(raw)
|
| 178 |
+
state = payload["state"] # <-- dict
|
| 179 |
+
observation = state["observation"]
|
| 180 |
|
| 181 |
+
# Track initial location
|
| 182 |
+
location = observation.split("\n")[0] if observation else "Unknown"
|
| 183 |
+
locations_visited.add(location)
|
|
|
|
|
|
|
|
|
|
| 184 |
|
| 185 |
+
if verbose:
|
| 186 |
+
print(f"\n{observation}")
|
|
|
|
|
|
|
|
|
|
| 187 |
|
| 188 |
+
# Parse initial score
|
| 189 |
+
self._update_score(state)
|
| 190 |
+
|
| 191 |
+
# Main ReAct loop
|
| 192 |
+
for step in range(1, max_steps + 1):
|
| 193 |
+
# Build prompt with context
|
| 194 |
+
user_prompt, sys_prompt, memory_text = self._build_prompt(state)
|
| 195 |
+
|
| 196 |
+
# Get response format
|
| 197 |
+
response_format = self._get_response_format()
|
| 198 |
+
|
| 199 |
+
# Call LLM for reasoning (use step-based seed for variety)
|
| 200 |
+
response = call_llm(user_prompt, sys_prompt, seed=seed + step, response_format=response_format)
|
| 201 |
+
|
| 202 |
+
# Parse the response
|
| 203 |
+
action, options_with_confidences = self._parse_response(response)
|
| 204 |
+
|
| 205 |
+
# Update action scores with memory
|
| 206 |
+
updated_options_with_logits = self.update_scores_with_memory(
|
| 207 |
+
observation,
|
| 208 |
+
options_with_confidences,
|
| 209 |
+
k=10,
|
| 210 |
+
r=0.9,
|
| 211 |
+
memory_text=memory_text)
|
| 212 |
+
|
| 213 |
+
if updated_options_with_logits:
|
| 214 |
+
# Select action with highest corrected confidence
|
| 215 |
+
valid_actions = state.get("valid_actions", []) or []
|
| 216 |
+
updated_options_valid_items = []
|
| 217 |
+
for option in updated_options_with_logits.items():
|
| 218 |
+
if option[1]['action'] in valid_actions:
|
| 219 |
+
updated_options_valid_items.append(option)
|
| 220 |
+
|
| 221 |
+
best_option = max(
|
| 222 |
+
updated_options_valid_items,
|
| 223 |
+
key=lambda x: x[1].get('corrected_confidence', 0)
|
| 224 |
+
)
|
| 225 |
+
action = best_option[1]['action']
|
| 226 |
+
|
| 227 |
+
if verbose:
|
| 228 |
+
print(f"\n--- Step {step} | Score: {self.score} ---")
|
| 229 |
+
print(f"[RAW_LLM_OUTPUT]\n{response}")
|
| 230 |
+
print(f"[ACTION] {action}")
|
| 231 |
+
|
| 232 |
+
# Loop detection
|
| 233 |
+
self.recent_actions.append(action)
|
| 234 |
+
if len(self.recent_actions) > 5:
|
| 235 |
+
self.recent_actions = self.recent_actions[-5:]
|
| 236 |
+
|
| 237 |
+
# Detect loops - if same action 3 times, force "look"
|
| 238 |
+
if len(self.recent_actions) >= 3 and len(set(self.recent_actions[-3:])) == 1:
|
| 239 |
+
if verbose:
|
| 240 |
+
print(f"[WARNING] Loop detected - forcing 'look'")
|
| 241 |
+
action = "look"
|
| 242 |
+
self.recent_actions.append(action)
|
| 243 |
+
|
| 244 |
+
moves += 1
|
| 245 |
+
|
| 246 |
+
# Update history
|
| 247 |
+
self.history.append({
|
| 248 |
+
"old_state": observation,
|
| 249 |
+
"action": action,
|
| 250 |
+
"full_response": response,
|
| 251 |
+
"reward": None,
|
| 252 |
+
"score": None,
|
| 253 |
+
})
|
| 254 |
+
|
| 255 |
+
# Execute the tool
|
| 256 |
+
try:
|
| 257 |
+
result = await client.call_tool("play_action", {"action": action})
|
| 258 |
+
raw = self._extract_result(result)
|
| 259 |
+
payload = json.loads(raw)
|
| 260 |
+
state = payload["state"]
|
| 261 |
+
observation = state["observation"]
|
| 262 |
+
|
| 263 |
+
if verbose:
|
| 264 |
+
print(f"[RESULT] {observation[:200]}...")
|
| 265 |
+
except Exception as e:
|
| 266 |
+
observation = f"Error: {e}"
|
| 267 |
+
if verbose:
|
| 268 |
+
print(f"[ERROR] {e}")
|
| 269 |
+
|
| 270 |
+
# Track location
|
| 271 |
+
location = observation.split("\n")[0] if observation else "Unknown"
|
| 272 |
+
locations_visited.add(location)
|
| 273 |
+
|
| 274 |
+
old_score = self.score
|
| 275 |
+
|
| 276 |
+
# Track score from observation
|
| 277 |
+
self._update_score(state)
|
| 278 |
+
|
| 279 |
+
self.history[-1]["reward"] = self.score - old_score
|
| 280 |
+
self.history[-1]["score"] = self.score
|
| 281 |
+
self.history[-1]["state"] = observation
|
| 282 |
+
|
| 283 |
+
# Record in result history
|
| 284 |
+
history.append((response, f"play_action(action: {action})", observation[:100]))
|
| 285 |
+
|
| 286 |
+
# Check for game over
|
| 287 |
+
if self._is_game_over(observation):
|
| 288 |
+
if verbose:
|
| 289 |
+
print("\n*** GAME OVER ***")
|
| 290 |
+
break
|
| 291 |
+
|
| 292 |
+
success = self._is_game_over(observation) and self.score > 0
|
| 293 |
+
|
| 294 |
+
# Adding the current execution to memory for future episodes
|
| 295 |
+
self.cross_mem.add_episode(
|
| 296 |
+
game_history=self.history,
|
| 297 |
+
final_score=self.score,
|
| 298 |
+
success=success,
|
| 299 |
+
state=observation
|
| 300 |
+
)
|
| 301 |
|
| 302 |
return RunResult(
|
| 303 |
+
final_score=self.score,
|
| 304 |
+
max_score=350,
|
| 305 |
moves=moves,
|
| 306 |
locations_visited=locations_visited,
|
| 307 |
+
game_completed=self._is_game_over(observation),
|
| 308 |
history=history,
|
| 309 |
)
|
| 310 |
|
| 311 |
+
def update_scores_with_memory(self, current_state, options_with_confidences, k, r, memory_text, exploration_prob=0.4):
|
| 312 |
+
nearest_trajectories = self.cross_mem.retrieve_similar(
|
| 313 |
+
game_history=self.history,
|
| 314 |
+
current_state=current_state,
|
| 315 |
+
current_summary=memory_text,
|
| 316 |
+
k=k,
|
| 317 |
+
r=r
|
| 318 |
+
)
|
| 319 |
+
if not nearest_trajectories:
|
| 320 |
+
return {}
|
| 321 |
|
| 322 |
+
# Get existing actions
|
| 323 |
+
existing_actions = set()
|
| 324 |
+
for option_data in options_with_confidences.values():
|
| 325 |
+
if isinstance(option_data, dict) and 'action' in option_data:
|
| 326 |
+
existing_actions.add(option_data['action'])
|
| 327 |
+
|
| 328 |
+
# Aggregate action rewards from nearest_trajectories
|
| 329 |
+
action_rewards = {}
|
| 330 |
+
for result_dict in nearest_trajectories:
|
| 331 |
+
action = result_dict.get('action', '').strip()
|
| 332 |
+
discounted_reward = result_dict.get('discounted_reward', 0)
|
| 333 |
+
if action not in action_rewards:
|
| 334 |
+
action_rewards[action] = []
|
| 335 |
+
action_rewards[action].append(discounted_reward)
|
| 336 |
|
| 337 |
+
# Look for the action in the options suggested by the LLM
|
| 338 |
+
for action in action_rewards:
|
| 339 |
+
found = False
|
| 340 |
+
for option_data in options_with_confidences.values():
|
| 341 |
+
if isinstance(option_data, dict) and option_data.get('action', '') == action:
|
| 342 |
+
found = True
|
| 343 |
+
break
|
| 344 |
+
if not found:
|
| 345 |
+
rewards = action_rewards[action]
|
| 346 |
+
if len(rewards) >= 1 and sum(rewards) / len(rewards) > 0:
|
| 347 |
+
if options_with_confidences:
|
| 348 |
+
new_option_num = max(options_with_confidences.keys()) + 1
|
| 349 |
+
else:
|
| 350 |
+
new_option_num = 1
|
| 351 |
+
options_with_confidences[new_option_num] = {
|
| 352 |
+
'action': action,
|
| 353 |
+
'confidence': 0
|
| 354 |
+
}
|
| 355 |
|
| 356 |
+
# Calculate average discounted reward for each action Q(s, a)
|
| 357 |
+
action_avg_rewards = {}
|
| 358 |
+
for action, rewards in action_rewards.items():
|
| 359 |
+
if rewards: # Ensure not empty
|
| 360 |
+
action_avg_rewards[action] = sum(rewards) / len(rewards)
|
| 361 |
+
else:
|
| 362 |
+
action_avg_rewards[action] = 0
|
| 363 |
+
|
| 364 |
+
# Calculate overall average discounted reward (baseline) "state value"
|
| 365 |
+
count = 0
|
| 366 |
+
if action_rewards:
|
| 367 |
+
all_rewards = []
|
| 368 |
+
for rewards_list in action_rewards.values():
|
| 369 |
+
all_rewards.extend(rewards_list)
|
| 370 |
+
count = len(all_rewards)
|
| 371 |
+
overall_avg_reward = sum(all_rewards) / len(all_rewards) if all_rewards else 0
|
| 372 |
+
else:
|
| 373 |
+
overall_avg_reward = 0
|
| 374 |
+
|
| 375 |
+
# The second part of step 4 (The exploration bonus) : Completing the Q(s, a) for the actions that are not in memory
|
| 376 |
+
alpha = self.exploration_alpha
|
| 377 |
+
new_actions_added = False
|
| 378 |
+
for option_num, option_data in options_with_confidences.items():
|
| 379 |
+
action = option_data['action']
|
| 380 |
+
if action not in action_avg_rewards:
|
| 381 |
+
rad = random.random()
|
| 382 |
+
print(f"action: {action} rad:{rad}")
|
| 383 |
+
if rad < exploration_prob:
|
| 384 |
+
# UCB-style exploration bonus: decreases as sample count increases
|
| 385 |
+
if count > 0:
|
| 386 |
+
action_avg_rewards[action] = overall_avg_reward + alpha / count
|
| 387 |
+
else:
|
| 388 |
+
action_avg_rewards[action] = overall_avg_reward + alpha
|
| 389 |
+
else:
|
| 390 |
+
action_avg_rewards[action] = 0
|
| 391 |
+
new_actions_added = True
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
# If new actions were added, recalculate overall_avg_reward
|
| 395 |
+
if new_actions_added:
|
| 396 |
+
all_avg_rewards = list(action_avg_rewards.values())
|
| 397 |
+
overall_avg_reward = sum(all_avg_rewards) / len(all_avg_rewards) if all_avg_rewards else 0
|
| 398 |
+
|
| 399 |
+
# Calculate advantage value for each action
|
| 400 |
+
action_advantages = {}
|
| 401 |
+
for action, avg_reward in action_avg_rewards.items():
|
| 402 |
+
action_advantages[action] = avg_reward - overall_avg_reward
|
| 403 |
+
|
| 404 |
+
# Normalize advantage values
|
| 405 |
+
if action_advantages:
|
| 406 |
+
adv_values = list(action_advantages.values())
|
| 407 |
+
|
| 408 |
+
positive_advs = [adv for adv in adv_values if adv > 0]
|
| 409 |
+
|
| 410 |
+
if positive_advs:
|
| 411 |
+
max_positive = max(positive_advs)
|
| 412 |
+
normalized_advantages = {action: adv / max_positive for action, adv in action_advantages.items()}
|
| 413 |
+
else:
|
| 414 |
+
max_negative_abs = abs(min(adv_values))
|
| 415 |
+
if max_negative_abs > 0:
|
| 416 |
+
normalized_advantages = {action: adv / max_negative_abs for action, adv in action_advantages.items()}
|
| 417 |
+
else:
|
| 418 |
+
normalized_advantages = {action: 0 for action in action_advantages.keys()}
|
| 419 |
+
else:
|
| 420 |
+
normalized_advantages = {}
|
| 421 |
+
|
| 422 |
+
updated_options = {}
|
| 423 |
+
if options_with_confidences and normalized_advantages:
|
| 424 |
+
print(f"\n=== Action Advantage Analysis & Logit Correction ===")
|
| 425 |
+
print(f"Overall action average reward baseline: {overall_avg_reward:.4f}")
|
| 426 |
+
|
| 427 |
+
for option_num, option_data in options_with_confidences.items():
|
| 428 |
+
if isinstance(option_data, dict) and 'action' in option_data:
|
| 429 |
+
action = option_data['action']
|
| 430 |
+
|
| 431 |
+
# Get normalized advantage value
|
| 432 |
+
normalized_advantage = normalized_advantages.get(action, 0)
|
| 433 |
+
raw_advantage = action_advantages.get(action, 0)
|
| 434 |
+
avg_reward = action_avg_rewards.get(action, 0)
|
| 435 |
+
|
| 436 |
+
# Calculate episode-based weight for normalized advantage
|
| 437 |
+
# Weight increases from 1.0 (episode 1) to 1.5 (episode 50)
|
| 438 |
+
# Formula: weight = 1.0 + (current_episode / 50) * 0.5
|
| 439 |
+
# Clamped to max of 1.5 for episodes beyond 50
|
| 440 |
+
|
| 441 |
+
if self.cross_mem:
|
| 442 |
+
current_episode = self.cross_mem.current_episode_number
|
| 443 |
+
episode_weight = min(1.0 + (current_episode / 50.0) * 0.5, 1.5)
|
| 444 |
+
else:
|
| 445 |
+
episode_weight = 1.0
|
| 446 |
+
|
| 447 |
+
# Apply episode weight to normalized advantage
|
| 448 |
+
weighted_normalized_advantage = normalized_advantage * episode_weight
|
| 449 |
+
|
| 450 |
+
# Correct logit (add weighted normalized advantage value)
|
| 451 |
+
normalized_prob = option_data.get('confidence', 0)
|
| 452 |
+
corrected_logprob = normalized_prob + weighted_normalized_advantage
|
| 453 |
+
|
| 454 |
+
# Create corrected option data
|
| 455 |
+
updated_options[option_num] = {
|
| 456 |
+
'action': action,
|
| 457 |
+
'normalized_advantage': normalized_advantage,
|
| 458 |
+
'corrected_confidence': corrected_logprob,
|
| 459 |
+
'avg_reward': avg_reward,
|
| 460 |
+
'raw_advantage': raw_advantage,
|
| 461 |
+
'confidence': option_data.get('confidence', 0)
|
| 462 |
+
}
|
| 463 |
+
|
| 464 |
+
print(f" Action: {action}")
|
| 465 |
+
print(f" Average reward: {avg_reward:.4f} | Raw advantage: {raw_advantage:.4f}")
|
| 466 |
+
print(f" Normalized advantage: {normalized_advantage:.4f}")
|
| 467 |
+
print(f" Original logprob: {normalized_prob:.4f} -> Corrected: {corrected_logprob:.4f}")
|
| 468 |
+
else:
|
| 469 |
+
# If no matching advantage value, keep original value
|
| 470 |
+
updated_options[option_num] = option_data.copy()
|
| 471 |
+
updated_options[option_num]['corrected_confidence'] = option_data.get('confidence', 0)
|
| 472 |
+
updated_options[option_num]['normalized_advantage'] = 0
|
| 473 |
+
print("="*50)
|
| 474 |
+
return updated_options
|
| 475 |
+
|
| 476 |
+
else:
|
| 477 |
+
return {}
|
| 478 |
+
|
| 479 |
+
def _build_prompt(self, state: dict) -> tuple[str, str]:
|
| 480 |
+
"""Build the prompt for the LLM with context."""
|
| 481 |
+
|
| 482 |
+
valid_actions = state.get("valid_actions", []) or []
|
| 483 |
+
current_inventory_list = state.get("inventory", []) or []
|
| 484 |
+
current_inventory = ", ".join(current_inventory_list) if current_inventory_list else "(empty)"
|
| 485 |
+
|
| 486 |
+
if self.history:
|
| 487 |
+
summary, memory_text = self.generate_history_summary(self.history, state.get("observation", ""), temperature=0.0, max_tokens=1000, current_inventory=current_inventory)
|
| 488 |
+
else:
|
| 489 |
+
summary = ""
|
| 490 |
+
memory_text = "No game history."
|
| 491 |
+
# Add inventory even if no game history
|
| 492 |
+
if current_inventory:
|
| 493 |
+
memory_text += f"\n\n[INVENTORY]\n{current_inventory}"
|
| 494 |
+
sys_prompt = """You are an expert player aiming to complete a text-based adventure game. Points are given for making progress in the game. Select promising actions based on the game state and memory of past interactions.
|
| 495 |
+
|
| 496 |
+
**EXPLORATION PRIORITY**: When you arrive at a NEW location you haven't fully explored before, you MUST thoroughly explore it FIRST before leaving."""
|
| 497 |
+
if self.guiding_prompt:
|
| 498 |
+
sys_prompt += f"\n\nFollow this guide: {self.guiding_prompt}"
|
| 499 |
+
|
| 500 |
+
# Add reminder about action list order only when valid_actions are provided
|
| 501 |
+
if valid_actions:
|
| 502 |
+
sys_prompt += """\n\n**CRITICAL CONSTRAINT**: When REFERENCE ACTIONS are provided, you MUST ONLY choose actions from that list. Any action not in the REFERENCE ACTIONS list is INVALID and will fail. Do NOT create custom actions. The list is unordered - position doesn't indicate quality."""
|
| 503 |
+
|
| 504 |
+
recent_history = ""
|
| 505 |
+
if self.history:
|
| 506 |
+
# Only keep the last 10 steps
|
| 507 |
+
recent_game_history = self.history[-5:]
|
| 508 |
+
start_index = max(0, len(self.history)-5)
|
| 509 |
+
for idx, entry in enumerate(recent_game_history):
|
| 510 |
+
actual_step = start_index + idx
|
| 511 |
+
recent_history += f"Step {actual_step}:\n"
|
| 512 |
+
recent_history += f"State: {entry.get('state', '')}\n"
|
| 513 |
+
recent_history += f"Action: {entry.get('action', '')}\n"
|
| 514 |
+
if entry.get('reward') is not None:
|
| 515 |
+
recent_history += f"Reward: {entry.get('reward', 0)}\n"
|
| 516 |
+
recent_history += "\n"
|
| 517 |
+
|
| 518 |
+
# Add valid actions section if available
|
| 519 |
+
valid_actions_text = ""
|
| 520 |
+
if valid_actions:
|
| 521 |
+
valid_actions = valid_actions[:] # make a copy to avoid side effects
|
| 522 |
+
rng = random.Random(time.time_ns())
|
| 523 |
+
rng.shuffle(valid_actions)
|
| 524 |
+
print(f"Valid actions (shuffled): {valid_actions}")
|
| 525 |
+
valid_actions_text = f"\nREFERENCE ACTIONS (ONLY VALID ACTIONS):\n{valid_actions}\n\n**STRICT REQUIREMENT**: These are the ONLY valid actions for the current state. You MUST select your actions EXCLUSIVELY from this list. Any action not in this list is INVALID and will be rejected by the game. Do NOT create, modify, or suggest any custom actions.**\n"
|
| 526 |
+
|
| 527 |
+
# Build option fields with confidence
|
| 528 |
+
option_fields = []
|
| 529 |
+
for i in range(1, self.top_actions + 1):
|
| 530 |
+
option_fields.append(f' "reasoning{i}": "Why this action makes sense",')
|
| 531 |
+
option_fields.append(f' "option{i}": "action command",')
|
| 532 |
+
option_fields.append(f' "confidence{i}": 80,')
|
| 533 |
+
option_format = '\n'.join(option_fields)
|
| 534 |
+
|
| 535 |
+
user_prompt = f"""
|
| 536 |
+
GAME HISTORY:
|
| 537 |
+
{summary}
|
| 538 |
+
{memory_text}
|
| 539 |
+
|
| 540 |
+
RECENT STEPS:
|
| 541 |
+
{recent_history}
|
| 542 |
+
|
| 543 |
+
CURRENT STATE: {state.get("observation", "")}
|
| 544 |
+
|
| 545 |
+
TASK:
|
| 546 |
+
1. Analyze your progress: What have you achieved? What's your next objective?
|
| 547 |
+
2. **MANDATORY: Check the REFERENCE ACTIONS list below - you MUST ONLY select from this list**
|
| 548 |
+
3. Propose {self.top_actions} different actions with reasoning
|
| 549 |
+
4. For EACH action, provide your confidence as an integer from 0 to 100 (e.g., 80 means 80% confidence that this action will help achieve the goal)
|
| 550 |
+
|
| 551 |
+
RESPONSE FORMAT (JSON):
|
| 552 |
+
{{
|
| 553 |
+
"progress_analysis": "What you've achieved and current challenges",
|
| 554 |
+
"next_objective": "Your next goal",
|
| 555 |
+
{option_format}
|
| 556 |
+
"best_action": 1
|
| 557 |
+
}}
|
| 558 |
+
|
| 559 |
+
IMPORTANT:
|
| 560 |
+
- **ABSOLUTE REQUIREMENT**: ALL actions (option1, option2, etc.) MUST be selected EXACTLY from the REFERENCE ACTIONS list below
|
| 561 |
+
- Actions NOT in the REFERENCE ACTIONS list are INVALID and will cause the game to fail
|
| 562 |
+
- DO NOT create custom actions, DO NOT modify actions from the list, DO NOT combine actions
|
| 563 |
+
- **CRITICAL**: The sum of all confidence values (confidence1 + confidence2 + ... + confidence{self.top_actions}) MUST equal 100
|
| 564 |
+
- The confidence values MUST have meaningful differences between them
|
| 565 |
+
- Higher confidence means you believe this action is more likely to succeed
|
| 566 |
+
- Lower confidence means you believe this action is less likely to succeed
|
| 567 |
+
- Pay attention to game hints and clues in state descriptions
|
| 568 |
+
- Don't repeat failed actions or create loops (e.g., north→south→north)
|
| 569 |
+
|
| 570 |
+
{valid_actions_text}
|
| 571 |
+
"""
|
| 572 |
+
return user_prompt, sys_prompt, memory_text
|
| 573 |
+
|
| 574 |
+
def generate_history_summary(self, game_history, current_state=None, temperature=0.8, max_tokens=1000, current_inventory=None):
|
| 575 |
+
"""
|
| 576 |
+
Generate structured LLM-summarized history context.
|
| 577 |
+
Summarizes the past states and optionally includes the current state.
|
| 578 |
+
|
| 579 |
+
Args:
|
| 580 |
+
game_history: List of game history entries containing state, action, score, reward, etc.
|
| 581 |
+
current_state: The current state to include in the summary (optional)
|
| 582 |
+
llm_model: The LLM model to use for generation
|
| 583 |
+
temperature: Temperature for LLM generation
|
| 584 |
+
max_tokens: Maximum tokens for the summary
|
| 585 |
+
current_inventory: The current inventory to include in the summary (optional)
|
| 586 |
+
|
| 587 |
Returns:
|
| 588 |
+
tuple: (summary_text, structured_summary)
|
| 589 |
+
- summary_text: Natural language summary under [SUMMARY]
|
| 590 |
+
- structured_summary: Structured sections [PROGRESS], [LOCATION], [NEXT_OBJECTIVE], [INVENTORY]
|
| 591 |
"""
|
| 592 |
+
if not game_history and not current_state:
|
| 593 |
+
return "", ""
|
| 594 |
+
|
| 595 |
+
# Extract states and actions from game_history
|
| 596 |
+
states = [entry.get('state', '') for entry in game_history]
|
| 597 |
+
actions = [entry.get('action', '') for entry in game_history]
|
| 598 |
+
current_step = len(game_history) # Current step index
|
| 599 |
+
|
| 600 |
+
# Build trajectory text for all past steps
|
| 601 |
+
earlier_trajectory_text = ""
|
| 602 |
+
for i in range(current_step):
|
| 603 |
+
earlier_trajectory_text += f"Step {i}:\n"
|
| 604 |
+
if i < len(states):
|
| 605 |
+
earlier_trajectory_text += f"State: {states[i]}\n"
|
| 606 |
+
if i < len(actions) and actions[i]:
|
| 607 |
+
earlier_trajectory_text += f"Action: {actions[i]}\n"
|
| 608 |
+
earlier_trajectory_text += "\n"
|
| 609 |
+
|
| 610 |
+
# Add current state if provided
|
| 611 |
+
if current_state:
|
| 612 |
+
earlier_trajectory_text += f"Step {current_step} (Current):\n"
|
| 613 |
+
earlier_trajectory_text += f"State: {current_state}\n"
|
| 614 |
+
|
| 615 |
+
sys_prompt = """You are an expert at analyzing game trajectories and creating highly distinctive summaries.
|
| 616 |
+
Your milestones must be CONCISE and DISTINCTIVE - use specific keywords that clearly differentiate different game states.
|
| 617 |
+
|
| 618 |
+
Output Format Requirements:
|
| 619 |
+
[SUMMARY]: Provide a natural language summary that describes:
|
| 620 |
+
- The game's objective or goal (inferred from the trajectory)
|
| 621 |
+
- Current progress toward that goal
|
| 622 |
+
- Key accomplishments so far
|
| 623 |
+
- Current situation/status
|
| 624 |
+
|
| 625 |
+
[PROGRESS]: List milestones in format "✓ M#: <action>→<key object/result>"
|
| 626 |
+
- If NO steps have score increases, output ONLY "No Progress" (no milestones)
|
| 627 |
+
|
| 628 |
+
[LOCATION]: Track location changes in format "Location: A→B→C"
|
| 629 |
+
- Extract location names from state descriptions
|
| 630 |
+
- Only record when location actually changes
|
| 631 |
+
- CRITICAL: IGNORE unproductive loops - if the player goes back and forth between locations WITHOUT score increases, skip those redundant location movements
|
| 632 |
+
- Focus on the MEANINGFUL location trajectory that led to progress or new discoveries
|
| 633 |
+
- Example: If player went "room A→room B→room A→room B→room A→room C" with no score increase for the A↔B movements, record as "room A→room C"
|
| 634 |
+
|
| 635 |
+
Milestone Format Rules:
|
| 636 |
+
- Use format: <verb>→<critical object/state>
|
| 637 |
+
- Focus on STATE CHANGES, not descriptions
|
| 638 |
+
- Use specific nouns not generic terms
|
| 639 |
+
- Each milestone should capture ONE concrete action or discovery
|
| 640 |
+
- CRITICAL: Only record steps where the score increased compared to the previous step
|
| 641 |
+
|
| 642 |
+
Examples:
|
| 643 |
+
BAD SUMMARY: "The player has been exploring."
|
| 644 |
+
GOOD SUMMARY: "The game's objective is to escape the haunted mansion. So far, the player has found a key in the library and unlocked the basement door, earning 15 points. Currently in the dark basement, the player needs to find a light source to continue exploring."
|
| 645 |
+
|
| 646 |
+
BAD: "✓ Milestone 1: Entered the library and proceeded to the ground floor stacks"
|
| 647 |
+
GOOD: "✓ M1: enter→library ground floor"
|
| 648 |
+
|
| 649 |
+
Location Examples:
|
| 650 |
+
BAD: "Location: You are in a library→You moved to another area→You are somewhere else"
|
| 651 |
+
GOOD: "Location: entrance→library→north corridor→atrium"
|
| 652 |
+
BAD: "Location: room A→room B→room A→room B→room A" (unproductive loop)
|
| 653 |
+
GOOD: "Location: room A" (ignore the meaningless back-and-forth)
|
| 654 |
+
"""
|
| 655 |
+
|
| 656 |
+
user_prompt = f"""Analyze the game trajectory and generate a natural language summary followed by structured sections.
|
| 657 |
+
|
| 658 |
+
Game History:
|
| 659 |
+
{earlier_trajectory_text}
|
| 660 |
+
|
| 661 |
+
Generate summary in this format:
|
| 662 |
+
|
| 663 |
+
[SUMMARY]
|
| 664 |
+
<Natural language paragraph (2-4 sentences) describing: game objective, current progress, key accomplishments, current situation>
|
| 665 |
+
|
| 666 |
+
[PROGRESS]
|
| 667 |
+
✓ M1: <verb>→<result>
|
| 668 |
+
✓ M2: <verb>→<result>
|
| 669 |
+
...
|
| 670 |
+
|
| 671 |
+
[LOCATION]
|
| 672 |
+
Location: A→B→C→...
|
| 673 |
+
|
| 674 |
+
Requirements:
|
| 675 |
+
- SUMMARY: Write a natural language paragraph explaining the game's goal and current status
|
| 676 |
+
- Use specific object names
|
| 677 |
+
- Focus on concrete state changes and discoveries
|
| 678 |
+
- Avoid repetitive phrasing between milestones
|
| 679 |
+
- CRITICAL: ONLY record progress where the score increased from the previous step
|
| 680 |
+
- IMPORTANT: If there are NO steps with score increases, output ONLY "No Progress" under [PROGRESS] section (do not create any milestone items)
|
| 681 |
+
- For locations: Extract actual location names from state descriptions and track when they change
|
| 682 |
+
- Location format must be: "Location: " followed by location names separated by →
|
| 683 |
+
- CRITICAL FOR LOCATION: Detect and REMOVE unproductive loops from the location trajectory
|
| 684 |
+
* Look at the score/reward at each step
|
| 685 |
+
* If the player moves back and forth between locations WITHOUT any score increase, those movements are loops - SKIP them
|
| 686 |
+
* Only include location changes that either: (1) led to score increases, OR (2) represent meaningful forward exploration
|
| 687 |
+
* Example: If steps show "A→B (no reward)→A (no reward)→B (no reward)→A (no reward)→C (reward +5)", output "Location: A→C"
|
| 688 |
+
|
| 689 |
+
Generate the complete summary:"""
|
| 690 |
+
|
| 691 |
+
full_response = call_llm(user_prompt, sys_prompt, temperature=temperature, max_tokens=max_tokens).strip()
|
| 692 |
+
|
| 693 |
+
if full_response == "":
|
| 694 |
+
print("[Warning] No response from LLM for history summary generation")
|
| 695 |
+
|
| 696 |
+
# Extract [SUMMARY] section
|
| 697 |
+
summary_text = ""
|
| 698 |
+
if "[SUMMARY]" in full_response:
|
| 699 |
+
summary_start = full_response.find("[SUMMARY]") + len("[SUMMARY]")
|
| 700 |
+
# Find the next section marker
|
| 701 |
+
next_section = full_response.find("[PROGRESS]", summary_start)
|
| 702 |
+
if next_section == -1:
|
| 703 |
+
next_section = full_response.find("[LOCATION]", summary_start)
|
| 704 |
+
|
| 705 |
+
if next_section != -1:
|
| 706 |
+
summary_text = full_response[summary_start:next_section].strip()
|
| 707 |
+
else:
|
| 708 |
+
summary_text = full_response[summary_start:].strip()
|
| 709 |
+
|
| 710 |
+
# Extract structured sections (everything except [SUMMARY])
|
| 711 |
+
structured_summary = full_response
|
| 712 |
+
if "[SUMMARY]" in full_response:
|
| 713 |
+
# Remove [SUMMARY] section from structured summary
|
| 714 |
+
progress_start = full_response.find("[PROGRESS]")
|
| 715 |
+
if progress_start != -1:
|
| 716 |
+
structured_summary = full_response[progress_start:]
|
| 717 |
+
|
| 718 |
+
# Add inventory section to structured summary if available
|
| 719 |
+
if current_inventory:
|
| 720 |
+
structured_summary += f"\n\n[INVENTORY]\n{current_inventory}"
|
| 721 |
+
|
| 722 |
+
return summary_text, structured_summary
|
| 723 |
+
|
| 724 |
+
def _get_response_format(self) -> dict:
|
| 725 |
+
# Dynamically create JSON Schema to force final choice to be valid option numbers only
|
| 726 |
+
properties = {}
|
| 727 |
+
required_fields = []
|
| 728 |
+
valid_choices = []
|
| 729 |
+
|
| 730 |
+
# Add progress analysis fields
|
| 731 |
+
properties["progress_analysis"] = {
|
| 732 |
+
"type": "string",
|
| 733 |
+
"description": "Analysis of achievements and progress so far"
|
| 734 |
+
}
|
| 735 |
+
properties["next_objective"] = {
|
| 736 |
+
"type": "string",
|
| 737 |
+
"description": "Overall objective for the next steps"
|
| 738 |
+
}
|
| 739 |
+
required_fields.extend(["progress_analysis", "next_objective"])
|
| 740 |
+
|
| 741 |
+
# Dynamically generate reasoning and option fields
|
| 742 |
+
for i in range(1, self.top_actions + 1):
|
| 743 |
+
reasoning_key = f"reasoning{i}"
|
| 744 |
+
option_key = f"option{i}"
|
| 745 |
+
|
| 746 |
+
properties[reasoning_key] = {
|
| 747 |
+
"type": "string",
|
| 748 |
+
"description": f"Reasoning for option {i}"
|
| 749 |
+
}
|
| 750 |
+
properties[option_key] = {
|
| 751 |
+
"type": "string",
|
| 752 |
+
"description": f"Option {i} possible action"
|
| 753 |
+
}
|
| 754 |
+
required_fields.append(reasoning_key)
|
| 755 |
+
required_fields.append(option_key)
|
| 756 |
+
valid_choices.append(i)
|
| 757 |
+
|
| 758 |
+
confidence_key = f"confidence{i}"
|
| 759 |
+
properties[confidence_key] = {
|
| 760 |
+
"type": "integer",
|
| 761 |
+
"minimum": 0,
|
| 762 |
+
"maximum": 100,
|
| 763 |
+
"description": f"Confidence for option {i} (0-100)"
|
| 764 |
+
}
|
| 765 |
+
required_fields.append(confidence_key)
|
| 766 |
+
|
| 767 |
+
# Add best_action field
|
| 768 |
+
properties["best_action"] = {
|
| 769 |
+
"type": "number",
|
| 770 |
+
"minimum": 1,
|
| 771 |
+
"maximum": self.top_actions,
|
| 772 |
+
"description": f"The number of the best option (must be one of: {valid_choices})"
|
| 773 |
+
}
|
| 774 |
+
required_fields.extend([ "best_action"])
|
| 775 |
+
|
| 776 |
+
response_format = {
|
| 777 |
+
"type": "json_schema",
|
| 778 |
+
"json_schema": {
|
| 779 |
+
"name": "game_action_choice",
|
| 780 |
+
"strict": True,
|
| 781 |
+
"schema": {
|
| 782 |
+
"type": "object",
|
| 783 |
+
"properties": properties,
|
| 784 |
+
"required": required_fields,
|
| 785 |
+
"additionalProperties": False
|
| 786 |
+
}
|
| 787 |
+
}
|
| 788 |
+
}
|
| 789 |
+
|
| 790 |
+
return response_format
|
| 791 |
|
| 792 |
+
def _parse_response(self, response: str) -> tuple[str, dict]:
|
| 793 |
+
"""Parse the LLM response to extract thought, tool, and arguments."""
|
| 794 |
+
# Parse JSON response
|
| 795 |
+
# Clean up markdown code blocks if present
|
| 796 |
+
cleaned_response = response.strip()
|
| 797 |
+
if cleaned_response.startswith('```'):
|
| 798 |
+
# Remove ```json or ``` at the start and ``` at the end
|
| 799 |
+
cleaned_response = re.sub(r'^```(?:json)?\s*\n?', '', cleaned_response)
|
| 800 |
+
cleaned_response = re.sub(r'\n?```\s*$', '', cleaned_response)
|
| 801 |
+
|
| 802 |
+
try:
|
| 803 |
+
json_response = json.loads(cleaned_response)
|
| 804 |
+
except json.JSONDecodeError as e:
|
| 805 |
+
print(f"JSON parsing error: {e}")
|
| 806 |
+
print(f"Raw response: {response}")
|
| 807 |
+
return "look", {}
|
| 808 |
+
|
| 809 |
+
# Extract and print progress analysis
|
| 810 |
+
progress_analysis = json_response.get("progress_analysis", "No analysis provided")
|
| 811 |
+
next_objective = json_response.get("next_objective", "No objective specified")
|
| 812 |
+
|
| 813 |
+
print("\n" + "="*50)
|
| 814 |
+
print("=== PROGRESS ANALYSIS ===")
|
| 815 |
+
print("="*50)
|
| 816 |
+
print(f"Achievements: {progress_analysis}")
|
| 817 |
+
print(f"\nNext Objective: {next_objective}")
|
| 818 |
+
print("="*50 + "\n")
|
| 819 |
+
|
| 820 |
+
# Extract all options with their reasoning and confidence
|
| 821 |
+
options_with_confidences = {}
|
| 822 |
|
| 823 |
+
for i in range(1, self.top_actions + 1):
|
| 824 |
+
reasoning_key = f"reasoning{i}"
|
| 825 |
+
option_key = f"option{i}"
|
| 826 |
+
confidence_key = f"confidence{i}"
|
| 827 |
+
|
| 828 |
+
reasoning = json_response.get(reasoning_key, "No reasoning provided")
|
| 829 |
+
option = json_response.get(option_key, "look")
|
| 830 |
+
confidence = json_response.get(confidence_key, 50)
|
| 831 |
+
|
| 832 |
+
options_with_confidences[i] = {
|
| 833 |
+
'action': option.strip(),
|
| 834 |
+
'reasoning': reasoning,
|
| 835 |
+
'confidence': confidence / 100.0
|
| 836 |
+
}
|
| 837 |
+
|
| 838 |
+
# Get the best action
|
| 839 |
+
best_choice = json_response.get("best_action", 1)
|
| 840 |
+
action_text = options_with_confidences.get(best_choice, {}).get('action', "look")
|
| 841 |
+
|
| 842 |
+
return action_text, options_with_confidences
|
| 843 |
+
|
| 844 |
+
def _extract_result(self, result) -> str:
|
| 845 |
+
"""Extract text from MCP tool result."""
|
| 846 |
+
if hasattr(result, 'content') and result.content:
|
| 847 |
+
return result.content[0].text
|
| 848 |
+
if isinstance(result, list) and result:
|
| 849 |
+
return result[0].text if hasattr(result[0], 'text') else str(result[0])
|
| 850 |
+
return str(result)
|
| 851 |
+
|
| 852 |
+
def _update_score(self, state: dict) -> None:
|
| 853 |
+
if not state:
|
| 854 |
+
return
|
| 855 |
+
if "score" in state:
|
| 856 |
+
self.score = int(state["score"])
|
| 857 |
+
|
| 858 |
+
def _is_game_over(self, text: str) -> bool:
|
| 859 |
+
"""Check if the game is over."""
|
| 860 |
+
game_over_phrases = [
|
| 861 |
+
"game over",
|
| 862 |
+
"you have died",
|
| 863 |
+
"you are dead",
|
| 864 |
+
"*** you have died ***",
|
| 865 |
+
]
|
| 866 |
+
text_lower = text.lower()
|
| 867 |
+
return any(phrase in text_lower for phrase in game_over_phrases)
|
| 868 |
|
| 869 |
|
| 870 |
# =============================================================================
|
cross_episode_memory.py
ADDED
|
@@ -0,0 +1,442 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Cross-Episode Memory for Jericho Agent with Vector Similarity Search
|
| 3 |
+
Uses state embeddings to find similar past experiences and update action selection.
|
| 4 |
+
|
| 5 |
+
Installation:
|
| 6 |
+
pip install sentence-transformers faiss-cpu numpy --break-system-packages
|
| 7 |
+
|
| 8 |
+
Alternative models:
|
| 9 |
+
- 'all-MiniLM-L6-v2' (default, 384 dims, fast)
|
| 10 |
+
- 'all-mpnet-base-v2' (768 dims, better quality)
|
| 11 |
+
- 'paraphrase-multilingual-MiniLM-L12-v2' (multilingual)
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import os
|
| 15 |
+
import json
|
| 16 |
+
import pickle
|
| 17 |
+
from typing import List, Dict, Any, Optional, Tuple
|
| 18 |
+
from collections import defaultdict, Counter
|
| 19 |
+
import numpy as np
|
| 20 |
+
from dotenv import load_dotenv
|
| 21 |
+
from sentence_transformers import SentenceTransformer
|
| 22 |
+
from huggingface_hub import InferenceClient
|
| 23 |
+
import faiss
|
| 24 |
+
import traceback
|
| 25 |
+
from openai import OpenAI
|
| 26 |
+
|
| 27 |
+
load_dotenv()
|
| 28 |
+
|
| 29 |
+
# Set USE_LOCAL_MODEL=1 in your .env to use a locally downloaded model
|
| 30 |
+
USE_LOCAL_MODEL = os.getenv("USE_LOCAL_MODEL", "0").strip() in ("1", "true", "yes")
|
| 31 |
+
LOCAL_MODEL_ID = os.getenv("LOCAL_MODEL_ID", "Qwen/Qwen2.5-3B-Instruct")
|
| 32 |
+
|
| 33 |
+
# Initialize the LLM client based on mode
|
| 34 |
+
_local_pipeline = None
|
| 35 |
+
|
| 36 |
+
if USE_LOCAL_MODEL:
|
| 37 |
+
import torch
|
| 38 |
+
from transformers import pipeline as _hf_pipeline
|
| 39 |
+
|
| 40 |
+
_local_pipeline = _hf_pipeline(
|
| 41 |
+
"text-generation",
|
| 42 |
+
model=LOCAL_MODEL_ID,
|
| 43 |
+
torch_dtype=torch.bfloat16,
|
| 44 |
+
device_map="auto",
|
| 45 |
+
)
|
| 46 |
+
LLM_CLIENT = None
|
| 47 |
+
else:
|
| 48 |
+
_hf_token = os.getenv("HF_TOKEN")
|
| 49 |
+
if not _hf_token:
|
| 50 |
+
raise ValueError("HF_TOKEN not found. Set it in your .env file.")
|
| 51 |
+
LLM_CLIENT = InferenceClient(token=_hf_token)
|
| 52 |
+
|
| 53 |
+
class CrossEpisodeMemory:
|
| 54 |
+
def __init__(
|
| 55 |
+
self,
|
| 56 |
+
base_dir: str,
|
| 57 |
+
eval_llm_model: str = "Qwen/Qwen2.5-72B-Instruct",
|
| 58 |
+
gamma: float = 0.95,
|
| 59 |
+
embedding_model: str = 'all-MiniLM-L6-v2',
|
| 60 |
+
vector_dim: int = 384
|
| 61 |
+
):
|
| 62 |
+
self.base_dir = base_dir
|
| 63 |
+
self.gamma = gamma
|
| 64 |
+
self.vector_dim = vector_dim
|
| 65 |
+
os.makedirs(self.base_dir, exist_ok=True)
|
| 66 |
+
|
| 67 |
+
# File paths
|
| 68 |
+
self.episodes_path = os.path.join(self.base_dir, 'episodes.jsonl')
|
| 69 |
+
self.vector_index_path = os.path.join(self.base_dir, 'state_vectors.index')
|
| 70 |
+
self.step_metadata_path = os.path.join(self.base_dir, 'step_metadata.pkl')
|
| 71 |
+
|
| 72 |
+
self.eval_llm_model = eval_llm_model
|
| 73 |
+
|
| 74 |
+
self.query_vector_cache = []
|
| 75 |
+
|
| 76 |
+
# Episode counter
|
| 77 |
+
self.current_episode_number = self._get_episode_count()
|
| 78 |
+
|
| 79 |
+
self.encoder = SentenceTransformer(embedding_model)
|
| 80 |
+
print(f"✓ Loaded embedding model: {embedding_model} (dim: {vector_dim})")
|
| 81 |
+
# Initialize FAISS vector database
|
| 82 |
+
self._init_vector_database()
|
| 83 |
+
|
| 84 |
+
def _init_vector_database(self):
|
| 85 |
+
# Load or create FAISS index
|
| 86 |
+
if os.path.exists(self.vector_index_path):
|
| 87 |
+
self.vector_index = faiss.read_index(self.vector_index_path)
|
| 88 |
+
print(f"✓ Loaded vector index: {self.vector_index.ntotal} vectors")
|
| 89 |
+
else:
|
| 90 |
+
# Use Inner Product for cosine similarity (with normalized vectors)
|
| 91 |
+
self.vector_index = faiss.IndexFlatIP(self.vector_dim)
|
| 92 |
+
print(f"✓ Created new vector index")
|
| 93 |
+
|
| 94 |
+
# Load metadata
|
| 95 |
+
if os.path.exists(self.step_metadata_path):
|
| 96 |
+
with open(self.step_metadata_path, 'rb') as f:
|
| 97 |
+
self.step_metadata = pickle.load(f)
|
| 98 |
+
print(f"✓ Loaded metadata: {len(self.step_metadata)} entries")
|
| 99 |
+
else:
|
| 100 |
+
self.step_metadata = []
|
| 101 |
+
|
| 102 |
+
def _save_vector_database(self):
|
| 103 |
+
"""Save FAISS index and metadata to disk."""
|
| 104 |
+
if self.vector_index is not None:
|
| 105 |
+
faiss.write_index(self.vector_index, self.vector_index_path)
|
| 106 |
+
if self.step_metadata:
|
| 107 |
+
with open(self.step_metadata_path, 'wb') as f:
|
| 108 |
+
pickle.dump(self.step_metadata, f)
|
| 109 |
+
|
| 110 |
+
def _get_episode_count(self) -> int:
|
| 111 |
+
"""Count existing episodes."""
|
| 112 |
+
if not os.path.exists(self.episodes_path):
|
| 113 |
+
return 0
|
| 114 |
+
with open(self.episodes_path, 'r') as f:
|
| 115 |
+
return sum(1 for _ in f)
|
| 116 |
+
|
| 117 |
+
def _encode_state(self, state_text: str, history_summary: str = ""):
|
| 118 |
+
# Combine current state with history context
|
| 119 |
+
if history_summary:
|
| 120 |
+
combined_text = f"History: {history_summary}\n\nCurrent: {state_text}"
|
| 121 |
+
else:
|
| 122 |
+
combined_text = f"Current: {state_text}"
|
| 123 |
+
|
| 124 |
+
try:
|
| 125 |
+
# Encode and normalize for cosine similarity
|
| 126 |
+
embedding = self.encoder.encode([combined_text])[0]
|
| 127 |
+
embedding = embedding / (np.linalg.norm(embedding) + 1e-8)
|
| 128 |
+
return embedding.astype('float32')
|
| 129 |
+
except Exception as e:
|
| 130 |
+
print(f"Encoding error: {e}")
|
| 131 |
+
return None
|
| 132 |
+
|
| 133 |
+
# def _encode_state(self, state_text: str, location: str = ""):
|
| 134 |
+
# # Combine current state with history context
|
| 135 |
+
# if location:
|
| 136 |
+
# combined_text = f"location: {location}\n\nCurrent: {state_text}"
|
| 137 |
+
# else:
|
| 138 |
+
# combined_text = f"Current: {state_text}"
|
| 139 |
+
|
| 140 |
+
# try:
|
| 141 |
+
# # Encode and normalize for cosine similarity
|
| 142 |
+
# embedding = self.encoder.encode([combined_text])[0]
|
| 143 |
+
# embedding = embedding / (np.linalg.norm(embedding) + 1e-8)
|
| 144 |
+
# return embedding.astype('float32')
|
| 145 |
+
# except Exception as e:
|
| 146 |
+
# print(f"Encoding error: {e}")
|
| 147 |
+
# return None
|
| 148 |
+
|
| 149 |
+
def add_episode(
|
| 150 |
+
self,
|
| 151 |
+
game_history: List[Dict[str, Any]],
|
| 152 |
+
final_score: float,
|
| 153 |
+
state: str,
|
| 154 |
+
success: bool = False
|
| 155 |
+
):
|
| 156 |
+
self.current_episode_number += 1
|
| 157 |
+
|
| 158 |
+
# Save episode metadata
|
| 159 |
+
episode_data = {
|
| 160 |
+
'episode_number': self.current_episode_number,
|
| 161 |
+
'final_score': final_score,
|
| 162 |
+
'success': success,
|
| 163 |
+
'steps': len(game_history),
|
| 164 |
+
'history': game_history
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
with open(self.episodes_path, 'a') as f:
|
| 168 |
+
f.write(json.dumps(episode_data) + '\n')
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
llm_step_scores, step_reasonings = self.evaluate_step_scores_with_llm(
|
| 172 |
+
game_history=game_history,
|
| 173 |
+
state=state,
|
| 174 |
+
final_score=final_score,
|
| 175 |
+
success=success,
|
| 176 |
+
temperature=0.8
|
| 177 |
+
)
|
| 178 |
+
try:
|
| 179 |
+
assert len(llm_step_scores) == len(game_history)
|
| 180 |
+
except Exception:
|
| 181 |
+
breakpoint()
|
| 182 |
+
|
| 183 |
+
self._add_steps_to_vector_db(game_history, final_score, llm_step_scores)
|
| 184 |
+
|
| 185 |
+
# Save vector database
|
| 186 |
+
self._save_vector_database()
|
| 187 |
+
|
| 188 |
+
def _add_steps_to_vector_db(self, game_history: List[Dict], final_score: float, llm_step_scores):
|
| 189 |
+
n = len(llm_step_scores)
|
| 190 |
+
|
| 191 |
+
for i, step in enumerate(game_history):
|
| 192 |
+
# Get current state
|
| 193 |
+
state_text = step.get('state', '')
|
| 194 |
+
action = step.get('action', '')
|
| 195 |
+
|
| 196 |
+
# Encode state with history
|
| 197 |
+
vector = self.query_vector_cache[i]
|
| 198 |
+
if vector is None:
|
| 199 |
+
continue
|
| 200 |
+
|
| 201 |
+
# Calculate discounted future reward from this step
|
| 202 |
+
discounted_reward = 0.0
|
| 203 |
+
for j in range(i, n):
|
| 204 |
+
if step.get('reward', 0) > 0:
|
| 205 |
+
reward = max(llm_step_scores[j], step.get('reward', 0))
|
| 206 |
+
else:
|
| 207 |
+
reward = llm_step_scores[j]
|
| 208 |
+
discounted_reward += (self.gamma ** (j - i)) * reward
|
| 209 |
+
|
| 210 |
+
# Store metadata for this step
|
| 211 |
+
metadata = {
|
| 212 |
+
'episode_number': self.current_episode_number,
|
| 213 |
+
'step_index': i,
|
| 214 |
+
'state': state_text,
|
| 215 |
+
'action': action,
|
| 216 |
+
'reward': step.get('reward', 0),
|
| 217 |
+
'score': step.get('score', 0),
|
| 218 |
+
'llm_step_score': llm_step_scores[i] if i < len(llm_step_scores) else 0,
|
| 219 |
+
'discounted_reward': discounted_reward,
|
| 220 |
+
'episode_final_score': final_score
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
# Add to vector index and metadata list
|
| 224 |
+
self.vector_index.add(np.array([vector]))
|
| 225 |
+
self.step_metadata.append(metadata)
|
| 226 |
+
|
| 227 |
+
def retrieve_similar(
|
| 228 |
+
self,
|
| 229 |
+
current_state: str,
|
| 230 |
+
game_history: List[Dict],
|
| 231 |
+
current_summary: str,
|
| 232 |
+
k: int = 5,
|
| 233 |
+
r: float = 0.9
|
| 234 |
+
):
|
| 235 |
+
# Encode current state with history
|
| 236 |
+
query_vector = self._encode_state(current_state, current_summary)
|
| 237 |
+
|
| 238 |
+
if query_vector is None:
|
| 239 |
+
return []
|
| 240 |
+
|
| 241 |
+
self.query_vector_cache.append(query_vector)
|
| 242 |
+
|
| 243 |
+
if self.vector_index is None:
|
| 244 |
+
print("Dual vector database not available")
|
| 245 |
+
return []
|
| 246 |
+
if self.vector_index.ntotal == 0:
|
| 247 |
+
print("Dual vector database is empty")
|
| 248 |
+
return []
|
| 249 |
+
|
| 250 |
+
# Phase 1: Recall top candidates using vector similarity (fast)
|
| 251 |
+
# Use 10x more candidates than final k to ensure good coverage
|
| 252 |
+
recall_size = min(max(k * 10, 100), self.vector_index.ntotal)
|
| 253 |
+
|
| 254 |
+
scores, indices = self.vector_index.search(query_vector.reshape(1, -1), recall_size)
|
| 255 |
+
|
| 256 |
+
high_similarity_trajectories = [] # similarity > r
|
| 257 |
+
medium_similarity_trajectories = [] # similarity <= r
|
| 258 |
+
|
| 259 |
+
for score, idx in zip(scores[0], indices[0]):
|
| 260 |
+
if idx < 0 or idx >= len(self.step_metadata):
|
| 261 |
+
continue
|
| 262 |
+
|
| 263 |
+
metadata = self.step_metadata[idx]
|
| 264 |
+
|
| 265 |
+
if score > r:
|
| 266 |
+
high_similarity_trajectories.append((score, metadata["discounted_reward"], metadata))
|
| 267 |
+
else:
|
| 268 |
+
medium_similarity_trajectories.append((score, metadata["discounted_reward"], metadata))
|
| 269 |
+
|
| 270 |
+
# Sort both lists by similarity first, then by discounted reward
|
| 271 |
+
high_similarity_trajectories.sort(key=lambda x: (x[0], x[1]), reverse=True)
|
| 272 |
+
medium_similarity_trajectories.sort(key=lambda x: (x[0], x[1]), reverse=True)
|
| 273 |
+
|
| 274 |
+
# Combine results: first all high similarity (>0.98), then top k from medium similarity if needed
|
| 275 |
+
filtered_trajectories = high_similarity_trajectories[:]
|
| 276 |
+
if len(filtered_trajectories) < k:
|
| 277 |
+
remaining_needed = k - len(filtered_trajectories)
|
| 278 |
+
filtered_trajectories.extend(medium_similarity_trajectories[:remaining_needed])
|
| 279 |
+
|
| 280 |
+
return [x[2] for x in filtered_trajectories[:k]]
|
| 281 |
+
|
| 282 |
+
def evaluate_step_scores_with_llm(self, game_history, state, final_score, success, temperature=0.8):
|
| 283 |
+
"""
|
| 284 |
+
Use LLM to assign scores for each step in the trajectory.
|
| 285 |
+
|
| 286 |
+
Args:
|
| 287 |
+
game_history: List of game history records
|
| 288 |
+
final_score: Final score
|
| 289 |
+
success: Whether successful
|
| 290 |
+
llm_model: LLM model to use
|
| 291 |
+
temperature: LLM temperature parameter
|
| 292 |
+
|
| 293 |
+
Returns:
|
| 294 |
+
list: Score list for each step
|
| 295 |
+
"""
|
| 296 |
+
if not game_history:
|
| 297 |
+
return []
|
| 298 |
+
|
| 299 |
+
trajectory_text = ""
|
| 300 |
+
for i, entry in enumerate(game_history):
|
| 301 |
+
trajectory_text += f"Step {i}:\n"
|
| 302 |
+
trajectory_text += f"Current state: {entry.get('old_state', '')}\n"
|
| 303 |
+
if entry.get('action'):
|
| 304 |
+
trajectory_text += f"Action: {entry.get('action')}\n"
|
| 305 |
+
trajectory_text += f"Result state: {entry.get('state', '')}\n"
|
| 306 |
+
score = entry.get('score', 0)
|
| 307 |
+
reward = entry.get('reward', 0)
|
| 308 |
+
trajectory_text += f"Action reward: {reward}\n"
|
| 309 |
+
trajectory_text += "\n"
|
| 310 |
+
trajectory_text += f"Step: {len(game_history)}:\nState: {state}\n"
|
| 311 |
+
sys_prompt = """You are scoring game actions to build training data for future gameplay.
|
| 312 |
+
|
| 313 |
+
PURPOSE: Rate each action based on its overall impact - positive scores for actions worth repeating, negative for actions to avoid.
|
| 314 |
+
|
| 315 |
+
SCORING RULES:
|
| 316 |
+
- Positive: Action led to progress, a reward, useful discoveries or new locations discovered
|
| 317 |
+
- Negative: Action wasted time, caused loops, or had no benefit
|
| 318 |
+
- Magnitude: Match the game's typical reward scale (calibrate based on rewards shown in trajectory)
|
| 319 |
+
- Evaluation: Judge by full consequence chain, not just immediate result
|
| 320 |
+
"""
|
| 321 |
+
|
| 322 |
+
user_prompt = f"""Score each action in this game session.
|
| 323 |
+
|
| 324 |
+
Final Result: {"SUCCESS" if success else "FAILURE"}, Final Score: {final_score}
|
| 325 |
+
|
| 326 |
+
Trajectory:
|
| 327 |
+
{trajectory_text}
|
| 328 |
+
|
| 329 |
+
JSON FORMAT:
|
| 330 |
+
{{
|
| 331 |
+
"step_analysis": [
|
| 332 |
+
{{
|
| 333 |
+
"step": 0,
|
| 334 |
+
"action": "exact action taken",
|
| 335 |
+
"detailed_reasoning": "What happened after this action and its consequences? Why should this be repeated or avoided?",
|
| 336 |
+
"score": 5,
|
| 337 |
+
}}
|
| 338 |
+
],
|
| 339 |
+
}}
|
| 340 |
+
|
| 341 |
+
Provide complete JSON response:"""
|
| 342 |
+
|
| 343 |
+
try:
|
| 344 |
+
messages = [
|
| 345 |
+
{"role": "system", "content": sys_prompt},
|
| 346 |
+
{"role": "user", "content": user_prompt},
|
| 347 |
+
]
|
| 348 |
+
|
| 349 |
+
if USE_LOCAL_MODEL and _local_pipeline is not None:
|
| 350 |
+
outputs = _local_pipeline(
|
| 351 |
+
messages,
|
| 352 |
+
max_new_tokens=max_tokens,
|
| 353 |
+
temperature=temperature,
|
| 354 |
+
do_sample=True,
|
| 355 |
+
)
|
| 356 |
+
scores_text = outputs[0]["generated_text"][-1]["content"]
|
| 357 |
+
|
| 358 |
+
response = LLM_CLIENT.chat.completions.create(
|
| 359 |
+
model=self.eval_llm_model,
|
| 360 |
+
messages=messages,
|
| 361 |
+
temperature=temperature,
|
| 362 |
+
max_tokens=8192,
|
| 363 |
+
)
|
| 364 |
+
scores_text = response.choices[0].message.content.strip()
|
| 365 |
+
except Exception as e:
|
| 366 |
+
print(getattr(e, 'response', None) and e.response.text)
|
| 367 |
+
print(traceback.format_exc())
|
| 368 |
+
print(f"[Error] LLM call failed: {e}")
|
| 369 |
+
print("[Warning] Falling back to reward-based scoring")
|
| 370 |
+
return [entry.get('reward', 0) for entry in game_history], ["No reasoning available" for _ in game_history]
|
| 371 |
+
|
| 372 |
+
if scores_text:
|
| 373 |
+
if scores_text.startswith('```json'):
|
| 374 |
+
scores_text = scores_text.replace('```json', '').replace('```', '').strip()
|
| 375 |
+
elif scores_text.startswith('```'):
|
| 376 |
+
scores_text = scores_text.replace('```', '').strip()
|
| 377 |
+
|
| 378 |
+
try:
|
| 379 |
+
response_data = json.loads(scores_text)
|
| 380 |
+
except json.JSONDecodeError as e:
|
| 381 |
+
print(f"[Warning] JSON decode error: {e}")
|
| 382 |
+
print(f"[Warning] Problematic JSON text (first 500 chars): {scores_text[:500]}")
|
| 383 |
+
print(f"[Warning] Falling back to reward-based scoring")
|
| 384 |
+
return [entry.get('reward', 0) for entry in game_history], ["No reasoning available" for _ in game_history]
|
| 385 |
+
|
| 386 |
+
if isinstance(response_data, dict):
|
| 387 |
+
if 'step_analysis' in response_data:
|
| 388 |
+
step_analyses = response_data['step_analysis']
|
| 389 |
+
|
| 390 |
+
analysis_dict = {}
|
| 391 |
+
for analysis in step_analyses:
|
| 392 |
+
step_num = analysis.get('step', -1)
|
| 393 |
+
if step_num >= 0:
|
| 394 |
+
analysis_dict[step_num] = analysis
|
| 395 |
+
|
| 396 |
+
scores = []
|
| 397 |
+
reasonings = []
|
| 398 |
+
missing_steps = []
|
| 399 |
+
|
| 400 |
+
print("\n=== LLM Step Analysis ===")
|
| 401 |
+
for i in range(len(game_history)):
|
| 402 |
+
if i in analysis_dict:
|
| 403 |
+
analysis = analysis_dict[i]
|
| 404 |
+
step_score = analysis.get('score', 0)
|
| 405 |
+
reasoning = analysis.get('detailed_reasoning', 'No reasoning')
|
| 406 |
+
scores.append(step_score)
|
| 407 |
+
reasonings.append(reasoning)
|
| 408 |
+
print(f"Step {i}: {analysis.get('action', 'Unknown')}")
|
| 409 |
+
print(f" Reasoning: {reasoning}")
|
| 410 |
+
print(f" Score: {step_score}")
|
| 411 |
+
print()
|
| 412 |
+
else:
|
| 413 |
+
scores.append(0)
|
| 414 |
+
reasonings.append('Missing from LLM analysis - filled with score 0')
|
| 415 |
+
missing_steps.append(i)
|
| 416 |
+
print(f"Step {i}: [MISSING FROM LLM ANALYSIS - FILLED WITH 0]")
|
| 417 |
+
print(f" Action: {game_history[i].get('action', 'Unknown')}")
|
| 418 |
+
print(f" Score: 0 (auto-filled)")
|
| 419 |
+
print()
|
| 420 |
+
|
| 421 |
+
if missing_steps:
|
| 422 |
+
print(f"[Warning] LLM missed analyzing steps: {missing_steps}")
|
| 423 |
+
print(f"[Info] Filled missing steps with score 0")
|
| 424 |
+
|
| 425 |
+
if 'overall_assessment' in response_data:
|
| 426 |
+
print(f"Overall Assessment: {response_data['overall_assessment']}")
|
| 427 |
+
print("=" * 50)
|
| 428 |
+
else:
|
| 429 |
+
print(f"[Warning] No step_analysis found in JSON format: {response_data}")
|
| 430 |
+
return [entry.get('reward', 0) for entry in game_history], ["No reasoning available" for _ in game_history]
|
| 431 |
+
elif isinstance(response_data, list):
|
| 432 |
+
scores = response_data
|
| 433 |
+
reasonings = ["Legacy format - no reasoning available" for _ in game_history]
|
| 434 |
+
else:
|
| 435 |
+
print(f"[Warning] Unexpected JSON format: {response_data}")
|
| 436 |
+
return [entry.get('reward', 0) for entry in game_history], ["No reasoning available" for _ in game_history]
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
return scores, reasonings
|
| 440 |
+
else:
|
| 441 |
+
print("[Warning] No valid response from LLM for step scoring")
|
| 442 |
+
return [entry.get('reward', 0) for entry in game_history], ["No reasoning available" for _ in game_history]
|
mcp_server.py
CHANGED
|
@@ -26,6 +26,8 @@ Then open the MCP Inspector in your browser to test the tools interactively.
|
|
| 26 |
|
| 27 |
import sys
|
| 28 |
import os
|
|
|
|
|
|
|
| 29 |
|
| 30 |
# Add parent directory to path to import games module
|
| 31 |
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
@@ -60,9 +62,9 @@ class GameManager:
|
|
| 60 |
self.state = None
|
| 61 |
self.game_name: str = ""
|
| 62 |
# TODO: Add more state tracking
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
|
| 67 |
def initialize(self, game: str = "zork1"):
|
| 68 |
"""Initialize or reset the game."""
|
|
@@ -71,7 +73,21 @@ class GameManager:
|
|
| 71 |
self.state = self.env.reset()
|
| 72 |
# TODO: Reset your state tracking here
|
| 73 |
return self.state.observation
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
def step(self, action: str) -> str:
|
| 76 |
"""Execute an action and return the result."""
|
| 77 |
if self.env is None:
|
|
@@ -135,12 +151,15 @@ def play_action(action: str) -> str:
|
|
| 135 |
# TODO: You might want to add action validation here
|
| 136 |
# TODO: You might want to include score changes in the response
|
| 137 |
|
| 138 |
-
|
| 139 |
|
| 140 |
# Optional: Append score info
|
| 141 |
# result += f"\n[Score: {game.get_score()} | Moves: {game.get_moves()}]"
|
| 142 |
|
| 143 |
-
|
|
|
|
|
|
|
|
|
|
| 144 |
|
| 145 |
|
| 146 |
# TODO: Implement additional tools to help your agent
|
|
@@ -158,17 +177,17 @@ def play_action(action: str) -> str:
|
|
| 158 |
# pass
|
| 159 |
|
| 160 |
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
|
| 173 |
|
| 174 |
# @mcp.tool()
|
|
@@ -184,21 +203,52 @@ def play_action(action: str) -> str:
|
|
| 184 |
# pass
|
| 185 |
|
| 186 |
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
#
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
|
|
|
|
| 202 |
|
| 203 |
# =============================================================================
|
| 204 |
# Run the server
|
|
|
|
| 26 |
|
| 27 |
import sys
|
| 28 |
import os
|
| 29 |
+
import json
|
| 30 |
+
from dataclasses import asdict
|
| 31 |
|
| 32 |
# Add parent directory to path to import games module
|
| 33 |
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
| 62 |
self.state = None
|
| 63 |
self.game_name: str = ""
|
| 64 |
# TODO: Add more state tracking
|
| 65 |
+
self.history: list[tuple[str, str]] = []
|
| 66 |
+
self.explored_locations: dict[str, set[str]] = {}
|
| 67 |
+
self.current_location: str = ""
|
| 68 |
|
| 69 |
def initialize(self, game: str = "zork1"):
|
| 70 |
"""Initialize or reset the game."""
|
|
|
|
| 73 |
self.state = self.env.reset()
|
| 74 |
# TODO: Reset your state tracking here
|
| 75 |
return self.state.observation
|
| 76 |
+
|
| 77 |
+
def get_full_state(self, include_valid_actions: bool = True) -> dict:
|
| 78 |
+
"""
|
| 79 |
+
Convert current env GameState (dataclass from TextAdventureEnv) to a JSON-ready dict.
|
| 80 |
+
"""
|
| 81 |
+
data = asdict(self.state)
|
| 82 |
+
|
| 83 |
+
if include_valid_actions:
|
| 84 |
+
try:
|
| 85 |
+
data["valid_actions"] = self.env.get_valid_actions()
|
| 86 |
+
except Exception:
|
| 87 |
+
data["valid_actions"] = []
|
| 88 |
+
|
| 89 |
+
return data
|
| 90 |
+
|
| 91 |
def step(self, action: str) -> str:
|
| 92 |
"""Execute an action and return the result."""
|
| 93 |
if self.env is None:
|
|
|
|
| 151 |
# TODO: You might want to add action validation here
|
| 152 |
# TODO: You might want to include score changes in the response
|
| 153 |
|
| 154 |
+
game.step(action)
|
| 155 |
|
| 156 |
# Optional: Append score info
|
| 157 |
# result += f"\n[Score: {game.get_score()} | Moves: {game.get_moves()}]"
|
| 158 |
|
| 159 |
+
payload = {
|
| 160 |
+
"state": game.get_full_state(include_valid_actions=True),
|
| 161 |
+
}
|
| 162 |
+
return json.dumps(payload)
|
| 163 |
|
| 164 |
|
| 165 |
# TODO: Implement additional tools to help your agent
|
|
|
|
| 177 |
# pass
|
| 178 |
|
| 179 |
|
| 180 |
+
@mcp.tool()
|
| 181 |
+
def inventory() -> str:
|
| 182 |
+
"""
|
| 183 |
+
Check what the player is carrying.
|
| 184 |
+
|
| 185 |
+
Returns:
|
| 186 |
+
List of items in the player's inventory
|
| 187 |
+
"""
|
| 188 |
+
game = get_game()
|
| 189 |
+
result = game.step("inventory")
|
| 190 |
+
return result
|
| 191 |
|
| 192 |
|
| 193 |
# @mcp.tool()
|
|
|
|
| 203 |
# pass
|
| 204 |
|
| 205 |
|
| 206 |
+
@mcp.tool()
|
| 207 |
+
def get_valid_actions() -> str:
|
| 208 |
+
"""
|
| 209 |
+
Get a list of likely valid actions from the current location.
|
| 210 |
+
|
| 211 |
+
Returns:
|
| 212 |
+
List of actions that might work here
|
| 213 |
+
"""
|
| 214 |
+
# This is a hint: Jericho provides get_valid_actions()
|
| 215 |
+
game = get_game()
|
| 216 |
+
if game.env and game.env.env:
|
| 217 |
+
valid = game.env.env.get_valid_actions()
|
| 218 |
+
return "Valid actions: " + ", ".join(valid[:20])
|
| 219 |
+
return "Could not determine valid actions"
|
| 220 |
+
|
| 221 |
+
@mcp.tool()
|
| 222 |
+
def reset_game() -> str:
|
| 223 |
+
"""
|
| 224 |
+
Reset the game to the beginning or switch to a different game.
|
| 225 |
+
|
| 226 |
+
Use this to start over if you get stuck, die, or want to try a different game.
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
game: Game name (e.g., 'zork1', 'zork2', 'advent', 'enchanter')
|
| 230 |
+
Use list_games() to see available options.
|
| 231 |
+
|
| 232 |
+
Returns:
|
| 233 |
+
The initial game text
|
| 234 |
+
"""
|
| 235 |
+
global _game
|
| 236 |
+
try:
|
| 237 |
+
if _game.env is None:
|
| 238 |
+
game = get_game()
|
| 239 |
+
else:
|
| 240 |
+
game_name = os.environ.get("GAME", "zork1")
|
| 241 |
+
_game.initialize(game_name)
|
| 242 |
+
game = _game
|
| 243 |
+
|
| 244 |
+
payload = {
|
| 245 |
+
"state": game.get_full_state(include_valid_actions=True),
|
| 246 |
+
}
|
| 247 |
+
return json.dumps(payload)
|
| 248 |
+
|
| 249 |
+
except ValueError as e:
|
| 250 |
|
| 251 |
+
return json.dumps({"error": str(e)})
|
| 252 |
|
| 253 |
# =============================================================================
|
| 254 |
# Run the server
|
requirements.txt
CHANGED
|
@@ -7,3 +7,9 @@
|
|
| 7 |
# Add any additional packages your agent needs below:
|
| 8 |
# numpy
|
| 9 |
# requests
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
# Add any additional packages your agent needs below:
|
| 8 |
# numpy
|
| 9 |
# requests
|
| 10 |
+
|
| 11 |
+
python-dotenv
|
| 12 |
+
spacy
|
| 13 |
+
|
| 14 |
+
faiss
|
| 15 |
+
sentence-transformers
|