KantBench / env /nplayer /models.py
jtowarek's picture
Upload folder using huggingface_hub
f7e2ae6 verified
"""Data models for the N-player environment."""
from __future__ import annotations
from typing import Optional
from pydantic import BaseModel, Field
from constant_definitions.game_constants import (
DEFAULT_FALSE,
DEFAULT_NONE,
DEFAULT_ZERO_FLOAT,
DEFAULT_ZERO_INT,
MIN_STEP_COUNT,
)
class NPlayerRoundResult(BaseModel):
round_number: int = Field(..., description="Round number (one-indexed)")
actions: list[str] = Field(..., description="Actions taken by all players")
payoffs: list[float] = Field(..., description="Payoffs received by all players")
class NPlayerAction(BaseModel):
action: str = Field(..., description="The action to take this round")
metadata: dict = Field(default_factory=dict)
class NPlayerObservation(BaseModel):
done: bool = Field(default=DEFAULT_FALSE, description="Whether the episode is over")
reward: float = Field(default=DEFAULT_ZERO_FLOAT, description="Reward for this step")
game_name: str = Field(default="", description="Name of the current game")
game_description: str = Field(default="", description="Description of the game rules")
available_actions: list[str] = Field(default_factory=list, description="Valid actions")
current_round: int = Field(default=DEFAULT_ZERO_INT, description="Current round number")
total_rounds: int = Field(default=DEFAULT_ZERO_INT, description="Total rounds in episode")
history: list[NPlayerRoundResult] = Field(default_factory=list, description="Round history")
scores: list[float] = Field(default_factory=list, description="Cumulative scores for all players")
num_players: int = Field(default=DEFAULT_ZERO_INT, description="Number of players")
player_index: int = Field(default=DEFAULT_ZERO_INT, description="This player's index")
last_round: Optional[NPlayerRoundResult] = Field(default=DEFAULT_NONE, description="Most recent round")
metadata: dict = Field(default_factory=dict)
class NPlayerGameState(BaseModel):
episode_id: Optional[str] = Field(default=DEFAULT_NONE, description="Episode identifier")
step_count: int = Field(default=DEFAULT_ZERO_INT, ge=MIN_STEP_COUNT, description="Steps taken")
game_name: str = Field(default="", description="Current game name")
current_round: int = Field(default=DEFAULT_ZERO_INT, description="Current round")
total_rounds: int = Field(default=DEFAULT_ZERO_INT, description="Total rounds")
num_players: int = Field(default=DEFAULT_ZERO_INT, description="Number of players")
scores: list[float] = Field(default_factory=list, description="Cumulative scores for all players")
history: list[NPlayerRoundResult] = Field(default_factory=list, description="Round history")
is_done: bool = Field(default=DEFAULT_FALSE, description="Whether episode has ended")