Spaces:
Sleeping
Sleeping
| from model import Model | |
| from typing import Union, Tuple | |
| from game import Connect4 | |
| from config import Config | |
| import torch | |
| from torch import Tensor | |
| import numpy as np | |
| class Node: | |
| def __init__(self, state: Union[Connect4, None], model: Model, name: str): | |
| # Current state that the node represent | |
| self.state = state | |
| # Name of the node to trace it | |
| self.name = name | |
| # A model instance that the node will use to get value and policy | |
| self.model = model | |
| # visit count | |
| self.N = 0 | |
| # Intermediate reward value | |
| self.W = 0 | |
| # value of the node | |
| self.value = None | |
| # Prior policy for action from this node | |
| self.policy = None | |
| # Set the winner of the current node. | |
| # Node by default indicating no one has won | |
| self.win = None | |
| # Children of current node | |
| self.children = {} | |
| # valid and invalid actions that can be take from this node | |
| self.valid_actions = None | |
| self.invalid_actions = None | |
| # Set the valid and invalid actions | |
| self.set_valid_actions() | |
| # Initialize the branches to the childrens | |
| self.initialize_edges() | |
| # Set the valid actions that can be taken from the state that | |
| # the node represent | |
| def set_valid_actions(self) -> None: | |
| if self.state is not None: | |
| self.valid_actions = self.state.get_valid_moves() | |
| self.invalid_actions = ~self.valid_actions | |
| # initialize the edges from this node to potential childrens | |
| def initialize_edges(self) -> None: | |
| if self.state is not None: | |
| self.children = {} | |
| for act, valid_move in enumerate(self.valid_actions): | |
| if valid_move: | |
| # set state as none for childrens as we do not have it | |
| self.children[act] = Node( | |
| state=None, | |
| model=self.model, | |
| name=self.name + '_' + str(act) | |
| ) | |
| def preprocess_state(self, x:np.ndarray) -> Tensor: | |
| x = torch.tensor(x, dtype=torch.float32, device=Config.device) | |
| x = x.unsqueeze(0) | |
| return x | |
| # define the forward pass for the current node | |
| def forward(self) -> None: | |
| with torch.no_grad(): | |
| value, policy = self.model(self.preprocess_state(self.state.get_state())) | |
| value = value[0, 0] | |
| policy = policy[0] | |
| # Mask the invalid actions | |
| policy[self.invalid_actions] = 0. | |
| # Prevent from all probability from turning 0 | |
| if policy.sum() == 0: | |
| policy[self.valid_actions] = 1. | |
| policy = policy.softmax(dim=-1) | |
| self.value = value.detach().cpu().numpy() | |
| self.policy = policy.detach().cpu().numpy() | |
| # Get policy for the current node | |
| def get_policy(self) -> np.ndarray: | |
| if self.policy is None: | |
| self.forward() | |
| return self.policy | |
| # Get the value associated with the node | |
| def get_value(self) -> float: | |
| if self.value is None: | |
| self.forward() | |
| return self.value | |
| class MCTS_NN: | |
| def __init__(self, state:Connect4, model:Model, log=None): | |
| self.root = Node(state=state, model=model, name='root') | |
| if log is not None: | |
| self.log = log | |
| # For the simulation on the Monte-carlo tree | |
| def selection(self, node: Node, add_dirichlet:bool=False, iter:int=0) -> float: | |
| # Get the best child of the current node | |
| # self.log.write(f'\nSelecting Best child of {node.name}') | |
| best_child, best_action = self.get_best_child(node, add_dirichlet, iter) | |
| # self.log.write(f"Iteartion {iter} - Best Action - {best_action} - Node: {node.name}") | |
| # If the child is a leaf node(i.e.) either is terminal or is not expanded | |
| # expand that node | |
| if best_child.state is None: | |
| # self.log.write(f'\nExpanding node {best_child.name}') | |
| val = self.expolore_and_expand(parent=node, child=best_child, action=best_action, iter=iter) | |
| # If the node is already expanded than traverse that node further | |
| else: | |
| # As per paper only add dirichlet noise for root node's | |
| # child selection and not later on | |
| # self.log.write(f'\nSelecting node further on {best_child.name}') | |
| val = self.selection(node=best_child, add_dirichlet=False, iter=iter) | |
| node.N += 1 | |
| node.W += val | |
| return -val | |
| # Expore and expand the tree | |
| def expolore_and_expand(self, parent: Node, child: None, action: int, iter=0) -> float: | |
| # self.log.write(f'\n<========== Explore or Expand Iteration {iter} ==========>') | |
| # Check if the current state is a terminal state | |
| if child.win is None: | |
| # It is not expanded and is not terminal | |
| # Perform the action for the parent state to get the next state | |
| next_state, win = parent.state.drop_piece(action) | |
| # First check if somone won in this next state | |
| if win is not None: | |
| val = -1 if win == parent.state.player_1 else 1 | |
| child.win = win | |
| # self.log.write(f'\nPlayer Turn for child is {next_state.player_1} | [Winner Found]') | |
| # self.log.write(f'\nWinner in that state {win} - child.Value is {val}') | |
| # self.log.write(f'\nWinning Child in state {child.name}: state\n{next_state}\n') | |
| # self.log.write('='*100) | |
| # self.log.write('\n') | |
| # else check if the next state results in draw | |
| elif next_state.is_draw(): | |
| # 0 value if no one has won in the state | |
| val = 0 | |
| # 0 for win means no one won | |
| child.win = 0 | |
| # self.log.write(f'\nPlayer Turn for child is {next_state.player_1}') | |
| # self.log.write(f'\nDraw Child in state {child.name}: state\n{next_state}\n') | |
| # self.log.write('='*100) | |
| # self.log.write('\n') | |
| # if the next_state is not winning nor it is draw | |
| # then expand it normally | |
| else: | |
| # If no one is winning yet then get the value for the current | |
| # state from the child's mode and set it | |
| child.state = next_state | |
| child.set_valid_actions() | |
| child.initialize_edges() | |
| val = child.get_value() | |
| # self.log.write(f'\nPlayer Turn for child is {next_state.player_1} | [No Winner]') | |
| # self.log.write(f'\nLeaf node expanded for "{child.name}" with val {val:.5f}\n') | |
| # self.log.write('='*100) | |
| # self.log.write('\n') | |
| else: | |
| # If the current child represent a draw state then give value 0 | |
| if child.win == 0: | |
| # self.log.write(f'\nTerminal DRAW state reached for child {child.name}\n') | |
| # self.log.write('='*100) | |
| # self.log.write('\n') | |
| val = 0 | |
| # If the winner in child node was the player who played a move | |
| # in the parent node then set -1 as value as it means that | |
| # the player in child node has lost | |
| elif child.win == parent.state.player_1: | |
| # self.log.write(f'\nTerminal Parent Winning state reached for child {child.name}\n') | |
| # self.log.write('='*100) | |
| # self.log.write('\n') | |
| val = -1 | |
| # if the winner of child node is the same as the player of child node | |
| # then provide value of +1 | |
| else: | |
| # self.log.write(f'\nTerminal child Winning state reached for child {child.name}\n') | |
| # self.log.write('='*100) | |
| # self.log.write('\n') | |
| val = 1 | |
| # Update the visit count and intermidiate reward of child node | |
| child.N += 1 | |
| child.W += val | |
| # Return negative of val because the player in parent node will be | |
| # the opposite player from the current node. Hence what is good | |
| # for current node's player should be bad for the parent node's player | |
| return -val | |
| # Calculate the PUCT score for a node's children | |
| def get_puct_score(self, parent: Node, child: Node, prior: float) -> float: | |
| # PUCT is the sum of q_value of current node + the U(S, a) | |
| q_value = 0 | |
| if child.N == 0: | |
| q_value = 0 | |
| else: | |
| # q_value = 1 - ((child.W/child.N) + 1)/2 | |
| q_value = -child.W/child.N | |
| # C_puct represent the exploration constant | |
| c_puct = 1 | |
| u_sa = c_puct * prior * (np.sqrt(parent.N))/(1+child.N) | |
| return q_value + u_sa | |
| def get_dirichlet_noise(self, node: None) -> np.ndarray: | |
| num_valid_action = node.valid_actions.sum() | |
| noise_vec = np.random.dirichlet([Config.DIRICHLET_ALPHA]*num_valid_action) | |
| noise_arr = np.zeros((len(node.valid_actions),), dtype=noise_vec.dtype) | |
| noise_arr[node.valid_actions] = noise_vec | |
| return noise_arr | |
| # Get the best child for any node | |
| def get_best_child(self, node: Node, add_dirichlet: bool, iter=0) -> Tuple[Node, int]: | |
| # the best node is simple the one with highest PUCT value | |
| policy = node.get_policy() | |
| if add_dirichlet: | |
| noise_arr = self.get_dirichlet_noise(node) | |
| policy = (1-Config.EPSILON)*policy + Config.EPSILON*noise_arr | |
| best_puct = float('-inf') | |
| best_child = None | |
| best_action = None | |
| # self.log.write(f'\n\n==================== Iteration {iter} ====================\n') | |
| for action, child in node.children.items(): | |
| puct = self.get_puct_score(parent=node, child=child, prior=policy[action]) | |
| # self.log.write(f'{action} - PUCT: {puct:.4f} | N = {child.N} | W = {child.W:.4f} | P = {policy[action]:.4f}\n') | |
| if puct > best_puct: | |
| best_puct = puct | |
| best_child = child | |
| best_action = action | |
| return best_child, best_action | |
| # return the policy pie for the root node based on the visit count | |
| def get_policy_pie(self, temperature:float=1): | |
| actions = np.zeros((len(self.root.valid_actions),)) | |
| for action, child in self.root.children.items(): | |
| actions[action] = (child.N)**(1/temperature) | |
| actions /= actions.sum() | |
| return actions | |
| # Traverse the tree by steping to one of the child node of root node | |
| def update_root(self, action: int) -> None: | |
| self.root = self.root.children[action] |