|
|
""" |
|
|
Author: Minh Pham-Dinh |
|
|
Created: Jan 26th, 2024 |
|
|
Last Modified: Feb 10th, 2024 |
|
|
Email: mhpham26@colby.edu |
|
|
|
|
|
Description: |
|
|
File containing all models that will be used in Dreamer. |
|
|
|
|
|
The implementation is based on: |
|
|
Hafner et al., "Dream to Control: Learning Behaviors by Latent Imagination," 2019. |
|
|
[Online]. Available: https://arxiv.org/abs/1912.01603 |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
|
|
|
def initialize_weights(m): |
|
|
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): |
|
|
nn.init.kaiming_uniform_(m.weight.data, nonlinearity="relu") |
|
|
nn.init.constant_(m.bias.data, 0) |
|
|
elif isinstance(m, nn.Linear): |
|
|
nn.init.kaiming_uniform_(m.weight.data) |
|
|
nn.init.constant_(m.bias.data, 0) |
|
|
|
|
|
|
|
|
class RSSM(nn.Module): |
|
|
"""Reccurent State Space Model (RSSM) |
|
|
The main model that we will use to learn the latent dynamic of the environment |
|
|
""" |
|
|
def __init__(self, stochastic_size, obs_embed_size, deterministic_size, hidden_size, action_size, activation=nn.ELU): |
|
|
super().__init__() |
|
|
self.stochastic_size = stochastic_size |
|
|
self.action_size = action_size |
|
|
self.deterministic_size = deterministic_size |
|
|
self.obs_embed_size = obs_embed_size |
|
|
self.action_size = action_size |
|
|
|
|
|
|
|
|
self.recurrent_linear = nn.Sequential( |
|
|
nn.Linear(stochastic_size + action_size, hidden_size), |
|
|
activation(), |
|
|
) |
|
|
self.gru_cell = nn.GRUCell(hidden_size, deterministic_size) |
|
|
|
|
|
|
|
|
self.representatio_model = nn.Sequential( |
|
|
nn.Linear(deterministic_size + obs_embed_size, hidden_size), |
|
|
activation(), |
|
|
nn.Linear(hidden_size, stochastic_size*2) |
|
|
) |
|
|
|
|
|
|
|
|
self.transition_model = nn.Sequential( |
|
|
nn.Linear(deterministic_size, hidden_size), |
|
|
activation(), |
|
|
nn.Linear(hidden_size, stochastic_size*2) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def recurrent(self, stoch_state, action, deterministic): |
|
|
"""The recurrent model, calculate the deterministic state given the stochastic state |
|
|
the action, and the prior deterministic |
|
|
|
|
|
Args: |
|
|
a_t-1 (batch_size, action_size): action at time step, cannot be None. |
|
|
s_t-1 (batch_size, stoch_size): stochastic state at time step. Defaults to None. |
|
|
h_t-1 (batch_size, deterministic_size): deterministic at timestep. Defaults to None. |
|
|
|
|
|
Returns: |
|
|
h_t: deterministic at next time step |
|
|
""" |
|
|
|
|
|
|
|
|
x = torch.cat((action, stoch_state), -1) |
|
|
out = self.recurrent_linear(x) |
|
|
out = self.gru_cell(out, deterministic) |
|
|
return out |
|
|
|
|
|
|
|
|
def representation(self, embed_obs, deterministic): |
|
|
"""Calculate the distribution p of the stochastic state. |
|
|
|
|
|
Args: |
|
|
o_t (batch_size, embeded_obs_size): embedded observation (encoded) |
|
|
h_t (batch_size, deterministic_size): determinstic size |
|
|
|
|
|
Returns: |
|
|
s_t posterior_distribution: distribution of stochastic states |
|
|
s_t posterior: sampled stochastic states |
|
|
""" |
|
|
x = torch.cat((embed_obs, deterministic), -1) |
|
|
out = self.representatio_model(x) |
|
|
mean, std = torch.chunk(out, 2, -1) |
|
|
std = F.softplus(std) + 0.1 |
|
|
|
|
|
post_dist = torch.distributions.Normal(mean, std) |
|
|
post = post_dist.rsample() |
|
|
|
|
|
return post_dist, post |
|
|
|
|
|
|
|
|
def transition(self, deterministic): |
|
|
"""Calculate the distribution q of the stochastic state. |
|
|
|
|
|
Args: |
|
|
h_t (batch_size, deterministic_size): determinstic size |
|
|
|
|
|
Returns: |
|
|
s_t prior_distribution: distribution of stochastic states |
|
|
s_t prior: sampled stochastic states |
|
|
""" |
|
|
out = self.transition_model(deterministic) |
|
|
mean, std = torch.chunk(out, 2, -1) |
|
|
std = F.softplus(std) + 0.1 |
|
|
|
|
|
prior_dist = torch.distributions.Normal(mean, std) |
|
|
prior = prior_dist.rsample() |
|
|
return prior_dist, prior |
|
|
|
|
|
|
|
|
class ConvEncoder(nn.Module): |
|
|
def __init__(self, depth=32, input_shape=(3,64,64), activation=nn.ReLU): |
|
|
super().__init__() |
|
|
self.depth = depth |
|
|
self.input_shape = input_shape |
|
|
self.conv_layer = nn.Sequential( |
|
|
nn.Conv2d( |
|
|
in_channels=input_shape[0], |
|
|
out_channels=depth * 1, |
|
|
kernel_size=4, |
|
|
stride=2, |
|
|
padding="valid" |
|
|
), |
|
|
activation(), |
|
|
nn.Conv2d( |
|
|
in_channels=depth * 1, |
|
|
out_channels=depth * 2, |
|
|
kernel_size=4, |
|
|
stride=2, |
|
|
padding="valid" |
|
|
), |
|
|
activation(), |
|
|
nn.Conv2d( |
|
|
in_channels=depth * 2, |
|
|
out_channels=depth * 4, |
|
|
kernel_size=4, |
|
|
stride=2, |
|
|
padding="valid" |
|
|
), |
|
|
activation(), |
|
|
nn.Conv2d( |
|
|
in_channels=depth * 4, |
|
|
out_channels=depth * 8, |
|
|
kernel_size=4, |
|
|
stride=2, |
|
|
padding="valid" |
|
|
), |
|
|
activation() |
|
|
) |
|
|
self.conv_layer.apply(initialize_weights) |
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
batch_shape = x.shape[:-len(self.input_shape)] |
|
|
if not batch_shape: |
|
|
batch_shape = (1, ) |
|
|
|
|
|
x = x.reshape(-1, *self.input_shape) |
|
|
|
|
|
out = self.conv_layer(x) |
|
|
|
|
|
|
|
|
return out.reshape(*batch_shape, -1) |
|
|
|
|
|
|
|
|
class ConvDecoder(nn.Module): |
|
|
"""Decode latent dynamic |
|
|
Also referred to as observation model by the official Dreamer paper |
|
|
|
|
|
""" |
|
|
def __init__(self, stochastic_size, deterministic_size, depth=32, out_shape=(3,64,64), activation=nn.ReLU): |
|
|
super().__init__() |
|
|
self.out_shape = out_shape |
|
|
self.net = nn.Sequential( |
|
|
nn.Linear(deterministic_size + stochastic_size, depth*32), |
|
|
nn.Unflatten(1, (depth * 32, 1)), |
|
|
nn.Unflatten(2, (1, 1)), |
|
|
nn.ConvTranspose2d( |
|
|
depth * 32, |
|
|
depth * 4, |
|
|
kernel_size=5, |
|
|
stride=2, |
|
|
), |
|
|
activation(), |
|
|
nn.ConvTranspose2d( |
|
|
depth * 4, |
|
|
depth * 2, |
|
|
kernel_size=5, |
|
|
stride=2, |
|
|
), |
|
|
activation(), |
|
|
nn.ConvTranspose2d( |
|
|
depth * 2, |
|
|
depth * 1, |
|
|
kernel_size=5 + 1, |
|
|
stride=2, |
|
|
), |
|
|
activation(), |
|
|
nn.ConvTranspose2d( |
|
|
depth * 1, |
|
|
out_shape[0], |
|
|
kernel_size=5+1, |
|
|
stride=2, |
|
|
), |
|
|
) |
|
|
self.net.apply(initialize_weights) |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, posterior, deterministic, mps_flatten=False): |
|
|
"""take in the stochastic state (posterior) and deterministic to construct the latent state then |
|
|
output reconstructed pixel observation |
|
|
|
|
|
Args: |
|
|
s_t (batch_sz, stoch_size): stochastic state (or posterior) |
|
|
h_t (batch_sz, deterministic_size): deterministic state |
|
|
mps_flatten (boolean): whether to flattening the output for mps device or not. This is because M1 GPU can |
|
|
only support max 4 dimension (stupid af) |
|
|
Returns: |
|
|
o'_t: reconstructed_obs |
|
|
""" |
|
|
x = torch.cat((posterior, deterministic), -1) |
|
|
batch_shape = x.shape[:-1] |
|
|
if not batch_shape: |
|
|
batch_shape = (1, ) |
|
|
|
|
|
x = x.reshape(-1, x.shape[-1]) |
|
|
|
|
|
if mps_flatten: |
|
|
batch_shape = (-1, ) |
|
|
|
|
|
mean = self.net(x).reshape(*batch_shape, *self.out_shape) |
|
|
|
|
|
dist = torch.distributions.Normal(mean, 1) |
|
|
|
|
|
|
|
|
return torch.distributions.Independent(dist, len(self.out_shape)) |
|
|
|
|
|
|
|
|
class RewardNet(nn.Module): |
|
|
"""reward prediction model. It take in the stochastic state and the deterministic to construct |
|
|
latent state. It then output the reward prediciton |
|
|
|
|
|
Args: |
|
|
nn (_type_): _description_ |
|
|
""" |
|
|
def __init__(self, input_size, hidden_size, activation=nn.ELU): |
|
|
super().__init__() |
|
|
|
|
|
self.net = nn.Sequential( |
|
|
nn.Linear(input_size, hidden_size), |
|
|
activation(), |
|
|
nn.Linear(hidden_size, 1) |
|
|
) |
|
|
|
|
|
|
|
|
def forward(self, stoch_state, deterministic): |
|
|
"""take in the stochastic state and deterministic to construct the latent state then |
|
|
output reard prediction |
|
|
|
|
|
Args: |
|
|
s_t (batch_sz, stoch_size): stochastic state (or posterior) |
|
|
h_t (batch_sz, deterministic_size): deterministic state |
|
|
|
|
|
Returns: |
|
|
r_t: rewards |
|
|
""" |
|
|
x = torch.cat((stoch_state, deterministic), -1) |
|
|
batch_shape = x.shape[:-1] |
|
|
if not batch_shape: |
|
|
batch_shape = (1, ) |
|
|
|
|
|
x = x.reshape(-1, x.shape[-1]) |
|
|
|
|
|
return self.net(x).reshape(*batch_shape, 1) |
|
|
|
|
|
|
|
|
class ContinuoNet(nn.Module): |
|
|
"""continuity prediction model. It take in the stochastic state and the deterministic to construct |
|
|
latent state. It then output the prediction of whether the termination state has been reached |
|
|
|
|
|
Args: |
|
|
nn (_type_): _description_ |
|
|
""" |
|
|
def __init__(self, input_size, hidden_size, activation=nn.ELU): |
|
|
super().__init__() |
|
|
|
|
|
self.net = nn.Sequential( |
|
|
nn.Linear(input_size, hidden_size), |
|
|
activation(), |
|
|
nn.Linear(hidden_size, hidden_size), |
|
|
activation(), |
|
|
nn.Linear(hidden_size, 1) |
|
|
) |
|
|
|
|
|
|
|
|
def forward(self, stoch_state, deterministic): |
|
|
"""take in the stochastic state and deterministic to construct the latent state then |
|
|
output reard prediction |
|
|
|
|
|
Args: |
|
|
s_t stoch_state (batch_sz, stoch_size): stochastic state (or posterior) |
|
|
h_t deterministic (batch_sz, deterministic_size): deterministic state |
|
|
|
|
|
Returns: |
|
|
dist: Beurnoulli distribution of done |
|
|
""" |
|
|
x = torch.cat((stoch_state, deterministic), -1) |
|
|
batch_shape = x.shape[:-1] |
|
|
if not batch_shape: |
|
|
batch_shape = (1, ) |
|
|
|
|
|
x = x.reshape(-1, x.shape[-1]) |
|
|
|
|
|
x = self.net(x).reshape(*batch_shape, 1) |
|
|
return x, torch.distributions.Independent(torch.distributions.Bernoulli(logits=x), 1) |
|
|
|
|
|
|
|
|
class Actor(nn.Module): |
|
|
"""actor network |
|
|
""" |
|
|
def __init__(self, |
|
|
latent_size, |
|
|
hidden_size, |
|
|
action_size, |
|
|
discrete=True, |
|
|
activation=nn.ELU, |
|
|
min_std=1e-4, |
|
|
init_std=5, |
|
|
mean_scale=5): |
|
|
|
|
|
super().__init__() |
|
|
self.latent_size = latent_size |
|
|
self.hidden_size = hidden_size |
|
|
self.action_size = (action_size if discrete else action_size*2) |
|
|
self.discrete = discrete |
|
|
self.min_std=min_std |
|
|
self.init_std = init_std |
|
|
self.mean_scale = mean_scale |
|
|
|
|
|
self.net = nn.Sequential( |
|
|
nn.Linear(latent_size, hidden_size), |
|
|
activation(), |
|
|
nn.Linear(hidden_size, self.action_size) |
|
|
) |
|
|
|
|
|
|
|
|
def forward(self, stoch_state, deterministic): |
|
|
"""actor network. get in stochastic state and deterministic state to construct latent state |
|
|
and then use latent state to predict appropriate action |
|
|
|
|
|
Args: |
|
|
s_t stoch_state (batch_sz, stoch_size): stochastic state (or posterior) |
|
|
h_t deterministic (batch_sz, deterministic_size): deterministic state |
|
|
|
|
|
Returns: |
|
|
action distribution. OneHot if discrete, else is tanhNormal |
|
|
""" |
|
|
latent_state = torch.cat((stoch_state, deterministic), -1) |
|
|
x = self.net(latent_state) |
|
|
|
|
|
if self.discrete: |
|
|
|
|
|
dist = torch.distributions.OneHotCategorical(logits=x) |
|
|
action = dist.sample() + dist.probs - dist.probs.detach() |
|
|
else: |
|
|
|
|
|
raw_init_std = np.log(np.exp(self.init_std) - 1) |
|
|
|
|
|
mean, std = torch.chunk(x, 2, -1) |
|
|
mean = self.mean_scale * F.tanh(mean / self.mean_scale) |
|
|
std = F.softplus(std + raw_init_std) + self.min_std |
|
|
|
|
|
dist = torch.distributions.Normal(mean, std) |
|
|
dist = torch.distributions.TransformedDistribution(dist, torch.distributions.TanhTransform()) |
|
|
action = torch.distributions.Independent(dist, 1).rsample() |
|
|
|
|
|
return action |
|
|
|
|
|
|
|
|
class Critic(nn.Module): |
|
|
""" |
|
|
critic network |
|
|
""" |
|
|
def __init__(self, latent_size, hidden_size, activation=nn.ELU): |
|
|
super().__init__() |
|
|
self.latent_size = latent_size |
|
|
|
|
|
self.net = nn.Sequential( |
|
|
nn.Linear(latent_size, hidden_size), |
|
|
activation(), |
|
|
nn.Linear(hidden_size, hidden_size), |
|
|
activation(), |
|
|
nn.Linear(hidden_size, 1) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, stoch_state, deterministic): |
|
|
"""critic network. get in stochastic state and deterministic state to construct latent state |
|
|
and then use latent state to predict state value |
|
|
|
|
|
Args: |
|
|
s_t stoch_state (batch_sz, seq_len, stoch_size): stochastic state (or posterior) |
|
|
h_t deterministic (batch_sz, seq_len, deterministic_size): deterministic state |
|
|
|
|
|
Returns: |
|
|
state value distribution. |
|
|
""" |
|
|
latent_state = torch.cat((stoch_state, deterministic), -1) |
|
|
|
|
|
batch_shape = latent_state.shape[:-1] |
|
|
if not batch_shape: |
|
|
batch_shape = (1, ) |
|
|
|
|
|
latent_state = latent_state.reshape(-1, self.latent_size) |
|
|
|
|
|
x = self.net(latent_state) |
|
|
|
|
|
x = x.reshape(*batch_shape, 1) |
|
|
|
|
|
dist = torch.distributions.Normal(x, 1) |
|
|
dist = torch.distributions.Independent(dist, 1) |
|
|
|
|
|
return dist |
|
|
|