File size: 10,971 Bytes
1c8c60e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 | import torch
def whiten_advantages(advantages: torch.Tensor) -> torch.Tensor:
"""
Whitens the advantages.
"""
whitened_advantages = (advantages - torch.mean(advantages)) / (
torch.std(advantages) + 1e-9
)
return whitened_advantages
def whiten_advantages_time_step_wise(
advantages: torch.Tensor, # (B, T)
) -> torch.Tensor:
"""
Whitens the advantages.
"""
assert advantages.dim() == 2, "Wrong dimensions."
whitened_advantages_time_step_wise = (
advantages - advantages.mean(dim=0, keepdim=True)
) / (advantages.std(dim=0, keepdim=True) + 1e-9)
return whitened_advantages_time_step_wise
def get_discounted_state_visitation_credits(
credits: torch.Tensor, discount_factor: float # (B, T)
) -> torch.Tensor:
"""
Computes discounted state visitation credits for a sequence of credits.
"""
return credits * (
discount_factor ** torch.arange(credits.shape[1], device=credits.device)
)
def get_discounted_returns(
rewards: torch.Tensor, # (B, T)
discount_factor: float,
) -> torch.Tensor:
"""
Computes Monte Carlo discounted returns for a sequence of rewards.
Args:
rewards (torch.Tensor): Array of rewards for each timestep.
Returns:
torch.Tensor: Array of discounted returns.
"""
assert rewards.dim() == 2, "Wrong dimensions."
B, T = rewards.shape
discounted_returns = torch.zeros_like(rewards)
accumulator = torch.zeros(B, device=rewards.device, dtype=rewards.dtype)
for t in reversed(range(T)):
accumulator = rewards[:, t] + discount_factor * accumulator
discounted_returns[:, t] = accumulator
return discounted_returns
def get_rloo_credits(credits: torch.Tensor): # (B, S)
assert credits.dim() == 2, "Wrong dimensions."
rloo_baselines = torch.zeros_like(credits)
n = credits.shape[0]
if n == 1:
return credits, rloo_baselines
rloo_baselines = (torch.sum(credits, dim=0, keepdim=True) - credits) / (n - 1)
rloo_credits = credits - rloo_baselines
return rloo_credits, rloo_baselines
def get_generalized_advantage_estimates(
rewards: torch.Tensor, # (B, T)
value_estimates: torch.Tensor, # (B, T+1)
discount_factor: float,
lambda_coef: float,
) -> torch.Tensor:
"""
Computes Generalized Advantage Estimates (GAE) for a sequence of rewards and value estimates.
See https://arxiv.org/pdf/1506.02438 for details.
Returns:
torch.Tensor: Array of GAE values.
"""
assert rewards.dim() == value_estimates.dim() == 2, "Wrong dimensions."
assert (
rewards.shape[0] == value_estimates.shape[0]
), f"Got shapes {rewards.shape} and {value_estimates.shape} of rewards and value estimates."
assert (
rewards.shape[1] == value_estimates.shape[1] - 1
), f"Got shapes {rewards.shape} and {value_estimates.shape} of rewards and value estimates."
T = rewards.shape[1]
tds = rewards + discount_factor * value_estimates[:, 1:] - value_estimates[:, :-1]
gaes = torch.zeros_like(tds)
acc = 0.0
for t in reversed(range(T)):
acc = tds[:, t] + lambda_coef * discount_factor * acc
gaes[:, t] = acc
return gaes
def get_advantage_alignment_weights(
advantages: torch.Tensor, # (B, T)
exclude_k_equals_t: bool,
gamma: float,
discount_t: bool,
) -> torch.Tensor:
"""
The advantage alignment credit is calculated as
\[
A^*(s_t, a_t, b_t) = A^1(s_t, a_t, b_t) + \beta \cdot
\left( \sum_{k < t} \gamma^{t-k} A^1(s_k, a_k, b_k) \right)
A^2(s_t, a_t, b_t)
\]
Here, the weights are defined as \( \beta \cdot
\left( \sum_{k < t} \gamma^{t-k} A^1(s_k, a_k, b_k) \)
"""
T = advantages.shape[1]
discounted_advantages = advantages * (
gamma * torch.ones((1, T), device=advantages.device)
) ** (-torch.arange(0, T, 1, device=advantages.device))
if exclude_k_equals_t:
sub = torch.eye(T, device=advantages.device)
else:
sub = torch.zeros((T, T), device=advantages.device)
# Identity is for \( k < t \), remove for \( k \leq t \)
ad_align_weights = discounted_advantages @ (
torch.triu(torch.ones((T, T), device=advantages.device)) - sub
)
t_discounts = (gamma * torch.ones((1, T), device=advantages.device)) ** (
torch.arange(0, T, 1, device=advantages.device)
)
ad_align_weights = t_discounts * ad_align_weights
if discount_t:
time_discounted_advantages = advantages * (
gamma * torch.ones((1, T), device=advantages.device)
) ** (torch.arange(0, T, 1, device=advantages.device))
ad_align_weights = ad_align_weights - advantages + time_discounted_advantages
return ad_align_weights
def get_advantage_alignment_credits(
a1: torch.Tensor, # (B, S)
a1_alternative: torch.Tensor, # (B, S, A)
a2: torch.Tensor, # (B, S)
exclude_k_equals_t: bool,
beta: float,
gamma: float = 1.0,
use_old_ad_align: bool = False,
use_sign: bool = False,
clipping: float | None = None,
use_time_regularization: bool = False,
force_coop_first_step: bool = False,
use_variance_regularization: bool = False,
rloo_branch: bool = False,
reuse_baseline: bool = False,
mean_normalize_ad_align: bool = False,
whiten_adalign_advantages: bool = False,
whiten_adalign_advantages_time_step_wise: bool = False,
discount_t: bool = False,
) -> torch.Tensor:
"""
Calculate the advantage alignment credits with vectorization, as described in https://arxiv.org/abs/2406.14662.
Recall that the advantage opponent shaping term of the AdAlign policy gradient is:
\[
\beta \mathbb{E}_{\substack{
\tau \sim \text{Pr}_{\mu}^{\pi^1, \pi^2} \\
a_t' \sim \pi^1(\cdot \mid s_t)
}}
\left[\sum_{t=0}^\infty \gamma^{t}\left( \sum_{k\leq t} A^1(s_k,a^{\prime}_k,b_k) \right) A^{2}(s_t,a_t, b_t)\nabla_{\theta^1}\text{log } \pi^1(a_t|s_t) \right]
\]
This method computes the following:
\[
Credit(s_t, a_t, b_t) = \gamma^t \left[ A^1(s_t, a_t, b_t) + \beta \left( \sum_{k\leq t} A^1(s_k,a^{\prime}_k,b_k) \right) A^{2}(s_t,a_t, b_t) \right]
\]
Args:
a1: Advantages of the main trajectories for the current agent.
a1_alternative: Advantages of the alternative trajectories for the current agent.
a2: Advantages of the main trajectories for the other agent.
discount_factor: Discount factor for the advantage alignment.
beta: Beta parameter for the advantage alignment.
gamma: Gamma parameter for the advantage alignment.
use_sign_in_ad_align: Whether to use sign in the advantage alignment.
Returns:
torch.Tensor: The advantage alignment credits.
"""
assert a1.dim() == a2.dim() == 2, "Advantages must be of shape (B, S)"
if a1_alternative is not None:
assert (
a1_alternative.dim() == 3
), "Alternative advantages must be of shape (B, S, A)"
B, T, A = a1_alternative.shape
else:
B, T = a1.shape
assert a1.shape == a2.shape, "Not the same shape"
sub_tensors = {}
if use_old_ad_align:
ad_align_weights = get_advantage_alignment_weights(
advantages=a1,
exclude_k_equals_t=exclude_k_equals_t,
gamma=gamma,
discount_t=discount_t,
)
sub_tensors["ad_align_weights_prev"] = ad_align_weights
if exclude_k_equals_t:
ad_align_weights = gamma * ad_align_weights
else:
assert a1_alternative is not None, "Alternative advantages must be provided"
if rloo_branch:
a1_alternative = torch.cat([a1.unsqueeze(2), a1_alternative], dim=2)
a1_alternative = a1_alternative.mean(dim=2)
# print(f"a1_alternative: {a1_alternative}, a1: {a1}\n")
a1, baseline = get_rloo_credits(a1)
if reuse_baseline:
a1_alternative = a1_alternative - baseline
else:
a1_alternative, _ = get_rloo_credits(a1_alternative)
assert a1.shape == a1_alternative.shape, "Not the same shape"
ad_align_weights = get_advantage_alignment_weights(
advantages=a1_alternative,
exclude_k_equals_t=exclude_k_equals_t,
gamma=gamma,
)
sub_tensors["ad_align_weights"] = ad_align_weights
# Use sign
if use_sign:
assert beta == 1.0, "beta should be 1.0 when using sign"
positive_signs = ad_align_weights > 0
negative_signs = ad_align_weights < 0
ad_align_weights[positive_signs] = 1
ad_align_weights[negative_signs] = -1
sub_tensors["ad_align_weights_sign"] = ad_align_weights
# (rest are 0)
###################
# Process weights
###################
# Use clipping
if clipping not in [0.0, None]:
upper_mask = ad_align_weights > 1
lower_mask = ad_align_weights < -1
ad_align_weights = torch.clip(
ad_align_weights,
-clipping,
clipping,
)
clipping_ratio = (
torch.sum(upper_mask) + torch.sum(lower_mask)
) / upper_mask.size
sub_tensors["clipped_ad_align_weights"] = ad_align_weights
# 1/1+t Regularization
if use_time_regularization:
t_values = torch.arange(1, T + 1).to(ad_align_weights.device)
ad_align_weights = ad_align_weights / t_values
sub_tensors["time_regularized_ad_align_weights"] = ad_align_weights
# Use coop on t=0
if force_coop_first_step:
ad_align_weights[:, 0] = 1
sub_tensors["coop_first_step_ad_align_weights"] = ad_align_weights
# # Normalize alignment terms (across same time step)
# if use_variance_regularization_in_ad_align:
# # TODO: verify
# reg_coef = torch.std(a1[:, -1]) / (torch.std(opp_shaping_terms[:, -1]) + 1e-9)
# opp_shaping_terms *= reg_coef
####################################
# Compose elements together
####################################
opp_shaping_terms = beta * ad_align_weights * a2
sub_tensors["ad_align_opp_shaping_terms"] = opp_shaping_terms
credits = a1 + opp_shaping_terms
if mean_normalize_ad_align:
credits = credits - credits.mean(dim=0)
sub_tensors["mean_normalized_ad_align_credits"] = credits
if whiten_adalign_advantages:
credits = (credits - credits.mean()) / (credits.std() + 1e-9)
sub_tensors["whitened_ad_align_credits"] = credits
if whiten_adalign_advantages_time_step_wise:
credits = (credits - credits.mean(dim=0, keepdim=True)) / (
credits.std(dim=0, keepdim=True) + 1e-9
)
sub_tensors["whitened_ad_align_credits_time_step_wise"] = credits
sub_tensors["final_ad_align_credits"] = credits
return credits, sub_tensors
|