|
|
|
|
|
import os
|
|
|
import random
|
|
|
import time
|
|
|
from dataclasses import dataclass
|
|
|
import ale_py
|
|
|
import gymnasium as gym
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.optim as optim
|
|
|
import tyro
|
|
|
from torch.distributions.categorical import Categorical
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
|
|
from atari_wrappers import (
|
|
|
ClipRewardEnv,
|
|
|
EpisodicLifeEnv,
|
|
|
FireResetEnv,
|
|
|
MaxAndSkipEnv,
|
|
|
NoopResetEnv,
|
|
|
)
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class Args:
|
|
|
exp_name: str = os.path.basename(__file__)[: -len(".py")]
|
|
|
"""the name of this experiment"""
|
|
|
seed: int = 1
|
|
|
"""seed of the experiment"""
|
|
|
torch_deterministic: bool = True
|
|
|
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
|
|
|
cuda: bool = True
|
|
|
"""if toggled, cuda will be enabled by default"""
|
|
|
track: bool = False
|
|
|
"""if toggled, this experiment will be tracked with Weights and Biases"""
|
|
|
wandb_project_name: str = "cleanRL"
|
|
|
"""the wandb's project name"""
|
|
|
wandb_entity: str = None
|
|
|
"""the entity (team) of wandb's project"""
|
|
|
capture_video: bool = True
|
|
|
"""whether to capture videos of the agent performances (check out `videos` folder)"""
|
|
|
|
|
|
|
|
|
env_id: str = "SpaceInvadersNoFrameskip-v4"
|
|
|
"""the id of the environment"""
|
|
|
|
|
|
total_timesteps: int = 10_000_000
|
|
|
learning_rate: float = 2.5e-4
|
|
|
"""the learning rate of the optimizer"""
|
|
|
num_envs: int = 8
|
|
|
"""the number of parallel game environments"""
|
|
|
|
|
|
num_steps: int = 128
|
|
|
"""the number of steps to run in each environment per policy rollout"""
|
|
|
anneal_lr: bool = True
|
|
|
"""Toggle learning rate annealing for policy and value networks"""
|
|
|
gamma: float = 0.99
|
|
|
"""the discount factor gamma"""
|
|
|
gae_lambda: float = 0.95
|
|
|
"""the lambda for the general advantage estimation"""
|
|
|
num_minibatches: int = 4
|
|
|
"""the number of mini-batches"""
|
|
|
update_epochs: int = 4
|
|
|
"""the K epochs to update the policy"""
|
|
|
norm_adv: bool = True
|
|
|
"""Toggles advantages normalization"""
|
|
|
clip_coef: float = 0.1
|
|
|
"""the surrogate clipping coefficient"""
|
|
|
clip_vloss: bool = True
|
|
|
"""Toggles whether or not to use a clipped loss for the value function, as per the paper."""
|
|
|
ent_coef: float = 0.01
|
|
|
"""coefficient of the entropy"""
|
|
|
vf_coef: float = 0.5
|
|
|
"""coefficient of the value function"""
|
|
|
max_grad_norm: float = 0.5
|
|
|
"""the maximum norm for the gradient clipping"""
|
|
|
target_kl: float = None
|
|
|
"""the target KL divergence threshold"""
|
|
|
|
|
|
|
|
|
batch_size: int = 0
|
|
|
"""the batch size (computed in runtime)"""
|
|
|
minibatch_size: int = 0
|
|
|
"""the mini-batch size (computed in runtime)"""
|
|
|
num_iterations: int = 0
|
|
|
"""the number of iterations (computed in runtime)"""
|
|
|
|
|
|
|
|
|
def make_env(env_id, idx, capture_video, run_name):
|
|
|
def thunk():
|
|
|
if capture_video and idx == 0:
|
|
|
env = gym.make(env_id, render_mode="rgb_array")
|
|
|
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
|
|
|
else:
|
|
|
env = gym.make(env_id)
|
|
|
env = gym.wrappers.RecordEpisodeStatistics(env)
|
|
|
env = NoopResetEnv(env, noop_max=30)
|
|
|
env = MaxAndSkipEnv(env, skip=4)
|
|
|
env = EpisodicLifeEnv(env)
|
|
|
if "FIRE" in env.unwrapped.get_action_meanings():
|
|
|
env = FireResetEnv(env)
|
|
|
env = ClipRewardEnv(env)
|
|
|
env = gym.wrappers.ResizeObservation(env, (84, 84))
|
|
|
env = gym.wrappers.GrayscaleObservation(env)
|
|
|
env = gym.wrappers.FrameStackObservation(env, 4)
|
|
|
return env
|
|
|
|
|
|
return thunk
|
|
|
|
|
|
|
|
|
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
|
|
|
torch.nn.init.orthogonal_(layer.weight, std)
|
|
|
torch.nn.init.constant_(layer.bias, bias_const)
|
|
|
return layer
|
|
|
|
|
|
|
|
|
class Agent(nn.Module):
|
|
|
def __init__(self, envs):
|
|
|
super().__init__()
|
|
|
self.network = nn.Sequential(
|
|
|
layer_init(nn.Conv2d(4, 32, 8, stride=4)),
|
|
|
nn.ReLU(),
|
|
|
layer_init(nn.Conv2d(32, 64, 4, stride=2)),
|
|
|
nn.ReLU(),
|
|
|
layer_init(nn.Conv2d(64, 64, 3, stride=1)),
|
|
|
nn.ReLU(),
|
|
|
nn.Flatten(),
|
|
|
layer_init(nn.Linear(64 * 7 * 7, 512)),
|
|
|
nn.ReLU(),
|
|
|
)
|
|
|
self.actor = layer_init(nn.Linear(512, envs.single_action_space.n), std=0.01)
|
|
|
self.critic = layer_init(nn.Linear(512, 1), std=1)
|
|
|
|
|
|
def get_value(self, x):
|
|
|
return self.critic(self.network(x / 255.0))
|
|
|
|
|
|
def get_action_and_value(self, x, action=None):
|
|
|
hidden = self.network(x / 255.0)
|
|
|
logits = self.actor(hidden)
|
|
|
probs = Categorical(logits=logits)
|
|
|
if action is None:
|
|
|
action = probs.sample()
|
|
|
return action, probs.log_prob(action), probs.entropy(), self.critic(hidden)
|
|
|
|
|
|
|
|
|
if __name__ == "":
|
|
|
args = tyro.cli(Args)
|
|
|
args.batch_size = int(args.num_envs * args.num_steps)
|
|
|
args.minibatch_size = int(args.batch_size // args.num_minibatches)
|
|
|
args.num_iterations = args.total_timesteps // args.batch_size
|
|
|
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
|
|
|
|
|
|
if args.track:
|
|
|
import wandb
|
|
|
wandb.init(
|
|
|
project=args.wandb_project_name,
|
|
|
entity=args.wandb_entity,
|
|
|
sync_tensorboard=True,
|
|
|
config=vars(args),
|
|
|
name=run_name,
|
|
|
monitor_gym=True,
|
|
|
save_code=True,
|
|
|
)
|
|
|
|
|
|
writer = SummaryWriter(f"runs/{run_name}")
|
|
|
writer.add_text(
|
|
|
"hyperparameters",
|
|
|
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
|
|
|
)
|
|
|
|
|
|
|
|
|
random.seed(args.seed)
|
|
|
np.random.seed(args.seed)
|
|
|
torch.manual_seed(args.seed)
|
|
|
torch.backends.cudnn.deterministic = args.torch_deterministic
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
|
|
|
print("Using device:", device)
|
|
|
|
|
|
|
|
|
envs = gym.vector.SyncVectorEnv(
|
|
|
[make_env(args.env_id, i, args.capture_video, run_name) for i in range(args.num_envs)],
|
|
|
)
|
|
|
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"
|
|
|
|
|
|
agent = Agent(envs).to(device)
|
|
|
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)
|
|
|
|
|
|
|
|
|
obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
|
|
|
actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
|
|
|
logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device)
|
|
|
rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
|
|
|
dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
|
|
|
values = torch.zeros((args.num_steps, args.num_envs)).to(device)
|
|
|
|
|
|
|
|
|
global_step = 0
|
|
|
start_time = time.time()
|
|
|
next_obs, _ = envs.reset(seed=args.seed)
|
|
|
next_obs = torch.Tensor(next_obs).to(device)
|
|
|
next_done = torch.zeros(args.num_envs).to(device)
|
|
|
|
|
|
|
|
|
episodic_returns = []
|
|
|
num_episodes_completed = 0
|
|
|
|
|
|
for iteration in range(1, args.num_iterations + 1):
|
|
|
|
|
|
if global_step % 500_000 == 0 and global_step > 0:
|
|
|
model_path = f"{run_name}_agent_{global_step}.pth"
|
|
|
torch.save(agent.state_dict(), model_path)
|
|
|
if args.track:
|
|
|
import wandb
|
|
|
wandb.save(model_path)
|
|
|
|
|
|
|
|
|
if args.anneal_lr:
|
|
|
frac = 1.0 - (iteration - 1.0) / args.num_iterations
|
|
|
lrnow = frac * args.learning_rate
|
|
|
optimizer.param_groups[0]["lr"] = lrnow
|
|
|
|
|
|
for step in range(0, args.num_steps):
|
|
|
global_step += args.num_envs
|
|
|
obs[step] = next_obs
|
|
|
dones[step] = next_done
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
action, logprob, _, value = agent.get_action_and_value(next_obs)
|
|
|
values[step] = value.flatten()
|
|
|
actions[step] = action
|
|
|
logprobs[step] = logprob
|
|
|
|
|
|
|
|
|
next_obs, reward, terminations, truncations, infos = envs.step(action.cpu().numpy())
|
|
|
next_done = np.logical_or(terminations, truncations)
|
|
|
rewards[step] = torch.tensor(reward).to(device).view(-1)
|
|
|
next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(next_done).to(device)
|
|
|
|
|
|
if "final_info" in infos:
|
|
|
for info in infos["final_info"]:
|
|
|
if info and "episode" in info:
|
|
|
episodic_returns.append(info["episode"]["r"])
|
|
|
num_episodes_completed += 1
|
|
|
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
|
|
|
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
|
|
|
if args.track:
|
|
|
import wandb
|
|
|
wandb.log({
|
|
|
"charts/episodic_return": info["episode"]["r"],
|
|
|
"charts/episodic_length": info["episode"]["l"],
|
|
|
"global_step": global_step
|
|
|
})
|
|
|
print(f"[Step {global_step}] Episode {num_episodes_completed} finished | Return: {info['episode']['r']:.2f} | Length: {info['episode']['l']}")
|
|
|
|
|
|
if global_step % 100_000 == 0:
|
|
|
if episodic_returns:
|
|
|
recent_returns = episodic_returns[-10:] if len(episodic_returns) >= 10 else episodic_returns
|
|
|
avg_return = np.mean(recent_returns)
|
|
|
print(f"\n{'='*60}")
|
|
|
print(f"[Step {global_step}] Progress Update:")
|
|
|
print(f" Episodes completed: {num_episodes_completed}")
|
|
|
print(f" Avg return (last {len(recent_returns)} episodes): {avg_return:.2f}")
|
|
|
print(f"{'='*60}\n")
|
|
|
else:
|
|
|
print(f"\n[Step {global_step}] No episodes completed yet (this is normal in early training)")
|
|
|
|
|
|
|
|
|
if episodic_returns:
|
|
|
avg_reward = np.mean(episodic_returns[-10:])
|
|
|
running_score = np.mean(episodic_returns[-100:]) if len(episodic_returns) >= 100 else np.mean(episodic_returns)
|
|
|
print(f"Episode {num_episodes_completed:5d} | Step {global_step:8,} | Score: {episodic_returns[-1]:8.2f} | Avg(10): {avg_reward:8.2f} | Running: {running_score:8.2f}")
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
next_value = agent.get_value(next_obs).reshape(1, -1)
|
|
|
advantages = torch.zeros_like(rewards).to(device)
|
|
|
lastgaelam = 0
|
|
|
for t in reversed(range(args.num_steps)):
|
|
|
if t == args.num_steps - 1:
|
|
|
nextnonterminal = 1.0 - next_done
|
|
|
nextvalues = next_value
|
|
|
else:
|
|
|
nextnonterminal = 1.0 - dones[t + 1]
|
|
|
nextvalues = values[t + 1]
|
|
|
delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
|
|
|
advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
|
|
|
returns = advantages + values
|
|
|
|
|
|
|
|
|
b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)
|
|
|
b_logprobs = logprobs.reshape(-1)
|
|
|
b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
|
|
|
b_advantages = advantages.reshape(-1)
|
|
|
b_returns = returns.reshape(-1)
|
|
|
b_values = values.reshape(-1)
|
|
|
|
|
|
|
|
|
b_inds = np.arange(args.batch_size)
|
|
|
clipfracs = []
|
|
|
for epoch in range(args.update_epochs):
|
|
|
np.random.shuffle(b_inds)
|
|
|
for start in range(0, args.batch_size, args.minibatch_size):
|
|
|
end = start + args.minibatch_size
|
|
|
mb_inds = b_inds[start:end]
|
|
|
|
|
|
_, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds])
|
|
|
logratio = newlogprob - b_logprobs[mb_inds]
|
|
|
ratio = logratio.exp()
|
|
|
|
|
|
with torch.no_grad():
|
|
|
old_approx_kl = (-logratio).mean()
|
|
|
approx_kl = ((ratio - 1) - logratio).mean()
|
|
|
clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()]
|
|
|
|
|
|
mb_advantages = b_advantages[mb_inds]
|
|
|
if args.norm_adv:
|
|
|
mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)
|
|
|
|
|
|
|
|
|
pg_loss1 = -mb_advantages * ratio
|
|
|
pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
|
|
|
pg_loss = torch.max(pg_loss1, pg_loss2).mean()
|
|
|
|
|
|
|
|
|
newvalue = newvalue.view(-1)
|
|
|
if args.clip_vloss:
|
|
|
v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2
|
|
|
v_clipped = b_values[mb_inds] + torch.clamp(
|
|
|
newvalue - b_values[mb_inds],
|
|
|
-args.clip_coef,
|
|
|
args.clip_coef,
|
|
|
)
|
|
|
v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2
|
|
|
v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
|
|
|
v_loss = 0.5 * v_loss_max.mean()
|
|
|
else:
|
|
|
v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean()
|
|
|
|
|
|
entropy_loss = entropy.mean()
|
|
|
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
loss.backward()
|
|
|
nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
|
|
|
optimizer.step()
|
|
|
|
|
|
if args.target_kl is not None and approx_kl > args.target_kl:
|
|
|
break
|
|
|
|
|
|
y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
|
|
|
var_y = np.var(y_true)
|
|
|
explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
|
|
|
|
|
|
|
|
|
writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
|
|
|
writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
|
|
|
writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
|
|
|
writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
|
|
|
writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step)
|
|
|
writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
|
|
|
writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step)
|
|
|
writer.add_scalar("losses/explained_variance", explained_var, global_step)
|
|
|
|
|
|
avg_return = np.mean(episodic_returns[-10:]) if episodic_returns else 0.0
|
|
|
sps = int(global_step / (time.time() - start_time))
|
|
|
print(f"Iter {iteration} | SPS: {sps} | VLoss: {v_loss.item():.4f} | PLoss: {pg_loss.item():.4f} | Ent: {entropy_loss.item():.4f} | ExpVar: {explained_var:.4f} | AvgRet: {avg_return:.2f} | Episodes: {num_episodes_completed}")
|
|
|
writer.add_scalar("charts/SPS", sps, global_step)
|
|
|
if args.track:
|
|
|
import wandb
|
|
|
wandb.log({
|
|
|
"charts/learning_rate": optimizer.param_groups[0]["lr"],
|
|
|
"losses/value_loss": v_loss.item(),
|
|
|
"losses/policy_loss": pg_loss.item(),
|
|
|
"losses/entropy": entropy_loss.item(),
|
|
|
"losses/old_approx_kl": old_approx_kl.item(),
|
|
|
"losses/approx_kl": approx_kl.item(),
|
|
|
"losses/clipfrac": np.mean(clipfracs),
|
|
|
"losses/explained_variance": explained_var,
|
|
|
"charts/SPS": sps,
|
|
|
"avg_return": avg_return,
|
|
|
"global_step": global_step,
|
|
|
"episodes_completed": num_episodes_completed
|
|
|
})
|
|
|
|
|
|
|
|
|
print(f"\n{'='*60}")
|
|
|
print(f"Training Complete!")
|
|
|
print(f"Total episodes completed: {num_episodes_completed}")
|
|
|
if episodic_returns:
|
|
|
final_avg = np.mean(episodic_returns[-100:]) if len(episodic_returns) >= 100 else np.mean(episodic_returns)
|
|
|
print(f"Final average return: {final_avg:.2f}")
|
|
|
print(f"{'='*60}\n")
|
|
|
|
|
|
|
|
|
final_model_path = f"{run_name}_agent_final.pth"
|
|
|
torch.save(agent.state_dict(), final_model_path)
|
|
|
if args.track:
|
|
|
import wandb
|
|
|
wandb.save(final_model_path)
|
|
|
envs.close()
|
|
|
writer.close()
|
|
|
|
|
|
def evaluate_agent(model_path, env_id="SpaceInvadersNoFrameskip-v4", num_episodes=10, seed=1, render=False, video_dir="eval_videos2"):
|
|
|
"""
|
|
|
Loads a PPO agent from model_path and evaluates it for num_episodes.
|
|
|
Each episode is a full game (all lives), not per life.
|
|
|
Saves videos to video_dir if specified.
|
|
|
"""
|
|
|
import os
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
os.makedirs(video_dir, exist_ok=True)
|
|
|
env = gym.make(env_id, render_mode="rgb_array")
|
|
|
|
|
|
env = gym.wrappers.RecordEpisodeStatistics(env)
|
|
|
env = NoopResetEnv(env, noop_max=30)
|
|
|
env = MaxAndSkipEnv(env)
|
|
|
|
|
|
if "FIRE" in env.unwrapped.get_action_meanings():
|
|
|
env = FireResetEnv(env)
|
|
|
|
|
|
env = gym.wrappers.ResizeObservation(env, (84, 84))
|
|
|
env = gym.wrappers.GrayscaleObservation(env)
|
|
|
env = gym.wrappers.FrameStackObservation(env, 4)
|
|
|
env.action_space.seed(seed)
|
|
|
env.reset(seed=seed)
|
|
|
|
|
|
|
|
|
dummy_envs = gym.vector.SyncVectorEnv([lambda: env])
|
|
|
agent = Agent(dummy_envs).to(device)
|
|
|
agent.load_state_dict(torch.load(model_path, map_location=device))
|
|
|
agent.eval()
|
|
|
|
|
|
episode_rewards = []
|
|
|
for ep in range(num_episodes):
|
|
|
obs, _ = env.reset()
|
|
|
obs = torch.tensor(obs, dtype=torch.float32, device=device).unsqueeze(0)
|
|
|
done = False
|
|
|
total_reward = 0.0
|
|
|
while not done:
|
|
|
with torch.no_grad():
|
|
|
action, _, _, _ = agent.get_action_and_value(obs)
|
|
|
obs, reward, terminated, truncated, info = env.step(action.cpu().numpy()[0])
|
|
|
obs = torch.tensor(obs, dtype=torch.float32, device=device).unsqueeze(0)
|
|
|
done = terminated or truncated
|
|
|
total_reward += reward
|
|
|
episode_rewards.append(total_reward)
|
|
|
print(f"Episode {ep+1}: Reward = {total_reward:.2f}")
|
|
|
|
|
|
|
|
|
try:
|
|
|
import matplotlib.pyplot as plt
|
|
|
plt.figure(figsize=(10, 5))
|
|
|
plt.plot(episode_rewards, marker='o')
|
|
|
plt.title(f'Rewards over {num_episodes} Episodes')
|
|
|
plt.xlabel('Episode')
|
|
|
plt.ylabel('Reward')
|
|
|
plt.grid(True)
|
|
|
plt.tight_layout()
|
|
|
plt.show()
|
|
|
except ImportError:
|
|
|
print("matplotlib is not installed. Install it to see reward plots.")
|
|
|
|
|
|
avg_reward = np.mean(episode_rewards)
|
|
|
print(f"\nEvaluated {num_episodes} episodes | Average Reward: {avg_reward:.2f}")
|
|
|
env.close()
|
|
|
return avg_reward
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
model_path = r"D:\Fall25\RL\Assignment-5\Reinforcement-Learning\Assignment-5\SpaceInvadersNoFrameskip-v4__PPO_atari__1__1766540445_agent_final.pth"
|
|
|
env_id = "SpaceInvadersNoFrameskip-v4"
|
|
|
num_episodes = 100
|
|
|
seed = 1
|
|
|
render = False
|
|
|
|
|
|
evaluate_agent(
|
|
|
model_path=model_path,
|
|
|
env_id=env_id,
|
|
|
num_episodes=num_episodes,
|
|
|
seed=seed,
|
|
|
render=render
|
|
|
) |