File size: 2,801 Bytes
f7e2ae6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
"""Data models for the N-player environment."""

from __future__ import annotations

from typing import Optional

from pydantic import BaseModel, Field

from constant_definitions.game_constants import (
    DEFAULT_FALSE,
    DEFAULT_NONE,
    DEFAULT_ZERO_FLOAT,
    DEFAULT_ZERO_INT,
    MIN_STEP_COUNT,
)


class NPlayerRoundResult(BaseModel):
    round_number: int = Field(..., description="Round number (one-indexed)")
    actions: list[str] = Field(..., description="Actions taken by all players")
    payoffs: list[float] = Field(..., description="Payoffs received by all players")


class NPlayerAction(BaseModel):
    action: str = Field(..., description="The action to take this round")
    metadata: dict = Field(default_factory=dict)


class NPlayerObservation(BaseModel):
    done: bool = Field(default=DEFAULT_FALSE, description="Whether the episode is over")
    reward: float = Field(default=DEFAULT_ZERO_FLOAT, description="Reward for this step")
    game_name: str = Field(default="", description="Name of the current game")
    game_description: str = Field(default="", description="Description of the game rules")
    available_actions: list[str] = Field(default_factory=list, description="Valid actions")
    current_round: int = Field(default=DEFAULT_ZERO_INT, description="Current round number")
    total_rounds: int = Field(default=DEFAULT_ZERO_INT, description="Total rounds in episode")
    history: list[NPlayerRoundResult] = Field(default_factory=list, description="Round history")
    scores: list[float] = Field(default_factory=list, description="Cumulative scores for all players")
    num_players: int = Field(default=DEFAULT_ZERO_INT, description="Number of players")
    player_index: int = Field(default=DEFAULT_ZERO_INT, description="This player's index")
    last_round: Optional[NPlayerRoundResult] = Field(default=DEFAULT_NONE, description="Most recent round")
    metadata: dict = Field(default_factory=dict)


class NPlayerGameState(BaseModel):
    episode_id: Optional[str] = Field(default=DEFAULT_NONE, description="Episode identifier")
    step_count: int = Field(default=DEFAULT_ZERO_INT, ge=MIN_STEP_COUNT, description="Steps taken")
    game_name: str = Field(default="", description="Current game name")
    current_round: int = Field(default=DEFAULT_ZERO_INT, description="Current round")
    total_rounds: int = Field(default=DEFAULT_ZERO_INT, description="Total rounds")
    num_players: int = Field(default=DEFAULT_ZERO_INT, description="Number of players")
    scores: list[float] = Field(default_factory=list, description="Cumulative scores for all players")
    history: list[NPlayerRoundResult] = Field(default_factory=list, description="Round history")
    is_done: bool = Field(default=DEFAULT_FALSE, description="Whether episode has ended")