Spaces:
Configuration error
Configuration error
File size: 3,390 Bytes
5899740 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 | # stratego/benchmarking/run_game.py
import textarena as ta
from stratego.env.stratego_env import StrategoEnv
def get_last_board_observation(state, player_id):
for obs in reversed(state.observations[player_id]):
if ta.ObservationType.GAME_BOARD in obs:
for elem in obs:
if isinstance(elem, str):
return elem
return ""
def run_game(agent0, agent1, size=6, seed=None):
env = StrategoEnv(env_id = "Stratego-custom",size=size)
env.reset(num_players=2, seed=seed)
invalid_moves = {0: 0, 1: 0}
repetitions = 0
turns = 0
done = False
winner = None
reason_verbose = "Unknown termination reason"
flag_captured = False
while not done:
state = env.get_state()
rep = env.repetition_count()
pid = state.current_player_id
agent = agent0 if pid == 0 else agent1
obs = get_last_board_observation(state, pid)
action = agent(obs) if callable(agent) else agent.act(obs)
done, _ = env.step(action)
turns += 1
if state.game_info.get(pid, {}).get("invalid_move"):
invalid_moves[pid] += 1
repetitions += rep.get(pid, 0)
if done:
gs = state.game_state
gi = state.game_info
if gs.get("termination") == "invalid":
reason_verbose = f"Invalid move: {gs.get('invalid_reason', 'Invalid move')}"
else:
raw = gi.get("reason", "")
# Normalize reason to string for downstream metrics/logs
if isinstance(raw, (list, tuple)):
raw_reason = "; ".join(map(str, raw))
else:
raw_reason = str(raw)
raw_lower = raw_reason.lower()
if "flag" in raw_lower:
flag_captured = True
reason_verbose = raw_reason
elif "no legal moves" in raw_lower or "no more movable pieces" in raw_lower or "no moves" in raw_lower:
reason_verbose = "Opponent had no legal moves"
elif "stalemate" in raw_lower:
reason_verbose = "Stalemate"
elif "turn limit" in raw_lower:
reason_verbose = "Turn limit reached"
elif "repetition" in raw_lower:
reason_verbose = "Two-squares repetition rule violation"
else:
reason_verbose = raw_reason or "Game ended without explicit winner"
# TextArena does not store a winner in game_info; derive from rewards
rewards = getattr(state, "rewards", None)
if rewards:
max_reward = max(rewards.values())
winners = [player for player, reward in rewards.items() if reward == max_reward]
if len(winners) == 1:
winner = winners[0]
else:
winner = -1
return {
"winner": winner if winner is not None else -1,
"turns": turns,
"invalid_moves_p0": invalid_moves[0],
"invalid_moves_p1": invalid_moves[1],
"repetitions": repetitions,
"flag_captured": flag_captured,
"game_end_reason": reason_verbose
}
|