File size: 14,686 Bytes
ce3feed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
"""
Author: Minh Pham-Dinh
Created: Jan 26th, 2024
Last Modified: Feb 10th, 2024
Email: mhpham26@colby.edu

Description:
    File containing all models that will be used in Dreamer.
    
    The implementation is based on:
    Hafner et al., "Dream to Control: Learning Behaviors by Latent Imagination," 2019. 
    [Online]. Available: https://arxiv.org/abs/1912.01603
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

def initialize_weights(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        nn.init.kaiming_uniform_(m.weight.data, nonlinearity="relu")
        nn.init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.Linear):
        nn.init.kaiming_uniform_(m.weight.data)
        nn.init.constant_(m.bias.data, 0)


class RSSM(nn.Module):
    """Reccurent State Space Model (RSSM)
    The main model that we will use to learn the latent dynamic of the environment
    """
    def __init__(self, stochastic_size, obs_embed_size, deterministic_size, hidden_size, action_size, activation=nn.ELU):
        super().__init__()
        self.stochastic_size = stochastic_size
        self.action_size = action_size
        self.deterministic_size = deterministic_size
        self.obs_embed_size = obs_embed_size
        self.action_size = action_size
        
        # recurrent
        self.recurrent_linear = nn.Sequential(
            nn.Linear(stochastic_size + action_size, hidden_size),
            activation(),
        )
        self.gru_cell = nn.GRUCell(hidden_size, deterministic_size)
        
        # representation model, for calculating posterior
        self.representatio_model = nn.Sequential(
            nn.Linear(deterministic_size + obs_embed_size, hidden_size),
            activation(),
            nn.Linear(hidden_size, stochastic_size*2)
        )
        
        # transition model, for calculating prior, use for imagining trajectories
        self.transition_model = nn.Sequential(
            nn.Linear(deterministic_size, hidden_size),
            activation(),
            nn.Linear(hidden_size, stochastic_size*2)
        )
        
        
        
    def recurrent(self, stoch_state, action, deterministic):
        """The recurrent model, calculate the deterministic state given the stochastic state
        the action, and the prior deterministic

        Args:
            a_t-1 (batch_size, action_size): action at time step, cannot be None.
            s_t-1 (batch_size, stoch_size): stochastic state at time step. Defaults to None.
            h_t-1 (batch_size, deterministic_size): deterministic at timestep. Defaults to None.

        Returns:
            h_t: deterministic at next time step
        """
        
        # initialize some sizes
        x = torch.cat((action, stoch_state), -1)
        out = self.recurrent_linear(x)
        out = self.gru_cell(out, deterministic)
        return out


    def representation(self, embed_obs, deterministic):
        """Calculate the distribution p of the stochastic state. 

        Args:
            o_t (batch_size, embeded_obs_size): embedded observation (encoded)
            h_t (batch_size, deterministic_size): determinstic size

        Returns:
            s_t posterior_distribution: distribution of stochastic states
            s_t posterior: sampled stochastic states
        """
        x = torch.cat((embed_obs, deterministic), -1)
        out = self.representatio_model(x)
        mean, std = torch.chunk(out, 2, -1)
        std = F.softplus(std) + 0.1
        
        post_dist = torch.distributions.Normal(mean, std)
        post = post_dist.rsample()
        
        return post_dist, post


    def transition(self, deterministic):
        """Calculate the distribution q of the stochastic state. 

        Args:
            h_t (batch_size, deterministic_size): determinstic size

        Returns:
            s_t prior_distribution: distribution of stochastic states
            s_t prior: sampled stochastic states
        """
        out = self.transition_model(deterministic)
        mean, std = torch.chunk(out, 2, -1)
        std = F.softplus(std) + 0.1
        
        prior_dist = torch.distributions.Normal(mean, std)
        prior = prior_dist.rsample()
        return prior_dist, prior
        

class ConvEncoder(nn.Module):
    def __init__(self, depth=32, input_shape=(3,64,64), activation=nn.ReLU):
        super().__init__()
        self.depth = depth
        self.input_shape = input_shape
        self.conv_layer = nn.Sequential(
            nn.Conv2d(
                in_channels=input_shape[0],
                out_channels=depth * 1,
                kernel_size=4,
                stride=2,
                padding="valid"
            ),
            activation(),
            nn.Conv2d(
                in_channels=depth * 1,
                out_channels=depth * 2,
                kernel_size=4,
                stride=2,
                padding="valid"
            ),
            activation(),
            nn.Conv2d(
                in_channels=depth * 2,
                out_channels=depth * 4,
                kernel_size=4,
                stride=2,
                padding="valid"
            ),
            activation(),
            nn.Conv2d(
                in_channels=depth * 4,
                out_channels=depth * 8,
                kernel_size=4,
                stride=2,
                padding="valid"
            ),
            activation()
        )
        self.conv_layer.apply(initialize_weights)
        
        
    def forward(self, x):
        batch_shape = x.shape[:-len(self.input_shape)]
        if not batch_shape:
            batch_shape = (1, )
        
        x = x.reshape(-1, *self.input_shape)
        
        out = self.conv_layer(x)
        
        #flatten output
        return out.reshape(*batch_shape, -1)
    

class ConvDecoder(nn.Module):
    """Decode latent dynamic
    Also referred to as observation model by the official Dreamer paper
    
    """
    def __init__(self, stochastic_size, deterministic_size, depth=32, out_shape=(3,64,64), activation=nn.ReLU):
        super().__init__()
        self.out_shape = out_shape
        self.net = nn.Sequential(
            nn.Linear(deterministic_size + stochastic_size, depth*32),
            nn.Unflatten(1, (depth * 32, 1)),
            nn.Unflatten(2, (1, 1)),
            nn.ConvTranspose2d(
                depth * 32,
                depth * 4,
                kernel_size=5,
                stride=2,
            ),
            activation(),
            nn.ConvTranspose2d(
                depth * 4,
                depth * 2,
                kernel_size=5,
                stride=2,
            ),
            activation(),
            nn.ConvTranspose2d(
                depth * 2,
                depth * 1,
                kernel_size=5 + 1,
                stride=2,
            ),
            activation(),
            nn.ConvTranspose2d(
                depth * 1,
                out_shape[0],
                kernel_size=5+1,
                stride=2,
            ),
        )
        self.net.apply(initialize_weights)
        
        
        
    def forward(self, posterior, deterministic, mps_flatten=False):
        """take in the stochastic state (posterior) and deterministic to construct the latent state then 
        output reconstructed pixel observation

        Args:
            s_t (batch_sz, stoch_size): stochastic state (or posterior)
            h_t (batch_sz, deterministic_size): deterministic state
            mps_flatten (boolean): whether to flattening the output for mps device or not. This is because M1 GPU can
                                   only support max 4 dimension (stupid af)
        Returns:
            o'_t: reconstructed_obs
        """
        x = torch.cat((posterior, deterministic), -1)
        batch_shape = x.shape[:-1]
        if not batch_shape:
            batch_shape = (1, )
        
        x = x.reshape(-1, x.shape[-1])
        
        if mps_flatten:
            batch_shape = (-1, )
        
        mean = self.net(x).reshape(*batch_shape, *self.out_shape)
        
        dist = torch.distributions.Normal(mean, 1)
        
        # #flatten output
        return torch.distributions.Independent(dist, len(self.out_shape))
    
    
class RewardNet(nn.Module):
    """reward prediction model. It take in the stochastic state and the deterministic to construct
    latent state. It then output the reward prediciton

    Args:
        nn (_type_): _description_
    """
    def __init__(self, input_size, hidden_size, activation=nn.ELU):
        super().__init__()
        
        self.net = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            activation(),
            nn.Linear(hidden_size, 1)
        )
        
        
    def forward(self, stoch_state, deterministic):
        """take in the stochastic state and deterministic to construct the latent state then 
        output reard prediction

        Args:
            s_t (batch_sz, stoch_size): stochastic state (or posterior)
            h_t (batch_sz, deterministic_size): deterministic state
            
        Returns:
            r_t: rewards
        """
        x = torch.cat((stoch_state, deterministic), -1)
        batch_shape = x.shape[:-1]
        if not batch_shape:
            batch_shape = (1, )

        x = x.reshape(-1, x.shape[-1])
        
        return self.net(x).reshape(*batch_shape, 1)
    

class ContinuoNet(nn.Module):
    """continuity prediction model. It take in the stochastic state and the deterministic to construct
    latent state. It then output the prediction of whether the termination state has been reached

    Args:
        nn (_type_): _description_
    """
    def __init__(self, input_size, hidden_size, activation=nn.ELU):
        super().__init__()
        
        self.net = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            activation(),
            nn.Linear(hidden_size, hidden_size),
            activation(),
            nn.Linear(hidden_size, 1)
        )
        
        
    def forward(self, stoch_state, deterministic):
        """take in the stochastic state and deterministic to construct the latent state then 
        output reard prediction

        Args:
            s_t stoch_state (batch_sz, stoch_size): stochastic state (or posterior)
            h_t deterministic (batch_sz, deterministic_size): deterministic state
            
        Returns:
            dist: Beurnoulli distribution of done
        """
        x = torch.cat((stoch_state, deterministic), -1)
        batch_shape = x.shape[:-1]
        if not batch_shape:
            batch_shape = (1, )

        x = x.reshape(-1, x.shape[-1])
        
        x = self.net(x).reshape(*batch_shape, 1)
        return x, torch.distributions.Independent(torch.distributions.Bernoulli(logits=x), 1)
    
    
class Actor(nn.Module):
    """actor network
    """
    def __init__(self,
                 latent_size,
                 hidden_size,
                 action_size, 
                 discrete=True, 
                 activation=nn.ELU, 
                 min_std=1e-4, 
                 init_std=5, 
                 mean_scale=5):
        
        super().__init__()
        self.latent_size = latent_size
        self.hidden_size = hidden_size
        self.action_size = (action_size if discrete else action_size*2)
        self.discrete = discrete
        self.min_std=min_std
        self.init_std = init_std
        self.mean_scale = mean_scale
        
        self.net = nn.Sequential(
            nn.Linear(latent_size, hidden_size),
            activation(),
            nn.Linear(hidden_size, self.action_size)
        )
    
        
    def forward(self, stoch_state, deterministic):
        """actor network. get in stochastic state and deterministic state to construct latent state
            and then use latent state to predict appropriate action

        Args:
            s_t stoch_state (batch_sz, stoch_size): stochastic state (or posterior)
            h_t deterministic (batch_sz, deterministic_size): deterministic state
            
        Returns:
            action distribution. OneHot if discrete, else is tanhNormal
        """
        latent_state = torch.cat((stoch_state, deterministic), -1)
        x = self.net(latent_state)
        
        if self.discrete:
            # straight through gradient (mentioned in DreamerV2)
            dist = torch.distributions.OneHotCategorical(logits=x)
            action = dist.sample() + dist.probs - dist.probs.detach()
        else:
            #ensure that the softplut output proper init_std
            raw_init_std = np.log(np.exp(self.init_std) - 1)
            
            mean, std = torch.chunk(x, 2, -1)
            mean = self.mean_scale * F.tanh(mean / self.mean_scale)
            std = F.softplus(std + raw_init_std) + self.min_std
            
            dist = torch.distributions.Normal(mean, std)
            dist = torch.distributions.TransformedDistribution(dist, torch.distributions.TanhTransform())
            action = torch.distributions.Independent(dist, 1).rsample()

        return action
    
    
class Critic(nn.Module):
    """
    critic network
    """
    def __init__(self, latent_size, hidden_size, activation=nn.ELU):
        super().__init__()
        self.latent_size = latent_size
        
        self.net = nn.Sequential(
            nn.Linear(latent_size, hidden_size),
            activation(),
            nn.Linear(hidden_size, hidden_size),
            activation(),
            nn.Linear(hidden_size, 1)
        )
        
        
        
    def forward(self, stoch_state, deterministic):
        """critic network. get in stochastic state and deterministic state to construct latent state
            and then use latent state to predict state value

        Args:
            s_t stoch_state (batch_sz, seq_len, stoch_size): stochastic state (or posterior)
            h_t deterministic (batch_sz, seq_len, deterministic_size): deterministic state
            
        Returns:
            state value distribution. 
        """
        latent_state = torch.cat((stoch_state, deterministic), -1)

        batch_shape = latent_state.shape[:-1]
        if not batch_shape:
            batch_shape = (1, )
        
        latent_state = latent_state.reshape(-1, self.latent_size)
        
        x = self.net(latent_state)
        
        x = x.reshape(*batch_shape, 1)
        
        dist = torch.distributions.Normal(x, 1)
        dist = torch.distributions.Independent(dist, 1)
        
        return dist