File size: 2,660 Bytes
f7e2ae6
5d2f027
 
 
 
 
 
 
f7e2ae6
5d2f027
 
f7e2ae6
dd8e198
5d2f027
 
f7e2ae6
5d2f027
f7e2ae6
 
5d2f027
 
f7e2ae6
5d2f027
f7e2ae6
 
5d2f027
f7e2ae6
 
5d2f027
f7e2ae6
 
5d2f027
f7e2ae6
5d2f027
 
f7e2ae6
 
5d2f027
f7e2ae6
5d2f027
f7e2ae6
 
 
 
 
 
 
 
 
 
 
 
 
5d2f027
 
f7e2ae6
5d2f027
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
"""KantBench Environment Client."""

from typing import Dict

from openenv.core.client_types import StepResult
from openenv.core.env_server.types import State
from openenv.core import EnvClient

from .models import KantBenchAction, KantBenchObservation


class KantBenchEnv(
    EnvClient[KantBenchAction, KantBenchObservation, State]
):
    """
    Client for the KantBench game theory environment.

    Maintains a persistent WebSocket connection to the environment server.
    Each client instance has its own dedicated environment session.

    Example:
        >>> with KantBenchEnv(base_url="http://localhost:8000") as client:
        ...     result = client.reset()
        ...     print(result.observation.game_name)
        ...     print(result.observation.available_moves)
        ...
        ...     result = client.step(KantBenchAction(move="cooperate"))
        ...     print(result.observation.your_payoff)

    Example with HF Space:
        >>> with KantBenchEnv(base_url="https://openenv-community-kantbench.hf.space") as client:
        ...     result = client.reset()
        ...     result = client.step(KantBenchAction(move="cooperate"))
    """

    def _step_payload(self, action: KantBenchAction) -> Dict:
        return {"move": action.move}

    def _parse_result(self, payload: Dict) -> StepResult[KantBenchObservation]:
        obs_data = payload.get("observation", {})
        observation = KantBenchObservation(
            game_name=obs_data.get("game_name", ""),
            game_description=obs_data.get("game_description", ""),
            available_moves=obs_data.get("available_moves", []),
            your_move=obs_data.get("your_move", ""),
            opponent_move=obs_data.get("opponent_move", ""),
            your_payoff=obs_data.get("your_payoff", 0.0),
            opponent_payoff=obs_data.get("opponent_payoff", 0.0),
            cumulative_score=obs_data.get("cumulative_score", 0.0),
            round_number=obs_data.get("round_number", 0),
            max_rounds=obs_data.get("max_rounds", 10),
            opponent_strategy=obs_data.get("opponent_strategy", ""),
            history=obs_data.get("history", []),
            done=payload.get("done", False),
            reward=payload.get("reward"),
            message=obs_data.get("message", ""),
        )

        return StepResult(
            observation=observation,
            reward=payload.get("reward"),
            done=payload.get("done", False),
        )

    def _parse_state(self, payload: Dict) -> State:
        return State(
            episode_id=payload.get("episode_id"),
            step_count=payload.get("step_count", 0),
        )