| """ |
| ChessEcon Training — Self-Play Loop |
| Runs games between the trainable agent and a heuristic opponent, |
| collects episodes, and triggers RL training every N games. |
| Emits training metrics to the backend WebSocket for live dashboard display. |
| """ |
| from __future__ import annotations |
| import asyncio |
| import json |
| import logging |
| import random |
| import uuid |
| import time |
| from pathlib import Path |
| from typing import List, Optional |
|
|
| import chess |
|
|
| from training.config import TrainingConfig |
| from training.reward import combined_reward, build_prompt, game_reward, economic_reward |
| from shared.models import Episode, GameOutcome, TrainingStep |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class SelfPlayLoop: |
| """ |
| Orchestrates self-play games between the trainable LLM agent |
| and a heuristic opponent. Collects episodes for RL training. |
| """ |
|
|
| def __init__(self, model, tokenizer, trainer, config: TrainingConfig): |
| self.model = model |
| self.tokenizer = tokenizer |
| self.trainer = trainer |
| self.config = config |
| self.episode_buffer: List[Episode] = [] |
| self.game_count = 0 |
| self.training_steps = 0 |
| self.win_count = 0 |
| self.total_profit = 0.0 |
| self.coaching_calls = 0 |
|
|
| def run(self, total_games: Optional[int] = None) -> None: |
| """Run the full self-play training loop.""" |
| total = total_games or self.config.total_games |
| logger.info(f"Starting self-play: {total} games, training every {self.config.train_every}") |
| logger.info(f"Model: {self.config.player_model} | Method: {self.config.rl_method}") |
|
|
| for game_num in range(1, total + 1): |
| episode = self._run_game(game_num) |
| self.episode_buffer.append(episode) |
| self.game_count += 1 |
|
|
| |
| if episode.outcome == GameOutcome.WHITE_WIN: |
| self.win_count += 1 |
| self.total_profit += episode.net_profit |
| self.coaching_calls += episode.coaching_calls |
|
|
| logger.info( |
| f"Game {game_num}/{total} | " |
| f"Outcome: {episode.outcome.value} | " |
| f"Reward: {episode.combined_reward:.3f} | " |
| f"Net P&L: {episode.net_profit:.1f}" |
| ) |
|
|
| |
| self._save_episode(episode) |
|
|
| |
| if game_num % self.config.train_every == 0: |
| self._train() |
|
|
| logger.info( |
| f"Self-play complete. " |
| f"Win rate: {self.win_count/max(self.game_count,1):.1%} | " |
| f"Avg profit: {self.total_profit/max(self.game_count,1):.2f} | " |
| f"Total training steps: {self.training_steps}" |
| ) |
|
|
| def _run_game(self, game_num: int) -> Episode: |
| """Run a single game and return the collected episode.""" |
| board = chess.Board() |
| agent_is_white = random.random() > 0.5 |
| agent_color = chess.WHITE if agent_is_white else chess.BLACK |
|
|
| |
| wallet = self.config.initial_wallet - self.config.entry_fee |
| coaching_calls = 0 |
| coaching_cost = 0.0 |
|
|
| prompts: List[str] = [] |
| responses: List[str] = [] |
| moves: List[str] = [] |
|
|
| max_moves = 150 |
| move_count = 0 |
|
|
| while not board.is_game_over() and move_count < max_moves: |
| legal = [m.uci() for m in board.legal_moves] |
| is_agent_turn = (board.turn == agent_color) |
|
|
| if is_agent_turn: |
| |
| can_coach = wallet >= self.config.coaching_fee + 5.0 |
| prompt = build_prompt( |
| fen=board.fen(), |
| legal_moves=legal, |
| wallet=wallet, |
| coaching_fee=self.config.coaching_fee, |
| move_number=board.fullmove_number, |
| can_afford_coaching=can_coach, |
| ) |
| |
| response = self._generate_response(prompt, legal) |
| move_uci = self._extract_move(response, legal) |
| buy_coaching = "BUY_COACHING: yes" in response.upper() and can_coach |
|
|
| if buy_coaching: |
| wallet -= self.config.coaching_fee |
| coaching_calls += 1 |
| coaching_cost += self.config.coaching_fee |
| |
| move_uci = self._heuristic_move(board) |
|
|
| prompts.append(prompt) |
| responses.append(response) |
| moves.append(move_uci) |
| else: |
| |
| move_uci = self._heuristic_move(board) |
|
|
| try: |
| board.push(chess.Move.from_uci(move_uci)) |
| except Exception: |
| move_uci = random.choice(legal) |
| board.push(chess.Move.from_uci(move_uci)) |
|
|
| move_count += 1 |
|
|
| |
| outcome = self._board_outcome(board) |
|
|
| |
| prize = 0.0 |
| if (outcome == GameOutcome.WHITE_WIN and agent_is_white) or \ |
| (outcome == GameOutcome.BLACK_WIN and not agent_is_white): |
| prize = self.config.entry_fee * 2 * self.config.prize_multiplier |
| elif outcome == GameOutcome.DRAW: |
| prize = self.config.entry_fee * self.config.prize_multiplier |
|
|
| net_profit = prize - self.config.entry_fee - coaching_cost |
| gr = game_reward(outcome, agent_is_white) |
| er = economic_reward(net_profit) |
| cr = combined_reward(outcome, agent_is_white, net_profit) |
|
|
| return Episode( |
| episode_id=str(uuid.uuid4()), |
| game_id=str(uuid.uuid4()), |
| agent_id="trainable_agent", |
| prompts=prompts, |
| responses=responses, |
| moves=moves, |
| outcome=outcome, |
| game_reward=gr, |
| economic_reward=er, |
| combined_reward=cr, |
| coaching_calls=coaching_calls, |
| coaching_cost=coaching_cost, |
| net_profit=net_profit, |
| ) |
|
|
| def _train(self) -> None: |
| """Run one training step on buffered episodes.""" |
| if not self.episode_buffer: |
| return |
|
|
| logger.info(f"Training on {len(self.episode_buffer)} episodes...") |
| metrics = self.trainer.train_step(self.episode_buffer) |
| self.training_steps += 1 |
|
|
| |
| metrics["win_rate"] = self.win_count / max(self.game_count, 1) |
| metrics["avg_profit"] = self.total_profit / max(self.game_count, 1) |
| metrics["coaching_rate"] = self.coaching_calls / max(self.game_count, 1) |
|
|
| |
| if self.training_steps % self.config.save_every == 0: |
| self.trainer.save_checkpoint(metrics) |
|
|
| |
| self._emit_training_metrics(metrics) |
|
|
| |
| self.episode_buffer.clear() |
|
|
| def _generate_response(self, prompt: str, legal_moves: list) -> str: |
| """Generate a response from the LLM model.""" |
| try: |
| from training.model_loader import generate_move |
| return generate_move( |
| self.model, self.tokenizer, prompt, |
| max_new_tokens=self.config.max_new_tokens, |
| temperature=self.config.temperature, |
| device=self.config.device, |
| ) |
| except Exception as e: |
| logger.debug(f"Model generation failed: {e} — using heuristic") |
| return f"MOVE: {random.choice(legal_moves)}\nBUY_COACHING: no\nREASONING: heuristic" |
|
|
| def _extract_move(self, response: str, legal_moves: list) -> str: |
| """Extract UCI move from model response.""" |
| import re |
| match = re.search(r"MOVE:\s*([a-h][1-8][a-h][1-8][qrbn]?)", response, re.IGNORECASE) |
| if match: |
| move = match.group(1).lower() |
| if move in legal_moves: |
| return move |
| |
| for token in re.findall(r"\b([a-h][1-8][a-h][1-8][qrbn]?)\b", response): |
| if token.lower() in legal_moves: |
| return token.lower() |
| return random.choice(legal_moves) |
|
|
| def _heuristic_move(self, board: chess.Board) -> str: |
| """Simple heuristic: prefer captures, then center moves.""" |
| legal = list(board.legal_moves) |
| captures = [m for m in legal if board.is_capture(m)] |
| if captures: |
| return random.choice(captures).uci() |
| center = [chess.Move.from_uci(m) for m in ["e2e4","d2d4","e7e5","d7d5","g1f3","b1c3"] |
| if chess.Move.from_uci(m) in legal] |
| if center: |
| return random.choice(center).uci() |
| return random.choice(legal).uci() |
|
|
| def _board_outcome(self, board: chess.Board) -> GameOutcome: |
| if not board.is_game_over(): |
| return GameOutcome.DRAW |
| result = board.result() |
| if result == "1-0": |
| return GameOutcome.WHITE_WIN |
| elif result == "0-1": |
| return GameOutcome.BLACK_WIN |
| return GameOutcome.DRAW |
|
|
| def _save_episode(self, episode: Episode) -> None: |
| """Append episode to JSONL file for later analysis.""" |
| data_dir = Path(self.config.data_dir) |
| data_dir.mkdir(parents=True, exist_ok=True) |
| file_path = data_dir / f"episodes_{time.strftime('%Y%m%d')}.jsonl" |
| with open(file_path, "a") as f: |
| f.write(episode.model_dump_json() + "\n") |
|
|
| def _emit_training_metrics(self, metrics: dict) -> None: |
| """Send training metrics to backend WebSocket (best-effort).""" |
| try: |
| import requests |
| backend_url = self.config.backend_url |
| requests.post( |
| f"{backend_url}/api/training/ingest", |
| json={"step": self.training_steps, **metrics}, |
| timeout=2, |
| ) |
| except Exception: |
| pass |
|
|