eyaa99 commited on
Commit
50e7ecd
·
verified ·
1 Parent(s): 68314ab

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +78 -10
agent.py CHANGED
@@ -27,6 +27,13 @@ from typing import Optional
27
  from dotenv import load_dotenv
28
  from huggingface_hub import InferenceClient
29
 
 
 
 
 
 
 
 
30
  # Load environment variables
31
  load_dotenv()
32
 
@@ -88,14 +95,16 @@ def call_llm(prompt: str, system_prompt: str, seed: int, max_tokens: int = 300)
88
  ]
89
 
90
  if USE_LOCAL_MODEL and _local_pipeline is not None:
 
91
  outputs = _local_pipeline(
92
  messages,
93
- max_new_tokens=max_tokens,
94
  temperature=0.0001, # Near-deterministic (0.0 unsupported by some backends)
95
  do_sample=True,
96
  )
97
  return outputs[0]["generated_text"][-1]["content"]
98
 
 
99
  response = LLM_CLIENT.chat.completions.create(
100
  model=LLM_MODEL,
101
  messages=messages,
@@ -248,25 +257,35 @@ class StudentAgent:
248
  response = self._call_llm(prompt, SYSTEM_PROMPT, seed + step)
249
  thought, tool_name, args = self._parse_response(response)
250
 
 
251
  tool_name = "play_action"
252
  action = str(args.get("action", "look")).strip() if isinstance(args, dict) else "look"
253
  if not action:
254
  action = "look"
255
 
 
256
  if action in self.failed_actions:
257
- action = self._fallback_action()
 
 
 
 
 
 
 
258
 
259
  # avoid repeating exact action too much
260
  if len(self.recent_actions) >= 2 and self.recent_actions[-1] == action and self.recent_actions[-2] == action:
261
- action = self._fallback_action()
262
 
263
  new_observation = str(await client.call_tool("play_action", {"action": action}))
264
  self._update_score(new_observation)
265
 
266
- # mark failure if no change
267
  new_norm = self._norm_obs(new_observation)
268
  if new_norm == self.last_obs_norm:
269
- self.failed_actions.add(action)
 
270
  self.last_obs_norm = new_norm
271
 
272
  self.recent_actions.append(action)
@@ -284,7 +303,7 @@ class StudentAgent:
284
  if "GAME OVER" in observation:
285
  return RunResult(
286
  final_score=final_score,
287
- max_score=350,
288
  moves=moves,
289
  locations_visited=locations_visited,
290
  game_completed=True,
@@ -293,7 +312,7 @@ class StudentAgent:
293
 
294
  return RunResult(
295
  final_score=final_score,
296
- max_score=350,
297
  moves=moves,
298
  locations_visited=locations_visited,
299
  game_completed=False,
@@ -355,11 +374,28 @@ class StudentAgent:
355
  """
356
  return call_llm(prompt, system_prompt, seed)
357
 
358
- def _fallback_action(self) -> str:
359
- # Simple exploration fallback
360
  for a in ["north", "south", "east", "west", "up", "down", "in", "out"]:
361
  if a not in self.failed_actions:
362
  return a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
  return "look"
364
 
365
  def _update_score(self, text: str):
@@ -371,4 +407,36 @@ class StudentAgent:
371
  s = re.sub(r"\[Score:.*?\]", "", text, flags=re.I)
372
  s = re.sub(r"Score:\s*\d+|Moves:\s*\d+", "", s, flags=re.I)
373
  s = re.sub(r"\s+", " ", s).strip()
374
- return s[:700]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  from dotenv import load_dotenv
28
  from huggingface_hub import InferenceClient
29
 
30
+ # Silence transformers warnings in local mode (prevents repeated max_length/max_new_tokens spam)
31
+ try:
32
+ import transformers
33
+ transformers.utils.logging.set_verbosity_error()
34
+ except Exception:
35
+ pass
36
+
37
  # Load environment variables
38
  load_dotenv()
39
 
 
95
  ]
96
 
97
  if USE_LOCAL_MODEL and _local_pipeline is not None:
98
+ # Keep local generation shorter + quieter
99
  outputs = _local_pipeline(
100
  messages,
101
+ max_new_tokens=min(max_tokens, 128),
102
  temperature=0.0001, # Near-deterministic (0.0 unsupported by some backends)
103
  do_sample=True,
104
  )
105
  return outputs[0]["generated_text"][-1]["content"]
106
 
107
+ # Hosted inference (may fail with 402 if credits depleted)
108
  response = LLM_CLIENT.chat.completions.create(
109
  model=LLM_MODEL,
110
  messages=messages,
 
257
  response = self._call_llm(prompt, SYSTEM_PROMPT, seed + step)
258
  thought, tool_name, args = self._parse_response(response)
259
 
260
+ # Keep it simple: always call play_action
261
  tool_name = "play_action"
262
  action = str(args.get("action", "look")).strip() if isinstance(args, dict) else "look"
263
  if not action:
264
  action = "look"
265
 
266
+ # Simple avoidance: don't repeat known-failed actions
267
  if action in self.failed_actions:
268
+ action = self._fallback_action_from_observation(observation)
269
+
270
+ # Hard anti-stuck rule: if we keep doing "look", force exploration
271
+ if len(self.recent_actions) >= 2 and self.recent_actions[-1] == "look" and self.recent_actions[-2] == "look":
272
+ if "inventory" not in self.failed_actions:
273
+ action = "inventory"
274
+ else:
275
+ action = self._fallback_action_from_observation(observation)
276
 
277
  # avoid repeating exact action too much
278
  if len(self.recent_actions) >= 2 and self.recent_actions[-1] == action and self.recent_actions[-2] == action:
279
+ action = self._fallback_action_from_observation(observation)
280
 
281
  new_observation = str(await client.call_tool("play_action", {"action": action}))
282
  self._update_score(new_observation)
283
 
284
+ # mark failure if no change (but do not mark "look" as failed)
285
  new_norm = self._norm_obs(new_observation)
286
  if new_norm == self.last_obs_norm:
287
+ if action != "look":
288
+ self.failed_actions.add(action)
289
  self.last_obs_norm = new_norm
290
 
291
  self.recent_actions.append(action)
 
303
  if "GAME OVER" in observation:
304
  return RunResult(
305
  final_score=final_score,
306
+ max_score=350, # Zork1 max score, adjust if needed
307
  moves=moves,
308
  locations_visited=locations_visited,
309
  game_completed=True,
 
312
 
313
  return RunResult(
314
  final_score=final_score,
315
+ max_score=350, # Zork1 max score, adjust if needed
316
  moves=moves,
317
  locations_visited=locations_visited,
318
  game_completed=False,
 
374
  """
375
  return call_llm(prompt, system_prompt, seed)
376
 
377
+ def _fallback_action_from_observation(self, observation: str) -> str:
378
+ # Try movement first
379
  for a in ["north", "south", "east", "west", "up", "down", "in", "out"]:
380
  if a not in self.failed_actions:
381
  return a
382
+
383
+ # Try simple object interactions based on words in the observation
384
+ words = re.findall(r"[A-Za-z]{3,}", observation.lower())
385
+ stop = {
386
+ "the","and","you","are","with","that","this","from","your","have","here","there",
387
+ "into","over","under","would","could","should","what","when","then","than","them",
388
+ "been","were","will","just","about","some","there","where","which"
389
+ }
390
+ candidates = [w for w in words if w not in stop]
391
+ candidates = candidates[:25]
392
+
393
+ for w in candidates:
394
+ for verb in ["examine", "take", "open"]:
395
+ cmd = f"{verb} {w}"
396
+ if cmd not in self.failed_actions:
397
+ return cmd
398
+
399
  return "look"
400
 
401
  def _update_score(self, text: str):
 
407
  s = re.sub(r"\[Score:.*?\]", "", text, flags=re.I)
408
  s = re.sub(r"Score:\s*\d+|Moves:\s*\d+", "", s, flags=re.I)
409
  s = re.sub(r"\s+", " ", s).strip()
410
+ return s[:700]
411
+
412
+
413
+ # =============================================================================
414
+ # For local testing
415
+ # =============================================================================
416
+
417
+ async def test_agent():
418
+ """Test the agent locally."""
419
+ from fastmcp import Client
420
+
421
+ # Path to your MCP server
422
+ server_path = "mcp_server.py"
423
+
424
+ agent = StudentAgent()
425
+
426
+ async with Client(server_path) as client:
427
+ result = await agent.run(
428
+ client=client,
429
+ game="zork1",
430
+ max_steps=10,
431
+ seed=42,
432
+ verbose=True,
433
+ )
434
+
435
+ print(f"\nFinal Score: {result.final_score}")
436
+ print(f"Moves: {result.moves}")
437
+ print(f"Locations: {result.locations_visited}")
438
+
439
+
440
+ if __name__ == "__main__":
441
+ import asyncio
442
+ asyncio.run(test_agent())