Commit
·
c2bc086
1
Parent(s):
a7471c9
pushing model
Browse files- CP_DQN.cleanrl_model +0 -0
- README.md +2 -2
- dqn.py +14 -11
- events.out.tfevents.1679277727.redi.183238.0 → events.out.tfevents.1679332037.redi.1483546.0 +2 -2
- replay.mp4 +0 -0
- videos/CartPole-v1__CP_DQN__4__1679277723-eval/rl-video-episode-0.mp4 +0 -0
- videos/CartPole-v1__CP_DQN__4__1679277723-eval/rl-video-episode-1.mp4 +0 -0
- videos/CartPole-v1__CP_DQN__4__1679277723-eval/rl-video-episode-8.mp4 +0 -0
- videos/CartPole-v1__CP_DQN__4__1679332033-eval/rl-video-episode-0.mp4 +0 -0
- videos/CartPole-v1__CP_DQN__4__1679332033-eval/rl-video-episode-1.mp4 +0 -0
- videos/CartPole-v1__CP_DQN__4__1679332033-eval/rl-video-episode-8.mp4 +0 -0
CP_DQN.cleanrl_model
CHANGED
|
Binary files a/CP_DQN.cleanrl_model and b/CP_DQN.cleanrl_model differ
|
|
|
README.md
CHANGED
|
@@ -16,7 +16,7 @@ model-index:
|
|
| 16 |
type: CartPole-v1
|
| 17 |
metrics:
|
| 18 |
- type: mean_reward
|
| 19 |
-
value:
|
| 20 |
name: mean_reward
|
| 21 |
verified: false
|
| 22 |
---
|
|
@@ -67,7 +67,7 @@ python dqn.py --track --wandb-entity pfunk --wandb-project-name dqpn --capture-v
|
|
| 67 |
'save_model': True,
|
| 68 |
'seed': 4,
|
| 69 |
'start_e': 1.0,
|
| 70 |
-
'target_network_frequency':
|
| 71 |
'target_tau': 1.0,
|
| 72 |
'torch_deterministic': True,
|
| 73 |
'total_timesteps': 500000,
|
|
|
|
| 16 |
type: CartPole-v1
|
| 17 |
metrics:
|
| 18 |
- type: mean_reward
|
| 19 |
+
value: 499.44 +/- 0.00
|
| 20 |
name: mean_reward
|
| 21 |
verified: false
|
| 22 |
---
|
|
|
|
| 67 |
'save_model': True,
|
| 68 |
'seed': 4,
|
| 69 |
'start_e': 1.0,
|
| 70 |
+
'target_network_frequency': 100,
|
| 71 |
'target_tau': 1.0,
|
| 72 |
'torch_deterministic': True,
|
| 73 |
'total_timesteps': 500000,
|
dqn.py
CHANGED
|
@@ -54,7 +54,7 @@ def parse_args():
|
|
| 54 |
help="the discount factor gamma")
|
| 55 |
parser.add_argument("--target-tau", type=float, default=1.0,
|
| 56 |
help="the target network update rate")
|
| 57 |
-
parser.add_argument("--target-network-frequency", type=int, default=
|
| 58 |
help="the timesteps it takes to update the target network")
|
| 59 |
parser.add_argument("--batch-size", type=int, default=256,
|
| 60 |
help="the batch size of sample from the reply memory")
|
|
@@ -137,17 +137,16 @@ if __name__ == "__main__":
|
|
| 137 |
wandb.log({name: x, "global_step": y})
|
| 138 |
|
| 139 |
# TRY NOT TO MODIFY: seeding
|
| 140 |
-
random.seed(args.seed)
|
| 141 |
-
np.random.seed(args.seed)
|
| 142 |
torch.manual_seed(args.seed)
|
| 143 |
-
torch.backends.cudnn.deterministic = args.torch_deterministic
|
|
|
|
|
|
|
| 144 |
|
| 145 |
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
|
| 146 |
|
| 147 |
# env setup
|
| 148 |
envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)])
|
| 149 |
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"
|
| 150 |
-
envs.seed(args.seed)
|
| 151 |
|
| 152 |
q_network = QNetwork(envs).to(device)
|
| 153 |
optimizer = optim.RMSprop(q_network.parameters(), lr=args.learning_rate)
|
|
@@ -159,10 +158,11 @@ if __name__ == "__main__":
|
|
| 159 |
envs.single_observation_space,
|
| 160 |
envs.single_action_space,
|
| 161 |
device,
|
| 162 |
-
optimize_memory_usage=True,
|
| 163 |
handle_timeout_termination=True,
|
| 164 |
)
|
| 165 |
start_time = time.time()
|
|
|
|
| 166 |
policy_update_counter = 0
|
| 167 |
episode_returns = []
|
| 168 |
|
|
@@ -247,10 +247,10 @@ if __name__ == "__main__":
|
|
| 247 |
log_value("td/a_below", a_below, global_step)
|
| 248 |
log_value("td/above", above, global_step)
|
| 249 |
log_value("td/a_above", a_above, global_step)
|
| 250 |
-
log_value("
|
| 251 |
-
log_value("
|
| 252 |
-
log_value("
|
| 253 |
-
log_value("
|
| 254 |
log_value("debug/steps_per_second", int(global_step / (time.time() - start_time)), global_step)
|
| 255 |
|
| 256 |
# optimize the model
|
|
@@ -260,13 +260,16 @@ if __name__ == "__main__":
|
|
| 260 |
|
| 261 |
# update target network
|
| 262 |
if global_step % args.target_network_frequency == 0:
|
|
|
|
| 263 |
for target_network_param, q_network_param in zip(target_network.parameters(), q_network.parameters()):
|
| 264 |
target_network_param.data.copy_(
|
| 265 |
args.target_tau * q_network_param.data + (1.0 - args.target_tau) * target_network_param.data
|
| 266 |
)
|
| 267 |
policy_update_counter += 1
|
|
|
|
| 268 |
if global_step % 100 == 0:
|
| 269 |
-
log_value("
|
|
|
|
| 270 |
|
| 271 |
if args.save_model:
|
| 272 |
model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"
|
|
|
|
| 54 |
help="the discount factor gamma")
|
| 55 |
parser.add_argument("--target-tau", type=float, default=1.0,
|
| 56 |
help="the target network update rate")
|
| 57 |
+
parser.add_argument("--target-network-frequency", type=int, default=100,
|
| 58 |
help="the timesteps it takes to update the target network")
|
| 59 |
parser.add_argument("--batch-size", type=int, default=256,
|
| 60 |
help="the batch size of sample from the reply memory")
|
|
|
|
| 137 |
wandb.log({name: x, "global_step": y})
|
| 138 |
|
| 139 |
# TRY NOT TO MODIFY: seeding
|
|
|
|
|
|
|
| 140 |
torch.manual_seed(args.seed)
|
| 141 |
+
# torch.backends.cudnn.deterministic = args.torch_deterministic
|
| 142 |
+
np.random.seed(args.seed)
|
| 143 |
+
random.seed(args.seed)
|
| 144 |
|
| 145 |
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
|
| 146 |
|
| 147 |
# env setup
|
| 148 |
envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)])
|
| 149 |
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"
|
|
|
|
| 150 |
|
| 151 |
q_network = QNetwork(envs).to(device)
|
| 152 |
optimizer = optim.RMSprop(q_network.parameters(), lr=args.learning_rate)
|
|
|
|
| 158 |
envs.single_observation_space,
|
| 159 |
envs.single_action_space,
|
| 160 |
device,
|
| 161 |
+
# optimize_memory_usage=True,
|
| 162 |
handle_timeout_termination=True,
|
| 163 |
)
|
| 164 |
start_time = time.time()
|
| 165 |
+
target_update_counter = 0
|
| 166 |
policy_update_counter = 0
|
| 167 |
episode_returns = []
|
| 168 |
|
|
|
|
| 247 |
log_value("td/a_below", a_below, global_step)
|
| 248 |
log_value("td/above", above, global_step)
|
| 249 |
log_value("td/a_above", a_above, global_step)
|
| 250 |
+
log_value("alg/pu_scalar", pu_scalar, global_step)
|
| 251 |
+
log_value("alg/a_pu_scalar", a_pu_scalar, global_step)
|
| 252 |
+
log_value("alg/policy_frequency_scalar_ratio", policy_frequency_scalar_ratio, global_step)
|
| 253 |
+
log_value("alg/a_policy_frequency_scalar_ratio", a_policy_frequency_scalar_ratio, global_step)
|
| 254 |
log_value("debug/steps_per_second", int(global_step / (time.time() - start_time)), global_step)
|
| 255 |
|
| 256 |
# optimize the model
|
|
|
|
| 260 |
|
| 261 |
# update target network
|
| 262 |
if global_step % args.target_network_frequency == 0:
|
| 263 |
+
target_update_counter += 1
|
| 264 |
for target_network_param, q_network_param in zip(target_network.parameters(), q_network.parameters()):
|
| 265 |
target_network_param.data.copy_(
|
| 266 |
args.target_tau * q_network_param.data + (1.0 - args.target_tau) * target_network_param.data
|
| 267 |
)
|
| 268 |
policy_update_counter += 1
|
| 269 |
+
|
| 270 |
if global_step % 100 == 0:
|
| 271 |
+
log_value("alg/n_target_update", target_update_counter, global_step)
|
| 272 |
+
log_value("alg/n_policy_update", policy_update_counter, global_step)
|
| 273 |
|
| 274 |
if args.save_model:
|
| 275 |
model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"
|
events.out.tfevents.1679277727.redi.183238.0 → events.out.tfevents.1679332037.redi.1483546.0
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ec2f3a5396fd9f4dd2f39620190713722ea55bd71a5e676744a4a7336e1edec8
|
| 3 |
+
size 628
|
replay.mp4
CHANGED
|
Binary files a/replay.mp4 and b/replay.mp4 differ
|
|
|
videos/CartPole-v1__CP_DQN__4__1679277723-eval/rl-video-episode-0.mp4
DELETED
|
Binary file (44.1 kB)
|
|
|
videos/CartPole-v1__CP_DQN__4__1679277723-eval/rl-video-episode-1.mp4
DELETED
|
Binary file (43.9 kB)
|
|
|
videos/CartPole-v1__CP_DQN__4__1679277723-eval/rl-video-episode-8.mp4
DELETED
|
Binary file (41.1 kB)
|
|
|
videos/CartPole-v1__CP_DQN__4__1679332033-eval/rl-video-episode-0.mp4
ADDED
|
Binary file (42.1 kB). View file
|
|
|
videos/CartPole-v1__CP_DQN__4__1679332033-eval/rl-video-episode-1.mp4
ADDED
|
Binary file (43.4 kB). View file
|
|
|
videos/CartPole-v1__CP_DQN__4__1679332033-eval/rl-video-episode-8.mp4
ADDED
|
Binary file (42.2 kB). View file
|
|
|