Spaces:
Running
Running
| """N-player game environment.""" | |
| from __future__ import annotations | |
| import uuid | |
| from typing import Any, Callable, Optional | |
| from common.games_meta.nplayer_config import NPlayerGameConfig, get_nplayer_game | |
| from env.nplayer.models import ( | |
| NPlayerAction, | |
| NPlayerGameState, | |
| NPlayerObservation, | |
| NPlayerRoundResult, | |
| ) | |
| from env.nplayer.strategies import get_nplayer_strategy, NPlayerStrategy | |
| _ONE = int(bool(True)) | |
| _ZERO = int() | |
| _ZERO_F = float() | |
| class NPlayerEnvironment: | |
| """Game-theory environment for N-player games. | |
| Player zero is the primary agent controlled via ``step()``. | |
| Players one through N-minus-one are auto-played by strategies or | |
| caller-provided functions (``opponent_fns``). | |
| """ | |
| def __init__(self) -> None: | |
| self._game: Optional[NPlayerGameConfig] = None | |
| self._strategies: list[Optional[NPlayerStrategy]] = [] | |
| self._opponent_fns: list[Optional[Callable[[NPlayerObservation], NPlayerAction]]] = [] | |
| self._state: NPlayerGameState = NPlayerGameState() | |
| # ------------------------------------------------------------------ | |
| # Public API | |
| # ------------------------------------------------------------------ | |
| def reset( | |
| self, | |
| game: str, | |
| *, | |
| num_rounds: Optional[int] = None, | |
| opponent_strategies: Optional[list[str]] = None, | |
| opponent_fns: Optional[list[Optional[Callable[[NPlayerObservation], NPlayerAction]]]] = None, | |
| episode_id: Optional[str] = None, | |
| ) -> NPlayerObservation: | |
| """Start a new episode. | |
| Parameters | |
| ---------- | |
| game: | |
| Key in ``NPLAYER_GAMES``. | |
| num_rounds: | |
| Override the default round count. | |
| opponent_strategies: | |
| Strategy names for players one through N-minus-one. If shorter | |
| than needed, the last entry is repeated. Defaults to all | |
| ``"random"``. | |
| opponent_fns: | |
| Callable opponents for players one through N-minus-one. ``None`` | |
| entries fall back to the corresponding strategy. | |
| episode_id: | |
| Optional identifier for the episode. | |
| """ | |
| self._game = get_nplayer_game(game) | |
| n = self._game.num_players | |
| num_opponents = n - _ONE | |
| # Resolve strategies | |
| if opponent_strategies is None: | |
| strat_names = ["random"] * num_opponents | |
| else: | |
| strat_names = list(opponent_strategies) | |
| while len(strat_names) < num_opponents: | |
| strat_names.append(strat_names[-_ONE]) | |
| self._strategies = [get_nplayer_strategy(s) for s in strat_names] | |
| # Resolve opponent fns | |
| if opponent_fns is None: | |
| self._opponent_fns = [None] * num_opponents | |
| else: | |
| fns: list[Optional[Callable]] = list(opponent_fns) | |
| while len(fns) < num_opponents: | |
| fns.append(None) | |
| self._opponent_fns = fns | |
| rounds = num_rounds if num_rounds is not None else self._game.default_rounds | |
| self._state = NPlayerGameState( | |
| episode_id=episode_id or str(uuid.uuid4()), | |
| game_name=game, | |
| total_rounds=rounds, | |
| num_players=n, | |
| scores=[_ZERO_F] * n, | |
| ) | |
| return self._build_observation(_ZERO) | |
| def step(self, action: NPlayerAction) -> NPlayerObservation: | |
| """Execute one round. | |
| The caller supplies the action for player zero. Opponents are | |
| auto-played. | |
| """ | |
| if self._game is None: | |
| raise RuntimeError("Call reset() before step().") | |
| if self._state.is_done: | |
| raise RuntimeError("Episode already finished. Call reset().") | |
| if action.action not in self._game.actions: | |
| raise ValueError( | |
| f"Invalid action '{action.action}'. " | |
| f"Choose from: {self._game.actions}" | |
| ) | |
| # Collect all actions: player zero first, then opponents | |
| all_actions: list[str] = [action.action] | |
| for idx in range(len(self._strategies)): | |
| player_idx = idx + _ONE | |
| opp_action = self._get_opponent_action(idx, player_idx) | |
| all_actions.append(opp_action) | |
| actions_tuple = tuple(all_actions) | |
| payoffs_tuple = self._game.payoff_fn(actions_tuple) | |
| new_round = len(self._state.history) + _ONE | |
| result = NPlayerRoundResult( | |
| round_number=new_round, | |
| actions=list(all_actions), | |
| payoffs=list(payoffs_tuple), | |
| ) | |
| history = list(self._state.history) + [result] | |
| new_scores = [ | |
| s + p for s, p in zip(self._state.scores, payoffs_tuple) | |
| ] | |
| done = new_round >= self._state.total_rounds | |
| self._state = NPlayerGameState( | |
| episode_id=self._state.episode_id, | |
| step_count=self._state.step_count + _ONE, | |
| game_name=self._state.game_name, | |
| current_round=new_round, | |
| total_rounds=self._state.total_rounds, | |
| num_players=self._state.num_players, | |
| scores=new_scores, | |
| history=history, | |
| is_done=done, | |
| ) | |
| return self._build_observation( | |
| _ZERO, | |
| reward=payoffs_tuple[_ZERO], | |
| last_round=result, | |
| done=done, | |
| ) | |
| def state(self) -> NPlayerGameState: | |
| return self._state | |
| # ------------------------------------------------------------------ | |
| # Internal helpers | |
| # ------------------------------------------------------------------ | |
| def _get_opponent_action(self, opp_idx: int, player_idx: int) -> str: | |
| """Get the action for opponent at opp_idx (player player_idx).""" | |
| assert self._game is not None | |
| fn = self._opponent_fns[opp_idx] | |
| if fn is not None: | |
| obs = self._build_observation(player_idx) | |
| opp_action = fn(obs) | |
| if opp_action.action not in self._game.actions: | |
| raise ValueError( | |
| f"Opponent {player_idx} returned invalid action " | |
| f"'{opp_action.action}'. Choose from: {self._game.actions}" | |
| ) | |
| return opp_action.action | |
| strategy = self._strategies[opp_idx] | |
| assert strategy is not None | |
| obs = self._build_observation(player_idx) | |
| return strategy.choose_action(obs) | |
| def _build_observation( | |
| self, | |
| player_index: int, | |
| reward: float = _ZERO_F, | |
| last_round: Optional[NPlayerRoundResult] = None, | |
| done: bool = False, | |
| ) -> NPlayerObservation: | |
| assert self._game is not None | |
| return NPlayerObservation( | |
| done=done, | |
| reward=reward, | |
| game_name=self._state.game_name, | |
| game_description=self._game.description, | |
| available_actions=list(self._game.actions), | |
| current_round=self._state.current_round, | |
| total_rounds=self._state.total_rounds, | |
| history=list(self._state.history), | |
| scores=list(self._state.scores), | |
| num_players=self._state.num_players, | |
| player_index=player_index, | |
| last_round=last_round, | |
| ) | |