Upload 6 files
Browse files- 64L1024D_1e-3maxlr_470k_step_1ep_1480ELO.pth +3 -0
- autoplay_muliproc.py +151 -0
- chesstransformer.py +251 -0
- environment.yml +66 -0
- play.py +72 -0
- tokenizer.py +163 -0
64L1024D_1e-3maxlr_470k_step_1ep_1480ELO.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:82fb0554f04255f854344432380ba0719af4e14c631ff8a0c9905a8e99cfbaf2
|
| 3 |
+
size 9746197380
|
autoplay_muliproc.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import chess
|
| 3 |
+
import chess.engine
|
| 4 |
+
import logging
|
| 5 |
+
import math
|
| 6 |
+
import argparse
|
| 7 |
+
import multiprocessing as mp
|
| 8 |
+
from chesstransformer import ChessTransformer
|
| 9 |
+
import tokenizer as tk
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
# Set up logging
|
| 13 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(processName)s - %(levelname)s - %(message)s')
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
parser = argparse.ArgumentParser(description='Chess Transformer Testing')
|
| 17 |
+
parser.add_argument('--cores', type=int, default=2, help='Cores to use for CPU chess engine')
|
| 18 |
+
parser.add_argument('--games', type=int, default=10, help='Number of games to play')
|
| 19 |
+
parser.add_argument('--stockfish_elo', type=int, default=1320, help='ELO rating for Stockfish. Min 1320')
|
| 20 |
+
parser.add_argument('--stockfish_path', type=str, default='./stockfish/stockfish-ubuntu-x86-64', help='Path to Stockfish binary')
|
| 21 |
+
|
| 22 |
+
args = parser.parse_args()
|
| 23 |
+
|
| 24 |
+
def setup_model():
|
| 25 |
+
logger.info("Loading ChessTransformer model...")
|
| 26 |
+
model = ChessTransformer()
|
| 27 |
+
model.load_state_dict(torch.load('./64L1024D_1e-3maxlr_470k_step_1ep_1480ELO.pth')["model_state_dict"])
|
| 28 |
+
model.eval().cuda()
|
| 29 |
+
logger.info("Model loaded successfully.")
|
| 30 |
+
return model
|
| 31 |
+
|
| 32 |
+
def predict_top_k_moves(model, tokenizer, game_sequence, k=100, device='cuda'):
|
| 33 |
+
game_sequence = torch.tensor([tokenizer.tokenize_game(game_sequence)], dtype=torch.long).to(device)
|
| 34 |
+
|
| 35 |
+
with torch.no_grad():
|
| 36 |
+
output = model(game_sequence)
|
| 37 |
+
next_move = output[0, -1, :]
|
| 38 |
+
next_softmax = torch.nn.functional.softmax(next_move, dim=-1)
|
| 39 |
+
top_k_probs, top_k_indices = torch.topk(next_softmax, k)
|
| 40 |
+
top_k_moves = [tokenizer.get_move(idx.item()) for idx in top_k_indices]
|
| 41 |
+
|
| 42 |
+
return list(zip(top_k_moves, top_k_probs.tolist()))
|
| 43 |
+
|
| 44 |
+
def get_legal_move(board, moves):
|
| 45 |
+
for move, prob in moves:
|
| 46 |
+
try:
|
| 47 |
+
if chess.Move.from_uci(move) in board.legal_moves:
|
| 48 |
+
return move, prob
|
| 49 |
+
except ValueError:
|
| 50 |
+
continue
|
| 51 |
+
return None, None
|
| 52 |
+
|
| 53 |
+
def play_game(model, tokenizer, stockfish_path, stockfish_elo, model_is_white, game_number):
|
| 54 |
+
#logger.info(f"Game {game_number}: Starting. Model playing as {'white' if model_is_white else 'black'}")
|
| 55 |
+
engine = chess.engine.SimpleEngine.popen_uci(stockfish_path)
|
| 56 |
+
engine.configure({"UCI_LimitStrength": True, "UCI_Elo": stockfish_elo})
|
| 57 |
+
|
| 58 |
+
board = chess.Board()
|
| 59 |
+
game_sequence = ['start']
|
| 60 |
+
move_count = 0
|
| 61 |
+
|
| 62 |
+
while not board.is_game_over():
|
| 63 |
+
move_count += 1
|
| 64 |
+
if (board.turn == chess.WHITE) == model_is_white:
|
| 65 |
+
top_k_moves = predict_top_k_moves(model, tokenizer, game_sequence)
|
| 66 |
+
legal_move, prob = get_legal_move(board, top_k_moves)
|
| 67 |
+
if legal_move is None:
|
| 68 |
+
logger.warning(f"Game {game_number}: No legal moves found in top-k on move {move_count}. Game over.")
|
| 69 |
+
return "0-1" if model_is_white else "1-0", move_count
|
| 70 |
+
board.push_uci(legal_move)
|
| 71 |
+
game_sequence.append(legal_move)
|
| 72 |
+
logger.debug(f"Game {game_number}: Model's move: {legal_move} (probability: {prob:.4f})")
|
| 73 |
+
else:
|
| 74 |
+
result = engine.play(board, chess.engine.Limit(time=0.1))
|
| 75 |
+
board.push(result.move)
|
| 76 |
+
game_sequence.append(result.move.uci())
|
| 77 |
+
logger.debug(f"Game {game_number}: Stockfish's move: {result.move.uci()}")
|
| 78 |
+
|
| 79 |
+
engine.quit()
|
| 80 |
+
result = board.result()
|
| 81 |
+
#logger.info(f"Game {game_number}: Finished. Result: {result}. Total moves: {move_count}")
|
| 82 |
+
return result, move_count
|
| 83 |
+
|
| 84 |
+
def worker(args):
|
| 85 |
+
model, tokenizer, stockfish_path, stockfish_elo, game_number = args
|
| 86 |
+
model_is_white = game_number % 2 == 0
|
| 87 |
+
result, move_count = play_game(model, tokenizer, stockfish_path, stockfish_elo, model_is_white, game_number)
|
| 88 |
+
return result, game_number, move_count
|
| 89 |
+
|
| 90 |
+
def calculate_elo_from_win_rate(win_rate, opponent_elo):
|
| 91 |
+
"""Calculate ELO based on win rate against an opponent."""
|
| 92 |
+
if win_rate == 0:
|
| 93 |
+
return float('-inf')
|
| 94 |
+
if win_rate == 1:
|
| 95 |
+
return float('inf')
|
| 96 |
+
elo_diff = -400 * math.log10(1 / win_rate - 1)
|
| 97 |
+
return opponent_elo + elo_diff
|
| 98 |
+
|
| 99 |
+
def main():
|
| 100 |
+
mp.set_start_method('spawn') # Set start method to 'spawn' for CUDA support
|
| 101 |
+
|
| 102 |
+
num_games = args.games
|
| 103 |
+
stockfish_elo = args.stockfish_elo
|
| 104 |
+
stockfish_path = args.stockfish_path
|
| 105 |
+
|
| 106 |
+
logger.info(f"Starting tournament: {num_games} games, Stockfish ELO: {stockfish_elo}")
|
| 107 |
+
|
| 108 |
+
model = setup_model()
|
| 109 |
+
tokenizer = tk.Tokenizer()
|
| 110 |
+
|
| 111 |
+
num_processes = args.cores
|
| 112 |
+
logger.info(f"Using {num_processes} CPU cores for parallel processing")
|
| 113 |
+
|
| 114 |
+
tasks = [(model, tokenizer, stockfish_path, stockfish_elo, i) for i in range(num_games)]
|
| 115 |
+
|
| 116 |
+
results = []
|
| 117 |
+
with mp.Pool(processes=num_processes) as pool:
|
| 118 |
+
with tqdm(total=num_games, desc="Games Progress") as pbar:
|
| 119 |
+
for result in pool.imap_unordered(worker, tasks):
|
| 120 |
+
results.append(result)
|
| 121 |
+
pbar.update()
|
| 122 |
+
|
| 123 |
+
# Process results
|
| 124 |
+
wins = draws = losses = 0
|
| 125 |
+
total_moves = 0
|
| 126 |
+
for result, game_number, move_count in results:
|
| 127 |
+
if result == "1-0" and game_number % 2 == 0:
|
| 128 |
+
wins += 1
|
| 129 |
+
elif result == "0-1" and game_number % 2 == 1:
|
| 130 |
+
wins += 1
|
| 131 |
+
elif result == "1/2-1/2":
|
| 132 |
+
draws += 1
|
| 133 |
+
else:
|
| 134 |
+
losses += 1
|
| 135 |
+
total_moves += move_count
|
| 136 |
+
|
| 137 |
+
win_rate = (wins + 0.5 * draws) / num_games
|
| 138 |
+
final_model_elo = calculate_elo_from_win_rate(win_rate, stockfish_elo)
|
| 139 |
+
elo_change = final_model_elo - stockfish_elo
|
| 140 |
+
|
| 141 |
+
logger.info("Tournament completed. Final results:")
|
| 142 |
+
logger.info(f"Total games: {num_games}")
|
| 143 |
+
logger.info(f"Wins: {wins}, Losses: {losses}, Draws: {draws}")
|
| 144 |
+
logger.info(f"Win rate: {win_rate:.2%}")
|
| 145 |
+
logger.info(f"Average moves per game: {total_moves/num_games:.2f}")
|
| 146 |
+
logger.info(f"Stockfish ELO: {stockfish_elo}")
|
| 147 |
+
logger.info(f"Final Model ELO: {final_model_elo:.2f}")
|
| 148 |
+
logger.info(f"ELO Change: {elo_change:+.2f}")
|
| 149 |
+
|
| 150 |
+
if __name__ == "__main__":
|
| 151 |
+
main()
|
chesstransformer.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
class PositionalEncoding(nn.Module):
|
| 7 |
+
def __init__(self, d_model, max_len=5000):
|
| 8 |
+
super(PositionalEncoding, self).__init__()
|
| 9 |
+
pe = torch.zeros(max_len, d_model)
|
| 10 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
| 11 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
| 12 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 13 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 14 |
+
pe = pe.unsqueeze(0).transpose(0, 1)
|
| 15 |
+
self.register_buffer('pe', pe)
|
| 16 |
+
|
| 17 |
+
def forward(self, x):
|
| 18 |
+
x = x + self.pe[:x.size(0), :]
|
| 19 |
+
return x
|
| 20 |
+
|
| 21 |
+
class StochasticDepth(nn.Module):
|
| 22 |
+
def __init__(self, p=0.8):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.p = p
|
| 25 |
+
|
| 26 |
+
def forward(self, x, residual):
|
| 27 |
+
if self.training:
|
| 28 |
+
if torch.rand(1).item() < self.p:
|
| 29 |
+
return x + residual
|
| 30 |
+
else:
|
| 31 |
+
return x
|
| 32 |
+
else:
|
| 33 |
+
return x + self.p * residual
|
| 34 |
+
|
| 35 |
+
class AdvancedTransformerLayer(nn.Module):
|
| 36 |
+
def __init__(self, d_model, nhead, dropout=0.1, stoch_depth_p=0.8):
|
| 37 |
+
super().__init__()
|
| 38 |
+
dim_feedforward = 4 * d_model
|
| 39 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
| 40 |
+
self.ff = nn.Sequential(
|
| 41 |
+
nn.Linear(d_model, dim_feedforward),
|
| 42 |
+
nn.ReLU(),
|
| 43 |
+
nn.Linear(dim_feedforward, d_model)
|
| 44 |
+
)
|
| 45 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 46 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 47 |
+
self.dropout = nn.Dropout(dropout)
|
| 48 |
+
self.stoch_depth = StochasticDepth(stoch_depth_p)
|
| 49 |
+
|
| 50 |
+
def forward(self, x, src_mask=None, src_key_padding_mask=None):
|
| 51 |
+
# x shape: (seq_len, batch_size, d_model)
|
| 52 |
+
norm_x = self.norm1(x)
|
| 53 |
+
|
| 54 |
+
# Convert boolean mask to float mask
|
| 55 |
+
if src_key_padding_mask is not None:
|
| 56 |
+
src_key_padding_mask = src_key_padding_mask.float().masked_fill(
|
| 57 |
+
src_key_padding_mask, float('-inf')).masked_fill(~src_key_padding_mask, float(0.0))
|
| 58 |
+
|
| 59 |
+
attn_output, _ = self.self_attn(norm_x, norm_x, norm_x,
|
| 60 |
+
attn_mask=src_mask,
|
| 61 |
+
key_padding_mask=src_key_padding_mask)
|
| 62 |
+
x = self.stoch_depth(x, self.dropout(attn_output))
|
| 63 |
+
|
| 64 |
+
norm_x = self.norm2(x)
|
| 65 |
+
ff_output = self.ff(norm_x)
|
| 66 |
+
x = self.stoch_depth(x, self.dropout(ff_output))
|
| 67 |
+
return x
|
| 68 |
+
|
| 69 |
+
class ChessTransformer(nn.Module):
|
| 70 |
+
def __init__(self, num_layers=64, d_model=1024, nhead=8, dropout=0.1, stoch_depth_p=0.9, num_tokens=2066, pad_token_id=2064):
|
| 71 |
+
super().__init__()
|
| 72 |
+
self.embedding = nn.Embedding(num_tokens, d_model)
|
| 73 |
+
self.pos_encoder = PositionalEncoding(d_model)
|
| 74 |
+
self.layers = nn.ModuleList([
|
| 75 |
+
AdvancedTransformerLayer(d_model, nhead, dropout, stoch_depth_p)
|
| 76 |
+
for _ in range(num_layers)
|
| 77 |
+
])
|
| 78 |
+
self.norm = nn.LayerNorm(d_model)
|
| 79 |
+
self.output = nn.Linear(d_model, num_tokens)
|
| 80 |
+
self.d_model = d_model
|
| 81 |
+
self.padding_idx = pad_token_id
|
| 82 |
+
|
| 83 |
+
def generate_square_subsequent_mask(self, sz):
|
| 84 |
+
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
| 85 |
+
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
| 86 |
+
return mask
|
| 87 |
+
|
| 88 |
+
def pad_sequences(self, sequences):
|
| 89 |
+
padding_value = self.padding_idx
|
| 90 |
+
max_len = max(len(seq) for seq in sequences)
|
| 91 |
+
padded_seqs = [seq + [padding_value] * (max_len - len(seq)) for seq in sequences]
|
| 92 |
+
return torch.LongTensor(padded_seqs)
|
| 93 |
+
|
| 94 |
+
def forward(self, x):
|
| 95 |
+
# x shape: (batch_size, seq_len)
|
| 96 |
+
batch_size, seq_len = x.size()
|
| 97 |
+
|
| 98 |
+
# Create padding mask
|
| 99 |
+
padding_mask = (x == self.padding_idx)
|
| 100 |
+
|
| 101 |
+
# Create causal mask
|
| 102 |
+
causal_mask = self.generate_square_subsequent_mask(seq_len).to(x.device)
|
| 103 |
+
|
| 104 |
+
# Embed and add positional encoding
|
| 105 |
+
x = self.embedding(x).transpose(0, 1) * math.sqrt(self.d_model)
|
| 106 |
+
x = self.pos_encoder(x)
|
| 107 |
+
|
| 108 |
+
# Pass through each layer
|
| 109 |
+
for layer in self.layers:
|
| 110 |
+
x = layer(x, src_mask=causal_mask, src_key_padding_mask=padding_mask)
|
| 111 |
+
|
| 112 |
+
x = self.norm(x)
|
| 113 |
+
output = self.output(x.transpose(0, 1))
|
| 114 |
+
|
| 115 |
+
return output
|
| 116 |
+
|
| 117 |
+
def winning_moves_loss(output, ground_truth, win_labels, pad_token_id=2064, start_token_id=2065):
|
| 118 |
+
"""
|
| 119 |
+
Compute the loss only for the winning moves of white and black.
|
| 120 |
+
"""
|
| 121 |
+
output = output.cuda()
|
| 122 |
+
ground_truth = ground_truth.cuda()
|
| 123 |
+
win_labels = win_labels.cuda()
|
| 124 |
+
|
| 125 |
+
batch_size, seq_len, num_tokens = output.shape
|
| 126 |
+
|
| 127 |
+
# Shift the ground truth to align with the output predictions
|
| 128 |
+
ground_truth_shifted = ground_truth[:, 1:].contiguous()
|
| 129 |
+
output_shifted = output[:, :-1, :].contiguous()
|
| 130 |
+
|
| 131 |
+
# Flatten the output and ground truth for easier masking
|
| 132 |
+
output_flat = output_shifted.view(-1, num_tokens)
|
| 133 |
+
ground_truth_flat = ground_truth_shifted.view(-1)
|
| 134 |
+
|
| 135 |
+
# Apply log softmax to the flattened output
|
| 136 |
+
output_log_softmax = F.log_softmax(output_flat, dim=-1)
|
| 137 |
+
|
| 138 |
+
# Repeat win_labels for each move in the sequence
|
| 139 |
+
win_labels_expanded = win_labels.unsqueeze(1).repeat(1, seq_len - 1).view(-1)
|
| 140 |
+
|
| 141 |
+
# Create a mask for the winning moves
|
| 142 |
+
move_indices = torch.arange(seq_len - 1, device=output.device).unsqueeze(0).repeat(batch_size, 1).view(-1)
|
| 143 |
+
white_win_mask = (win_labels_expanded == 1) & (move_indices % 2 == 0)
|
| 144 |
+
black_win_mask = (win_labels_expanded == 0) & (move_indices % 2 == 1)
|
| 145 |
+
|
| 146 |
+
# Combine the masks
|
| 147 |
+
selected_moves_mask = (white_win_mask | black_win_mask) & (ground_truth_flat != pad_token_id) & (ground_truth_flat != start_token_id)
|
| 148 |
+
|
| 149 |
+
# Calculate the negative log-likelihood loss only for the selected moves
|
| 150 |
+
loss = F.nll_loss(output_log_softmax, ground_truth_flat, reduction='none')
|
| 151 |
+
|
| 152 |
+
loss = loss * selected_moves_mask.float()
|
| 153 |
+
|
| 154 |
+
# Average the loss over the selected moves
|
| 155 |
+
selected_moves_count = selected_moves_mask.float().sum()
|
| 156 |
+
if selected_moves_count > 0:
|
| 157 |
+
loss = loss.sum() / selected_moves_count
|
| 158 |
+
else:
|
| 159 |
+
loss = loss.sum() # If no moves are selected, return 0 loss
|
| 160 |
+
|
| 161 |
+
return loss
|
| 162 |
+
|
| 163 |
+
def all_moves_loss(output, ground_truth, pad_token_id=2064, start_token_id=2065):
|
| 164 |
+
"""
|
| 165 |
+
Compute the loss for all valid moves in the sequence, excluding start and padding tokens.
|
| 166 |
+
"""
|
| 167 |
+
batch_size, seq_len, num_tokens = output.shape
|
| 168 |
+
|
| 169 |
+
output = output.cuda()
|
| 170 |
+
ground_truth = ground_truth.cuda()
|
| 171 |
+
|
| 172 |
+
# Shift the output and ground truth to align them
|
| 173 |
+
output_shifted = output[:, :-1, :].contiguous()
|
| 174 |
+
ground_truth_shifted = ground_truth[:, 1:].contiguous()
|
| 175 |
+
|
| 176 |
+
# Flatten the shifted output and ground truth
|
| 177 |
+
output_flat = output_shifted.view(-1, num_tokens)
|
| 178 |
+
ground_truth_flat = ground_truth_shifted.view(-1)
|
| 179 |
+
|
| 180 |
+
# Apply log softmax to the flattened output
|
| 181 |
+
output_log_softmax = F.log_softmax(output_flat, dim=-1)
|
| 182 |
+
|
| 183 |
+
# Create a mask for all valid moves (excluding padding and start tokens)
|
| 184 |
+
valid_moves_mask = ((ground_truth_flat != pad_token_id) &
|
| 185 |
+
(ground_truth_flat != start_token_id))
|
| 186 |
+
|
| 187 |
+
# Calculate the negative log-likelihood loss for all moves
|
| 188 |
+
loss = F.nll_loss(output_log_softmax, ground_truth_flat, reduction='none')
|
| 189 |
+
|
| 190 |
+
# Apply the mask to exclude padding and start tokens
|
| 191 |
+
loss = loss * valid_moves_mask.float()
|
| 192 |
+
|
| 193 |
+
# Average the loss over all valid moves
|
| 194 |
+
valid_moves_count = valid_moves_mask.float().sum()
|
| 195 |
+
if valid_moves_count > 0:
|
| 196 |
+
loss = loss.sum() / valid_moves_count
|
| 197 |
+
else:
|
| 198 |
+
loss = loss.sum() # If no valid moves, return 0 loss
|
| 199 |
+
|
| 200 |
+
return loss
|
| 201 |
+
|
| 202 |
+
def weighted_chess_loss(output, ground_truth, win_labels, winning_weight=1.0, losing_weight=0.1, pad_token_id=2064, start_token_id=2065):
|
| 203 |
+
"""
|
| 204 |
+
Compute a weighted loss for all moves, with higher weight for winning moves.
|
| 205 |
+
"""
|
| 206 |
+
output = output.cuda()
|
| 207 |
+
ground_truth = ground_truth.cuda()
|
| 208 |
+
win_labels = win_labels.cuda()
|
| 209 |
+
|
| 210 |
+
batch_size, seq_len, num_tokens = output.shape
|
| 211 |
+
|
| 212 |
+
# Shift the ground truth to align with the output predictions
|
| 213 |
+
ground_truth_shifted = ground_truth[:, 1:].contiguous()
|
| 214 |
+
output_shifted = output[:, :-1, :].contiguous()
|
| 215 |
+
|
| 216 |
+
# Flatten the output and ground truth for easier masking
|
| 217 |
+
output_flat = output_shifted.view(-1, num_tokens)
|
| 218 |
+
ground_truth_flat = ground_truth_shifted.view(-1)
|
| 219 |
+
|
| 220 |
+
# Apply log softmax to the flattened output
|
| 221 |
+
output_log_softmax = F.log_softmax(output_flat, dim=-1)
|
| 222 |
+
|
| 223 |
+
# Repeat win_labels for each move in the sequence
|
| 224 |
+
win_labels_expanded = win_labels.unsqueeze(1).repeat(1, seq_len - 1).view(-1)
|
| 225 |
+
|
| 226 |
+
# Create masks for winning and losing moves
|
| 227 |
+
move_indices = torch.arange(seq_len - 1, device=output.device).unsqueeze(0).repeat(batch_size, 1).view(-1)
|
| 228 |
+
white_win_mask = (win_labels_expanded == 1) & (move_indices % 2 == 0)
|
| 229 |
+
black_win_mask = (win_labels_expanded == 0) & (move_indices % 2 == 1)
|
| 230 |
+
winning_moves_mask = white_win_mask | black_win_mask
|
| 231 |
+
|
| 232 |
+
# Create a mask for all valid moves (excluding padding and start tokens)
|
| 233 |
+
valid_moves_mask = (ground_truth_flat != pad_token_id) & (ground_truth_flat != start_token_id)
|
| 234 |
+
|
| 235 |
+
# Calculate the negative log-likelihood loss for all valid moves
|
| 236 |
+
loss = F.nll_loss(output_log_softmax, ground_truth_flat, reduction='none')
|
| 237 |
+
|
| 238 |
+
# Apply weights based on whether the move is winning or losing
|
| 239 |
+
weights = torch.where(winning_moves_mask & valid_moves_mask, winning_weight, losing_weight)
|
| 240 |
+
|
| 241 |
+
# Apply the weights and the valid moves mask to the loss
|
| 242 |
+
weighted_loss = loss * weights * valid_moves_mask.float()
|
| 243 |
+
|
| 244 |
+
# Average the loss over all valid moves
|
| 245 |
+
valid_moves_count = valid_moves_mask.float().sum()
|
| 246 |
+
if valid_moves_count > 0:
|
| 247 |
+
avg_loss = weighted_loss.sum() / valid_moves_count
|
| 248 |
+
else:
|
| 249 |
+
avg_loss = weighted_loss.sum() # If no valid moves, return 0 loss
|
| 250 |
+
|
| 251 |
+
return avg_loss
|
environment.yml
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: chessbot
|
| 2 |
+
channels:
|
| 3 |
+
- defaults
|
| 4 |
+
dependencies:
|
| 5 |
+
- _libgcc_mutex=0.1=main
|
| 6 |
+
- _openmp_mutex=5.1=1_gnu
|
| 7 |
+
- bzip2=1.0.8=h5eee18b_6
|
| 8 |
+
- ca-certificates=2024.9.24=h06a4308_0
|
| 9 |
+
- expat=2.6.3=h6a678d5_0
|
| 10 |
+
- ld_impl_linux-64=2.40=h12ee557_0
|
| 11 |
+
- libffi=3.4.4=h6a678d5_1
|
| 12 |
+
- libgcc-ng=11.2.0=h1234567_1
|
| 13 |
+
- libgomp=11.2.0=h1234567_1
|
| 14 |
+
- libstdcxx-ng=11.2.0=h1234567_1
|
| 15 |
+
- libuuid=1.41.5=h5eee18b_0
|
| 16 |
+
- ncurses=6.4=h6a678d5_0
|
| 17 |
+
- openssl=3.0.15=h5eee18b_0
|
| 18 |
+
- pip=24.2=py312h06a4308_0
|
| 19 |
+
- python=3.12.7=h5148396_0
|
| 20 |
+
- readline=8.2=h5eee18b_0
|
| 21 |
+
- setuptools=75.1.0=py312h06a4308_0
|
| 22 |
+
- sqlite=3.45.3=h5eee18b_0
|
| 23 |
+
- tk=8.6.14=h39e8969_0
|
| 24 |
+
- wheel=0.44.0=py312h06a4308_0
|
| 25 |
+
- xz=5.4.6=h5eee18b_1
|
| 26 |
+
- zlib=1.2.13=h5eee18b_1
|
| 27 |
+
- pip:
|
| 28 |
+
- absl-py==2.1.0
|
| 29 |
+
- chess==1.11.0
|
| 30 |
+
- filelock==3.13.1
|
| 31 |
+
- fsspec==2024.2.0
|
| 32 |
+
- grpcio==1.66.2
|
| 33 |
+
- jinja2==3.1.3
|
| 34 |
+
- markdown==3.7
|
| 35 |
+
- markupsafe==2.1.5
|
| 36 |
+
- mpmath==1.3.0
|
| 37 |
+
- networkx==3.2.1
|
| 38 |
+
- numpy==2.1.2
|
| 39 |
+
- nvidia-cublas-cu12==12.4.2.65
|
| 40 |
+
- nvidia-cuda-cupti-cu12==12.4.99
|
| 41 |
+
- nvidia-cuda-nvrtc-cu12==12.4.99
|
| 42 |
+
- nvidia-cuda-runtime-cu12==12.4.99
|
| 43 |
+
- nvidia-cudnn-cu12==9.1.0.70
|
| 44 |
+
- nvidia-cufft-cu12==11.2.0.44
|
| 45 |
+
- nvidia-curand-cu12==10.3.5.119
|
| 46 |
+
- nvidia-cusolver-cu12==11.6.0.99
|
| 47 |
+
- nvidia-cusparse-cu12==12.3.0.142
|
| 48 |
+
- nvidia-nccl-cu12==2.20.5
|
| 49 |
+
- nvidia-nvjitlink-cu12==12.4.99
|
| 50 |
+
- nvidia-nvtx-cu12==12.4.99
|
| 51 |
+
- packaging==24.1
|
| 52 |
+
- pandas==2.2.3
|
| 53 |
+
- protobuf==5.28.2
|
| 54 |
+
- pyarrow==17.0.0
|
| 55 |
+
- python-dateutil==2.9.0.post0
|
| 56 |
+
- pytz==2024.2
|
| 57 |
+
- six==1.16.0
|
| 58 |
+
- sympy==1.12
|
| 59 |
+
- tensorboard==2.18.0
|
| 60 |
+
- tensorboard-data-server==0.7.2
|
| 61 |
+
- torch==2.4.1+cu124
|
| 62 |
+
- tqdm==4.66.5
|
| 63 |
+
- triton==3.0.0
|
| 64 |
+
- typing-extensions==4.9.0
|
| 65 |
+
- tzdata==2024.2
|
| 66 |
+
- werkzeug==3.0.4
|
play.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from chesstransformer import ChessTransformer
|
| 4 |
+
import tokenizer as tk
|
| 5 |
+
|
| 6 |
+
model = ChessTransformer()
|
| 7 |
+
model.load_state_dict(torch.load('./64L1024D_1e-3maxlr_470k_step_1ep_1480ELO.pth')["model_state_dict"])
|
| 8 |
+
model.eval().cuda()
|
| 9 |
+
|
| 10 |
+
# Initialize tokenizer
|
| 11 |
+
t = tk.Tokenizer()
|
| 12 |
+
|
| 13 |
+
def predict_move(model, game_sequence, tokenizer, device='cuda', top_k=5):
|
| 14 |
+
model.eval()
|
| 15 |
+
game_sequence = torch.tensor([tokenizer.tokenize_game(game_sequence)], dtype=torch.long).to(device)
|
| 16 |
+
|
| 17 |
+
with torch.no_grad():
|
| 18 |
+
output = model(game_sequence)
|
| 19 |
+
logits = output[0, -1, :] # Get logits for the last move
|
| 20 |
+
top_k_logits, top_k_indices = torch.topk(logits, top_k)
|
| 21 |
+
|
| 22 |
+
# Apply softmax to get probabilities
|
| 23 |
+
probs = F.softmax(top_k_logits, dim=-1)
|
| 24 |
+
|
| 25 |
+
# Sample from the probability distribution
|
| 26 |
+
sampled_index = torch.multinomial(probs, 1).item()
|
| 27 |
+
sampled_token = top_k_indices[sampled_index].item()
|
| 28 |
+
|
| 29 |
+
sampled_move = tokenizer.untokenize_game([sampled_token])[0]
|
| 30 |
+
|
| 31 |
+
# Get all top_k moves and their probabilities for display
|
| 32 |
+
top_k_moves = [tokenizer.untokenize_game([idx.item()])[0] for idx in top_k_indices]
|
| 33 |
+
top_k_probs = probs.cpu().numpy()
|
| 34 |
+
|
| 35 |
+
return sampled_move, top_k_moves, top_k_probs
|
| 36 |
+
|
| 37 |
+
def play_game():
|
| 38 |
+
input_game = []
|
| 39 |
+
print("Let's play chess! Enter your moves in UCI format (e.g., 'e2e4'). Type 'exit' to quit or 'undo' to undo the last move.")
|
| 40 |
+
|
| 41 |
+
while True:
|
| 42 |
+
user_move = input("Your move: ").strip()
|
| 43 |
+
if user_move.lower() == 'exit':
|
| 44 |
+
print("Game over. Thanks for playing!")
|
| 45 |
+
break
|
| 46 |
+
elif user_move.lower() == 'undo':
|
| 47 |
+
if len(input_game) >= 2:
|
| 48 |
+
input_game.pop() # Remove bot's move
|
| 49 |
+
input_game.pop() # Remove user's move
|
| 50 |
+
print("Last move undone. Current game sequence:", input_game)
|
| 51 |
+
else:
|
| 52 |
+
print("Cannot undo. No moves to undo.")
|
| 53 |
+
continue
|
| 54 |
+
|
| 55 |
+
input_game.append(user_move)
|
| 56 |
+
print("Current game sequence:", input_game)
|
| 57 |
+
|
| 58 |
+
try:
|
| 59 |
+
bot_move, top_moves, top_probs = predict_move(model, input_game, t)
|
| 60 |
+
|
| 61 |
+
# Display top moves and their probabilities
|
| 62 |
+
moves_probs_str = ', '.join(f"{move} ({prob:.2%})" for move, prob in zip(top_moves, top_probs))
|
| 63 |
+
print(f"Top {len(top_moves)} moves and probabilities: {moves_probs_str}")
|
| 64 |
+
|
| 65 |
+
print(f"Bot's sampled move: {bot_move}")
|
| 66 |
+
input_game.append(bot_move)
|
| 67 |
+
except Exception as e:
|
| 68 |
+
print("An error occurred:", e)
|
| 69 |
+
break
|
| 70 |
+
|
| 71 |
+
if __name__ == "__main__":
|
| 72 |
+
play_game()
|
tokenizer.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class Tokenizer:
|
| 2 |
+
def __init__(self):
|
| 3 |
+
self.move_dict = create_move_dict()
|
| 4 |
+
self.inverse_dict = inverse_move_dict(self.move_dict)
|
| 5 |
+
|
| 6 |
+
def tokenize_game(self, moves_list):
|
| 7 |
+
tokenized_moves = []
|
| 8 |
+
for move in moves_list:
|
| 9 |
+
tokenized_moves.append(self.move_dict[move])
|
| 10 |
+
return tokenized_moves
|
| 11 |
+
|
| 12 |
+
def untokenize_game(self, tokenized_moves):
|
| 13 |
+
inverse_moves = []
|
| 14 |
+
for move in tokenized_moves:
|
| 15 |
+
if move == 2064:
|
| 16 |
+
inverse_moves.append("[pad]")
|
| 17 |
+
continue
|
| 18 |
+
if move == 2065:
|
| 19 |
+
inverse_moves.append("[start]")
|
| 20 |
+
continue
|
| 21 |
+
inverse_moves.append(self.inverse_dict[move])
|
| 22 |
+
return inverse_moves
|
| 23 |
+
|
| 24 |
+
def tokenize_move(self, move):
|
| 25 |
+
return self.move_dict[move]
|
| 26 |
+
|
| 27 |
+
def get_move(self, tokenized_move):
|
| 28 |
+
return self.inverse_dict[tokenized_move]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# Helper function to convert square index to algebraic notation
|
| 32 |
+
def square_to_algebraic(square):
|
| 33 |
+
files = 'abcdefgh'
|
| 34 |
+
ranks = '12345678'
|
| 35 |
+
file = files[square % 8]
|
| 36 |
+
rank = ranks[square // 8]
|
| 37 |
+
return file + rank
|
| 38 |
+
|
| 39 |
+
# Modified chess_moves function to account for all moves
|
| 40 |
+
def chess_moves(starting_square):
|
| 41 |
+
moves = []
|
| 42 |
+
ss = starting_square
|
| 43 |
+
|
| 44 |
+
# Calculate file and rank
|
| 45 |
+
file_start = (ss // 8) * 8
|
| 46 |
+
file_end = file_start + 7
|
| 47 |
+
|
| 48 |
+
# Horizontal moves - to left
|
| 49 |
+
for i in range(ss - 1, file_start - 1, -1):
|
| 50 |
+
moves.append((ss, i))
|
| 51 |
+
|
| 52 |
+
# Horizontal moves - to right
|
| 53 |
+
for i in range(ss + 1, file_end + 1):
|
| 54 |
+
moves.append((ss, i))
|
| 55 |
+
|
| 56 |
+
# Vertical moves - above
|
| 57 |
+
for i in range(ss + 8, 64, 8):
|
| 58 |
+
moves.append((ss, i))
|
| 59 |
+
|
| 60 |
+
# Vertical moves - below
|
| 61 |
+
for i in range(ss - 8, -1, -8):
|
| 62 |
+
moves.append((ss, i))
|
| 63 |
+
|
| 64 |
+
# Diagonal moves
|
| 65 |
+
# Upper left
|
| 66 |
+
i = ss
|
| 67 |
+
while (i := i + 7) < 64 and i % 8 != 7:
|
| 68 |
+
moves.append((ss, i))
|
| 69 |
+
|
| 70 |
+
# Lower left
|
| 71 |
+
i = ss
|
| 72 |
+
while (i := i - 9) >= 0 and i % 8 != 7:
|
| 73 |
+
moves.append((ss, i))
|
| 74 |
+
|
| 75 |
+
# Upper right
|
| 76 |
+
i = ss
|
| 77 |
+
while (i := i + 9) < 64 and i % 8 != 0:
|
| 78 |
+
moves.append((ss, i))
|
| 79 |
+
|
| 80 |
+
# Lower right
|
| 81 |
+
i = ss
|
| 82 |
+
while (i := i - 7) >= 0 and i % 8 != 0:
|
| 83 |
+
moves.append((ss, i))
|
| 84 |
+
|
| 85 |
+
# Inner 5x5 square
|
| 86 |
+
for j in range(-2, 3):
|
| 87 |
+
for i in range(-2, 3):
|
| 88 |
+
target = ss + i + j * 8
|
| 89 |
+
if 0 <= target < 64 and (target // 8 == (ss // 8) + j) and target != ss:
|
| 90 |
+
moves.append((ss, target))
|
| 91 |
+
|
| 92 |
+
# Pawn moves (including promotions)
|
| 93 |
+
if ss // 8 == 1: # White pawn's initial position
|
| 94 |
+
if ss + 8 < 64:
|
| 95 |
+
moves.append((ss, ss + 8))
|
| 96 |
+
if (ss + 16) < 64:
|
| 97 |
+
moves.append((ss, ss + 16))
|
| 98 |
+
if ss + 9 < 64 and (ss + 9) % 8 != 0:
|
| 99 |
+
moves.append((ss, ss + 9))
|
| 100 |
+
if ss + 7 < 64 and (ss + 7) % 8 != 7:
|
| 101 |
+
moves.append((ss, ss + 7))
|
| 102 |
+
elif ss // 8 == 6: # Black pawn's initial position
|
| 103 |
+
if ss - 8 >= 0:
|
| 104 |
+
moves.append((ss, ss - 8))
|
| 105 |
+
if (ss - 16) >= 0:
|
| 106 |
+
moves.append((ss, ss - 16))
|
| 107 |
+
if ss - 9 >= 0 and (ss - 9) % 8 != 7:
|
| 108 |
+
moves.append((ss, ss - 9))
|
| 109 |
+
if ss - 7 >= 0 and (ss - 7) % 8 != 0:
|
| 110 |
+
moves.append((ss, ss - 7))
|
| 111 |
+
|
| 112 |
+
#remove duplicate tuples
|
| 113 |
+
seen = set()
|
| 114 |
+
result = []
|
| 115 |
+
for item in moves:
|
| 116 |
+
if item not in seen:
|
| 117 |
+
seen.add(item)
|
| 118 |
+
result.append(item)
|
| 119 |
+
|
| 120 |
+
return result
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
# Function to create a dictionary of moves with promotion
|
| 124 |
+
def create_move_dict():
|
| 125 |
+
move_dict = {}
|
| 126 |
+
count = 0
|
| 127 |
+
promotion_pieces = ['q', 'r', 'b', 'n'] # Queen, Rook, Bishop, Knight
|
| 128 |
+
|
| 129 |
+
for i in range(64):
|
| 130 |
+
for move in chess_moves(i):
|
| 131 |
+
start_sq_algebraic = square_to_algebraic(move[0])
|
| 132 |
+
end_sq_algebraic = square_to_algebraic(move[1])
|
| 133 |
+
move_dict[f"{start_sq_algebraic}{end_sq_algebraic}"] = count
|
| 134 |
+
count += 1
|
| 135 |
+
# Add promotions if applicable
|
| 136 |
+
if move[1] // 8 == 7 and i // 8 == 6: # White pawn reaching last rank
|
| 137 |
+
for piece in promotion_pieces:
|
| 138 |
+
move_dict[f"{start_sq_algebraic}{end_sq_algebraic}{piece}"] = count
|
| 139 |
+
count += 1
|
| 140 |
+
elif move[1] // 8 == 0 and i // 8 == 1: # Black pawn reaching last rank
|
| 141 |
+
for piece in promotion_pieces:
|
| 142 |
+
move_dict[f"{start_sq_algebraic}{end_sq_algebraic}{piece}"] = count
|
| 143 |
+
count += 1
|
| 144 |
+
|
| 145 |
+
move_dict["pad"] = 2064
|
| 146 |
+
move_dict["start"] = 2065
|
| 147 |
+
return move_dict
|
| 148 |
+
|
| 149 |
+
def inverse_move_dict(move_dict):
|
| 150 |
+
inverse_dict = {}
|
| 151 |
+
for k, v in move_dict.items():
|
| 152 |
+
inverse_dict[v] = k
|
| 153 |
+
return inverse_dict
|
| 154 |
+
|
| 155 |
+
def tokenize_game(moves_list):
|
| 156 |
+
move_dict = create_move_dict()
|
| 157 |
+
tokenized_moves = []
|
| 158 |
+
for move in moves_list:
|
| 159 |
+
tokenized_moves.append(move_dict[move])
|
| 160 |
+
return tokenized_moves
|
| 161 |
+
|
| 162 |
+
if __name__ == "__main__":
|
| 163 |
+
t = Tokenizer()
|