Muqeeth's picture
Add files using upload-large-folder tool
1c8c60e verified
import torch
def whiten_advantages(advantages: torch.Tensor) -> torch.Tensor:
"""
Whitens the advantages.
"""
whitened_advantages = (advantages - torch.mean(advantages)) / (
torch.std(advantages) + 1e-9
)
return whitened_advantages
def whiten_advantages_time_step_wise(
advantages: torch.Tensor, # (B, T)
) -> torch.Tensor:
"""
Whitens the advantages.
"""
assert advantages.dim() == 2, "Wrong dimensions."
whitened_advantages_time_step_wise = (
advantages - advantages.mean(dim=0, keepdim=True)
) / (advantages.std(dim=0, keepdim=True) + 1e-9)
return whitened_advantages_time_step_wise
def get_discounted_state_visitation_credits(
credits: torch.Tensor, discount_factor: float # (B, T)
) -> torch.Tensor:
"""
Computes discounted state visitation credits for a sequence of credits.
"""
return credits * (
discount_factor ** torch.arange(credits.shape[1], device=credits.device)
)
def get_discounted_returns(
rewards: torch.Tensor, # (B, T)
discount_factor: float,
) -> torch.Tensor:
"""
Computes Monte Carlo discounted returns for a sequence of rewards.
Args:
rewards (torch.Tensor): Array of rewards for each timestep.
Returns:
torch.Tensor: Array of discounted returns.
"""
assert rewards.dim() == 2, "Wrong dimensions."
B, T = rewards.shape
discounted_returns = torch.zeros_like(rewards)
accumulator = torch.zeros(B, device=rewards.device, dtype=rewards.dtype)
for t in reversed(range(T)):
accumulator = rewards[:, t] + discount_factor * accumulator
discounted_returns[:, t] = accumulator
return discounted_returns
def get_rloo_credits(credits: torch.Tensor): # (B, S)
assert credits.dim() == 2, "Wrong dimensions."
rloo_baselines = torch.zeros_like(credits)
n = credits.shape[0]
if n == 1:
return credits, rloo_baselines
rloo_baselines = (torch.sum(credits, dim=0, keepdim=True) - credits) / (n - 1)
rloo_credits = credits - rloo_baselines
return rloo_credits, rloo_baselines
def get_generalized_advantage_estimates(
rewards: torch.Tensor, # (B, T)
value_estimates: torch.Tensor, # (B, T+1)
discount_factor: float,
lambda_coef: float,
) -> torch.Tensor:
"""
Computes Generalized Advantage Estimates (GAE) for a sequence of rewards and value estimates.
See https://arxiv.org/pdf/1506.02438 for details.
Returns:
torch.Tensor: Array of GAE values.
"""
assert rewards.dim() == value_estimates.dim() == 2, "Wrong dimensions."
assert (
rewards.shape[0] == value_estimates.shape[0]
), f"Got shapes {rewards.shape} and {value_estimates.shape} of rewards and value estimates."
assert (
rewards.shape[1] == value_estimates.shape[1] - 1
), f"Got shapes {rewards.shape} and {value_estimates.shape} of rewards and value estimates."
T = rewards.shape[1]
tds = rewards + discount_factor * value_estimates[:, 1:] - value_estimates[:, :-1]
gaes = torch.zeros_like(tds)
acc = 0.0
for t in reversed(range(T)):
acc = tds[:, t] + lambda_coef * discount_factor * acc
gaes[:, t] = acc
return gaes
def get_advantage_alignment_weights(
advantages: torch.Tensor, # (B, T)
exclude_k_equals_t: bool,
gamma: float,
discount_t: bool,
) -> torch.Tensor:
"""
The advantage alignment credit is calculated as
\[
A^*(s_t, a_t, b_t) = A^1(s_t, a_t, b_t) + \beta \cdot
\left( \sum_{k < t} \gamma^{t-k} A^1(s_k, a_k, b_k) \right)
A^2(s_t, a_t, b_t)
\]
Here, the weights are defined as \( \beta \cdot
\left( \sum_{k < t} \gamma^{t-k} A^1(s_k, a_k, b_k) \)
"""
T = advantages.shape[1]
discounted_advantages = advantages * (
gamma * torch.ones((1, T), device=advantages.device)
) ** (-torch.arange(0, T, 1, device=advantages.device))
if exclude_k_equals_t:
sub = torch.eye(T, device=advantages.device)
else:
sub = torch.zeros((T, T), device=advantages.device)
# Identity is for \( k < t \), remove for \( k \leq t \)
ad_align_weights = discounted_advantages @ (
torch.triu(torch.ones((T, T), device=advantages.device)) - sub
)
t_discounts = (gamma * torch.ones((1, T), device=advantages.device)) ** (
torch.arange(0, T, 1, device=advantages.device)
)
ad_align_weights = t_discounts * ad_align_weights
if discount_t:
time_discounted_advantages = advantages * (
gamma * torch.ones((1, T), device=advantages.device)
) ** (torch.arange(0, T, 1, device=advantages.device))
ad_align_weights = ad_align_weights - advantages + time_discounted_advantages
return ad_align_weights
def get_advantage_alignment_credits(
a1: torch.Tensor, # (B, S)
a1_alternative: torch.Tensor, # (B, S, A)
a2: torch.Tensor, # (B, S)
exclude_k_equals_t: bool,
beta: float,
gamma: float = 1.0,
use_old_ad_align: bool = False,
use_sign: bool = False,
clipping: float | None = None,
use_time_regularization: bool = False,
force_coop_first_step: bool = False,
use_variance_regularization: bool = False,
rloo_branch: bool = False,
reuse_baseline: bool = False,
mean_normalize_ad_align: bool = False,
whiten_adalign_advantages: bool = False,
whiten_adalign_advantages_time_step_wise: bool = False,
discount_t: bool = False,
) -> torch.Tensor:
"""
Calculate the advantage alignment credits with vectorization, as described in https://arxiv.org/abs/2406.14662.
Recall that the advantage opponent shaping term of the AdAlign policy gradient is:
\[
\beta \mathbb{E}_{\substack{
\tau \sim \text{Pr}_{\mu}^{\pi^1, \pi^2} \\
a_t' \sim \pi^1(\cdot \mid s_t)
}}
\left[\sum_{t=0}^\infty \gamma^{t}\left( \sum_{k\leq t} A^1(s_k,a^{\prime}_k,b_k) \right) A^{2}(s_t,a_t, b_t)\nabla_{\theta^1}\text{log } \pi^1(a_t|s_t) \right]
\]
This method computes the following:
\[
Credit(s_t, a_t, b_t) = \gamma^t \left[ A^1(s_t, a_t, b_t) + \beta \left( \sum_{k\leq t} A^1(s_k,a^{\prime}_k,b_k) \right) A^{2}(s_t,a_t, b_t) \right]
\]
Args:
a1: Advantages of the main trajectories for the current agent.
a1_alternative: Advantages of the alternative trajectories for the current agent.
a2: Advantages of the main trajectories for the other agent.
discount_factor: Discount factor for the advantage alignment.
beta: Beta parameter for the advantage alignment.
gamma: Gamma parameter for the advantage alignment.
use_sign_in_ad_align: Whether to use sign in the advantage alignment.
Returns:
torch.Tensor: The advantage alignment credits.
"""
assert a1.dim() == a2.dim() == 2, "Advantages must be of shape (B, S)"
if a1_alternative is not None:
assert (
a1_alternative.dim() == 3
), "Alternative advantages must be of shape (B, S, A)"
B, T, A = a1_alternative.shape
else:
B, T = a1.shape
assert a1.shape == a2.shape, "Not the same shape"
sub_tensors = {}
if use_old_ad_align:
ad_align_weights = get_advantage_alignment_weights(
advantages=a1,
exclude_k_equals_t=exclude_k_equals_t,
gamma=gamma,
discount_t=discount_t,
)
sub_tensors["ad_align_weights_prev"] = ad_align_weights
if exclude_k_equals_t:
ad_align_weights = gamma * ad_align_weights
else:
assert a1_alternative is not None, "Alternative advantages must be provided"
if rloo_branch:
a1_alternative = torch.cat([a1.unsqueeze(2), a1_alternative], dim=2)
a1_alternative = a1_alternative.mean(dim=2)
# print(f"a1_alternative: {a1_alternative}, a1: {a1}\n")
a1, baseline = get_rloo_credits(a1)
if reuse_baseline:
a1_alternative = a1_alternative - baseline
else:
a1_alternative, _ = get_rloo_credits(a1_alternative)
assert a1.shape == a1_alternative.shape, "Not the same shape"
ad_align_weights = get_advantage_alignment_weights(
advantages=a1_alternative,
exclude_k_equals_t=exclude_k_equals_t,
gamma=gamma,
)
sub_tensors["ad_align_weights"] = ad_align_weights
# Use sign
if use_sign:
assert beta == 1.0, "beta should be 1.0 when using sign"
positive_signs = ad_align_weights > 0
negative_signs = ad_align_weights < 0
ad_align_weights[positive_signs] = 1
ad_align_weights[negative_signs] = -1
sub_tensors["ad_align_weights_sign"] = ad_align_weights
# (rest are 0)
###################
# Process weights
###################
# Use clipping
if clipping not in [0.0, None]:
upper_mask = ad_align_weights > 1
lower_mask = ad_align_weights < -1
ad_align_weights = torch.clip(
ad_align_weights,
-clipping,
clipping,
)
clipping_ratio = (
torch.sum(upper_mask) + torch.sum(lower_mask)
) / upper_mask.size
sub_tensors["clipped_ad_align_weights"] = ad_align_weights
# 1/1+t Regularization
if use_time_regularization:
t_values = torch.arange(1, T + 1).to(ad_align_weights.device)
ad_align_weights = ad_align_weights / t_values
sub_tensors["time_regularized_ad_align_weights"] = ad_align_weights
# Use coop on t=0
if force_coop_first_step:
ad_align_weights[:, 0] = 1
sub_tensors["coop_first_step_ad_align_weights"] = ad_align_weights
# # Normalize alignment terms (across same time step)
# if use_variance_regularization_in_ad_align:
# # TODO: verify
# reg_coef = torch.std(a1[:, -1]) / (torch.std(opp_shaping_terms[:, -1]) + 1e-9)
# opp_shaping_terms *= reg_coef
####################################
# Compose elements together
####################################
opp_shaping_terms = beta * ad_align_weights * a2
sub_tensors["ad_align_opp_shaping_terms"] = opp_shaping_terms
credits = a1 + opp_shaping_terms
if mean_normalize_ad_align:
credits = credits - credits.mean(dim=0)
sub_tensors["mean_normalized_ad_align_credits"] = credits
if whiten_adalign_advantages:
credits = (credits - credits.mean()) / (credits.std() + 1e-9)
sub_tensors["whitened_ad_align_credits"] = credits
if whiten_adalign_advantages_time_step_wise:
credits = (credits - credits.mean(dim=0, keepdim=True)) / (
credits.std(dim=0, keepdim=True) + 1e-9
)
sub_tensors["whitened_ad_align_credits_time_step_wise"] = credits
sub_tensors["final_ad_align_credits"] = credits
return credits, sub_tensors