KantBench / env /nplayer /environment.py
jtowarek's picture
Upload folder using huggingface_hub
f7e2ae6 verified
"""N-player game environment."""
from __future__ import annotations
import uuid
from typing import Any, Callable, Optional
from common.games_meta.nplayer_config import NPlayerGameConfig, get_nplayer_game
from env.nplayer.models import (
NPlayerAction,
NPlayerGameState,
NPlayerObservation,
NPlayerRoundResult,
)
from env.nplayer.strategies import get_nplayer_strategy, NPlayerStrategy
_ONE = int(bool(True))
_ZERO = int()
_ZERO_F = float()
class NPlayerEnvironment:
"""Game-theory environment for N-player games.
Player zero is the primary agent controlled via ``step()``.
Players one through N-minus-one are auto-played by strategies or
caller-provided functions (``opponent_fns``).
"""
def __init__(self) -> None:
self._game: Optional[NPlayerGameConfig] = None
self._strategies: list[Optional[NPlayerStrategy]] = []
self._opponent_fns: list[Optional[Callable[[NPlayerObservation], NPlayerAction]]] = []
self._state: NPlayerGameState = NPlayerGameState()
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def reset(
self,
game: str,
*,
num_rounds: Optional[int] = None,
opponent_strategies: Optional[list[str]] = None,
opponent_fns: Optional[list[Optional[Callable[[NPlayerObservation], NPlayerAction]]]] = None,
episode_id: Optional[str] = None,
) -> NPlayerObservation:
"""Start a new episode.
Parameters
----------
game:
Key in ``NPLAYER_GAMES``.
num_rounds:
Override the default round count.
opponent_strategies:
Strategy names for players one through N-minus-one. If shorter
than needed, the last entry is repeated. Defaults to all
``"random"``.
opponent_fns:
Callable opponents for players one through N-minus-one. ``None``
entries fall back to the corresponding strategy.
episode_id:
Optional identifier for the episode.
"""
self._game = get_nplayer_game(game)
n = self._game.num_players
num_opponents = n - _ONE
# Resolve strategies
if opponent_strategies is None:
strat_names = ["random"] * num_opponents
else:
strat_names = list(opponent_strategies)
while len(strat_names) < num_opponents:
strat_names.append(strat_names[-_ONE])
self._strategies = [get_nplayer_strategy(s) for s in strat_names]
# Resolve opponent fns
if opponent_fns is None:
self._opponent_fns = [None] * num_opponents
else:
fns: list[Optional[Callable]] = list(opponent_fns)
while len(fns) < num_opponents:
fns.append(None)
self._opponent_fns = fns
rounds = num_rounds if num_rounds is not None else self._game.default_rounds
self._state = NPlayerGameState(
episode_id=episode_id or str(uuid.uuid4()),
game_name=game,
total_rounds=rounds,
num_players=n,
scores=[_ZERO_F] * n,
)
return self._build_observation(_ZERO)
def step(self, action: NPlayerAction) -> NPlayerObservation:
"""Execute one round.
The caller supplies the action for player zero. Opponents are
auto-played.
"""
if self._game is None:
raise RuntimeError("Call reset() before step().")
if self._state.is_done:
raise RuntimeError("Episode already finished. Call reset().")
if action.action not in self._game.actions:
raise ValueError(
f"Invalid action '{action.action}'. "
f"Choose from: {self._game.actions}"
)
# Collect all actions: player zero first, then opponents
all_actions: list[str] = [action.action]
for idx in range(len(self._strategies)):
player_idx = idx + _ONE
opp_action = self._get_opponent_action(idx, player_idx)
all_actions.append(opp_action)
actions_tuple = tuple(all_actions)
payoffs_tuple = self._game.payoff_fn(actions_tuple)
new_round = len(self._state.history) + _ONE
result = NPlayerRoundResult(
round_number=new_round,
actions=list(all_actions),
payoffs=list(payoffs_tuple),
)
history = list(self._state.history) + [result]
new_scores = [
s + p for s, p in zip(self._state.scores, payoffs_tuple)
]
done = new_round >= self._state.total_rounds
self._state = NPlayerGameState(
episode_id=self._state.episode_id,
step_count=self._state.step_count + _ONE,
game_name=self._state.game_name,
current_round=new_round,
total_rounds=self._state.total_rounds,
num_players=self._state.num_players,
scores=new_scores,
history=history,
is_done=done,
)
return self._build_observation(
_ZERO,
reward=payoffs_tuple[_ZERO],
last_round=result,
done=done,
)
@property
def state(self) -> NPlayerGameState:
return self._state
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _get_opponent_action(self, opp_idx: int, player_idx: int) -> str:
"""Get the action for opponent at opp_idx (player player_idx)."""
assert self._game is not None
fn = self._opponent_fns[opp_idx]
if fn is not None:
obs = self._build_observation(player_idx)
opp_action = fn(obs)
if opp_action.action not in self._game.actions:
raise ValueError(
f"Opponent {player_idx} returned invalid action "
f"'{opp_action.action}'. Choose from: {self._game.actions}"
)
return opp_action.action
strategy = self._strategies[opp_idx]
assert strategy is not None
obs = self._build_observation(player_idx)
return strategy.choose_action(obs)
def _build_observation(
self,
player_index: int,
reward: float = _ZERO_F,
last_round: Optional[NPlayerRoundResult] = None,
done: bool = False,
) -> NPlayerObservation:
assert self._game is not None
return NPlayerObservation(
done=done,
reward=reward,
game_name=self._state.game_name,
game_description=self._game.description,
available_actions=list(self._game.actions),
current_round=self._state.current_round,
total_rounds=self._state.total_rounds,
history=list(self._state.history),
scores=list(self._state.scores),
num_players=self._state.num_players,
player_index=player_index,
last_round=last_round,
)