| | import os |
| | os.system("mv NeoChess/san_moves.txt /usr/local/python/3.12.1/lib/python3.12/site-packages/torchrl/envs/custom/") |
| | import torchrl |
| | import torch |
| | import chess |
| | import chess.engine |
| | import gymnasium |
| | import numpy as np |
| | import tensordict |
| | from collections import defaultdict |
| | from tensordict.nn import TensorDictModule |
| | from tensordict.nn.distributions import NormalParamExtractor |
| | from torch import nn |
| | from torchrl.collectors import SyncDataCollector |
| | from torchrl.data.replay_buffers import ReplayBuffer |
| | from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement |
| | from torchrl.data.replay_buffers.storages import LazyTensorStorage |
| | import torch.nn.functional as F |
| | from torch.distributions import Categorical |
| | from torchrl.envs import ( |
| | Compose, |
| | DoubleToFloat, |
| | ObservationNorm, |
| | StepCounter, |
| | TransformedEnv, |
| | ) |
| | from torchrl.envs.libs.gym import GymEnv |
| | from torchrl.envs.utils import check_env_specs, ExplorationType, set_exploration_type |
| | from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator, MaskedCategorical, ActorCriticWrapper |
| | from torchrl.objectives import ClipPPOLoss |
| | from torchrl.objectives.value import GAE |
| | from tqdm import tqdm |
| | from torchrl.envs.custom.chess import ChessEnv |
| | from torchrl.envs.libs.gym import set_gym_backend, GymWrapper |
| | from torchrl.envs import GymEnv |
| | from tensordict import TensorDict |
| |
|
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | def board_to_tensor(board): |
| | piece_encoding = { |
| | 'P': 1, 'N': 2, 'B': 3, 'R': 4, 'Q': 5, 'K': 6, |
| | 'p': 7, 'n': 8, 'b': 9, 'r': 10, 'q': 11, 'k': 12 |
| | } |
| |
|
| | tensor = torch.zeros(64, dtype=torch.long) |
| | for square in chess.SQUARES: |
| | piece = board.piece_at(square) |
| | if piece: |
| | tensor[square] = piece_encoding[piece.symbol()] |
| | else: |
| | tensor[square] = 0 |
| |
|
| | return tensor.unsqueeze(0) |
| |
|
| | class Policy(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | self.embedding = nn.Embedding(13, 32) |
| | self.attention = nn.MultiheadAttention(embed_dim=32, num_heads=16) |
| | self.neu = 256 |
| | self.neurons = nn.Sequential( |
| | nn.Linear(64*32, self.neu), |
| | nn.ReLU(), |
| | nn.Linear(self.neu, self.neu), |
| | nn.ReLU(), |
| | nn.Linear(self.neu, self.neu), |
| | nn.ReLU(), |
| | nn.Linear(self.neu, self.neu), |
| | nn.ReLU(), |
| | nn.Linear(self.neu, 128), |
| | nn.ReLU(), |
| | nn.Linear(128, 29275), |
| | ) |
| |
|
| | def forward(self, x): |
| | x = chess.Board(x) |
| | color = x.turn |
| | x = board_to_tensor(x) |
| | x = self.embedding(x) |
| | x = x.permute(1, 0, 2) |
| | attn_output, _ = self.attention(x, x, x) |
| | x = attn_output.permute(1, 0, 2).contiguous() |
| | x = x.view(x.size(0), -1) |
| | x = self.neurons(x) * color |
| | return x |
| |
|
| | class Value(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | self.embedding = nn.Embedding(13, 64) |
| | self.attention = nn.MultiheadAttention(embed_dim=64, num_heads=16) |
| | self.neu = 512 |
| | self.neurons = nn.Sequential( |
| | nn.Linear(64*64, self.neu), |
| | nn.ReLU(), |
| | nn.Linear(self.neu, self.neu), |
| | nn.ReLU(), |
| | nn.Linear(self.neu, self.neu), |
| | nn.ReLU(), |
| | nn.Linear(self.neu, self.neu), |
| | nn.ReLU(), |
| | nn.Linear(self.neu, self.neu), |
| | nn.ReLU(), |
| | nn.Linear(self.neu, self.neu), |
| | nn.ReLU(), |
| | nn.Linear(self.neu, self.neu), |
| | nn.ReLU(), |
| | nn.Linear(self.neu, self.neu), |
| | nn.ReLU(), |
| | nn.Linear(self.neu, self.neu), |
| | nn.ReLU(), |
| | nn.Linear(self.neu, self.neu), |
| | nn.ReLU(), |
| | nn.Linear(self.neu, self.neu), |
| | nn.ReLU(), |
| | nn.Linear(self.neu, self.neu), |
| | nn.ReLU(), |
| | nn.Linear(self.neu, self.neu), |
| | nn.ReLU(), |
| | nn.Linear(self.neu, 64), |
| | nn.ReLU(), |
| | nn.Linear(64, 4) |
| | ) |
| |
|
| | def forward(self, x): |
| | x = chess.Board(x) |
| | color = x.turn |
| | x = board_to_tensor(x) |
| | x = self.embedding(x) |
| | x = x.permute(1, 0, 2) |
| | attn_output, _ = self.attention(x, x, x) |
| | x = attn_output.permute(1, 0, 2).contiguous() |
| | x = x.view(x.size(0), -1) |
| | x = self.neurons(x) |
| | x = x[0][0]/10 |
| | if color == chess.WHITE: |
| | x = -x |
| | return x |
| |
|
| | with set_gym_backend("gymnasium"): |
| | env = ChessEnv( |
| | stateful=True, |
| | include_fen=True, |
| | include_san=False, |
| | ) |
| |
|
| | policy = Policy().to(device) |
| | value = Value().to(device) |
| | valweight = torch.load("NeoChess/chessy_modelt-1.pth",map_location=device,weights_only=False) |
| | value.load_state_dict(valweight) |
| | polweight = torch.load("NeoChess/chessy_policy.pth",map_location=device,weights_only=False) |
| | policy.load_state_dict(polweight) |
| |
|
| | def sample_masked_action(logits, mask): |
| | masked_logits = logits.clone() |
| | masked_logits[~mask] = float('-inf') |
| | probs = F.softmax(masked_logits, dim=-1) |
| | dist = Categorical(probs=probs) |
| | action = dist.sample() |
| | log_prob = dist.log_prob(action) |
| | return action, log_prob |
| |
|
| | class FENPolicyWrapper(nn.Module): |
| | def __init__(self, policy_net): |
| | super().__init__() |
| | self.policy_net = policy_net |
| |
|
| | def forward(self, fens, action_mask=None) -> torch.tensor: |
| | if isinstance(fens, (TensorDict, dict)): |
| | fens = fens["fen"] |
| |
|
| | |
| | if isinstance(fens, str): |
| | fens = [fens] |
| |
|
| | |
| | while isinstance(fens[0], list): |
| | fens = fens[0] |
| |
|
| | |
| | if action_mask is not None: |
| | if isinstance(action_mask, torch.Tensor): |
| | action_mask = action_mask.unsqueeze(0) if action_mask.ndim == 1 else action_mask |
| | if not isinstance(action_mask, list): |
| | action_mask = [action_mask[i] for i in range(len(fens))] |
| |
|
| | logits_list = [] |
| |
|
| | for i, fen in enumerate(fens): |
| | logits = self.policy_net(fen) |
| |
|
| | |
| | if action_mask is not None: |
| | mask = action_mask[i].bool() |
| | logits = logits.masked_fill(~mask, float("-inf")) |
| |
|
| | logits_list.append(logits) |
| |
|
| | return torch.stack(logits_list).squeeze(-2).squeeze(-2) |
| |
|
| | class FENValueWrapper(nn.Module): |
| | def __init__(self, value_net): |
| | super().__init__() |
| | self.value_net = value_net |
| |
|
| | def forward(self, fens) -> torch.tensor: |
| | if isinstance(fens, TensorDict) or isinstance(fens,dict): |
| | fens = fens["fen"] |
| | if isinstance(fens, str): |
| | fens = [fens] |
| | while isinstance(fens[0], list): |
| | fens = fens[0] |
| | state_value = [] |
| | for fen in fens: |
| | state_value += [self.value_net(fen)] |
| | state_value = torch.stack(state_value) |
| | |
| | if state_value.ndim == 0: |
| | state_value = state_value.unsqueeze(0) |
| | return state_value |
| |
|
| | ACTION_DIM = 64 * 73 |
| |
|
| | from functools import partial |
| | |
| | policy_module = TensorDictModule( |
| | FENPolicyWrapper(policy), |
| | in_keys=["fen"], |
| | out_keys=["logits"] |
| | ) |
| | value_module = TensorDictModule( |
| | FENValueWrapper(value), |
| | in_keys=["fen"], |
| | out_keys=["state_value"] |
| | ) |
| |
|
| | def masked_categorical_factory(logits, action_mask): |
| | return MaskedCategorical(logits=logits, mask=action_mask) |
| |
|
| | actor = ProbabilisticActor( |
| | module=policy_module, |
| | in_keys=["logits", "action_mask"], |
| | out_keys=["action"], |
| | distribution_class=masked_categorical_factory, |
| | return_log_prob=True, |
| | ) |
| | |
| | obs = env.reset() |
| | print(obs) |
| | print(policy_module(obs)) |
| | print(value_module(obs)) |
| | print(actor(obs)) |
| |
|
| | rollout = env.rollout(3) |
| |
|
| | from torchrl.record.loggers import generate_exp_name, get_logger |
| | def train_ppo_chess(chess_env, num_iterations=1, frames_per_batch=1000, |
| | num_epochs=100, lr=3e-4, gamma=0.99, lmbda=0.95, |
| | clip_epsilon=0.2, device="cpu"): |
| | global actor_module, value_module, loss_module |
| | """ |
| | Main PPO training loop for Chess |
| | |
| | Args: |
| | chess_env: Your ChessEnv instance |
| | num_iterations: Number of training iterations |
| | frames_per_batch: Number of environment steps per batch |
| | num_epochs: Number of PPO epochs per iteration |
| | lr: Learning rate |
| | gamma: Discount factor |
| | lmbda: GAE lambda parameter |
| | clip_epsilon: PPO clipping parameter |
| | device: Training device |
| | """ |
| |
|
| | |
| | env = chess_env |
| | |
| | actor_module = actor |
| |
|
| | collector = SyncDataCollector( |
| | env, |
| | actor_module, |
| | frames_per_batch=frames_per_batch, |
| | total_frames=-1, |
| | device=device, |
| | ) |
| |
|
| | |
| | replay_buffer = ReplayBuffer( |
| | storage=LazyTensorStorage(frames_per_batch), |
| | sampler=SamplerWithoutReplacement(), |
| | batch_size=256, |
| | ) |
| |
|
| | |
| | loss_module = ClipPPOLoss( |
| | actor_network=actor_module, |
| | critic_network=value_module, |
| | clip_epsilon=clip_epsilon, |
| | entropy_bonus=True, |
| | entropy_coef=0.01, |
| | critic_coef=1.0, |
| | normalize_advantage=True, |
| | ) |
| |
|
| | optim = torch.optim.Adam(loss_module.parameters(), lr=lr) |
| |
|
| | |
| | logger = get_logger("tensorboard", logger_name="ppo_chess", experiment_name=generate_exp_name("PPO", "Chess")) |
| |
|
| | |
| | collected_frames = 0 |
| |
|
| | for iteration in range(num_iterations): |
| | print(f"\n=== Iteration {iteration + 1}/{num_iterations} ===") |
| |
|
| | |
| | batch_data = [] |
| | for i, batch in enumerate(collector): |
| | batch_data.append(batch) |
| | collected_frames += batch.numel() |
| |
|
| | |
| | if len(batch_data) * collector.frames_per_batch >= frames_per_batch: |
| | break |
| |
|
| | |
| | if batch_data: |
| | full_batch = torch.cat(batch_data, dim=0) |
| |
|
| | |
| | with torch.no_grad(): |
| | full_batch = loss_module.value_estimator(full_batch) |
| |
|
| | replay_buffer.extend(full_batch) |
| |
|
| | |
| | total_loss = 0 |
| | total_actor_loss = 0 |
| | total_critic_loss = 0 |
| | total_entropy_loss = 0 |
| |
|
| | for epoch in range(num_epochs): |
| | epoch_loss = 0 |
| | epoch_actor_loss = 0 |
| | epoch_critic_loss = 0 |
| | epoch_entropy_loss = 0 |
| | num_batches = 0 |
| |
|
| | for batch in replay_buffer: |
| | print(batch) |
| | |
| | if "state_value" in batch and batch["state_value"].dim() > 1: |
| | batch["state_value"] = batch["state_value"].squeeze(-1) |
| |
|
| | batch["value_target"] = batch["value_target"].squeeze(1) |
| | |
| | loss_dict = loss_module(batch) |
| | loss = loss_dict["loss_objective"] + loss_dict["loss_critic"] + loss_dict["loss_entropy"] |
| |
|
| | |
| | optim.zero_grad() |
| | loss.backward() |
| | torch.nn.utils.clip_grad_norm_(loss_module.parameters(), max_norm=0.5) |
| | optim.step() |
| |
|
| | |
| | epoch_loss += loss.item() |
| | epoch_actor_loss += loss_dict["loss_objective"].item() |
| | epoch_critic_loss += loss_dict["loss_critic"].item() |
| | epoch_entropy_loss += loss_dict["loss_entropy"].item() |
| | num_batches += 1 |
| |
|
| | |
| | if num_batches > 0: |
| | total_loss += epoch_loss / num_batches |
| | total_actor_loss += epoch_actor_loss / num_batches |
| | total_critic_loss += epoch_critic_loss / num_batches |
| | total_entropy_loss += epoch_entropy_loss / num_batches |
| |
|
| | |
| | avg_total_loss = total_loss / num_epochs |
| | avg_actor_loss = total_actor_loss / num_epochs |
| | avg_critic_loss = total_critic_loss / num_epochs |
| | avg_entropy_loss = total_entropy_loss / num_epochs |
| |
|
| | |
| | metrics = { |
| | "train/total_loss": avg_total_loss, |
| | "train/actor_loss": avg_actor_loss, |
| | "train/critic_loss": avg_critic_loss, |
| | "train/entropy_loss": avg_entropy_loss, |
| | "train/collected_frames": collected_frames, |
| | } |
| |
|
| | |
| | if "reward" in batch.keys(): |
| | avg_reward = batch["reward"].mean().item() |
| | metrics["train/avg_reward"] = avg_reward |
| | print(f"Average Reward: {avg_reward:.3f}") |
| |
|
| | for key, value in metrics.items(): |
| | logger.log_scalar(key, value, step=iteration) |
| |
|
| | print(f"Total Loss: {avg_total_loss:.4f}") |
| | print(f"Actor Loss: {avg_actor_loss:.4f}") |
| | print(f"Critic Loss: {avg_critic_loss:.4f}") |
| | print(f"Entropy Loss: {avg_entropy_loss:.4f}") |
| | print(f"Collected Frames: {collected_frames}") |
| |
|
| | |
| | replay_buffer.empty() |
| |
|
| | print("\nTraining completed!") |
| |
|
| | train_ppo_chess(env) |
| | torch.save(value.state_dict(),"NeoChess/chessy_model.pth") |
| | torch.save(policy.state_dict(),"NeoChess/chessy_policy.pth") |