"""Core KantBench environment implementing the OpenEnv Environment interface.""" from __future__ import annotations import uuid from typing import Any, Callable, Optional from openenv.core.env_server.interfaces import Environment from env.models import GameAction, GameObservation, GameState, RoundResult from common.games import GameConfig, get_game, GAMES from common.strategies import get_strategy, STRATEGIES, OpponentStrategy from constant_definitions.game_constants import DEFAULT_NUM_ROUNDS _ONE = int(bool(True)) _ZERO_F = float() class KantEnvironment(Environment[GameObservation, GameAction, GameState]): """Game-theory environment hosting multiple classic games. The agent plays against a built-in opponent strategy or another agent function. The opponent's move is computed automatically inside ``step()`` via the selected strategy or the provided ``opponent_fn``. """ SUPPORTS_CONCURRENT_SESSIONS = True def __init__(self) -> None: super().__init__() self._game: Optional[GameConfig] = None self._strategy: Optional[OpponentStrategy] = None self._strategy_name: str = "" self._opponent_fn: Optional[Callable[[GameObservation], GameAction]] = None self._state: GameState = GameState() # ------------------------------------------------------------------ # OpenEnv interface # ------------------------------------------------------------------ def reset( self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs: Any, ) -> GameObservation: game_name: str = kwargs.get("game", "prisoners_dilemma") strategy_name: str = kwargs.get("strategy", "tit_for_tat") num_rounds: Optional[int] = kwargs.get("num_rounds") opponent_fn: Optional[Callable[[GameObservation], GameAction]] = kwargs.get( "opponent_fn", ) self._game = get_game(game_name) self._opponent_fn = opponent_fn if opponent_fn is not None: self._strategy = None self._strategy_name = "agent" else: self._strategy = get_strategy(strategy_name) self._strategy_name = strategy_name rounds = num_rounds if num_rounds is not None else self._game.default_rounds self._state = GameState( episode_id=episode_id or str(uuid.uuid4()), game_name=game_name, opponent_strategy=strategy_name, total_rounds=rounds, ) return self._build_observation() def step( self, action: GameAction, **kwargs: Any, ) -> GameObservation: if self._game is None: raise RuntimeError("Call reset() before step().") if self._state.is_done: raise RuntimeError("Episode already finished. Call reset().") if action.action not in self._game.actions: raise ValueError( f"Invalid action '{action.action}'. " f"Choose from: {self._game.actions}" ) player_action = action.action opponent_action = self._auto_play_opponent(player_action) p_pay, o_pay = self._game.payoff_fn(player_action, opponent_action) new_round = len(self._state.history) + _ONE result = RoundResult( round_number=new_round, player_action=player_action, opponent_action=opponent_action, player_payoff=p_pay, opponent_payoff=o_pay, ) history = list(self._state.history) + [result] p_score = self._state.player_score + p_pay o_score = self._state.opponent_score + o_pay done = new_round >= self._state.total_rounds self._state = GameState( episode_id=self._state.episode_id, step_count=self._state.step_count + _ONE, game_name=self._state.game_name, opponent_strategy=self._state.opponent_strategy, current_round=new_round, total_rounds=self._state.total_rounds, player_score=p_score, opponent_score=o_score, history=history, is_done=done, ) return self._build_observation(reward=p_pay, last_round=result, done=done) @property def state(self) -> GameState: return self._state # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ def _auto_play_opponent(self, player_action: str) -> str: assert self._game is not None if self._opponent_fn is not None: opp_obs = self._build_opponent_observation() opp_action = self._opponent_fn(opp_obs) opp_actions = self._opponent_actions() if opp_action.action not in opp_actions: raise ValueError( f"Opponent returned invalid action '{opp_action.action}'. " f"Choose from: {opp_actions}" ) return opp_action.action assert self._strategy is not None hist = [ { "player_action": r.player_action, "opponent_action": r.opponent_action, } for r in self._state.history ] opp_actions = self._opponent_actions() return self._strategy.choose_action( self._game.game_type, opp_actions, hist, ) def _opponent_actions(self) -> list[str]: assert self._game is not None if self._game.opponent_actions is not None: return list(self._game.opponent_actions) gt = self._game.game_type if gt == "ultimatum": return ["accept", "reject"] if gt == "trust": return _trust_return_actions() # matrix, public_goods, auction, commons, dictator, centipede, # stackelberg, and all generated games share action space return list(self._game.actions) def _build_opponent_observation(self) -> GameObservation: """Build a GameObservation from the opponent's perspective. Swaps player/opponent in history, scores, and payoffs so the opponent agent sees itself as the "player". """ assert self._game is not None flipped_history = [ RoundResult( round_number=r.round_number, player_action=r.opponent_action, opponent_action=r.player_action, player_payoff=r.opponent_payoff, opponent_payoff=r.player_payoff, ) for r in self._state.history ] opp_actions = self._opponent_actions() return GameObservation( done=False, reward=_ZERO_F, game_name=self._state.game_name, game_description=self._game.description, available_actions=opp_actions, current_round=self._state.current_round, total_rounds=self._state.total_rounds, history=flipped_history, player_score=self._state.opponent_score, opponent_score=self._state.player_score, opponent_strategy="agent", ) def _build_observation( self, reward: float = _ZERO_F, last_round: Optional[RoundResult] = None, done: bool = False, ) -> GameObservation: assert self._game is not None return GameObservation( done=done, reward=reward, game_name=self._state.game_name, game_description=self._game.description, available_actions=list(self._game.actions), current_round=self._state.current_round, total_rounds=self._state.total_rounds, history=list(self._state.history), player_score=self._state.player_score, opponent_score=self._state.opponent_score, opponent_strategy=self._strategy_name, last_round=last_round, ) def _trust_return_actions() -> list[str]: from constant_definitions.game_constants import TRUST_ENDOWMENT, TRUST_MULTIPLIER cap = TRUST_ENDOWMENT * TRUST_MULTIPLIER return [f"return_{i}" for i in range(cap + _ONE)]