File size: 4,915 Bytes
96170c3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import numpy as np
import random
import torch
import torch.nn as nn
import os
import inspect
import pickle
import gdown
from network import Actor
def weight_init(m):
"""Custom weight init for Conv2D and Linear layers.
Reference: https://github.com/MishaLaskin/rad/blob/master/curl_sac.py"""
if isinstance(m, nn.Linear):
nn.init.orthogonal_(m.weight.data)
m.bias.data.fill_(0.0)
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
# delta-orthogonal init from https://arxiv.org/pdf/1806.05393.pdf
assert m.weight.size(2) == m.weight.size(3)
m.weight.data.fill_(0.0)
m.bias.data.fill_(0.0)
mid = m.weight.size(2) // 2
gain = nn.init.calculate_gain('relu')
nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain)
def set_seed(random_seed):
if random_seed <= 0:
random_seed = np.random.randint(1, 9999)
else:
random_seed = random_seed
torch.manual_seed(random_seed)
np.random.seed(random_seed)
random.seed(random_seed)
return random_seed
def make_env(env_name, seed):
import gymnasium as gym
# openai gym
env = gym.make(env_name)
env.action_space.seed(seed)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
action_bound = [env.action_space.low[0], env.action_space.high[0]]
env_info = {'name': env_name, 'state_dim': state_dim, 'action_dim': action_dim, 'action_bound': action_bound, 'seed': seed}
return env, env_info
def get_learning_info(args, seed):
env, env_info = make_env(args.env_name, seed)
device = 'cuda'
alpha_dict = {'HalfCheetah-v3': args.alpha_threshold, 'Walker2d-v3': args.alpha_threshold,
'Ant-v3': args.alpha_threshold, 'Hopper-v3': args.alpha_threshold}
thresholds = {"ALPHA_THRESHOLD": alpha_dict[args.env_name], "THETA_THRESHOLD": args.theta_threshold}
max_action = 1
t_p = Actor(env_info['state_dim'], env_info['action_dim'], (400, 300), 1)
num_teacher_param = sum(p2.numel() for p2 in t_p.parameters())
kwargs = {
"env": env,
"args": args,
"env_info": env_info,
"thresholds": thresholds,
"discount": args.discount,
"datasize": args.datasize,
"tau": args.tau,
"device": device,
"num_teacher_param": num_teacher_param,
"noise_clip": args.noise_clip * max_action,
"policy_freq": args.policy_freq,
"h": args.h,
}
return kwargs
def get_compression_ratio(num_teacher_param, agent):
kep_w = 0
for c in agent.actor.children():
kep_w += c.get_num_remained_weights()
#
return kep_w / num_teacher_param
def load_buffer(env_name, level, datasize):
current_dir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
file_path = os.path.join(current_dir, "teacher_buffer", "[" + level + "_buffer]_" + env_name + ".pickle")
try:
with open(file_path, "rb") as fr:
buffer = pickle.load(fr)
buffer.size = datasize
except FileNotFoundError:
# Download the file
if level == 'expert':
print("Downloading the teacher buffer...")
if env_name == "Ant-v3":
file_id = "10VBf3bM38bNw9WsniQvirpNjRFWp8HZO"
elif env_name == "Walker2d-v3":
file_id = "1ungLoqNKS4NIldZ9H2mswwGh-3Ipgy0D"
elif env_name == "HalfCheetah-v3":
file_id = "1wO0HwDi1GNf9d9SrDJrf9x8XMZDOTkzl"
elif env_name == "Hopper-v3":
file_id ="10pqCliJSM_Iyb05dxHZfYs9VlmCmPryE"
else:
raise ValueError("Invalid Environment Name")
url = f"https://drive.google.com/uc?id={file_id}"
gdown.download(url, file_path, quiet=False)
print("Download Complete!")
elif level == 'medium':
if env_name == "Ant-v3":
file_id = "1-SKleNu6l-tY2awkx3tgVDUKbjkOaj_D"
elif env_name == "Walker2d-v3":
file_id = "1x6nkBBSWMRb3bENxUzcntHT1WlSNJmoh"
elif env_name == "HalfCheetah-v3":
file_id = "1OHkB6yVK3QcqbuJH0B_iNW_2cBnv96mR"
elif env_name == "Hopper-v3":
file_id ="1uqH2pgKKrhadsCXCwQWrvDvZ4ZyYFkM-"
else:
raise ValueError("Invalid Environment Name")
url = f"https://drive.google.com/uc?id={file_id}"
gdown.download(url, file_path, quiet=False)
else:
raise ValueError("Invalid Level. Choose from ['expert', 'medium']")
with open(file_path, "rb") as fr:
buffer = pickle.load(fr)
buffer.size = datasize
return buffer
|