jangwon-kim-cocel's picture
Upload 11 files
96170c3 verified
import numpy as np
import random
import torch
import torch.nn as nn
import os
import inspect
import pickle
import gdown
from network import Actor
def weight_init(m):
"""Custom weight init for Conv2D and Linear layers.
Reference: https://github.com/MishaLaskin/rad/blob/master/curl_sac.py"""
if isinstance(m, nn.Linear):
nn.init.orthogonal_(m.weight.data)
m.bias.data.fill_(0.0)
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
# delta-orthogonal init from https://arxiv.org/pdf/1806.05393.pdf
assert m.weight.size(2) == m.weight.size(3)
m.weight.data.fill_(0.0)
m.bias.data.fill_(0.0)
mid = m.weight.size(2) // 2
gain = nn.init.calculate_gain('relu')
nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain)
def set_seed(random_seed):
if random_seed <= 0:
random_seed = np.random.randint(1, 9999)
else:
random_seed = random_seed
torch.manual_seed(random_seed)
np.random.seed(random_seed)
random.seed(random_seed)
return random_seed
def make_env(env_name, seed):
import gymnasium as gym
# openai gym
env = gym.make(env_name)
env.action_space.seed(seed)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
action_bound = [env.action_space.low[0], env.action_space.high[0]]
env_info = {'name': env_name, 'state_dim': state_dim, 'action_dim': action_dim, 'action_bound': action_bound, 'seed': seed}
return env, env_info
def get_learning_info(args, seed):
env, env_info = make_env(args.env_name, seed)
device = 'cuda'
alpha_dict = {'HalfCheetah-v3': args.alpha_threshold, 'Walker2d-v3': args.alpha_threshold,
'Ant-v3': args.alpha_threshold, 'Hopper-v3': args.alpha_threshold}
thresholds = {"ALPHA_THRESHOLD": alpha_dict[args.env_name], "THETA_THRESHOLD": args.theta_threshold}
max_action = 1
t_p = Actor(env_info['state_dim'], env_info['action_dim'], (400, 300), 1)
num_teacher_param = sum(p2.numel() for p2 in t_p.parameters())
kwargs = {
"env": env,
"args": args,
"env_info": env_info,
"thresholds": thresholds,
"discount": args.discount,
"datasize": args.datasize,
"tau": args.tau,
"device": device,
"num_teacher_param": num_teacher_param,
"noise_clip": args.noise_clip * max_action,
"policy_freq": args.policy_freq,
"h": args.h,
}
return kwargs
def get_compression_ratio(num_teacher_param, agent):
kep_w = 0
for c in agent.actor.children():
kep_w += c.get_num_remained_weights()
#
return kep_w / num_teacher_param
def load_buffer(env_name, level, datasize):
current_dir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
file_path = os.path.join(current_dir, "teacher_buffer", "[" + level + "_buffer]_" + env_name + ".pickle")
try:
with open(file_path, "rb") as fr:
buffer = pickle.load(fr)
buffer.size = datasize
except FileNotFoundError:
# Download the file
if level == 'expert':
print("Downloading the teacher buffer...")
if env_name == "Ant-v3":
file_id = "10VBf3bM38bNw9WsniQvirpNjRFWp8HZO"
elif env_name == "Walker2d-v3":
file_id = "1ungLoqNKS4NIldZ9H2mswwGh-3Ipgy0D"
elif env_name == "HalfCheetah-v3":
file_id = "1wO0HwDi1GNf9d9SrDJrf9x8XMZDOTkzl"
elif env_name == "Hopper-v3":
file_id ="10pqCliJSM_Iyb05dxHZfYs9VlmCmPryE"
else:
raise ValueError("Invalid Environment Name")
url = f"https://drive.google.com/uc?id={file_id}"
gdown.download(url, file_path, quiet=False)
print("Download Complete!")
elif level == 'medium':
if env_name == "Ant-v3":
file_id = "1-SKleNu6l-tY2awkx3tgVDUKbjkOaj_D"
elif env_name == "Walker2d-v3":
file_id = "1x6nkBBSWMRb3bENxUzcntHT1WlSNJmoh"
elif env_name == "HalfCheetah-v3":
file_id = "1OHkB6yVK3QcqbuJH0B_iNW_2cBnv96mR"
elif env_name == "Hopper-v3":
file_id ="1uqH2pgKKrhadsCXCwQWrvDvZ4ZyYFkM-"
else:
raise ValueError("Invalid Environment Name")
url = f"https://drive.google.com/uc?id={file_id}"
gdown.download(url, file_path, quiet=False)
else:
raise ValueError("Invalid Level. Choose from ['expert', 'medium']")
with open(file_path, "rb") as fr:
buffer = pickle.load(fr)
buffer.size = datasize
return buffer