import numpy as np import torch import torch.nn.functional as F import gym import os import random import math import metaworld import metaworld.envs.mujoco.env_dict as _env_dict from moviepy.editor import ImageSequenceClip from collections import deque from gym.wrappers.time_limit import TimeLimit from rlkit.envs.wrappers import NormalizedBoxEnv from collections import deque from skimage.util.shape import view_as_windows from torch import nn from torch import distributions as pyd from softgym.registered_env import env_arg_dict, SOFTGYM_ENVS from softgym.utils.normalized_env import normalize def make_softgym_env(cfg): env_name = cfg.env.replace('softgym_','') env_kwargs = env_arg_dict[env_name] env = normalize(SOFTGYM_ENVS[env_name](**env_kwargs)) return env def make_classic_control_env(cfg): if "CartPole" in cfg.env: from envs.cartpole import CartPoleEnv env = CartPoleEnv() else: raise NotImplementedError return TimeLimit(NormalizedBoxEnv(env), env.horizon) def tie_weights(src, trg): assert type(src) == type(trg) trg.weight = src.weight trg.bias = src.bias def make_metaworld_env(cfg): env_name = cfg.env.replace('metaworld_','') if env_name in _env_dict.ALL_V2_ENVIRONMENTS: env_cls = _env_dict.ALL_V2_ENVIRONMENTS[env_name] else: env_cls = _env_dict.ALL_V1_ENVIRONMENTS[env_name] env = env_cls(render_mode='rgb_array') env.camera_name = env_name env._freeze_rand_vec = False env._set_task_called = True env.seed(cfg.seed) return TimeLimit(NormalizedBoxEnv(env), env.max_path_length) class eval_mode(object): def __init__(self, *models): self.models = models def __enter__(self): self.prev_states = [] for model in self.models: self.prev_states.append(model.training) model.train(False) def __exit__(self, *args): for model, state in zip(self.models, self.prev_states): model.train(state) return False class train_mode(object): def __init__(self, *models): self.models = models def __enter__(self): self.prev_states = [] for model in self.models: self.prev_states.append(model.training) model.train(True) def __exit__(self, *args): for model, state in zip(self.models, self.prev_states): model.train(state) return False def soft_update_params(net, target_net, tau): for param, target_param in zip(net.parameters(), target_net.parameters()): target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) def set_seed_everywhere(seed): torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) def make_dir(*path_parts): dir_path = os.path.join(*path_parts) try: os.mkdir(dir_path) except OSError: pass return dir_path def weight_init(m): """Custom weight init for Conv2D and Linear layers.""" if isinstance(m, nn.Linear): nn.init.orthogonal_(m.weight.data) if hasattr(m.bias, 'data'): m.bias.data.fill_(0.0) class MLP(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim, hidden_depth, output_mod=None): super().__init__() self.trunk = mlp(input_dim, hidden_dim, output_dim, hidden_depth, output_mod) self.apply(weight_init) def forward(self, x): return self.trunk(x) class TanhTransform(pyd.transforms.Transform): domain = pyd.constraints.real codomain = pyd.constraints.interval(-1.0, 1.0) bijective = True sign = +1 def __init__(self, cache_size=1): super().__init__(cache_size=cache_size) @staticmethod def atanh(x): return 0.5 * (x.log1p() - (-x).log1p()) def __eq__(self, other): return isinstance(other, TanhTransform) def _call(self, x): return x.tanh() def _inverse(self, y): # We do not clamp to the boundary here as it may degrade the performance of certain algorithms. # one should use `cache_size=1` instead return self.atanh(y) def log_abs_det_jacobian(self, x, y): # We use a formula that is more numerically stable, see details in the following link # https://github.com/tensorflow/probability/commit/ef6bb176e0ebd1cf6e25c6b5cecdd2428c22963f#diff-e120f70e92e6741bca649f04fcd907b7 return 2.0 * (math.log(2.0) - x - F.softplus(-2.0 * x)) class SquashedNormal(pyd.transformed_distribution.TransformedDistribution): def __init__(self, loc, scale): self.loc = loc self.scale = scale self.base_dist = pyd.Normal(loc, scale) transforms = [TanhTransform()] super().__init__(self.base_dist, transforms) @property def mean(self): mu = self.loc for tr in self.transforms: mu = tr(mu) return mu class TorchRunningMeanStd: def __init__(self, epsilon=1e-4, shape=(), device=None): self.mean = torch.zeros(shape, device=device) self.var = torch.ones(shape, device=device) self.count = epsilon def update(self, x): with torch.no_grad(): batch_mean = torch.mean(x, axis=0) batch_var = torch.var(x, axis=0) batch_count = x.shape[0] self.update_from_moments(batch_mean, batch_var, batch_count) def update_from_moments(self, batch_mean, batch_var, batch_count): self.mean, self.var, self.count = update_mean_var_count_from_moments( self.mean, self.var, self.count, batch_mean, batch_var, batch_count ) @property def std(self): return torch.sqrt(self.var) def update_mean_var_count_from_moments( mean, var, count, batch_mean, batch_var, batch_count ): delta = batch_mean - mean tot_count = count + batch_count new_mean = mean + delta + batch_count / tot_count m_a = var * count m_b = batch_var * batch_count M2 = m_a + m_b + torch.pow(delta, 2) * count * batch_count / tot_count new_var = M2 / tot_count new_count = tot_count return new_mean, new_var, new_count def mlp(input_dim, hidden_dim, output_dim, hidden_depth, output_mod=None): if hidden_depth == 0: mods = [nn.Linear(input_dim, output_dim)] else: mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)] for i in range(hidden_depth - 1): mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)] mods.append(nn.Linear(hidden_dim, output_dim)) if output_mod is not None: mods.append(output_mod) trunk = nn.Sequential(*mods) return trunk def to_np(t): if t is None: return None elif t.nelement() == 0: return np.array([]) else: return t.cpu().detach().numpy() def save_numpy_as_gif(array, filename, fps=20, scale=1.0): # ensure that the file has the .gif extension fname, _ = os.path.splitext(filename) filename = fname + '.gif' # copy into the color dimension if the images are black and white if array.ndim == 3: array = array[..., np.newaxis] * np.ones(3) # make the moviepy clip clip = ImageSequenceClip(list(array), fps=fps).resize(scale) clip.write_gif(filename, fps=fps) return clip def get_info_stats(infos): # infos is a list with N_traj x T entries N = len(infos) T = len(infos[0]) all_keys = infos[0][0].keys() stat_dict = {} for key in all_keys: stat_dict[key + '_mean'] = [] stat_dict[key + '_final'] = [] for traj_idx, ep_info in enumerate(infos): for time_idx, info in enumerate(ep_info): stat_dict[key + '_mean'].append(info[key]) stat_dict[key + '_final'].append(info[key]) stat_dict[key + '_mean'] = np.mean(stat_dict[key + '_mean']) stat_dict[key + '_final'] = np.mean(stat_dict[key + '_final']) return stat_dict