connect-4-API / mcts.py
Gruhit Patel
init-backend
1fab54b verified
from model import Model
from typing import Union, Tuple
from game import Connect4
from config import Config
import torch
from torch import Tensor
import numpy as np
class Node:
def __init__(self, state: Union[Connect4, None], model: Model, name: str):
# Current state that the node represent
self.state = state
# Name of the node to trace it
self.name = name
# A model instance that the node will use to get value and policy
self.model = model
# visit count
self.N = 0
# Intermediate reward value
self.W = 0
# value of the node
self.value = None
# Prior policy for action from this node
self.policy = None
# Set the winner of the current node.
# Node by default indicating no one has won
self.win = None
# Children of current node
self.children = {}
# valid and invalid actions that can be take from this node
self.valid_actions = None
self.invalid_actions = None
# Set the valid and invalid actions
self.set_valid_actions()
# Initialize the branches to the childrens
self.initialize_edges()
# Set the valid actions that can be taken from the state that
# the node represent
def set_valid_actions(self) -> None:
if self.state is not None:
self.valid_actions = self.state.get_valid_moves()
self.invalid_actions = ~self.valid_actions
# initialize the edges from this node to potential childrens
def initialize_edges(self) -> None:
if self.state is not None:
self.children = {}
for act, valid_move in enumerate(self.valid_actions):
if valid_move:
# set state as none for childrens as we do not have it
self.children[act] = Node(
state=None,
model=self.model,
name=self.name + '_' + str(act)
)
def preprocess_state(self, x:np.ndarray) -> Tensor:
x = torch.tensor(x, dtype=torch.float32, device=Config.device)
x = x.unsqueeze(0)
return x
# define the forward pass for the current node
def forward(self) -> None:
with torch.no_grad():
value, policy = self.model(self.preprocess_state(self.state.get_state()))
value = value[0, 0]
policy = policy[0]
# Mask the invalid actions
policy[self.invalid_actions] = 0.
# Prevent from all probability from turning 0
if policy.sum() == 0:
policy[self.valid_actions] = 1.
policy = policy.softmax(dim=-1)
self.value = value.detach().cpu().numpy()
self.policy = policy.detach().cpu().numpy()
# Get policy for the current node
def get_policy(self) -> np.ndarray:
if self.policy is None:
self.forward()
return self.policy
# Get the value associated with the node
def get_value(self) -> float:
if self.value is None:
self.forward()
return self.value
class MCTS_NN:
def __init__(self, state:Connect4, model:Model, log=None):
self.root = Node(state=state, model=model, name='root')
if log is not None:
self.log = log
# For the simulation on the Monte-carlo tree
def selection(self, node: Node, add_dirichlet:bool=False, iter:int=0) -> float:
# Get the best child of the current node
# self.log.write(f'\nSelecting Best child of {node.name}')
best_child, best_action = self.get_best_child(node, add_dirichlet, iter)
# self.log.write(f"Iteartion {iter} - Best Action - {best_action} - Node: {node.name}")
# If the child is a leaf node(i.e.) either is terminal or is not expanded
# expand that node
if best_child.state is None:
# self.log.write(f'\nExpanding node {best_child.name}')
val = self.expolore_and_expand(parent=node, child=best_child, action=best_action, iter=iter)
# If the node is already expanded than traverse that node further
else:
# As per paper only add dirichlet noise for root node's
# child selection and not later on
# self.log.write(f'\nSelecting node further on {best_child.name}')
val = self.selection(node=best_child, add_dirichlet=False, iter=iter)
node.N += 1
node.W += val
return -val
# Expore and expand the tree
def expolore_and_expand(self, parent: Node, child: None, action: int, iter=0) -> float:
# self.log.write(f'\n<========== Explore or Expand Iteration {iter} ==========>')
# Check if the current state is a terminal state
if child.win is None:
# It is not expanded and is not terminal
# Perform the action for the parent state to get the next state
next_state, win = parent.state.drop_piece(action)
# First check if somone won in this next state
if win is not None:
val = -1 if win == parent.state.player_1 else 1
child.win = win
# self.log.write(f'\nPlayer Turn for child is {next_state.player_1} | [Winner Found]')
# self.log.write(f'\nWinner in that state {win} - child.Value is {val}')
# self.log.write(f'\nWinning Child in state {child.name}: state\n{next_state}\n')
# self.log.write('='*100)
# self.log.write('\n')
# else check if the next state results in draw
elif next_state.is_draw():
# 0 value if no one has won in the state
val = 0
# 0 for win means no one won
child.win = 0
# self.log.write(f'\nPlayer Turn for child is {next_state.player_1}')
# self.log.write(f'\nDraw Child in state {child.name}: state\n{next_state}\n')
# self.log.write('='*100)
# self.log.write('\n')
# if the next_state is not winning nor it is draw
# then expand it normally
else:
# If no one is winning yet then get the value for the current
# state from the child's mode and set it
child.state = next_state
child.set_valid_actions()
child.initialize_edges()
val = child.get_value()
# self.log.write(f'\nPlayer Turn for child is {next_state.player_1} | [No Winner]')
# self.log.write(f'\nLeaf node expanded for "{child.name}" with val {val:.5f}\n')
# self.log.write('='*100)
# self.log.write('\n')
else:
# If the current child represent a draw state then give value 0
if child.win == 0:
# self.log.write(f'\nTerminal DRAW state reached for child {child.name}\n')
# self.log.write('='*100)
# self.log.write('\n')
val = 0
# If the winner in child node was the player who played a move
# in the parent node then set -1 as value as it means that
# the player in child node has lost
elif child.win == parent.state.player_1:
# self.log.write(f'\nTerminal Parent Winning state reached for child {child.name}\n')
# self.log.write('='*100)
# self.log.write('\n')
val = -1
# if the winner of child node is the same as the player of child node
# then provide value of +1
else:
# self.log.write(f'\nTerminal child Winning state reached for child {child.name}\n')
# self.log.write('='*100)
# self.log.write('\n')
val = 1
# Update the visit count and intermidiate reward of child node
child.N += 1
child.W += val
# Return negative of val because the player in parent node will be
# the opposite player from the current node. Hence what is good
# for current node's player should be bad for the parent node's player
return -val
# Calculate the PUCT score for a node's children
def get_puct_score(self, parent: Node, child: Node, prior: float) -> float:
# PUCT is the sum of q_value of current node + the U(S, a)
q_value = 0
if child.N == 0:
q_value = 0
else:
# q_value = 1 - ((child.W/child.N) + 1)/2
q_value = -child.W/child.N
# C_puct represent the exploration constant
c_puct = 1
u_sa = c_puct * prior * (np.sqrt(parent.N))/(1+child.N)
return q_value + u_sa
def get_dirichlet_noise(self, node: None) -> np.ndarray:
num_valid_action = node.valid_actions.sum()
noise_vec = np.random.dirichlet([Config.DIRICHLET_ALPHA]*num_valid_action)
noise_arr = np.zeros((len(node.valid_actions),), dtype=noise_vec.dtype)
noise_arr[node.valid_actions] = noise_vec
return noise_arr
# Get the best child for any node
def get_best_child(self, node: Node, add_dirichlet: bool, iter=0) -> Tuple[Node, int]:
# the best node is simple the one with highest PUCT value
policy = node.get_policy()
if add_dirichlet:
noise_arr = self.get_dirichlet_noise(node)
policy = (1-Config.EPSILON)*policy + Config.EPSILON*noise_arr
best_puct = float('-inf')
best_child = None
best_action = None
# self.log.write(f'\n\n==================== Iteration {iter} ====================\n')
for action, child in node.children.items():
puct = self.get_puct_score(parent=node, child=child, prior=policy[action])
# self.log.write(f'{action} - PUCT: {puct:.4f} | N = {child.N} | W = {child.W:.4f} | P = {policy[action]:.4f}\n')
if puct > best_puct:
best_puct = puct
best_child = child
best_action = action
return best_child, best_action
# return the policy pie for the root node based on the visit count
def get_policy_pie(self, temperature:float=1):
actions = np.zeros((len(self.root.valid_actions),))
for action, child in self.root.children.items():
actions[action] = (child.N)**(1/temperature)
actions /= actions.sum()
return actions
# Traverse the tree by steping to one of the child node of root node
def update_root(self, action: int) -> None:
self.root = self.root.children[action]