File size: 2,419 Bytes
1fab54b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from model import Model
from buffer import Buffer
from game import Connect4
from mcts import MCTS_NN

import numpy as np
from typing import Tuple, List

class Agent:
    def __init__(self, row:int, col:int, n_action: int, obs_shape: Tuple[int, int, int],
                 model: Model, iteration: int, temperature:float):

        self.row = row
        self.col = col
        self.n_action = n_action
        self.obs_shape = obs_shape
        self.iteration = iteration
        self.temperature = temperature

        # Create buffer instance
        self.buffer = Buffer(n_action=self.n_action, obs_shape=self.obs_shape)

        # Target model instance
        self.target_model = model

    # Reset the MCTS class instance and buffer
    def reset(self, state: Connect4, reset_buffer: bool = False) -> None:
        # Reset the state of the Monte-carlo tree search instance
        self.mcts = MCTS_NN(state=state, model=self.target_model)

    # Reset the buffer
    def reset_buffer(self) -> None:
        self.buffer.reset()

    # Get the policy from mcts simulation
    def perform_mcts(self) -> np.ndarray:
        for _ in range(self.iteration):
            self.mcts.selection(self.mcts.root, add_dirichlet=True)

        policy = self.mcts.get_policy_pie(self.temperature)

        return policy

    # Get an action for any state
    def get_action(self) -> int:
        policy = self.perform_mcts()
        action = np.random.choice(self.n_action, p=policy)
        return action, policy

    # This method updates the buffer and send it to the buffer object
    def update_buffer(self, episodic_buffer: List)->None:
        # Get the last index of the episodic buffer
        idx = len(episodic_buffer) - 1

        # Always the last state will have value 1 as it would be the winning move
        value = 1
        while idx >= 0:
            episodic_buffer[idx][1] = value
            value *= -1 # For parent the value is negative
            idx -= 1 # Go to the previous experience tuple

        for state, value, policy in episodic_buffer:
            self.buffer.store_experience(
                state = state,
                value = value,
                policy = policy
            )

    # Update the root to set it to one of its child node
    # based on the actio taken in the above method `get_action()`
    def update(self, action: int) -> None:
        self.mcts.update_root(action)