File size: 6,004 Bytes
f7e2ae6
5d2f027
047aab1
688c130
 
5d2f027
 
 
 
047aab1
5d2f027
 
 
 
 
f7e2ae6
 
047aab1
 
 
 
 
 
5d2f027
688c130
 
 
5d2f027
 
047aab1
5d2f027
047aab1
 
688c130
 
 
5d2f027
 
 
 
f7e2ae6
047aab1
 
 
5d2f027
f7e2ae6
047aab1
688c130
 
 
 
 
 
 
 
 
 
047aab1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d2f027
f7e2ae6
047aab1
 
 
 
 
 
 
 
5d2f027
f7e2ae6
 
047aab1
 
 
 
f7e2ae6
 
 
 
5d2f027
 
f7e2ae6
 
 
 
 
 
 
 
 
 
5d2f027
f7e2ae6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
047aab1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
"""KantBench environment adapter for the HF Space.

Thin wrapper that delegates to the real KantEnvironment (90+ 2-player games,
17 strategies, meta-games, composable variants) and NPlayerEnvironment
(3 N-player games) instead of a standalone reimplementation.
"""

from __future__ import annotations

from typing import Any, Optional

from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State

from models import KantBenchAction, KantBenchObservation
from env.environment import KantEnvironment
from env.models import GameAction
from env.nplayer.environment import NPlayerEnvironment
from env.nplayer.models import NPlayerAction, NPlayerObservation

# Register built-in N-player games into the registry
import common.games_meta.nplayer_games  # noqa: F401
from common.games_meta.nplayer_config import NPLAYER_GAMES

from common.games import GAMES
from common.variants import compose_game


class KantbenchEnvironment(Environment):
    """Game theory environment exposing 90+ two-player and N-player games.

    Wraps the real KantEnvironment and NPlayerEnvironment, routing
    automatically based on the requested game name.

    Supports a ``variant`` reset parameter for dynamic game composition
    (e.g. ``variant="constitutional"`` or ``variant="cheap_talk"``).
    """

    SUPPORTS_CONCURRENT_SESSIONS: bool = True

    def __init__(self) -> None:
        self._env_2p = KantEnvironment()
        self._env_np = NPlayerEnvironment()
        self._is_nplayer: bool = False

    def reset(self, **kwargs: Any) -> KantBenchObservation:
        game_name: str = kwargs.get("game", "prisoners_dilemma")
        variant: Optional[str] = kwargs.pop("variant", None)

        # Dynamic variant composition — compose game on-the-fly and
        # register it so KantEnvironment can look it up via get_game().
        # Constitutional variant creates fresh mutable closure per call.
        if variant and game_name in GAMES:
            composed = compose_game(game_name, variant)
            composed_key = f"_composed_{variant}_{game_name}"
            GAMES[composed_key] = composed
            kwargs["game"] = composed_key

        if game_name in NPLAYER_GAMES:
            self._is_nplayer = True
            # Map Space kwargs to NPlayerEnvironment.reset signature
            opponent_strategies: Optional[list[str]] = None
            strategy = kwargs.get("strategy")
            if strategy:
                opponent_strategies = [strategy]
            obs = self._env_np.reset(
                game_name,
                num_rounds=kwargs.get("num_rounds"),
                opponent_strategies=opponent_strategies,
            )
            return _nplayer_to_space_obs(obs)
        else:
            self._is_nplayer = False
            obs = self._env_2p.reset(**kwargs)
            return _to_space_obs(obs)

    def step(self, action: KantBenchAction, **kwargs: Any) -> KantBenchObservation:
        if self._is_nplayer:
            internal_action = NPlayerAction(action=action.move)
            obs = self._env_np.step(internal_action)
            return _nplayer_to_space_obs(obs)
        else:
            internal_action = GameAction(action=action.move)
            obs = self._env_2p.step(internal_action, **kwargs)
            return _to_space_obs(obs)

    @property
    def state(self) -> State:
        if self._is_nplayer:
            s = self._env_np.state
        else:
            s = self._env_2p.state
        return State(
            episode_id=s.episode_id or "",
            step_count=s.step_count,
        )


def _to_space_obs(obs) -> KantBenchObservation:
    """Convert internal GameObservation to Space-facing KantBenchObservation."""
    last = obs.last_round
    history = [
        {
            "round": r.round_number,
            "your_move": r.player_action,
            "opponent_move": r.opponent_action,
            "your_payoff": r.player_payoff,
            "opponent_payoff": r.opponent_payoff,
        }
        for r in obs.history
    ]
    return KantBenchObservation(
        game_name=obs.game_name,
        game_description=obs.game_description,
        available_moves=list(obs.available_actions),
        your_move=last.player_action if last else "",
        opponent_move=last.opponent_action if last else "",
        your_payoff=last.player_payoff if last else 0.0,
        opponent_payoff=last.opponent_payoff if last else 0.0,
        cumulative_score=obs.player_score,
        round_number=obs.current_round,
        max_rounds=obs.total_rounds,
        opponent_strategy=obs.opponent_strategy,
        history=history,
        done=obs.done,
        reward=obs.reward,
        message="Game over — call reset() to start a new episode." if obs.done else "",
    )


def _nplayer_to_space_obs(obs: NPlayerObservation) -> KantBenchObservation:
    """Convert NPlayerObservation to Space-facing KantBenchObservation."""
    last = obs.last_round
    history = [
        {
            "round": r.round_number,
            "actions": r.actions,
            "payoffs": r.payoffs,
        }
        for r in obs.history
    ]
    return KantBenchObservation(
        game_name=obs.game_name,
        game_description=obs.game_description,
        available_moves=list(obs.available_actions),
        your_move=last.actions[0] if last else "",
        opponent_move="",  # N-player: see history for all actions
        your_payoff=last.payoffs[0] if last else 0.0,
        opponent_payoff=0.0,  # N-player: see history for all payoffs
        cumulative_score=obs.scores[0] if obs.scores else 0.0,
        round_number=obs.current_round,
        max_rounds=obs.total_rounds,
        opponent_strategy="",
        history=history,
        done=obs.done,
        reward=obs.reward,
        message="Game over — call reset() to start a new episode." if obs.done else "",
        num_players=obs.num_players,
        player_index=obs.player_index,
        all_scores=list(obs.scores),
    )