AlphaZero-TicTacToe / game_logic.py
NihalGazi's picture
Create game_logic.py
2ef2e11 verified
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# --- CONFIGURATION ---
BOARD_SIZE = 5
WIN_LENGTH = 4
ACTION_SIZE = BOARD_SIZE * BOARD_SIZE
class GomokuGame:
def __init__(self):
self.n = BOARD_SIZE
def get_init_board(self):
return np.zeros((self.n, self.n))
def get_board_size(self):
return (self.n, self.n)
def get_action_size(self):
return ACTION_SIZE
def get_next_state(self, board, player, action):
b = np.copy(board)
row = action // self.n
col = action % self.n
b[row, col] = player
return (b, -player)
def get_valid_moves(self, board):
return (board.reshape(-1) == 0).astype(int)
def get_game_ended(self, board, player):
n = self.n
wl = WIN_LENGTH
def check_win(p):
# Horizontal
for r in range(n):
for c in range(n - wl + 1):
if np.all(board[r, c:c+wl] == p): return True
# Vertical
for r in range(n - wl + 1):
for c in range(n):
if np.all(board[r:r+wl, c] == p): return True
# Diagonals
for r in range(n - wl + 1):
for c in range(n - wl + 1):
if all(board[r+k, c+k] == p for k in range(wl)): return True
for r in range(n - wl + 1):
for c in range(wl - 1, n):
if all(board[r+k, c-k] == p for k in range(wl)): return True
return False
if check_win(player): return 1
if check_win(-player): return -1
if np.sum(board == 0) == 0: return 1e-4
return 0
def get_canonical_form(self, board, player):
return player * board
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 128, 3, stride=1, padding=1)
self.conv2 = nn.Conv2d(128, 128, 3, stride=1, padding=1)
self.conv3 = nn.Conv2d(128, 128, 3, stride=1, padding=1)
self.conv4 = nn.Conv2d(128, 128, 3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(128); self.bn2 = nn.BatchNorm2d(128)
self.bn3 = nn.BatchNorm2d(128); self.bn4 = nn.BatchNorm2d(128)
self.pi_conv = nn.Conv2d(128, 2, 1)
self.pi_fc = nn.Linear(2 * BOARD_SIZE * BOARD_SIZE, ACTION_SIZE)
self.v_conv = nn.Conv2d(128, 1, 1)
self.v_fc1 = nn.Linear(1 * BOARD_SIZE * BOARD_SIZE, 64)
self.v_fc2 = nn.Linear(64, 1)
def forward(self, s):
s = s.view(-1, 1, BOARD_SIZE, BOARD_SIZE)
s = F.relu(self.bn1(self.conv1(s)))
s = F.relu(self.bn2(self.conv2(s)))
s = F.relu(self.bn3(self.conv3(s)))
s = F.relu(self.bn4(self.conv4(s)))
pi = F.relu(self.pi_conv(s))
pi = pi.view(-1, 2 * BOARD_SIZE * BOARD_SIZE)
pi = self.pi_fc(pi)
pi = F.log_softmax(pi, dim=1)
v = F.relu(self.v_conv(s))
v = v.view(-1, 1 * BOARD_SIZE * BOARD_SIZE)
v = F.relu(self.v_fc1(v))
v = torch.tanh(self.v_fc2(v))
return pi, v
class MCTS:
def __init__(self, game, net):
self.game = game
self.net = net
self.Qsa = {}
self.Nsa = {}
self.Ns = {}
self.Ps = {}
self.Es = {}
self.Vs = {}
def getActionProb(self, canonicalBoard, temp=1, sims=500, device='cpu'):
for _ in range(sims):
self.search(canonicalBoard, device)
s = canonicalBoard.tostring()
counts = [self.Nsa[(s, a)] if (s, a) in self.Nsa else 0 for a in range(self.game.get_action_size())]
if temp == 0:
bestAs = np.array(np.argwhere(counts == np.max(counts))).flatten()
bestA = np.random.choice(bestAs)
probs = [0] * len(counts)
probs[bestA] = 1
return probs
counts = [x ** (1. / temp) for x in counts]
counts_sum = float(sum(counts))
probs = [x / counts_sum for x in counts]
return probs
def search(self, canonicalBoard, device):
s = canonicalBoard.tostring()
if s not in self.Es:
self.Es[s] = self.game.get_game_ended(canonicalBoard, 1)
if self.Es[s] != 0: return -self.Es[s]
if s not in self.Ps:
board_tensor = torch.FloatTensor(canonicalBoard.astype(np.float64)).to(device)
self.net.eval()
with torch.no_grad():
pi, v = self.net(board_tensor)
self.Ps[s] = torch.exp(pi).data.cpu().numpy()[0]
valid_moves = self.game.get_valid_moves(canonicalBoard)
self.Ps[s] = self.Ps[s] * valid_moves
sum_Ps_s = np.sum(self.Ps[s])
if sum_Ps_s > 0: self.Ps[s] /= sum_Ps_s
else: self.Ps[s] = valid_moves / np.sum(valid_moves)
self.Vs[s] = valid_moves
self.Ns[s] = 0
return -v.item()
valid_moves = self.Vs[s]
best_uct = -float('inf')
best_a = -1
cpuct = 1.0
for a in range(self.game.get_action_size()):
if valid_moves[a]:
if (s, a) in self.Qsa:
u = self.Qsa[(s, a)] + cpuct * self.Ps[s][a] * math.sqrt(self.Ns[s]) / (1 + self.Nsa[(s, a)])
else:
u = cpuct * self.Ps[s][a] * math.sqrt(self.Ns[s] + 1e-8)
if u > best_uct:
best_uct = u; best_a = a
a = best_a
next_s, next_player = self.game.get_next_state(canonicalBoard, 1, a)
next_s = self.game.get_canonical_form(next_s, next_player)
v = self.search(next_s, device)
if (s, a) in self.Qsa:
self.Qsa[(s, a)] = (self.Nsa[(s, a)] * self.Qsa[(s, a)] + v) / (self.Nsa[(s, a)] + 1)
self.Nsa[(s, a)] += 1
else:
self.Qsa[(s, a)] = v
self.Nsa[(s, a)] = 1
self.Ns[s] += 1
return -v