File size: 6,938 Bytes
e4d7d50 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 | """
ChessEcon — Shared Data Models
Pydantic models used by both the backend API and the training pipeline.
"""
from __future__ import annotations
from enum import Enum
from typing import Optional, List, Dict, Any
from pydantic import BaseModel, Field
import time
# ── Enums ──────────────────────────────────────────────────────────────────────
class GameStatus(str, Enum):
WAITING = "waiting"
ACTIVE = "active"
FINISHED = "finished"
class GameOutcome(str, Enum):
WHITE_WIN = "white_win"
BLACK_WIN = "black_win"
DRAW = "draw"
ONGOING = "ongoing"
class EventType(str, Enum):
GAME_START = "game_start"
MOVE = "move"
COACHING_REQUEST = "coaching_request"
COACHING_RESULT = "coaching_result"
GAME_END = "game_end"
TRAINING_STEP = "training_step"
ECONOMY_UPDATE = "economy_update"
ERROR = "error"
class PositionComplexity(str, Enum):
SIMPLE = "simple"
MODERATE = "moderate"
COMPLEX = "complex"
CRITICAL = "critical"
class RLMethod(str, Enum):
GRPO = "grpo"
PPO = "ppo"
RLOO = "rloo"
REINFORCE = "reinforce"
DPO = "dpo"
# ── Chess Models ───────────────────────────────────────────────────────────────
class MoveRequest(BaseModel):
game_id: str
player: str # "white" | "black"
move_uci: str
class MoveResponse(BaseModel):
game_id: str
move_uci: str
fen: str
legal_moves: List[str]
outcome: GameOutcome
move_number: int
is_check: bool
is_checkmate: bool
is_stalemate: bool
class GameState(BaseModel):
game_id: str
fen: str
legal_moves: List[str]
outcome: GameOutcome
move_number: int
move_history: List[str] = Field(default_factory=list)
status: GameStatus = GameStatus.ACTIVE
white_player: str = "white"
black_player: str = "black"
created_at: float = Field(default_factory=time.time)
class NewGameResponse(BaseModel):
game_id: str
fen: str
legal_moves: List[str]
status: GameStatus
# ── Economy Models ─────────────────────────────────────────────────────────────
class Transaction(BaseModel):
tx_id: str
agent_id: str
amount: float # positive = credit, negative = debit
description: str
timestamp: float = Field(default_factory=time.time)
class WalletState(BaseModel):
agent_id: str
balance: float
total_earned: float = 0.0
total_spent: float = 0.0
coaching_calls: int = 0
games_played: int = 0
games_won: int = 0
class TournamentResult(BaseModel):
game_id: str
winner: Optional[str] # agent_id or None for draw
outcome: GameOutcome
prize_paid: float
entry_fees_collected: float
organizer_cut: float
# ── Agent Models ───────────────────────────────────────────────────────────────
class ComplexityAnalysis(BaseModel):
fen: str
score: float # 0.0 – 1.0
level: PositionComplexity
factors: Dict[str, float] = Field(default_factory=dict)
recommend_coaching: bool = False
class CoachingRequest(BaseModel):
game_id: str
agent_id: str
fen: str
legal_moves: List[str]
wallet_balance: float
complexity: ComplexityAnalysis
class CoachingResponse(BaseModel):
game_id: str
agent_id: str
recommended_move: str
analysis: str
cost: float
model_used: str
tokens_used: int = 0
# ── Training Models ────────────────────────────────────────────────────────────
class Episode(BaseModel):
"""A single completed game episode used for RL training."""
episode_id: str
game_id: str
agent_id: str
prompts: List[str] # LLM prompts at each move
responses: List[str] # LLM responses at each move
moves: List[str] # UCI moves played
outcome: GameOutcome
game_reward: float # +1 win, 0 draw, -1 loss
economic_reward: float # normalized net profit
combined_reward: float # weighted combination
coaching_calls: int = 0
coaching_cost: float = 0.0
net_profit: float = 0.0
created_at: float = Field(default_factory=time.time)
class TrainingStep(BaseModel):
step: int
method: RLMethod
loss: float
policy_reward: float
kl_divergence: float
win_rate: float
avg_profit: float
coaching_rate: float
episodes_used: int
timestamp: float = Field(default_factory=time.time)
class TrainingConfig(BaseModel):
model_name: str = "Qwen/Qwen2.5-0.5B-Instruct"
method: RLMethod = RLMethod.GRPO
learning_rate: float = 1e-5
batch_size: int = 4
num_generations: int = 4
max_new_tokens: int = 128
temperature: float = 0.9
kl_coef: float = 0.1
train_every: int = 5
total_games: int = 100
save_every: int = 10
device: str = "cpu"
checkpoint_dir: str = "training/checkpoints"
data_dir: str = "training/data"
# ── WebSocket Event Models ─────────────────────────────────────────────────────
class WSEvent(BaseModel):
"""Generic WebSocket event envelope."""
type: EventType
timestamp: float = Field(default_factory=time.time)
data: Dict[str, Any] = Field(default_factory=dict)
class GameStartEvent(BaseModel):
game_id: str
white_agent: str
black_agent: str
white_wallet: float
black_wallet: float
entry_fee: float
class MoveEvent(BaseModel):
game_id: str
player: str
move_uci: str
fen: str
move_number: int
wallet_white: float
wallet_black: float
used_coaching: bool = False
class GameEndEvent(BaseModel):
game_id: str
outcome: GameOutcome
winner: Optional[str]
white_wallet_final: float
black_wallet_final: float
prize_paid: float
total_moves: int
class TrainingStepEvent(BaseModel):
step: int
loss: float
reward: float
kl_div: float
win_rate: float
avg_profit: float
coaching_rate: float
class EconomyUpdateEvent(BaseModel):
game_number: int
white_wallet: float
black_wallet: float
prize_income: float
coaching_cost: float
entry_fee: float
net_pnl: float
cumulative_pnl: float
|