| |
| |
| |
| |
| |
|
|
| """ |
| Data models for OpenSpiel Environment. |
| |
| This module defines the Action, Observation, and State types for OpenSpiel games. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from dataclasses import dataclass, field |
| from typing import Any, Dict, List, Optional |
|
|
| from openenv_core.env_server import Action, Observation, State |
|
|
|
|
| @dataclass |
| class OpenSpielAction(Action): |
| """ |
| Action for OpenSpiel environments. |
| |
| Attributes: |
| action_id: The integer action ID to take (from legal_actions). |
| game_name: Name of the OpenSpiel game (e.g., "catch", "tic_tac_toe"). |
| game_params: Optional game-specific parameters (e.g., {"rows": 8, "columns": 6}). |
| """ |
| action_id: int |
| game_name: str = "catch" |
| game_params: Dict[str, Any] = field(default_factory=dict) |
|
|
|
|
| @dataclass |
| class OpenSpielObservation(Observation): |
| """ |
| Observation from OpenSpiel environment. |
| |
| This represents what the agent sees after taking an action. |
| For single-player games, this is straightforward. |
| For multi-player games, this is from the perspective of the agent player. |
| |
| Attributes: |
| info_state: Information state tensor (list of floats) for the agent. |
| This contains all information available to the agent. |
| legal_actions: List of legal action IDs the agent can take. |
| game_phase: String describing the current phase (e.g., "playing", "terminal"). |
| current_player_id: ID of the current player (-1 for simultaneous, player ID otherwise). |
| opponent_last_action: Last action taken by opponent (if available, None otherwise). |
| """ |
| info_state: List[float] |
| legal_actions: List[int] |
| game_phase: str = "playing" |
| current_player_id: int = 0 |
| opponent_last_action: Optional[int] = None |
|
|
|
|
| @dataclass |
| class OpenSpielState(State): |
| """ |
| State for OpenSpiel environment. |
| |
| Attributes: |
| game_name: Name of the OpenSpiel game. |
| agent_player: Which player ID the agent controls (0 by default). |
| opponent_policy: Name of the opponent policy ("random", "fixed", etc.). |
| game_params: Game-specific parameters. |
| num_players: Total number of players in the game. |
| """ |
| game_name: str = "catch" |
| agent_player: int = 0 |
| opponent_policy: str = "random" |
| game_params: Dict[str, Any] = field(default_factory=dict) |
| num_players: int = 1 |
|
|