DarshanScripts commited on
Commit
b2b76f2
·
verified ·
1 Parent(s): 06bc894

Upload stratego\benchmarking\run_game.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. stratego//benchmarking//run_game.py +95 -0
stratego//benchmarking//run_game.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # stratego/benchmarking/run_game.py
2
+
3
+ import textarena as ta
4
+ from stratego.env.stratego_env import StrategoEnv
5
+
6
+
7
+ def get_last_board_observation(state, player_id):
8
+ for obs in reversed(state.observations[player_id]):
9
+ if ta.ObservationType.GAME_BOARD in obs:
10
+ for elem in obs:
11
+ if isinstance(elem, str):
12
+ return elem
13
+ return ""
14
+
15
+
16
+ def run_game(agent0, agent1, size=6, seed=None):
17
+ env = StrategoEnv(env_id = "Stratego-custom",size=size)
18
+ env.reset(num_players=2, seed=seed)
19
+
20
+ invalid_moves = {0: 0, 1: 0}
21
+ repetitions = 0
22
+ turns = 0
23
+
24
+ done = False
25
+ winner = None
26
+ reason_verbose = "Unknown termination reason"
27
+ flag_captured = False
28
+
29
+ while not done:
30
+ state = env.get_state()
31
+ rep = env.repetition_count()
32
+ pid = state.current_player_id
33
+ agent = agent0 if pid == 0 else agent1
34
+
35
+ obs = get_last_board_observation(state, pid)
36
+ action = agent(obs) if callable(agent) else agent.act(obs)
37
+
38
+ done, _ = env.step(action)
39
+ turns += 1
40
+
41
+ if state.game_info.get(pid, {}).get("invalid_move"):
42
+ invalid_moves[pid] += 1
43
+
44
+ repetitions += rep.get(pid, 0)
45
+
46
+ if done:
47
+ gs = state.game_state
48
+ gi = state.game_info
49
+
50
+ if gs.get("termination") == "invalid":
51
+ reason_verbose = f"Invalid move: {gs.get('invalid_reason', 'Invalid move')}"
52
+
53
+ else:
54
+ raw = gi.get("reason", "")
55
+ # Normalize reason to string for downstream metrics/logs
56
+ if isinstance(raw, (list, tuple)):
57
+ raw_reason = "; ".join(map(str, raw))
58
+ else:
59
+ raw_reason = str(raw)
60
+
61
+ raw_lower = raw_reason.lower()
62
+
63
+ if "flag" in raw_lower:
64
+ flag_captured = True
65
+ reason_verbose = raw_reason
66
+ elif "no legal moves" in raw_lower or "no more movable pieces" in raw_lower or "no moves" in raw_lower:
67
+ reason_verbose = "Opponent had no legal moves"
68
+ elif "stalemate" in raw_lower:
69
+ reason_verbose = "Stalemate"
70
+ elif "turn limit" in raw_lower:
71
+ reason_verbose = "Turn limit reached"
72
+ elif "repetition" in raw_lower:
73
+ reason_verbose = "Two-squares repetition rule violation"
74
+ else:
75
+ reason_verbose = raw_reason or "Game ended without explicit winner"
76
+
77
+ # TextArena does not store a winner in game_info; derive from rewards
78
+ rewards = getattr(state, "rewards", None)
79
+ if rewards:
80
+ max_reward = max(rewards.values())
81
+ winners = [player for player, reward in rewards.items() if reward == max_reward]
82
+ if len(winners) == 1:
83
+ winner = winners[0]
84
+ else:
85
+ winner = -1
86
+
87
+ return {
88
+ "winner": winner if winner is not None else -1,
89
+ "turns": turns,
90
+ "invalid_moves_p0": invalid_moves[0],
91
+ "invalid_moves_p1": invalid_moves[1],
92
+ "repetitions": repetitions,
93
+ "flag_captured": flag_captured,
94
+ "game_end_reason": reason_verbose
95
+ }