File size: 1,454 Bytes
1fab54b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from typing import Tuple
import numpy as np

class Buffer:
    def __init__(self, n_action: int, obs_shape: Tuple[int, int, int]):
        self.n_action = n_action
        self.obs_shape = obs_shape
        self.mem_size = 0

        # Creating empty lists for storing value. Provide dynamicness
        self.state = []
        self.value = []
        self.policy = []

    def store_experience(self, state: np.ndarray, value: float, policy: np.ndarray):
        self.state.append(state)
        self.value.append(value)
        self.policy.append(policy)

        self.mem_size += 1

    def sample(self, batch_size: int) -> Tuple[
        np.ndarray,
        np.ndarray,
        np.ndarray
    ]: # type: ignore
        # shuffle the memmory
        np.random.shuffle(self.state)
        np.random.shuffle(self.value)
        np.random.shuffle(self.policy)

        for start_idx in range(0, self.mem_size, batch_size):
            end_idx = min(start_idx+batch_size, self.mem_size)
            s = self.state[start_idx:end_idx]
            v = self.value[start_idx:end_idx]
            p = self.policy[start_idx:end_idx]

            yield (np.array(s), np.array(v), np.array(p))

    # Reset the the buffer to store new experience
    def reset(self) -> None:
        self.state = []
        self.value = []
        self.policy = []

        self.mem_size = 0

    # Return the length of the buffer
    def __len__(self):
        return self.mem_size