{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "b753ca88", "metadata": {}, "outputs": [], "source": [ "import chess \n", "from chess import pgn\n", "import os\n", "from tqdm import tqdm\n", "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "import numpy as np\n" ] }, { "cell_type": "code", "execution_count": null, "id": "a116056b", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " 93%|█████████▎| 14/15 [03:50<00:16, 16.44s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Games loaded: 61363\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "# hàm để load các file pgn\n", "def load_pgn(file_path):\n", " games = []\n", " with open(file_path, 'r') as pgn_file:\n", " while True:\n", " game = pgn.read_game(pgn_file)\n", " if game is None:\n", " break\n", " games.append(game)\n", " return games\n", "\n", "files = [file for file in os.listdir(\"data\") if file.endswith(\".pgn\")]\n", "LIMIT_OF_FILES = len(files) #load all files\n", "game = []\n", "i = 1\n", "for file in tqdm(files):\n", " game.extend(load_pgn(f\"data/{file}\"))\n", " if i >= LIMIT_OF_FILES:\n", " break\n", " i += 1\n", "print(\"Games loaded:\", len(game))" ] }, { "cell_type": "code", "execution_count": 2, "id": "8c876069", "metadata": {}, "outputs": [], "source": [ "from chess import Board\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": null, "id": "9ae41f58", "metadata": {}, "outputs": [], "source": [ "def board_to_matrix(board: Board):\n", "# Bàn cờ được biểu diễn dưới dạng ma trận 3D với 13 kênh\n", "# 12 kênh đầu tiên đại diện cho các loại quân cờ (6 loại cho mỗi màu)\n", "# Kênh thứ 13 đại diện cho các nước đi hợp lệ\n", " matrix = np.zeros((13, 8, 8))\n", " piece_map = board.piece_map()\n", " for square, piece in piece_map.items():\n", " row, col = divmod(square, 8)\n", " piece_type = piece.piece_type - 1\n", " piece_color = 0 if piece.color else 6\n", " matrix[piece_type + piece_color, row, col] = 1\n", " legal_moves = board.legal_moves\n", " for move in legal_moves:\n", " to_square = move.to_square\n", " row_to, col_to = divmod(to_square, 8)\n", " matrix[12, row_to, col_to] = 1\n", "\n", " return matrix\n", "\n", "def create_input_for_nn(games): # Tạo dữ liệu đầu vào và nhãn từ các ván cờ\n", " X = []\n", " y = []\n", " for game in games:\n", " board = game.board()\n", " for move in game.mainline_moves():\n", " X.append(board_to_matrix(board))\n", " y.append(move.uci())\n", " board.push(move)\n", " return np.array(X, dtype=np.float32), np.array(y)\n", "\n", "\n", "def encode_moves(moves): # Mã hóa các nước đi thành số nguyên\n", " move_to_int = {move: idx for idx, move in enumerate(set(moves))}\n", " return np.array([move_to_int[move] for move in moves], dtype=np.float32), move_to_int" ] }, { "cell_type": "code", "execution_count": null, "id": "5a128d5d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "NUM of moves 5538627\n" ] } ], "source": [ "X, y = create_input_for_nn(game) # Lưu giá trị trả về từ hàm\n", "print(\"NUM of moves\", len(y)) # In ra số lượng nước đi" ] }, { "cell_type": "code", "execution_count": null, "id": "302f7579", "metadata": {}, "outputs": [], "source": [ "\n", "y, move_to_int = encode_moves(y) # Mã hóa các nước đi thành số nguyên\n", "num_classes = len(move_to_int)\n", "\n", "X = torch.tensor(X, dtype = torch.float32) # Chuyển đổi X thành tensor\n", "y = torch.tensor(y, dtype = torch.long) # Chuyển đổi y thành tensor" ] }, { "cell_type": "code", "execution_count": 7, "id": "1602bcf0", "metadata": {}, "outputs": [], "source": [ "from dataset import ChessDataset\n", "from model import ChessModel" ] }, { "cell_type": "code", "execution_count": null, "id": "ca936e40", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using device: cuda\n" ] } ], "source": [ "# Create Dataset and DataLoader\n", "from torch.utils.data import DataLoader # type: ignore\n", "dataset = ChessDataset(X, y) # Tạo dataset từ tensor X và y\n", "dataloader = DataLoader(dataset, batch_size=64, shuffle=True) # Tạo DataLoader để load dữ liệu theo batch size 64\n", "\n", "# Check for GPU\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "# Model Initialization\n", "model = ChessModel(num_classes=num_classes).to(device)\n", "criterion = nn.CrossEntropyLoss()\n", "optimizer = optim.Adam(model.parameters(), lr=0.0001)" ] }, { "cell_type": "code", "execution_count": null, "id": "854f506e", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:26<00:00, 152.83it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 51/101, Loss: 3.2092, Time: 9m26s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:25<00:00, 153.09it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 52/101, Loss: 2.5455, Time: 9m25s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:29<00:00, 152.08it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 53/101, Loss: 2.3894, Time: 9m29s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:29<00:00, 151.96it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 54/101, Loss: 2.2996, Time: 9m29s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:29<00:00, 152.01it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 55/101, Loss: 2.2376, Time: 9m29s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:29<00:00, 151.94it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 56/101, Loss: 2.1911, Time: 9m29s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:29<00:00, 151.87it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 57/101, Loss: 2.1532, Time: 9m29s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:29<00:00, 152.07it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 58/101, Loss: 2.1213, Time: 9m29s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:28<00:00, 152.33it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 59/101, Loss: 2.0945, Time: 9m28s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:28<00:00, 152.15it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 60/101, Loss: 2.0712, Time: 9m28s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:28<00:00, 152.31it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 61/101, Loss: 2.0509, Time: 9m28s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:28<00:00, 152.26it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 62/101, Loss: 2.0328, Time: 9m28s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:27<00:00, 152.39it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 63/101, Loss: 2.0171, Time: 9m27s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:30<00:00, 151.76it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 64/101, Loss: 2.0031, Time: 9m30s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:29<00:00, 151.89it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 65/101, Loss: 1.9905, Time: 9m29s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:32<00:00, 151.06it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 66/101, Loss: 1.9789, Time: 9m32s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:48<00:00, 146.96it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 67/101, Loss: 1.9688, Time: 9m48s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:47<00:00, 147.41it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 68/101, Loss: 1.9591, Time: 9m47s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:41<00:00, 148.76it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 69/101, Loss: 1.9503, Time: 9m41s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:29<00:00, 152.02it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 70/101, Loss: 1.9424, Time: 9m29s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:27<00:00, 152.58it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 71/101, Loss: 1.9349, Time: 9m27s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:28<00:00, 152.16it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 72/101, Loss: 1.9279, Time: 9m28s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:28<00:00, 152.35it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 73/101, Loss: 1.9212, Time: 9m28s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:27<00:00, 152.48it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 74/101, Loss: 1.9154, Time: 9m27s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:28<00:00, 152.35it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 75/101, Loss: 1.9101, Time: 9m28s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:27<00:00, 152.39it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 76/101, Loss: 1.9049, Time: 9m27s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:27<00:00, 152.41it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 77/101, Loss: 1.9002, Time: 9m27s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:27<00:00, 152.44it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 78/101, Loss: 1.8961, Time: 9m27s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:27<00:00, 152.57it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 79/101, Loss: 1.8919, Time: 9m27s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:29<00:00, 151.95it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 80/101, Loss: 1.8881, Time: 9m29s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:20<00:00, 154.40it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 81/101, Loss: 1.8848, Time: 9m20s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:21<00:00, 154.02it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 82/101, Loss: 1.8816, Time: 9m21s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:25<00:00, 152.96it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 83/101, Loss: 1.8788, Time: 9m25s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:28<00:00, 152.32it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 84/101, Loss: 1.8761, Time: 9m28s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:28<00:00, 152.20it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 85/101, Loss: 1.8733, Time: 9m28s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:30<00:00, 151.65it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 86/101, Loss: 1.8709, Time: 9m30s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:26<00:00, 152.75it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 87/101, Loss: 1.8688, Time: 9m26s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:25<00:00, 153.01it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 88/101, Loss: 1.8668, Time: 9m25s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:28<00:00, 152.24it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 89/101, Loss: 1.8647, Time: 9m28s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:26<00:00, 152.85it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 90/101, Loss: 1.8628, Time: 9m26s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:28<00:00, 152.32it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 91/101, Loss: 1.8612, Time: 9m28s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:32<00:00, 151.28it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 92/101, Loss: 1.8597, Time: 9m32s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:31<00:00, 151.42it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 93/101, Loss: 1.8581, Time: 9m31s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:31<00:00, 151.52it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 94/101, Loss: 1.8566, Time: 9m31s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:30<00:00, 151.59it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 95/101, Loss: 1.8553, Time: 9m30s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:31<00:00, 151.53it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 96/101, Loss: 1.8542, Time: 9m31s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:31<00:00, 151.52it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 97/101, Loss: 1.8532, Time: 9m31s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:31<00:00, 151.51it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 98/101, Loss: 1.8519, Time: 9m31s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:31<00:00, 151.56it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 99/101, Loss: 1.8510, Time: 9m31s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 86542/86542 [09:30<00:00, 151.68it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 100/101, Loss: 1.8500, Time: 9m30s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "import time\n", "num_epochs = 50\n", "for epoch in range(num_epochs):\n", " start_time = time.time()\n", " model.train()\n", " running_loss = 0.0\n", " for inputs, labels in tqdm(dataloader):\n", " inputs, labels = inputs.to(device), labels.to(device) # chuyển dữ liệu sang GPU nếu có\n", " optimizer.zero_grad() # trả về gradient về 0\n", "\n", " outputs = model(inputs) \n", "\n", " # Tính loss và backpropagation\n", " loss = criterion(outputs, labels)\n", " loss.backward()\n", " \n", " # Gradient clipping\n", " torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n", " \n", " optimizer.step()\n", " running_loss += loss.item()\n", " end_time = time.time()\n", " epoch_time = end_time - start_time\n", " minutes: int = int(epoch_time // 60)\n", " seconds: int = int(epoch_time) - minutes * 60\n", " print(f'Epoch {epoch + 1 + 50}/{num_epochs + 1 + 50}, Loss: {running_loss / len(dataloader):.4f}, Time: {minutes}m{seconds}s')" ] }, { "cell_type": "code", "execution_count": null, "id": "3521671f", "metadata": {}, "outputs": [], "source": [ "\n", "# lưu model vào models/\n", "torch.save(model.state_dict(), \"models/TORCH_1_100EPOCHS.pth\")" ] }, { "cell_type": "code", "execution_count": 13, "id": "c7a53a41", "metadata": {}, "outputs": [], "source": [ "import pickle\n", "\n", "with open(\"models/heavy_move_to_int_1\", \"wb\") as file:\n", " pickle.dump(move_to_int, file)" ] } ], "metadata": { "kernelspec": { "display_name": "chess_game", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.0" } }, "nbformat": 4, "nbformat_minor": 5 }