connect-4-API / agent.py
Gruhit Patel
init-backend
1fab54b verified
from model import Model
from buffer import Buffer
from game import Connect4
from mcts import MCTS_NN
import numpy as np
from typing import Tuple, List
class Agent:
def __init__(self, row:int, col:int, n_action: int, obs_shape: Tuple[int, int, int],
model: Model, iteration: int, temperature:float):
self.row = row
self.col = col
self.n_action = n_action
self.obs_shape = obs_shape
self.iteration = iteration
self.temperature = temperature
# Create buffer instance
self.buffer = Buffer(n_action=self.n_action, obs_shape=self.obs_shape)
# Target model instance
self.target_model = model
# Reset the MCTS class instance and buffer
def reset(self, state: Connect4, reset_buffer: bool = False) -> None:
# Reset the state of the Monte-carlo tree search instance
self.mcts = MCTS_NN(state=state, model=self.target_model)
# Reset the buffer
def reset_buffer(self) -> None:
self.buffer.reset()
# Get the policy from mcts simulation
def perform_mcts(self) -> np.ndarray:
for _ in range(self.iteration):
self.mcts.selection(self.mcts.root, add_dirichlet=True)
policy = self.mcts.get_policy_pie(self.temperature)
return policy
# Get an action for any state
def get_action(self) -> int:
policy = self.perform_mcts()
action = np.random.choice(self.n_action, p=policy)
return action, policy
# This method updates the buffer and send it to the buffer object
def update_buffer(self, episodic_buffer: List)->None:
# Get the last index of the episodic buffer
idx = len(episodic_buffer) - 1
# Always the last state will have value 1 as it would be the winning move
value = 1
while idx >= 0:
episodic_buffer[idx][1] = value
value *= -1 # For parent the value is negative
idx -= 1 # Go to the previous experience tuple
for state, value, policy in episodic_buffer:
self.buffer.store_experience(
state = state,
value = value,
policy = policy
)
# Update the root to set it to one of its child node
# based on the actio taken in the above method `get_action()`
def update(self, action: int) -> None:
self.mcts.update_root(action)