|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import gym |
|
|
import os |
|
|
import random |
|
|
import math |
|
|
import metaworld |
|
|
import metaworld.envs.mujoco.env_dict as _env_dict |
|
|
from moviepy.editor import ImageSequenceClip |
|
|
from collections import deque |
|
|
from gym.wrappers.time_limit import TimeLimit |
|
|
from rlkit.envs.wrappers import NormalizedBoxEnv |
|
|
from collections import deque |
|
|
from skimage.util.shape import view_as_windows |
|
|
from torch import nn |
|
|
from torch import distributions as pyd |
|
|
from softgym.registered_env import env_arg_dict, SOFTGYM_ENVS |
|
|
from softgym.utils.normalized_env import normalize |
|
|
|
|
|
def make_softgym_env(cfg): |
|
|
env_name = cfg.env.replace('softgym_','') |
|
|
env_kwargs = env_arg_dict[env_name] |
|
|
env = normalize(SOFTGYM_ENVS[env_name](**env_kwargs)) |
|
|
|
|
|
return env |
|
|
|
|
|
def make_classic_control_env(cfg): |
|
|
if "CartPole" in cfg.env: |
|
|
from envs.cartpole import CartPoleEnv |
|
|
env = CartPoleEnv() |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
return TimeLimit(NormalizedBoxEnv(env), env.horizon) |
|
|
|
|
|
|
|
|
def tie_weights(src, trg): |
|
|
assert type(src) == type(trg) |
|
|
trg.weight = src.weight |
|
|
trg.bias = src.bias |
|
|
|
|
|
def make_metaworld_env(cfg): |
|
|
env_name = cfg.env.replace('metaworld_','') |
|
|
if env_name in _env_dict.ALL_V2_ENVIRONMENTS: |
|
|
env_cls = _env_dict.ALL_V2_ENVIRONMENTS[env_name] |
|
|
else: |
|
|
env_cls = _env_dict.ALL_V1_ENVIRONMENTS[env_name] |
|
|
|
|
|
env = env_cls(render_mode='rgb_array') |
|
|
env.camera_name = env_name |
|
|
|
|
|
env._freeze_rand_vec = False |
|
|
env._set_task_called = True |
|
|
env.seed(cfg.seed) |
|
|
|
|
|
return TimeLimit(NormalizedBoxEnv(env), env.max_path_length) |
|
|
|
|
|
class eval_mode(object): |
|
|
def __init__(self, *models): |
|
|
self.models = models |
|
|
|
|
|
def __enter__(self): |
|
|
self.prev_states = [] |
|
|
for model in self.models: |
|
|
self.prev_states.append(model.training) |
|
|
model.train(False) |
|
|
|
|
|
def __exit__(self, *args): |
|
|
for model, state in zip(self.models, self.prev_states): |
|
|
model.train(state) |
|
|
return False |
|
|
|
|
|
|
|
|
class train_mode(object): |
|
|
def __init__(self, *models): |
|
|
self.models = models |
|
|
|
|
|
def __enter__(self): |
|
|
self.prev_states = [] |
|
|
for model in self.models: |
|
|
self.prev_states.append(model.training) |
|
|
model.train(True) |
|
|
|
|
|
def __exit__(self, *args): |
|
|
for model, state in zip(self.models, self.prev_states): |
|
|
model.train(state) |
|
|
return False |
|
|
|
|
|
def soft_update_params(net, target_net, tau): |
|
|
for param, target_param in zip(net.parameters(), target_net.parameters()): |
|
|
target_param.data.copy_(tau * param.data + |
|
|
(1 - tau) * target_param.data) |
|
|
|
|
|
def set_seed_everywhere(seed): |
|
|
torch.manual_seed(seed) |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
np.random.seed(seed) |
|
|
random.seed(seed) |
|
|
|
|
|
def make_dir(*path_parts): |
|
|
dir_path = os.path.join(*path_parts) |
|
|
try: |
|
|
os.mkdir(dir_path) |
|
|
except OSError: |
|
|
pass |
|
|
return dir_path |
|
|
|
|
|
def weight_init(m): |
|
|
"""Custom weight init for Conv2D and Linear layers.""" |
|
|
if isinstance(m, nn.Linear): |
|
|
nn.init.orthogonal_(m.weight.data) |
|
|
if hasattr(m.bias, 'data'): |
|
|
m.bias.data.fill_(0.0) |
|
|
|
|
|
class MLP(nn.Module): |
|
|
def __init__(self, |
|
|
input_dim, |
|
|
hidden_dim, |
|
|
output_dim, |
|
|
hidden_depth, |
|
|
output_mod=None): |
|
|
super().__init__() |
|
|
self.trunk = mlp(input_dim, hidden_dim, output_dim, hidden_depth, |
|
|
output_mod) |
|
|
self.apply(weight_init) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.trunk(x) |
|
|
|
|
|
class TanhTransform(pyd.transforms.Transform): |
|
|
domain = pyd.constraints.real |
|
|
codomain = pyd.constraints.interval(-1.0, 1.0) |
|
|
bijective = True |
|
|
sign = +1 |
|
|
|
|
|
def __init__(self, cache_size=1): |
|
|
super().__init__(cache_size=cache_size) |
|
|
|
|
|
@staticmethod |
|
|
def atanh(x): |
|
|
return 0.5 * (x.log1p() - (-x).log1p()) |
|
|
|
|
|
def __eq__(self, other): |
|
|
return isinstance(other, TanhTransform) |
|
|
|
|
|
def _call(self, x): |
|
|
return x.tanh() |
|
|
|
|
|
def _inverse(self, y): |
|
|
|
|
|
|
|
|
return self.atanh(y) |
|
|
|
|
|
def log_abs_det_jacobian(self, x, y): |
|
|
|
|
|
|
|
|
return 2.0 * (math.log(2.0) - x - F.softplus(-2.0 * x)) |
|
|
|
|
|
class SquashedNormal(pyd.transformed_distribution.TransformedDistribution): |
|
|
def __init__(self, loc, scale): |
|
|
self.loc = loc |
|
|
self.scale = scale |
|
|
|
|
|
self.base_dist = pyd.Normal(loc, scale) |
|
|
transforms = [TanhTransform()] |
|
|
super().__init__(self.base_dist, transforms) |
|
|
|
|
|
@property |
|
|
def mean(self): |
|
|
mu = self.loc |
|
|
for tr in self.transforms: |
|
|
mu = tr(mu) |
|
|
return mu |
|
|
|
|
|
class TorchRunningMeanStd: |
|
|
def __init__(self, epsilon=1e-4, shape=(), device=None): |
|
|
self.mean = torch.zeros(shape, device=device) |
|
|
self.var = torch.ones(shape, device=device) |
|
|
self.count = epsilon |
|
|
|
|
|
def update(self, x): |
|
|
with torch.no_grad(): |
|
|
batch_mean = torch.mean(x, axis=0) |
|
|
batch_var = torch.var(x, axis=0) |
|
|
batch_count = x.shape[0] |
|
|
self.update_from_moments(batch_mean, batch_var, batch_count) |
|
|
|
|
|
def update_from_moments(self, batch_mean, batch_var, batch_count): |
|
|
self.mean, self.var, self.count = update_mean_var_count_from_moments( |
|
|
self.mean, self.var, self.count, batch_mean, batch_var, batch_count |
|
|
) |
|
|
|
|
|
@property |
|
|
def std(self): |
|
|
return torch.sqrt(self.var) |
|
|
|
|
|
|
|
|
def update_mean_var_count_from_moments( |
|
|
mean, var, count, batch_mean, batch_var, batch_count |
|
|
): |
|
|
delta = batch_mean - mean |
|
|
tot_count = count + batch_count |
|
|
|
|
|
new_mean = mean + delta + batch_count / tot_count |
|
|
m_a = var * count |
|
|
m_b = batch_var * batch_count |
|
|
M2 = m_a + m_b + torch.pow(delta, 2) * count * batch_count / tot_count |
|
|
new_var = M2 / tot_count |
|
|
new_count = tot_count |
|
|
|
|
|
return new_mean, new_var, new_count |
|
|
|
|
|
def mlp(input_dim, hidden_dim, output_dim, hidden_depth, output_mod=None): |
|
|
if hidden_depth == 0: |
|
|
mods = [nn.Linear(input_dim, output_dim)] |
|
|
else: |
|
|
mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)] |
|
|
for i in range(hidden_depth - 1): |
|
|
mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)] |
|
|
mods.append(nn.Linear(hidden_dim, output_dim)) |
|
|
if output_mod is not None: |
|
|
mods.append(output_mod) |
|
|
trunk = nn.Sequential(*mods) |
|
|
return trunk |
|
|
|
|
|
def to_np(t): |
|
|
if t is None: |
|
|
return None |
|
|
elif t.nelement() == 0: |
|
|
return np.array([]) |
|
|
else: |
|
|
return t.cpu().detach().numpy() |
|
|
|
|
|
def save_numpy_as_gif(array, filename, fps=20, scale=1.0): |
|
|
|
|
|
|
|
|
fname, _ = os.path.splitext(filename) |
|
|
filename = fname + '.gif' |
|
|
|
|
|
|
|
|
if array.ndim == 3: |
|
|
array = array[..., np.newaxis] * np.ones(3) |
|
|
|
|
|
|
|
|
clip = ImageSequenceClip(list(array), fps=fps).resize(scale) |
|
|
clip.write_gif(filename, fps=fps) |
|
|
return clip |
|
|
|
|
|
def get_info_stats(infos): |
|
|
|
|
|
N = len(infos) |
|
|
T = len(infos[0]) |
|
|
|
|
|
all_keys = infos[0][0].keys() |
|
|
stat_dict = {} |
|
|
for key in all_keys: |
|
|
stat_dict[key + '_mean'] = [] |
|
|
stat_dict[key + '_final'] = [] |
|
|
for traj_idx, ep_info in enumerate(infos): |
|
|
for time_idx, info in enumerate(ep_info): |
|
|
stat_dict[key + '_mean'].append(info[key]) |
|
|
stat_dict[key + '_final'].append(info[key]) |
|
|
stat_dict[key + '_mean'] = np.mean(stat_dict[key + '_mean']) |
|
|
stat_dict[key + '_final'] = np.mean(stat_dict[key + '_final']) |
|
|
|
|
|
return stat_dict |