basketball_code / utils.py
youqiwong's picture
Upload folder using huggingface_hub
0c51b93 verified
import numpy as np
import torch
import torch.nn.functional as F
import gym
import os
import random
import math
import metaworld
import metaworld.envs.mujoco.env_dict as _env_dict
from moviepy.editor import ImageSequenceClip
from collections import deque
from gym.wrappers.time_limit import TimeLimit
from rlkit.envs.wrappers import NormalizedBoxEnv
from collections import deque
from skimage.util.shape import view_as_windows
from torch import nn
from torch import distributions as pyd
from softgym.registered_env import env_arg_dict, SOFTGYM_ENVS
from softgym.utils.normalized_env import normalize
def make_softgym_env(cfg):
env_name = cfg.env.replace('softgym_','')
env_kwargs = env_arg_dict[env_name]
env = normalize(SOFTGYM_ENVS[env_name](**env_kwargs))
return env
def make_classic_control_env(cfg):
if "CartPole" in cfg.env:
from envs.cartpole import CartPoleEnv
env = CartPoleEnv()
else:
raise NotImplementedError
return TimeLimit(NormalizedBoxEnv(env), env.horizon)
def tie_weights(src, trg):
assert type(src) == type(trg)
trg.weight = src.weight
trg.bias = src.bias
def make_metaworld_env(cfg):
env_name = cfg.env.replace('metaworld_','')
if env_name in _env_dict.ALL_V2_ENVIRONMENTS:
env_cls = _env_dict.ALL_V2_ENVIRONMENTS[env_name]
else:
env_cls = _env_dict.ALL_V1_ENVIRONMENTS[env_name]
env = env_cls(render_mode='rgb_array')
env.camera_name = env_name
env._freeze_rand_vec = False
env._set_task_called = True
env.seed(cfg.seed)
return TimeLimit(NormalizedBoxEnv(env), env.max_path_length)
class eval_mode(object):
def __init__(self, *models):
self.models = models
def __enter__(self):
self.prev_states = []
for model in self.models:
self.prev_states.append(model.training)
model.train(False)
def __exit__(self, *args):
for model, state in zip(self.models, self.prev_states):
model.train(state)
return False
class train_mode(object):
def __init__(self, *models):
self.models = models
def __enter__(self):
self.prev_states = []
for model in self.models:
self.prev_states.append(model.training)
model.train(True)
def __exit__(self, *args):
for model, state in zip(self.models, self.prev_states):
model.train(state)
return False
def soft_update_params(net, target_net, tau):
for param, target_param in zip(net.parameters(), target_net.parameters()):
target_param.data.copy_(tau * param.data +
(1 - tau) * target_param.data)
def set_seed_everywhere(seed):
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
def make_dir(*path_parts):
dir_path = os.path.join(*path_parts)
try:
os.mkdir(dir_path)
except OSError:
pass
return dir_path
def weight_init(m):
"""Custom weight init for Conv2D and Linear layers."""
if isinstance(m, nn.Linear):
nn.init.orthogonal_(m.weight.data)
if hasattr(m.bias, 'data'):
m.bias.data.fill_(0.0)
class MLP(nn.Module):
def __init__(self,
input_dim,
hidden_dim,
output_dim,
hidden_depth,
output_mod=None):
super().__init__()
self.trunk = mlp(input_dim, hidden_dim, output_dim, hidden_depth,
output_mod)
self.apply(weight_init)
def forward(self, x):
return self.trunk(x)
class TanhTransform(pyd.transforms.Transform):
domain = pyd.constraints.real
codomain = pyd.constraints.interval(-1.0, 1.0)
bijective = True
sign = +1
def __init__(self, cache_size=1):
super().__init__(cache_size=cache_size)
@staticmethod
def atanh(x):
return 0.5 * (x.log1p() - (-x).log1p())
def __eq__(self, other):
return isinstance(other, TanhTransform)
def _call(self, x):
return x.tanh()
def _inverse(self, y):
# We do not clamp to the boundary here as it may degrade the performance of certain algorithms.
# one should use `cache_size=1` instead
return self.atanh(y)
def log_abs_det_jacobian(self, x, y):
# We use a formula that is more numerically stable, see details in the following link
# https://github.com/tensorflow/probability/commit/ef6bb176e0ebd1cf6e25c6b5cecdd2428c22963f#diff-e120f70e92e6741bca649f04fcd907b7
return 2.0 * (math.log(2.0) - x - F.softplus(-2.0 * x))
class SquashedNormal(pyd.transformed_distribution.TransformedDistribution):
def __init__(self, loc, scale):
self.loc = loc
self.scale = scale
self.base_dist = pyd.Normal(loc, scale)
transforms = [TanhTransform()]
super().__init__(self.base_dist, transforms)
@property
def mean(self):
mu = self.loc
for tr in self.transforms:
mu = tr(mu)
return mu
class TorchRunningMeanStd:
def __init__(self, epsilon=1e-4, shape=(), device=None):
self.mean = torch.zeros(shape, device=device)
self.var = torch.ones(shape, device=device)
self.count = epsilon
def update(self, x):
with torch.no_grad():
batch_mean = torch.mean(x, axis=0)
batch_var = torch.var(x, axis=0)
batch_count = x.shape[0]
self.update_from_moments(batch_mean, batch_var, batch_count)
def update_from_moments(self, batch_mean, batch_var, batch_count):
self.mean, self.var, self.count = update_mean_var_count_from_moments(
self.mean, self.var, self.count, batch_mean, batch_var, batch_count
)
@property
def std(self):
return torch.sqrt(self.var)
def update_mean_var_count_from_moments(
mean, var, count, batch_mean, batch_var, batch_count
):
delta = batch_mean - mean
tot_count = count + batch_count
new_mean = mean + delta + batch_count / tot_count
m_a = var * count
m_b = batch_var * batch_count
M2 = m_a + m_b + torch.pow(delta, 2) * count * batch_count / tot_count
new_var = M2 / tot_count
new_count = tot_count
return new_mean, new_var, new_count
def mlp(input_dim, hidden_dim, output_dim, hidden_depth, output_mod=None):
if hidden_depth == 0:
mods = [nn.Linear(input_dim, output_dim)]
else:
mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)]
for i in range(hidden_depth - 1):
mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)]
mods.append(nn.Linear(hidden_dim, output_dim))
if output_mod is not None:
mods.append(output_mod)
trunk = nn.Sequential(*mods)
return trunk
def to_np(t):
if t is None:
return None
elif t.nelement() == 0:
return np.array([])
else:
return t.cpu().detach().numpy()
def save_numpy_as_gif(array, filename, fps=20, scale=1.0):
# ensure that the file has the .gif extension
fname, _ = os.path.splitext(filename)
filename = fname + '.gif'
# copy into the color dimension if the images are black and white
if array.ndim == 3:
array = array[..., np.newaxis] * np.ones(3)
# make the moviepy clip
clip = ImageSequenceClip(list(array), fps=fps).resize(scale)
clip.write_gif(filename, fps=fps)
return clip
def get_info_stats(infos):
# infos is a list with N_traj x T entries
N = len(infos)
T = len(infos[0])
all_keys = infos[0][0].keys()
stat_dict = {}
for key in all_keys:
stat_dict[key + '_mean'] = []
stat_dict[key + '_final'] = []
for traj_idx, ep_info in enumerate(infos):
for time_idx, info in enumerate(ep_info):
stat_dict[key + '_mean'].append(info[key])
stat_dict[key + '_final'].append(info[key])
stat_dict[key + '_mean'] = np.mean(stat_dict[key + '_mean'])
stat_dict[key + '_final'] = np.mean(stat_dict[key + '_final'])
return stat_dict