from model import Model from buffer import Buffer from game import Connect4 from mcts import MCTS_NN import numpy as np from typing import Tuple, List class Agent: def __init__(self, row:int, col:int, n_action: int, obs_shape: Tuple[int, int, int], model: Model, iteration: int, temperature:float): self.row = row self.col = col self.n_action = n_action self.obs_shape = obs_shape self.iteration = iteration self.temperature = temperature # Create buffer instance self.buffer = Buffer(n_action=self.n_action, obs_shape=self.obs_shape) # Target model instance self.target_model = model # Reset the MCTS class instance and buffer def reset(self, state: Connect4, reset_buffer: bool = False) -> None: # Reset the state of the Monte-carlo tree search instance self.mcts = MCTS_NN(state=state, model=self.target_model) # Reset the buffer def reset_buffer(self) -> None: self.buffer.reset() # Get the policy from mcts simulation def perform_mcts(self) -> np.ndarray: for _ in range(self.iteration): self.mcts.selection(self.mcts.root, add_dirichlet=True) policy = self.mcts.get_policy_pie(self.temperature) return policy # Get an action for any state def get_action(self) -> int: policy = self.perform_mcts() action = np.random.choice(self.n_action, p=policy) return action, policy # This method updates the buffer and send it to the buffer object def update_buffer(self, episodic_buffer: List)->None: # Get the last index of the episodic buffer idx = len(episodic_buffer) - 1 # Always the last state will have value 1 as it would be the winning move value = 1 while idx >= 0: episodic_buffer[idx][1] = value value *= -1 # For parent the value is negative idx -= 1 # Go to the previous experience tuple for state, value, policy in episodic_buffer: self.buffer.store_experience( state = state, value = value, policy = policy ) # Update the root to set it to one of its child node # based on the actio taken in the above method `get_action()` def update(self, action: int) -> None: self.mcts.update_root(action)