Spaces:
Configuration error
Configuration error
Upload stratego\benchmarking\run_game.py with huggingface_hub
Browse files
stratego//benchmarking//run_game.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# stratego/benchmarking/run_game.py
|
| 2 |
+
|
| 3 |
+
import textarena as ta
|
| 4 |
+
from stratego.env.stratego_env import StrategoEnv
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def get_last_board_observation(state, player_id):
|
| 8 |
+
for obs in reversed(state.observations[player_id]):
|
| 9 |
+
if ta.ObservationType.GAME_BOARD in obs:
|
| 10 |
+
for elem in obs:
|
| 11 |
+
if isinstance(elem, str):
|
| 12 |
+
return elem
|
| 13 |
+
return ""
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def run_game(agent0, agent1, size=6, seed=None):
|
| 17 |
+
env = StrategoEnv(env_id = "Stratego-custom",size=size)
|
| 18 |
+
env.reset(num_players=2, seed=seed)
|
| 19 |
+
|
| 20 |
+
invalid_moves = {0: 0, 1: 0}
|
| 21 |
+
repetitions = 0
|
| 22 |
+
turns = 0
|
| 23 |
+
|
| 24 |
+
done = False
|
| 25 |
+
winner = None
|
| 26 |
+
reason_verbose = "Unknown termination reason"
|
| 27 |
+
flag_captured = False
|
| 28 |
+
|
| 29 |
+
while not done:
|
| 30 |
+
state = env.get_state()
|
| 31 |
+
rep = env.repetition_count()
|
| 32 |
+
pid = state.current_player_id
|
| 33 |
+
agent = agent0 if pid == 0 else agent1
|
| 34 |
+
|
| 35 |
+
obs = get_last_board_observation(state, pid)
|
| 36 |
+
action = agent(obs) if callable(agent) else agent.act(obs)
|
| 37 |
+
|
| 38 |
+
done, _ = env.step(action)
|
| 39 |
+
turns += 1
|
| 40 |
+
|
| 41 |
+
if state.game_info.get(pid, {}).get("invalid_move"):
|
| 42 |
+
invalid_moves[pid] += 1
|
| 43 |
+
|
| 44 |
+
repetitions += rep.get(pid, 0)
|
| 45 |
+
|
| 46 |
+
if done:
|
| 47 |
+
gs = state.game_state
|
| 48 |
+
gi = state.game_info
|
| 49 |
+
|
| 50 |
+
if gs.get("termination") == "invalid":
|
| 51 |
+
reason_verbose = f"Invalid move: {gs.get('invalid_reason', 'Invalid move')}"
|
| 52 |
+
|
| 53 |
+
else:
|
| 54 |
+
raw = gi.get("reason", "")
|
| 55 |
+
# Normalize reason to string for downstream metrics/logs
|
| 56 |
+
if isinstance(raw, (list, tuple)):
|
| 57 |
+
raw_reason = "; ".join(map(str, raw))
|
| 58 |
+
else:
|
| 59 |
+
raw_reason = str(raw)
|
| 60 |
+
|
| 61 |
+
raw_lower = raw_reason.lower()
|
| 62 |
+
|
| 63 |
+
if "flag" in raw_lower:
|
| 64 |
+
flag_captured = True
|
| 65 |
+
reason_verbose = raw_reason
|
| 66 |
+
elif "no legal moves" in raw_lower or "no more movable pieces" in raw_lower or "no moves" in raw_lower:
|
| 67 |
+
reason_verbose = "Opponent had no legal moves"
|
| 68 |
+
elif "stalemate" in raw_lower:
|
| 69 |
+
reason_verbose = "Stalemate"
|
| 70 |
+
elif "turn limit" in raw_lower:
|
| 71 |
+
reason_verbose = "Turn limit reached"
|
| 72 |
+
elif "repetition" in raw_lower:
|
| 73 |
+
reason_verbose = "Two-squares repetition rule violation"
|
| 74 |
+
else:
|
| 75 |
+
reason_verbose = raw_reason or "Game ended without explicit winner"
|
| 76 |
+
|
| 77 |
+
# TextArena does not store a winner in game_info; derive from rewards
|
| 78 |
+
rewards = getattr(state, "rewards", None)
|
| 79 |
+
if rewards:
|
| 80 |
+
max_reward = max(rewards.values())
|
| 81 |
+
winners = [player for player, reward in rewards.items() if reward == max_reward]
|
| 82 |
+
if len(winners) == 1:
|
| 83 |
+
winner = winners[0]
|
| 84 |
+
else:
|
| 85 |
+
winner = -1
|
| 86 |
+
|
| 87 |
+
return {
|
| 88 |
+
"winner": winner if winner is not None else -1,
|
| 89 |
+
"turns": turns,
|
| 90 |
+
"invalid_moves_p0": invalid_moves[0],
|
| 91 |
+
"invalid_moves_p1": invalid_moves[1],
|
| 92 |
+
"repetitions": repetitions,
|
| 93 |
+
"flag_captured": flag_captured,
|
| 94 |
+
"game_end_reason": reason_verbose
|
| 95 |
+
}
|