File size: 10,971 Bytes
1c8c60e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
import torch


def whiten_advantages(advantages: torch.Tensor) -> torch.Tensor:
    """
    Whitens the advantages.
    """
    whitened_advantages = (advantages - torch.mean(advantages)) / (
        torch.std(advantages) + 1e-9
    )
    return whitened_advantages


def whiten_advantages_time_step_wise(
    advantages: torch.Tensor,  # (B, T)
) -> torch.Tensor:
    """
    Whitens the advantages.
    """
    assert advantages.dim() == 2, "Wrong dimensions."
    whitened_advantages_time_step_wise = (
        advantages - advantages.mean(dim=0, keepdim=True)
    ) / (advantages.std(dim=0, keepdim=True) + 1e-9)
    return whitened_advantages_time_step_wise


def get_discounted_state_visitation_credits(
    credits: torch.Tensor, discount_factor: float  # (B, T)
) -> torch.Tensor:
    """
    Computes discounted state visitation credits for a sequence of credits.
    """
    return credits * (
        discount_factor ** torch.arange(credits.shape[1], device=credits.device)
    )


def get_discounted_returns(
    rewards: torch.Tensor,  # (B, T)
    discount_factor: float,
) -> torch.Tensor:
    """
    Computes Monte Carlo discounted returns for a sequence of rewards.

    Args:
        rewards (torch.Tensor): Array of rewards for each timestep.

    Returns:
        torch.Tensor: Array of discounted returns.
    """
    assert rewards.dim() == 2, "Wrong dimensions."
    B, T = rewards.shape
    discounted_returns = torch.zeros_like(rewards)
    accumulator = torch.zeros(B, device=rewards.device, dtype=rewards.dtype)
    for t in reversed(range(T)):
        accumulator = rewards[:, t] + discount_factor * accumulator
        discounted_returns[:, t] = accumulator
    return discounted_returns


def get_rloo_credits(credits: torch.Tensor):  # (B, S)
    assert credits.dim() == 2, "Wrong dimensions."
    rloo_baselines = torch.zeros_like(credits)
    n = credits.shape[0]
    if n == 1:
        return credits, rloo_baselines
    rloo_baselines = (torch.sum(credits, dim=0, keepdim=True) - credits) / (n - 1)
    rloo_credits = credits - rloo_baselines
    return rloo_credits, rloo_baselines


def get_generalized_advantage_estimates(
    rewards: torch.Tensor,  # (B, T)
    value_estimates: torch.Tensor,  # (B, T+1)
    discount_factor: float,
    lambda_coef: float,
) -> torch.Tensor:
    """
    Computes Generalized Advantage Estimates (GAE) for a sequence of rewards and value estimates.
    See https://arxiv.org/pdf/1506.02438 for details.


    Returns:
        torch.Tensor: Array of GAE values.
    """
    assert rewards.dim() == value_estimates.dim() == 2, "Wrong dimensions."

    assert (
        rewards.shape[0] == value_estimates.shape[0]
    ), f"Got shapes {rewards.shape} and {value_estimates.shape} of rewards and value estimates."
    assert (
        rewards.shape[1] == value_estimates.shape[1] - 1
    ), f"Got shapes {rewards.shape} and {value_estimates.shape} of rewards and value estimates."

    T = rewards.shape[1]
    tds = rewards + discount_factor * value_estimates[:, 1:] - value_estimates[:, :-1]
    gaes = torch.zeros_like(tds)
    acc = 0.0
    for t in reversed(range(T)):
        acc = tds[:, t] + lambda_coef * discount_factor * acc
        gaes[:, t] = acc
    return gaes


def get_advantage_alignment_weights(
    advantages: torch.Tensor,  # (B, T)
    exclude_k_equals_t: bool,
    gamma: float,
    discount_t: bool,
) -> torch.Tensor:
    """
    The advantage alignment credit is calculated as

    \[
        A^*(s_t, a_t, b_t) = A^1(s_t, a_t, b_t) + \beta \cdot
        \left( \sum_{k < t} \gamma^{t-k} A^1(s_k, a_k, b_k) \right)
        A^2(s_t, a_t, b_t)
    \]

    Here, the weights are defined as \( \beta \cdot
        \left( \sum_{k < t} \gamma^{t-k} A^1(s_k, a_k, b_k) \)
    """
    T = advantages.shape[1]
    discounted_advantages = advantages * (
        gamma * torch.ones((1, T), device=advantages.device)
    ) ** (-torch.arange(0, T, 1, device=advantages.device))
    if exclude_k_equals_t:
        sub = torch.eye(T, device=advantages.device)
    else:
        sub = torch.zeros((T, T), device=advantages.device)
    # Identity is for \( k < t \), remove for \( k \leq t \)
    ad_align_weights = discounted_advantages @ (
        torch.triu(torch.ones((T, T), device=advantages.device)) - sub
    )
    t_discounts = (gamma * torch.ones((1, T), device=advantages.device)) ** (
        torch.arange(0, T, 1, device=advantages.device)
    )
    ad_align_weights = t_discounts * ad_align_weights
    if discount_t:
        time_discounted_advantages = advantages * (
            gamma * torch.ones((1, T), device=advantages.device)
        ) ** (torch.arange(0, T, 1, device=advantages.device))
        ad_align_weights = ad_align_weights - advantages + time_discounted_advantages
    return ad_align_weights


def get_advantage_alignment_credits(
    a1: torch.Tensor,  # (B, S)
    a1_alternative: torch.Tensor,  # (B, S, A)
    a2: torch.Tensor,  # (B, S)
    exclude_k_equals_t: bool,
    beta: float,
    gamma: float = 1.0,
    use_old_ad_align: bool = False,
    use_sign: bool = False,
    clipping: float | None = None,
    use_time_regularization: bool = False,
    force_coop_first_step: bool = False,
    use_variance_regularization: bool = False,
    rloo_branch: bool = False,
    reuse_baseline: bool = False,
    mean_normalize_ad_align: bool = False,
    whiten_adalign_advantages: bool = False,
    whiten_adalign_advantages_time_step_wise: bool = False,
    discount_t: bool = False,
) -> torch.Tensor:
    """
    Calculate the advantage alignment credits with vectorization, as described in https://arxiv.org/abs/2406.14662.

    Recall that the advantage opponent shaping term of the AdAlign policy gradient is:
    \[
        \beta \mathbb{E}_{\substack{
        \tau \sim \text{Pr}_{\mu}^{\pi^1, \pi^2} \\
        a_t' \sim \pi^1(\cdot \mid s_t)
        }}
        \left[\sum_{t=0}^\infty  \gamma^{t}\left( \sum_{k\leq t} A^1(s_k,a^{\prime}_k,b_k) \right) A^{2}(s_t,a_t, b_t)\nabla_{\theta^1}\text{log } \pi^1(a_t|s_t) \right]
    \]

    This method computes the following:
    \[
        Credit(s_t, a_t, b_t) = \gamma^t \left[ A^1(s_t, a_t, b_t) + \beta \left( \sum_{k\leq t} A^1(s_k,a^{\prime}_k,b_k) \right) A^{2}(s_t,a_t, b_t) \right]
    \]

    Args:
        a1: Advantages of the main trajectories for the current agent.
        a1_alternative: Advantages of the alternative trajectories for the current agent.
        a2: Advantages of the main trajectories for the other agent.
        discount_factor: Discount factor for the advantage alignment.
        beta: Beta parameter for the advantage alignment.
        gamma: Gamma parameter for the advantage alignment.
        use_sign_in_ad_align: Whether to use sign in the advantage alignment.

    Returns:
        torch.Tensor: The advantage alignment credits.
    """

    assert a1.dim() == a2.dim() == 2, "Advantages must be of shape (B, S)"
    if a1_alternative is not None:
        assert (
            a1_alternative.dim() == 3
        ), "Alternative advantages must be of shape (B, S, A)"
        B, T, A = a1_alternative.shape
    else:
        B, T = a1.shape
    assert a1.shape == a2.shape, "Not the same shape"

    sub_tensors = {}

    if use_old_ad_align:
        ad_align_weights = get_advantage_alignment_weights(
            advantages=a1,
            exclude_k_equals_t=exclude_k_equals_t,
            gamma=gamma,
            discount_t=discount_t,
        )
        sub_tensors["ad_align_weights_prev"] = ad_align_weights
        if exclude_k_equals_t:
            ad_align_weights = gamma * ad_align_weights
    else:
        assert a1_alternative is not None, "Alternative advantages must be provided"
        if rloo_branch:
            a1_alternative = torch.cat([a1.unsqueeze(2), a1_alternative], dim=2)
            a1_alternative = a1_alternative.mean(dim=2)
            # print(f"a1_alternative: {a1_alternative}, a1: {a1}\n")
            a1, baseline = get_rloo_credits(a1)
            if reuse_baseline:
                a1_alternative = a1_alternative - baseline
            else:
                a1_alternative, _ = get_rloo_credits(a1_alternative)
        assert a1.shape == a1_alternative.shape, "Not the same shape"
        ad_align_weights = get_advantage_alignment_weights(
            advantages=a1_alternative,
            exclude_k_equals_t=exclude_k_equals_t,
            gamma=gamma,
        )
        sub_tensors["ad_align_weights"] = ad_align_weights

    # Use sign
    if use_sign:
        assert beta == 1.0, "beta should be 1.0 when using sign"
        positive_signs = ad_align_weights > 0
        negative_signs = ad_align_weights < 0
        ad_align_weights[positive_signs] = 1
        ad_align_weights[negative_signs] = -1
        sub_tensors["ad_align_weights_sign"] = ad_align_weights
        # (rest are 0)

    ###################
    # Process weights
    ###################

    # Use clipping
    if clipping not in [0.0, None]:
        upper_mask = ad_align_weights > 1
        lower_mask = ad_align_weights < -1

        ad_align_weights = torch.clip(
            ad_align_weights,
            -clipping,
            clipping,
        )
        clipping_ratio = (
            torch.sum(upper_mask) + torch.sum(lower_mask)
        ) / upper_mask.size
        sub_tensors["clipped_ad_align_weights"] = ad_align_weights

    # 1/1+t Regularization
    if use_time_regularization:
        t_values = torch.arange(1, T + 1).to(ad_align_weights.device)
        ad_align_weights = ad_align_weights / t_values
        sub_tensors["time_regularized_ad_align_weights"] = ad_align_weights

    # Use coop on t=0
    if force_coop_first_step:
        ad_align_weights[:, 0] = 1
        sub_tensors["coop_first_step_ad_align_weights"] = ad_align_weights
    # # Normalize alignment terms (across same time step)
    # if use_variance_regularization_in_ad_align:
    #     # TODO: verify
    #     reg_coef = torch.std(a1[:, -1]) / (torch.std(opp_shaping_terms[:, -1]) + 1e-9)
    #     opp_shaping_terms *= reg_coef

    ####################################
    # Compose elements together
    ####################################

    opp_shaping_terms = beta * ad_align_weights * a2
    sub_tensors["ad_align_opp_shaping_terms"] = opp_shaping_terms

    credits = a1 + opp_shaping_terms
    if mean_normalize_ad_align:
        credits = credits - credits.mean(dim=0)
        sub_tensors["mean_normalized_ad_align_credits"] = credits
    if whiten_adalign_advantages:
        credits = (credits - credits.mean()) / (credits.std() + 1e-9)
        sub_tensors["whitened_ad_align_credits"] = credits
    if whiten_adalign_advantages_time_step_wise:
        credits = (credits - credits.mean(dim=0, keepdim=True)) / (
            credits.std(dim=0, keepdim=True) + 1e-9
        )
        sub_tensors["whitened_ad_align_credits_time_step_wise"] = credits
    sub_tensors["final_ad_align_credits"] = credits

    return credits, sub_tensors