PyDreamerV1 / utils /models.py
minhphd's picture
Upload 30 files
ce3feed verified
"""
Author: Minh Pham-Dinh
Created: Jan 26th, 2024
Last Modified: Feb 10th, 2024
Email: mhpham26@colby.edu
Description:
File containing all models that will be used in Dreamer.
The implementation is based on:
Hafner et al., "Dream to Control: Learning Behaviors by Latent Imagination," 2019.
[Online]. Available: https://arxiv.org/abs/1912.01603
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
def initialize_weights(m):
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
nn.init.kaiming_uniform_(m.weight.data, nonlinearity="relu")
nn.init.constant_(m.bias.data, 0)
elif isinstance(m, nn.Linear):
nn.init.kaiming_uniform_(m.weight.data)
nn.init.constant_(m.bias.data, 0)
class RSSM(nn.Module):
"""Reccurent State Space Model (RSSM)
The main model that we will use to learn the latent dynamic of the environment
"""
def __init__(self, stochastic_size, obs_embed_size, deterministic_size, hidden_size, action_size, activation=nn.ELU):
super().__init__()
self.stochastic_size = stochastic_size
self.action_size = action_size
self.deterministic_size = deterministic_size
self.obs_embed_size = obs_embed_size
self.action_size = action_size
# recurrent
self.recurrent_linear = nn.Sequential(
nn.Linear(stochastic_size + action_size, hidden_size),
activation(),
)
self.gru_cell = nn.GRUCell(hidden_size, deterministic_size)
# representation model, for calculating posterior
self.representatio_model = nn.Sequential(
nn.Linear(deterministic_size + obs_embed_size, hidden_size),
activation(),
nn.Linear(hidden_size, stochastic_size*2)
)
# transition model, for calculating prior, use for imagining trajectories
self.transition_model = nn.Sequential(
nn.Linear(deterministic_size, hidden_size),
activation(),
nn.Linear(hidden_size, stochastic_size*2)
)
def recurrent(self, stoch_state, action, deterministic):
"""The recurrent model, calculate the deterministic state given the stochastic state
the action, and the prior deterministic
Args:
a_t-1 (batch_size, action_size): action at time step, cannot be None.
s_t-1 (batch_size, stoch_size): stochastic state at time step. Defaults to None.
h_t-1 (batch_size, deterministic_size): deterministic at timestep. Defaults to None.
Returns:
h_t: deterministic at next time step
"""
# initialize some sizes
x = torch.cat((action, stoch_state), -1)
out = self.recurrent_linear(x)
out = self.gru_cell(out, deterministic)
return out
def representation(self, embed_obs, deterministic):
"""Calculate the distribution p of the stochastic state.
Args:
o_t (batch_size, embeded_obs_size): embedded observation (encoded)
h_t (batch_size, deterministic_size): determinstic size
Returns:
s_t posterior_distribution: distribution of stochastic states
s_t posterior: sampled stochastic states
"""
x = torch.cat((embed_obs, deterministic), -1)
out = self.representatio_model(x)
mean, std = torch.chunk(out, 2, -1)
std = F.softplus(std) + 0.1
post_dist = torch.distributions.Normal(mean, std)
post = post_dist.rsample()
return post_dist, post
def transition(self, deterministic):
"""Calculate the distribution q of the stochastic state.
Args:
h_t (batch_size, deterministic_size): determinstic size
Returns:
s_t prior_distribution: distribution of stochastic states
s_t prior: sampled stochastic states
"""
out = self.transition_model(deterministic)
mean, std = torch.chunk(out, 2, -1)
std = F.softplus(std) + 0.1
prior_dist = torch.distributions.Normal(mean, std)
prior = prior_dist.rsample()
return prior_dist, prior
class ConvEncoder(nn.Module):
def __init__(self, depth=32, input_shape=(3,64,64), activation=nn.ReLU):
super().__init__()
self.depth = depth
self.input_shape = input_shape
self.conv_layer = nn.Sequential(
nn.Conv2d(
in_channels=input_shape[0],
out_channels=depth * 1,
kernel_size=4,
stride=2,
padding="valid"
),
activation(),
nn.Conv2d(
in_channels=depth * 1,
out_channels=depth * 2,
kernel_size=4,
stride=2,
padding="valid"
),
activation(),
nn.Conv2d(
in_channels=depth * 2,
out_channels=depth * 4,
kernel_size=4,
stride=2,
padding="valid"
),
activation(),
nn.Conv2d(
in_channels=depth * 4,
out_channels=depth * 8,
kernel_size=4,
stride=2,
padding="valid"
),
activation()
)
self.conv_layer.apply(initialize_weights)
def forward(self, x):
batch_shape = x.shape[:-len(self.input_shape)]
if not batch_shape:
batch_shape = (1, )
x = x.reshape(-1, *self.input_shape)
out = self.conv_layer(x)
#flatten output
return out.reshape(*batch_shape, -1)
class ConvDecoder(nn.Module):
"""Decode latent dynamic
Also referred to as observation model by the official Dreamer paper
"""
def __init__(self, stochastic_size, deterministic_size, depth=32, out_shape=(3,64,64), activation=nn.ReLU):
super().__init__()
self.out_shape = out_shape
self.net = nn.Sequential(
nn.Linear(deterministic_size + stochastic_size, depth*32),
nn.Unflatten(1, (depth * 32, 1)),
nn.Unflatten(2, (1, 1)),
nn.ConvTranspose2d(
depth * 32,
depth * 4,
kernel_size=5,
stride=2,
),
activation(),
nn.ConvTranspose2d(
depth * 4,
depth * 2,
kernel_size=5,
stride=2,
),
activation(),
nn.ConvTranspose2d(
depth * 2,
depth * 1,
kernel_size=5 + 1,
stride=2,
),
activation(),
nn.ConvTranspose2d(
depth * 1,
out_shape[0],
kernel_size=5+1,
stride=2,
),
)
self.net.apply(initialize_weights)
def forward(self, posterior, deterministic, mps_flatten=False):
"""take in the stochastic state (posterior) and deterministic to construct the latent state then
output reconstructed pixel observation
Args:
s_t (batch_sz, stoch_size): stochastic state (or posterior)
h_t (batch_sz, deterministic_size): deterministic state
mps_flatten (boolean): whether to flattening the output for mps device or not. This is because M1 GPU can
only support max 4 dimension (stupid af)
Returns:
o'_t: reconstructed_obs
"""
x = torch.cat((posterior, deterministic), -1)
batch_shape = x.shape[:-1]
if not batch_shape:
batch_shape = (1, )
x = x.reshape(-1, x.shape[-1])
if mps_flatten:
batch_shape = (-1, )
mean = self.net(x).reshape(*batch_shape, *self.out_shape)
dist = torch.distributions.Normal(mean, 1)
# #flatten output
return torch.distributions.Independent(dist, len(self.out_shape))
class RewardNet(nn.Module):
"""reward prediction model. It take in the stochastic state and the deterministic to construct
latent state. It then output the reward prediciton
Args:
nn (_type_): _description_
"""
def __init__(self, input_size, hidden_size, activation=nn.ELU):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_size, hidden_size),
activation(),
nn.Linear(hidden_size, 1)
)
def forward(self, stoch_state, deterministic):
"""take in the stochastic state and deterministic to construct the latent state then
output reard prediction
Args:
s_t (batch_sz, stoch_size): stochastic state (or posterior)
h_t (batch_sz, deterministic_size): deterministic state
Returns:
r_t: rewards
"""
x = torch.cat((stoch_state, deterministic), -1)
batch_shape = x.shape[:-1]
if not batch_shape:
batch_shape = (1, )
x = x.reshape(-1, x.shape[-1])
return self.net(x).reshape(*batch_shape, 1)
class ContinuoNet(nn.Module):
"""continuity prediction model. It take in the stochastic state and the deterministic to construct
latent state. It then output the prediction of whether the termination state has been reached
Args:
nn (_type_): _description_
"""
def __init__(self, input_size, hidden_size, activation=nn.ELU):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_size, hidden_size),
activation(),
nn.Linear(hidden_size, hidden_size),
activation(),
nn.Linear(hidden_size, 1)
)
def forward(self, stoch_state, deterministic):
"""take in the stochastic state and deterministic to construct the latent state then
output reard prediction
Args:
s_t stoch_state (batch_sz, stoch_size): stochastic state (or posterior)
h_t deterministic (batch_sz, deterministic_size): deterministic state
Returns:
dist: Beurnoulli distribution of done
"""
x = torch.cat((stoch_state, deterministic), -1)
batch_shape = x.shape[:-1]
if not batch_shape:
batch_shape = (1, )
x = x.reshape(-1, x.shape[-1])
x = self.net(x).reshape(*batch_shape, 1)
return x, torch.distributions.Independent(torch.distributions.Bernoulli(logits=x), 1)
class Actor(nn.Module):
"""actor network
"""
def __init__(self,
latent_size,
hidden_size,
action_size,
discrete=True,
activation=nn.ELU,
min_std=1e-4,
init_std=5,
mean_scale=5):
super().__init__()
self.latent_size = latent_size
self.hidden_size = hidden_size
self.action_size = (action_size if discrete else action_size*2)
self.discrete = discrete
self.min_std=min_std
self.init_std = init_std
self.mean_scale = mean_scale
self.net = nn.Sequential(
nn.Linear(latent_size, hidden_size),
activation(),
nn.Linear(hidden_size, self.action_size)
)
def forward(self, stoch_state, deterministic):
"""actor network. get in stochastic state and deterministic state to construct latent state
and then use latent state to predict appropriate action
Args:
s_t stoch_state (batch_sz, stoch_size): stochastic state (or posterior)
h_t deterministic (batch_sz, deterministic_size): deterministic state
Returns:
action distribution. OneHot if discrete, else is tanhNormal
"""
latent_state = torch.cat((stoch_state, deterministic), -1)
x = self.net(latent_state)
if self.discrete:
# straight through gradient (mentioned in DreamerV2)
dist = torch.distributions.OneHotCategorical(logits=x)
action = dist.sample() + dist.probs - dist.probs.detach()
else:
#ensure that the softplut output proper init_std
raw_init_std = np.log(np.exp(self.init_std) - 1)
mean, std = torch.chunk(x, 2, -1)
mean = self.mean_scale * F.tanh(mean / self.mean_scale)
std = F.softplus(std + raw_init_std) + self.min_std
dist = torch.distributions.Normal(mean, std)
dist = torch.distributions.TransformedDistribution(dist, torch.distributions.TanhTransform())
action = torch.distributions.Independent(dist, 1).rsample()
return action
class Critic(nn.Module):
"""
critic network
"""
def __init__(self, latent_size, hidden_size, activation=nn.ELU):
super().__init__()
self.latent_size = latent_size
self.net = nn.Sequential(
nn.Linear(latent_size, hidden_size),
activation(),
nn.Linear(hidden_size, hidden_size),
activation(),
nn.Linear(hidden_size, 1)
)
def forward(self, stoch_state, deterministic):
"""critic network. get in stochastic state and deterministic state to construct latent state
and then use latent state to predict state value
Args:
s_t stoch_state (batch_sz, seq_len, stoch_size): stochastic state (or posterior)
h_t deterministic (batch_sz, seq_len, deterministic_size): deterministic state
Returns:
state value distribution.
"""
latent_state = torch.cat((stoch_state, deterministic), -1)
batch_shape = latent_state.shape[:-1]
if not batch_shape:
batch_shape = (1, )
latent_state = latent_state.reshape(-1, self.latent_size)
x = self.net(latent_state)
x = x.reshape(*batch_shape, 1)
dist = torch.distributions.Normal(x, 1)
dist = torch.distributions.Independent(dist, 1)
return dist