File size: 5,664 Bytes
ce3feed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
"""
Author: Minh Pham-Dinh
Created: Jan 26th, 2024
Last Modified: Feb 5th, 2024
Email: mhpham26@colby.edu

Description:
    File containing the ReplayBuffer that will be used in Dreamer.
    
    The implementation is based on:
    Hafner et al., "Dream to Control: Learning Behaviors by Latent Imagination," 2019. 
    [Online]. Available: https://arxiv.org/abs/1912.01603
"""

import numpy as np
from gymnasium import Env
import torch
from addict import Dict

class ReplayBuffer:
    def __init__(self, capacity, obs_size, action_size):
        
        # check if the env is gymnasium or dmc
        self.obs_size = obs_size
        self.action_size = action_size

        # from SimpleDreamer implementation, saving memory
        state_type = np.uint8 if len(self.obs_size) < 3 else np.float32
        
        self.observation = np.zeros((capacity, ) + self.obs_size, dtype=state_type)
        
        self.actions = np.zeros((capacity, ) + self.action_size, dtype=np.float32)
        self.rewards = np.zeros((capacity, 1), dtype=np.float32)
        self.dones = np.zeros((capacity, 1), dtype=np.float32)

        self.pointer = 0
        self.full = False
    
        print(f'''
-----------initialized memory----------              

obs_buffer_shape: {self.observation.shape}
actions_buffer_shape: {self.actions.shape}
rewards_buffer_shape: {self.rewards.shape}
dones_buffer_shape: {self.dones.shape}

----------------------------------------
              ''')

    def add(self, obs, action, reward, done):
        """Add method for buffer

        Args:
            obs (np.array): current observation
            action (np.array): action taken
            reward (float): reward received after action
            next_obs (np.array): next observation
            done (bool): boolean value of termination or truncation
        """
        self.observation[self.pointer] = obs
        self.actions[self.pointer] = action
        self.rewards[self.pointer] = reward
        self.dones[self.pointer] = done
        self.pointer = (self.pointer + 1) % self.observation.shape[0]
        if self.pointer == 0:
            self.full = True

    def sample(self, batch_size, seq_len, device):
        """
        Samples batches of experiences of fixed sequence length from the replay buffer, 
        taking into account the circular nature of the buffer to avoid crossing the 
        "end" of the buffer when it is full.

        This method ensures that sampled sequences are continuous and do not wrap around 
        the end of the buffer, maintaining the temporal integrity of experiences. This is 
        particularly important when the buffer is full, and the pointer marks the boundary 
        between the newest and oldest data in the buffer.

        Args:
            batch_size (int): The number of sequences to sample.
            seq_len (int): The length of each sequence to sample.
            device (torch.device): The device on which the sampled data will be loaded.

        Raises:
            Exception: If there is not enough data in the buffer to sample a full sequence.

        Returns:
            Dict: A dictionary containing the sampled sequences of observations, actions, 
            rewards, and dones. Each item in the dictionary is a tensor of shape 
            (batch_size, seq_len, feature_dimension), except for 'dones' which is of shape 
            (batch_size, seq_len, 1).

        Notes:
            - The method handles different scenarios based on the buffer's state (full or not) 
            and the pointer's position to ensure valid sequence sampling without wrapping.
            - When the buffer is not full, sequences can start from index 0 up to the 
            index where `seq_len` sequences can fit without surpassing the current pointer.
            - When the buffer is full, the method ensures sequences do not start in a way 
            that would cause them to wrap around past the pointer, effectively crossing 
            the boundary between the newest and oldest data.
            - This approach guarantees the sampled sequences respect the temporal order 
            and continuity necessary for algorithms that rely on sequences of experiences.
        """
        
        # Ensure there's enough data to sample
        if self.pointer < seq_len and not self.full:
            raise Exception('not enough data to sample')

        # detail: handling different cases for circular sampling
        if self.full:
            if self.pointer - seq_len < 0:
                valid_range = np.arange(self.pointer, self.observation.shape[0] - (self.pointer - seq_len) + 1)
            else:
                range_1 = np.arange(0, self.pointer - seq_len + 1)
                range_2 = np.arange(self.pointer, self.observation.shape[0])
                valid_range = np.concatenate((range_1, range_2), -1)
        else:
            valid_range = np.arange(0, self.pointer-seq_len+1)

        start_index = np.random.choice(valid_range, (batch_size, 1))
        
        seq_len = np.arange(seq_len)
        sample_idcs = (start_index + seq_len) % self.observation.shape[0]
        
        batch = Dict()
        
        batch.obs = torch.from_numpy(self.observation[sample_idcs]).to(device)
        batch.actions = torch.from_numpy(self.actions[sample_idcs]).to(device)
        batch.rewards = torch.from_numpy(self.rewards[sample_idcs]).to(device)
        batch.dones = torch.from_numpy(self.dones[sample_idcs]).to(device)
        
        return batch
    
    def clear(self, ):
        self.pointer = 0
        self.full = False

    def __len__(self, ):
        return self.pointer