kaupane's picture
Upload 10 files
e8a0fd8 verified
import torch
from collections import deque
import numpy as np
from typing import List, Iterator, Tuple, Optional
import chess
class Game:
"""
Represents a single chess game trajectory with all relevant data for RL training.
Acts as a *temporary* buffer inside loop
Handles:
- Storing trajectory data (fens, reps, actions, log_probs, values, invalid_masks)
- Tracking game status (active/complete)
"""
def __init__(self):
self.active = True
self.valid = True
self.completion_reason = None
self.game_result = None
self.fens = []
self.repetition_counts = []
self.actions = []
self.values = []
self.log_probs = []
self.invalid_masks = []
def update_trajectory(self, fen, rep, act, val, logp, inv_m):
self.fens.append(fen)
self.repetition_counts.append(rep)
self.actions.append(act)
self.values.append(val)
self.log_probs.append(logp)
self.invalid_masks.append(inv_m)
def update_game_status(self, done, reason):
if done == True:
self.active = False
if reason in ["1-0","0-1","1/2-1/2"]:
self.completion_reason = reason
self.game_result = reason
else:
self.completion_reason = reason
self.game_result = None
self.valid = False
def get_white_trajectory(self):
"""Extract the trajectory for white"""
indices = []
for i in range(len(self.fens) - 1):
board = chess.Board(self.fens[i])
if board.turn: # True if white to move
indices.append(i)
return {
'fens': [self.fens[i] for i in indices],
'repetition_counts': [self.repetition_counts[i] for i in indices],
'actions': [self.actions[i] for i in indices],
'values': [self.values[i] for i in indices],
'log_probs': [self.log_probs[i] for i in indices],
'invalid_masks': [self.invalid_masks[i] for i in indices]
}
def get_black_trajectory(self):
"""Extract the trajectory for black pieces."""
indices = []
for i in range(len(self.fens) - 1):
board = chess.Board(self.fens[i])
if not board.turn: # False if black to move
indices.append(i)
return {
'fens': [self.fens[i] for i in indices],
'repetition_counts': [self.repetition_counts[i] for i in indices],
'actions': [self.actions[i] for i in indices],
'values': [self.values[i] for i in indices],
'log_probs': [self.log_probs[i] for i in indices],
'invalid_masks': [self.invalid_masks[i] for i in indices]
}
class ReplayBuffer:
"""
The buffer class for PPO reinforcement learning.
Handles:
- store samples including:
1. fens
2. reps
3. actions
4. log_probs
5. values
6. invalid_masks
- calculate advantage based on reward and value (7. advantage)
- output samples in batches
Since PPO is largely on-policy, so the replay buffer will not be so large that deque is not appropriate
"""
def __init__(self,
capacity: int,
batch_size: int,
gamma: float,
gae_lambda: float,
shuffle: bool=True
):
self.gamma = gamma
self.gae_lambda = gae_lambda
self.fens = deque(maxlen=capacity)
self.repetition_counts = deque(maxlen=capacity)
self.actions = deque(maxlen=capacity)
self.log_probs = deque(maxlen=capacity)
self.values = deque(maxlen=capacity)
self.invalid_masks = deque(maxlen=capacity)
self.advantages = deque(maxlen=capacity)
self.batch_size = batch_size
self.shuffle = shuffle
def push_game(self, game: Game):
"""
Process a completed game and add its trajectories to the buffer.
Handles reward computation for both white and black players.
"""
if not game.valid:
return
result = game.game_result
if result not in ["1-0","0-1","1/2-1/2"]:
raise ValueError(f"Result not recognized: {result}. Either an incompleted game was passed in, or game.update_game_status() method is wrong.")
if result == "1-0": result = 1
elif result == "0-1": result = -1
elif result == "1/2-1/2": result = 0
white_traj = game.get_white_trajectory()
if white_traj["fens"]:
self._process_trajectory(
white_traj["fens"],
white_traj["repetition_counts"],
white_traj["actions"],
white_traj["log_probs"],
white_traj["values"],
white_traj["invalid_masks"],
result
)
black_traj = game.get_black_trajectory()
if black_traj["fens"]:
self._process_trajectory(
black_traj["fens"],
black_traj["repetition_counts"],
black_traj["actions"],
black_traj["log_probs"],
black_traj["values"],
black_traj["invalid_masks"],
-result # flip reward for black's perspective
)
def _process_trajectory(self, fens, reps, actions, log_probs, values, invalid_masks, final_reward):
"""Process a trajectory for one player, compute advantages and add to buffer"""
values_tensor = torch.tensor(values) if not torch.is_tensor(values) else values
advantages = self._compute_advantage(values_tensor, final_reward)
for i in range(len(fens)):
self.fens.append(fens[i])
self.repetition_counts.append(reps[i])
self.actions.append(actions[i])
self.log_probs.append(log_probs[i])
self.values.append(values[i])
self.invalid_masks.append(invalid_masks[i])
self.advantages.append(advantages[i])
def _compute_advantage(self, value_traj: torch.Tensor, final_reward: float) -> torch.Tensor:
"""
Calculate GAE with only a terminal reward: r_t = 0 for t < T-1 and r_{T-1} = final_reward
Args:
value_traj: value prediction of the model
final_reward: game result
Returns:
advantage, torch.Tensor of shape same with value_traj
"""
vals = value_traj.detach().cpu().float()
T = vals.shape[0] if vals.dim() > 0 else 1
adv = torch.zeros(T)
next_value = 0.0
gae = 0.0
for t in reversed(range(T)):
reward = final_reward if t == T-1 else 0.0
delta = reward + self.gamma * next_value - vals[t]
gae = delta + self.gamma * self.gae_lambda * gae
adv[t] = gae
next_value = vals[t]
return adv
def sample(self) -> Iterator[Tuple[List[str], # fen
torch.Tensor,# rep
torch.Tensor,# act
torch.Tensor,# logp
torch.Tensor,# val
torch.Tensor,# inv_m
torch.Tensor]]: # adv
"""Yield minibatches of size batch_size for training"""
n = len(self.fens)
if n < self.batch_size:
return
idxs = np.arange(n)
if self.shuffle:
np.random.shuffle(idxs)
for start in range(0, n, self.batch_size):
batch_idx = idxs[start:start+self.batch_size]
if len(batch_idx) < self.batch_size:
break
fens_b = [self.fens[i] for i in batch_idx]
reps_b = torch.stack([
self.repetition_counts[i].detach().clone() if torch.is_tensor(self.repetition_counts[i])
else torch.tensor(self.repetition_counts[i])
for i in batch_idx
])
acts_b = torch.stack([
self.actions[i].detach().clone() if torch.is_tensor(self.actions[i])
else torch.tensor(self.actions[i])
for i in batch_idx
])
logps_b = torch.stack([
self.log_probs[i].detach().clone() if torch.is_tensor(self.log_probs[i])
else torch.tensor(self.log_probs[i])
for i in batch_idx
])
vals_b = torch.stack([
self.values[i].detach().clone() if torch.is_tensor(self.values[i])
else torch.tensor(self.values[i])
for i in batch_idx
])
advs_b = torch.stack([
self.advantages[i].detach().clone() if torch.is_tensor(self.advantages[i])
else torch.tensor(self.advantages[i])
for i in batch_idx
])
invs_b = torch.stack([
self.invalid_masks[i] if torch.is_tensor(self.invalid_masks[i])
else torch.tensor(self.invalid_masks[i])
for i in batch_idx
])
yield fens_b, reps_b, acts_b, logps_b, vals_b, invs_b, advs_b
def __len__(self) -> int:
return len(self.fens)
def clear(self) -> None:
self.fens.clear()
self.repetition_counts.clear()
self.actions.clear()
self.log_probs.clear()
self.values.clear()
self.invalid_masks.clear()
self.advantages.clear()