chessecon / shared /models.py
suvasis's picture
code add
e4d7d50
"""
ChessEcon — Shared Data Models
Pydantic models used by both the backend API and the training pipeline.
"""
from __future__ import annotations
from enum import Enum
from typing import Optional, List, Dict, Any
from pydantic import BaseModel, Field
import time
# ── Enums ──────────────────────────────────────────────────────────────────────
class GameStatus(str, Enum):
WAITING = "waiting"
ACTIVE = "active"
FINISHED = "finished"
class GameOutcome(str, Enum):
WHITE_WIN = "white_win"
BLACK_WIN = "black_win"
DRAW = "draw"
ONGOING = "ongoing"
class EventType(str, Enum):
GAME_START = "game_start"
MOVE = "move"
COACHING_REQUEST = "coaching_request"
COACHING_RESULT = "coaching_result"
GAME_END = "game_end"
TRAINING_STEP = "training_step"
ECONOMY_UPDATE = "economy_update"
ERROR = "error"
class PositionComplexity(str, Enum):
SIMPLE = "simple"
MODERATE = "moderate"
COMPLEX = "complex"
CRITICAL = "critical"
class RLMethod(str, Enum):
GRPO = "grpo"
PPO = "ppo"
RLOO = "rloo"
REINFORCE = "reinforce"
DPO = "dpo"
# ── Chess Models ───────────────────────────────────────────────────────────────
class MoveRequest(BaseModel):
game_id: str
player: str # "white" | "black"
move_uci: str
class MoveResponse(BaseModel):
game_id: str
move_uci: str
fen: str
legal_moves: List[str]
outcome: GameOutcome
move_number: int
is_check: bool
is_checkmate: bool
is_stalemate: bool
class GameState(BaseModel):
game_id: str
fen: str
legal_moves: List[str]
outcome: GameOutcome
move_number: int
move_history: List[str] = Field(default_factory=list)
status: GameStatus = GameStatus.ACTIVE
white_player: str = "white"
black_player: str = "black"
created_at: float = Field(default_factory=time.time)
class NewGameResponse(BaseModel):
game_id: str
fen: str
legal_moves: List[str]
status: GameStatus
# ── Economy Models ─────────────────────────────────────────────────────────────
class Transaction(BaseModel):
tx_id: str
agent_id: str
amount: float # positive = credit, negative = debit
description: str
timestamp: float = Field(default_factory=time.time)
class WalletState(BaseModel):
agent_id: str
balance: float
total_earned: float = 0.0
total_spent: float = 0.0
coaching_calls: int = 0
games_played: int = 0
games_won: int = 0
class TournamentResult(BaseModel):
game_id: str
winner: Optional[str] # agent_id or None for draw
outcome: GameOutcome
prize_paid: float
entry_fees_collected: float
organizer_cut: float
# ── Agent Models ───────────────────────────────────────────────────────────────
class ComplexityAnalysis(BaseModel):
fen: str
score: float # 0.0 – 1.0
level: PositionComplexity
factors: Dict[str, float] = Field(default_factory=dict)
recommend_coaching: bool = False
class CoachingRequest(BaseModel):
game_id: str
agent_id: str
fen: str
legal_moves: List[str]
wallet_balance: float
complexity: ComplexityAnalysis
class CoachingResponse(BaseModel):
game_id: str
agent_id: str
recommended_move: str
analysis: str
cost: float
model_used: str
tokens_used: int = 0
# ── Training Models ────────────────────────────────────────────────────────────
class Episode(BaseModel):
"""A single completed game episode used for RL training."""
episode_id: str
game_id: str
agent_id: str
prompts: List[str] # LLM prompts at each move
responses: List[str] # LLM responses at each move
moves: List[str] # UCI moves played
outcome: GameOutcome
game_reward: float # +1 win, 0 draw, -1 loss
economic_reward: float # normalized net profit
combined_reward: float # weighted combination
coaching_calls: int = 0
coaching_cost: float = 0.0
net_profit: float = 0.0
created_at: float = Field(default_factory=time.time)
class TrainingStep(BaseModel):
step: int
method: RLMethod
loss: float
policy_reward: float
kl_divergence: float
win_rate: float
avg_profit: float
coaching_rate: float
episodes_used: int
timestamp: float = Field(default_factory=time.time)
class TrainingConfig(BaseModel):
model_name: str = "Qwen/Qwen2.5-0.5B-Instruct"
method: RLMethod = RLMethod.GRPO
learning_rate: float = 1e-5
batch_size: int = 4
num_generations: int = 4
max_new_tokens: int = 128
temperature: float = 0.9
kl_coef: float = 0.1
train_every: int = 5
total_games: int = 100
save_every: int = 10
device: str = "cpu"
checkpoint_dir: str = "training/checkpoints"
data_dir: str = "training/data"
# ── WebSocket Event Models ─────────────────────────────────────────────────────
class WSEvent(BaseModel):
"""Generic WebSocket event envelope."""
type: EventType
timestamp: float = Field(default_factory=time.time)
data: Dict[str, Any] = Field(default_factory=dict)
class GameStartEvent(BaseModel):
game_id: str
white_agent: str
black_agent: str
white_wallet: float
black_wallet: float
entry_fee: float
class MoveEvent(BaseModel):
game_id: str
player: str
move_uci: str
fen: str
move_number: int
wallet_white: float
wallet_black: float
used_coaching: bool = False
class GameEndEvent(BaseModel):
game_id: str
outcome: GameOutcome
winner: Optional[str]
white_wallet_final: float
black_wallet_final: float
prize_paid: float
total_moves: int
class TrainingStepEvent(BaseModel):
step: int
loss: float
reward: float
kl_div: float
win_rate: float
avg_profit: float
coaching_rate: float
class EconomyUpdateEvent(BaseModel):
game_number: int
white_wallet: float
black_wallet: float
prize_income: float
coaching_cost: float
entry_fee: float
net_pnl: float
cumulative_pnl: float