Spaces:
Sleeping
Sleeping
| import torch | |
| from collections import deque | |
| import numpy as np | |
| from typing import List, Iterator, Tuple, Optional | |
| import chess | |
| class Game: | |
| """ | |
| Represents a single chess game trajectory with all relevant data for RL training. | |
| Acts as a *temporary* buffer inside loop | |
| Handles: | |
| - Storing trajectory data (fens, reps, actions, log_probs, values, invalid_masks) | |
| - Tracking game status (active/complete) | |
| """ | |
| def __init__(self): | |
| self.active = True | |
| self.valid = True | |
| self.completion_reason = None | |
| self.game_result = None | |
| self.fens = [] | |
| self.repetition_counts = [] | |
| self.actions = [] | |
| self.values = [] | |
| self.log_probs = [] | |
| self.invalid_masks = [] | |
| def update_trajectory(self, fen, rep, act, val, logp, inv_m): | |
| self.fens.append(fen) | |
| self.repetition_counts.append(rep) | |
| self.actions.append(act) | |
| self.values.append(val) | |
| self.log_probs.append(logp) | |
| self.invalid_masks.append(inv_m) | |
| def update_game_status(self, done, reason): | |
| if done == True: | |
| self.active = False | |
| if reason in ["1-0","0-1","1/2-1/2"]: | |
| self.completion_reason = reason | |
| self.game_result = reason | |
| else: | |
| self.completion_reason = reason | |
| self.game_result = None | |
| self.valid = False | |
| def get_white_trajectory(self): | |
| """Extract the trajectory for white""" | |
| indices = [] | |
| for i in range(len(self.fens) - 1): | |
| board = chess.Board(self.fens[i]) | |
| if board.turn: # True if white to move | |
| indices.append(i) | |
| return { | |
| 'fens': [self.fens[i] for i in indices], | |
| 'repetition_counts': [self.repetition_counts[i] for i in indices], | |
| 'actions': [self.actions[i] for i in indices], | |
| 'values': [self.values[i] for i in indices], | |
| 'log_probs': [self.log_probs[i] for i in indices], | |
| 'invalid_masks': [self.invalid_masks[i] for i in indices] | |
| } | |
| def get_black_trajectory(self): | |
| """Extract the trajectory for black pieces.""" | |
| indices = [] | |
| for i in range(len(self.fens) - 1): | |
| board = chess.Board(self.fens[i]) | |
| if not board.turn: # False if black to move | |
| indices.append(i) | |
| return { | |
| 'fens': [self.fens[i] for i in indices], | |
| 'repetition_counts': [self.repetition_counts[i] for i in indices], | |
| 'actions': [self.actions[i] for i in indices], | |
| 'values': [self.values[i] for i in indices], | |
| 'log_probs': [self.log_probs[i] for i in indices], | |
| 'invalid_masks': [self.invalid_masks[i] for i in indices] | |
| } | |
| class ReplayBuffer: | |
| """ | |
| The buffer class for PPO reinforcement learning. | |
| Handles: | |
| - store samples including: | |
| 1. fens | |
| 2. reps | |
| 3. actions | |
| 4. log_probs | |
| 5. values | |
| 6. invalid_masks | |
| - calculate advantage based on reward and value (7. advantage) | |
| - output samples in batches | |
| Since PPO is largely on-policy, so the replay buffer will not be so large that deque is not appropriate | |
| """ | |
| def __init__(self, | |
| capacity: int, | |
| batch_size: int, | |
| gamma: float, | |
| gae_lambda: float, | |
| shuffle: bool=True | |
| ): | |
| self.gamma = gamma | |
| self.gae_lambda = gae_lambda | |
| self.fens = deque(maxlen=capacity) | |
| self.repetition_counts = deque(maxlen=capacity) | |
| self.actions = deque(maxlen=capacity) | |
| self.log_probs = deque(maxlen=capacity) | |
| self.values = deque(maxlen=capacity) | |
| self.invalid_masks = deque(maxlen=capacity) | |
| self.advantages = deque(maxlen=capacity) | |
| self.batch_size = batch_size | |
| self.shuffle = shuffle | |
| def push_game(self, game: Game): | |
| """ | |
| Process a completed game and add its trajectories to the buffer. | |
| Handles reward computation for both white and black players. | |
| """ | |
| if not game.valid: | |
| return | |
| result = game.game_result | |
| if result not in ["1-0","0-1","1/2-1/2"]: | |
| raise ValueError(f"Result not recognized: {result}. Either an incompleted game was passed in, or game.update_game_status() method is wrong.") | |
| if result == "1-0": result = 1 | |
| elif result == "0-1": result = -1 | |
| elif result == "1/2-1/2": result = 0 | |
| white_traj = game.get_white_trajectory() | |
| if white_traj["fens"]: | |
| self._process_trajectory( | |
| white_traj["fens"], | |
| white_traj["repetition_counts"], | |
| white_traj["actions"], | |
| white_traj["log_probs"], | |
| white_traj["values"], | |
| white_traj["invalid_masks"], | |
| result | |
| ) | |
| black_traj = game.get_black_trajectory() | |
| if black_traj["fens"]: | |
| self._process_trajectory( | |
| black_traj["fens"], | |
| black_traj["repetition_counts"], | |
| black_traj["actions"], | |
| black_traj["log_probs"], | |
| black_traj["values"], | |
| black_traj["invalid_masks"], | |
| -result # flip reward for black's perspective | |
| ) | |
| def _process_trajectory(self, fens, reps, actions, log_probs, values, invalid_masks, final_reward): | |
| """Process a trajectory for one player, compute advantages and add to buffer""" | |
| values_tensor = torch.tensor(values) if not torch.is_tensor(values) else values | |
| advantages = self._compute_advantage(values_tensor, final_reward) | |
| for i in range(len(fens)): | |
| self.fens.append(fens[i]) | |
| self.repetition_counts.append(reps[i]) | |
| self.actions.append(actions[i]) | |
| self.log_probs.append(log_probs[i]) | |
| self.values.append(values[i]) | |
| self.invalid_masks.append(invalid_masks[i]) | |
| self.advantages.append(advantages[i]) | |
| def _compute_advantage(self, value_traj: torch.Tensor, final_reward: float) -> torch.Tensor: | |
| """ | |
| Calculate GAE with only a terminal reward: r_t = 0 for t < T-1 and r_{T-1} = final_reward | |
| Args: | |
| value_traj: value prediction of the model | |
| final_reward: game result | |
| Returns: | |
| advantage, torch.Tensor of shape same with value_traj | |
| """ | |
| vals = value_traj.detach().cpu().float() | |
| T = vals.shape[0] if vals.dim() > 0 else 1 | |
| adv = torch.zeros(T) | |
| next_value = 0.0 | |
| gae = 0.0 | |
| for t in reversed(range(T)): | |
| reward = final_reward if t == T-1 else 0.0 | |
| delta = reward + self.gamma * next_value - vals[t] | |
| gae = delta + self.gamma * self.gae_lambda * gae | |
| adv[t] = gae | |
| next_value = vals[t] | |
| return adv | |
| def sample(self) -> Iterator[Tuple[List[str], # fen | |
| torch.Tensor,# rep | |
| torch.Tensor,# act | |
| torch.Tensor,# logp | |
| torch.Tensor,# val | |
| torch.Tensor,# inv_m | |
| torch.Tensor]]: # adv | |
| """Yield minibatches of size batch_size for training""" | |
| n = len(self.fens) | |
| if n < self.batch_size: | |
| return | |
| idxs = np.arange(n) | |
| if self.shuffle: | |
| np.random.shuffle(idxs) | |
| for start in range(0, n, self.batch_size): | |
| batch_idx = idxs[start:start+self.batch_size] | |
| if len(batch_idx) < self.batch_size: | |
| break | |
| fens_b = [self.fens[i] for i in batch_idx] | |
| reps_b = torch.stack([ | |
| self.repetition_counts[i].detach().clone() if torch.is_tensor(self.repetition_counts[i]) | |
| else torch.tensor(self.repetition_counts[i]) | |
| for i in batch_idx | |
| ]) | |
| acts_b = torch.stack([ | |
| self.actions[i].detach().clone() if torch.is_tensor(self.actions[i]) | |
| else torch.tensor(self.actions[i]) | |
| for i in batch_idx | |
| ]) | |
| logps_b = torch.stack([ | |
| self.log_probs[i].detach().clone() if torch.is_tensor(self.log_probs[i]) | |
| else torch.tensor(self.log_probs[i]) | |
| for i in batch_idx | |
| ]) | |
| vals_b = torch.stack([ | |
| self.values[i].detach().clone() if torch.is_tensor(self.values[i]) | |
| else torch.tensor(self.values[i]) | |
| for i in batch_idx | |
| ]) | |
| advs_b = torch.stack([ | |
| self.advantages[i].detach().clone() if torch.is_tensor(self.advantages[i]) | |
| else torch.tensor(self.advantages[i]) | |
| for i in batch_idx | |
| ]) | |
| invs_b = torch.stack([ | |
| self.invalid_masks[i] if torch.is_tensor(self.invalid_masks[i]) | |
| else torch.tensor(self.invalid_masks[i]) | |
| for i in batch_idx | |
| ]) | |
| yield fens_b, reps_b, acts_b, logps_b, vals_b, invs_b, advs_b | |
| def __len__(self) -> int: | |
| return len(self.fens) | |
| def clear(self) -> None: | |
| self.fens.clear() | |
| self.repetition_counts.clear() | |
| self.actions.clear() | |
| self.log_probs.clear() | |
| self.values.clear() | |
| self.invalid_masks.clear() | |
| self.advantages.clear() | |