Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import math | |
| # --- CONFIGURATION --- | |
| BOARD_SIZE = 5 | |
| WIN_LENGTH = 4 | |
| ACTION_SIZE = BOARD_SIZE * BOARD_SIZE | |
| class GomokuGame: | |
| def __init__(self): | |
| self.n = BOARD_SIZE | |
| def get_init_board(self): | |
| return np.zeros((self.n, self.n)) | |
| def get_board_size(self): | |
| return (self.n, self.n) | |
| def get_action_size(self): | |
| return ACTION_SIZE | |
| def get_next_state(self, board, player, action): | |
| b = np.copy(board) | |
| row = action // self.n | |
| col = action % self.n | |
| b[row, col] = player | |
| return (b, -player) | |
| def get_valid_moves(self, board): | |
| return (board.reshape(-1) == 0).astype(int) | |
| def get_game_ended(self, board, player): | |
| n = self.n | |
| wl = WIN_LENGTH | |
| def check_win(p): | |
| # Horizontal | |
| for r in range(n): | |
| for c in range(n - wl + 1): | |
| if np.all(board[r, c:c+wl] == p): return True | |
| # Vertical | |
| for r in range(n - wl + 1): | |
| for c in range(n): | |
| if np.all(board[r:r+wl, c] == p): return True | |
| # Diagonals | |
| for r in range(n - wl + 1): | |
| for c in range(n - wl + 1): | |
| if all(board[r+k, c+k] == p for k in range(wl)): return True | |
| for r in range(n - wl + 1): | |
| for c in range(wl - 1, n): | |
| if all(board[r+k, c-k] == p for k in range(wl)): return True | |
| return False | |
| if check_win(player): return 1 | |
| if check_win(-player): return -1 | |
| if np.sum(board == 0) == 0: return 1e-4 | |
| return 0 | |
| def get_canonical_form(self, board, player): | |
| return player * board | |
| class Net(nn.Module): | |
| def __init__(self): | |
| super(Net, self).__init__() | |
| self.conv1 = nn.Conv2d(1, 128, 3, stride=1, padding=1) | |
| self.conv2 = nn.Conv2d(128, 128, 3, stride=1, padding=1) | |
| self.conv3 = nn.Conv2d(128, 128, 3, stride=1, padding=1) | |
| self.conv4 = nn.Conv2d(128, 128, 3, stride=1, padding=1) | |
| self.bn1 = nn.BatchNorm2d(128); self.bn2 = nn.BatchNorm2d(128) | |
| self.bn3 = nn.BatchNorm2d(128); self.bn4 = nn.BatchNorm2d(128) | |
| self.pi_conv = nn.Conv2d(128, 2, 1) | |
| self.pi_fc = nn.Linear(2 * BOARD_SIZE * BOARD_SIZE, ACTION_SIZE) | |
| self.v_conv = nn.Conv2d(128, 1, 1) | |
| self.v_fc1 = nn.Linear(1 * BOARD_SIZE * BOARD_SIZE, 64) | |
| self.v_fc2 = nn.Linear(64, 1) | |
| def forward(self, s): | |
| s = s.view(-1, 1, BOARD_SIZE, BOARD_SIZE) | |
| s = F.relu(self.bn1(self.conv1(s))) | |
| s = F.relu(self.bn2(self.conv2(s))) | |
| s = F.relu(self.bn3(self.conv3(s))) | |
| s = F.relu(self.bn4(self.conv4(s))) | |
| pi = F.relu(self.pi_conv(s)) | |
| pi = pi.view(-1, 2 * BOARD_SIZE * BOARD_SIZE) | |
| pi = self.pi_fc(pi) | |
| pi = F.log_softmax(pi, dim=1) | |
| v = F.relu(self.v_conv(s)) | |
| v = v.view(-1, 1 * BOARD_SIZE * BOARD_SIZE) | |
| v = F.relu(self.v_fc1(v)) | |
| v = torch.tanh(self.v_fc2(v)) | |
| return pi, v | |
| class MCTS: | |
| def __init__(self, game, net): | |
| self.game = game | |
| self.net = net | |
| self.Qsa = {} | |
| self.Nsa = {} | |
| self.Ns = {} | |
| self.Ps = {} | |
| self.Es = {} | |
| self.Vs = {} | |
| def getActionProb(self, canonicalBoard, temp=1, sims=500, device='cpu'): | |
| for _ in range(sims): | |
| self.search(canonicalBoard, device) | |
| s = canonicalBoard.tostring() | |
| counts = [self.Nsa[(s, a)] if (s, a) in self.Nsa else 0 for a in range(self.game.get_action_size())] | |
| if temp == 0: | |
| bestAs = np.array(np.argwhere(counts == np.max(counts))).flatten() | |
| bestA = np.random.choice(bestAs) | |
| probs = [0] * len(counts) | |
| probs[bestA] = 1 | |
| return probs | |
| counts = [x ** (1. / temp) for x in counts] | |
| counts_sum = float(sum(counts)) | |
| probs = [x / counts_sum for x in counts] | |
| return probs | |
| def search(self, canonicalBoard, device): | |
| s = canonicalBoard.tostring() | |
| if s not in self.Es: | |
| self.Es[s] = self.game.get_game_ended(canonicalBoard, 1) | |
| if self.Es[s] != 0: return -self.Es[s] | |
| if s not in self.Ps: | |
| board_tensor = torch.FloatTensor(canonicalBoard.astype(np.float64)).to(device) | |
| self.net.eval() | |
| with torch.no_grad(): | |
| pi, v = self.net(board_tensor) | |
| self.Ps[s] = torch.exp(pi).data.cpu().numpy()[0] | |
| valid_moves = self.game.get_valid_moves(canonicalBoard) | |
| self.Ps[s] = self.Ps[s] * valid_moves | |
| sum_Ps_s = np.sum(self.Ps[s]) | |
| if sum_Ps_s > 0: self.Ps[s] /= sum_Ps_s | |
| else: self.Ps[s] = valid_moves / np.sum(valid_moves) | |
| self.Vs[s] = valid_moves | |
| self.Ns[s] = 0 | |
| return -v.item() | |
| valid_moves = self.Vs[s] | |
| best_uct = -float('inf') | |
| best_a = -1 | |
| cpuct = 1.0 | |
| for a in range(self.game.get_action_size()): | |
| if valid_moves[a]: | |
| if (s, a) in self.Qsa: | |
| u = self.Qsa[(s, a)] + cpuct * self.Ps[s][a] * math.sqrt(self.Ns[s]) / (1 + self.Nsa[(s, a)]) | |
| else: | |
| u = cpuct * self.Ps[s][a] * math.sqrt(self.Ns[s] + 1e-8) | |
| if u > best_uct: | |
| best_uct = u; best_a = a | |
| a = best_a | |
| next_s, next_player = self.game.get_next_state(canonicalBoard, 1, a) | |
| next_s = self.game.get_canonical_form(next_s, next_player) | |
| v = self.search(next_s, device) | |
| if (s, a) in self.Qsa: | |
| self.Qsa[(s, a)] = (self.Nsa[(s, a)] * self.Qsa[(s, a)] + v) / (self.Nsa[(s, a)] + 1) | |
| self.Nsa[(s, a)] += 1 | |
| else: | |
| self.Qsa[(s, a)] = v | |
| self.Nsa[(s, a)] = 1 | |
| self.Ns[s] += 1 | |
| return -v |