Spaces:
Running
Running
File size: 2,801 Bytes
f7e2ae6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 | """Data models for the N-player environment."""
from __future__ import annotations
from typing import Optional
from pydantic import BaseModel, Field
from constant_definitions.game_constants import (
DEFAULT_FALSE,
DEFAULT_NONE,
DEFAULT_ZERO_FLOAT,
DEFAULT_ZERO_INT,
MIN_STEP_COUNT,
)
class NPlayerRoundResult(BaseModel):
round_number: int = Field(..., description="Round number (one-indexed)")
actions: list[str] = Field(..., description="Actions taken by all players")
payoffs: list[float] = Field(..., description="Payoffs received by all players")
class NPlayerAction(BaseModel):
action: str = Field(..., description="The action to take this round")
metadata: dict = Field(default_factory=dict)
class NPlayerObservation(BaseModel):
done: bool = Field(default=DEFAULT_FALSE, description="Whether the episode is over")
reward: float = Field(default=DEFAULT_ZERO_FLOAT, description="Reward for this step")
game_name: str = Field(default="", description="Name of the current game")
game_description: str = Field(default="", description="Description of the game rules")
available_actions: list[str] = Field(default_factory=list, description="Valid actions")
current_round: int = Field(default=DEFAULT_ZERO_INT, description="Current round number")
total_rounds: int = Field(default=DEFAULT_ZERO_INT, description="Total rounds in episode")
history: list[NPlayerRoundResult] = Field(default_factory=list, description="Round history")
scores: list[float] = Field(default_factory=list, description="Cumulative scores for all players")
num_players: int = Field(default=DEFAULT_ZERO_INT, description="Number of players")
player_index: int = Field(default=DEFAULT_ZERO_INT, description="This player's index")
last_round: Optional[NPlayerRoundResult] = Field(default=DEFAULT_NONE, description="Most recent round")
metadata: dict = Field(default_factory=dict)
class NPlayerGameState(BaseModel):
episode_id: Optional[str] = Field(default=DEFAULT_NONE, description="Episode identifier")
step_count: int = Field(default=DEFAULT_ZERO_INT, ge=MIN_STEP_COUNT, description="Steps taken")
game_name: str = Field(default="", description="Current game name")
current_round: int = Field(default=DEFAULT_ZERO_INT, description="Current round")
total_rounds: int = Field(default=DEFAULT_ZERO_INT, description="Total rounds")
num_players: int = Field(default=DEFAULT_ZERO_INT, description="Number of players")
scores: list[float] = Field(default_factory=list, description="Cumulative scores for all players")
history: list[NPlayerRoundResult] = Field(default_factory=list, description="Round history")
is_done: bool = Field(default=DEFAULT_FALSE, description="Whether episode has ended")
|