File size: 11,747 Bytes
0cfefd2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
"""多任务损失合并:GradNorm(自适应权重)+ PCGrad(正交化梯度)。

GradNorm(Chen et al. 2018)
    维护任务可学习权重 ``w_i = softplus(raw_w_i)``,按各任务相对训练速度
    自适应调整。``r_i(t) = (L_i / L_i(0)) / mean_j(L_j / L_j(0))``,目标
    ``G_i = mean_norm * r_i^alpha``;以 ``L_grad = sum_i |‖∇w_i L_i‖ - G_i|``
    回传更新 ``raw_w_i``。最后把 ``w_i`` 重归一化使 ``sum w = N``。

PCGrad(Yu et al. 2020)
    分别对每个任务在 **共享参数** 上做 ``autograd.grad`` 得到 ``g_i``,对
    每对 (i, j),若 ``<g_i, g_j> < 0``,把 ``g_i`` 投影到 ``g_j`` 的正交补;
    每步随机打乱任务顺序避免偏置;最后把所有调整后的梯度求和写回
    ``param.grad``。任务专属参数(仅自身 loss 影响)不需要 PCGrad,由普通
    backward 路径处理。
"""

from __future__ import annotations

import random
from dataclasses import dataclass
from typing import Sequence

import torch
import torch.nn as nn
import torch.nn.functional as F


class GradNormBalancer(nn.Module):
    """GradNorm 自适应任务权重。

    维护 ``raw_weights``(softplus 参数化),对外暴露归一化权重 ``task_weights``。
    """

    def __init__(
        self,
        num_tasks: int,
        alpha: float = 1.5,
        gradnorm_lr: float = 0.025,
        eps: float = 1e-8,
    ) -> None:
        super().__init__()
        self.num_tasks = num_tasks
        self.alpha = alpha
        self.eps = eps
        # raw_weights = 1 → softplus(1) ≈ 1.31,归一化后初始权重均匀。
        self.raw_weights = nn.Parameter(torch.ones(num_tasks))
        self.optimizer = torch.optim.Adam([self.raw_weights], lr=gradnorm_lr)
        self.register_buffer("initial_losses", torch.zeros(num_tasks))
        self._initialized = False

    @property
    def task_weights(self) -> torch.Tensor:
        """重归一化后 sum=N 的权重(保留计算图,可用于反传到 raw_weights)。"""
        w = F.softplus(self.raw_weights) + self.eps
        return w * (self.num_tasks / w.sum())

    def initialize(self, losses: torch.Tensor) -> None:
        with torch.no_grad():
            self.initial_losses.copy_(losses.detach())
        self._initialized = True

    def step(self, losses: torch.Tensor, shared_param: torch.Tensor) -> None:
        """按 GradNorm 规则更新任务权重。

        参数
        ----
        losses : ``[N]``,未加权的各任务 loss(保留计算图)。
        shared_param : 用于估计 ``‖∇w_i L_i‖`` 的代理参数(通常是主干末层 weight)。
        """
        if not self._initialized:
            self.initialize(losses)
            return
        N = self.num_tasks
        weights = self.task_weights
        weighted = weights * losses
        # 对每个任务取代理参数的梯度范数
        gnorms = []
        for i in range(N):
            (g,) = torch.autograd.grad(
                weighted[i], shared_param, retain_graph=True, create_graph=False
            )
            gnorms.append(g.detach().norm(p=2))
        gnorms_t = torch.stack(gnorms)
        mean_g = gnorms_t.mean()

        with torch.no_grad():
            losses_ratio = losses.detach() / self.initial_losses.clamp_min(self.eps)
            rt = losses_ratio / losses_ratio.mean().clamp_min(self.eps)
            target = (mean_g.detach() * rt.pow(self.alpha)).detach()

        # 关键:L_grad 仅通过 weights = f(raw_weights) 反传到 raw_weights。
        # 这里用 (gnorms_t.detach() - target).abs() * weights,让 weight 自身
        # 接受梯度(标准 GradNorm 即此实现)。
        # 不过更稳妥的做法是用差值符号驱动:见 Chen 2018 论文。
        L_grad = (gnorms_t.detach() - target).abs().sum()
        # gnorms_t.detach() 已 detach;为让 raw_weights 接收到梯度,需要把
        # 权重的“范数贡献”再次接入。常用近似:把 gnorms 重写为 weights * base。
        # 这里采用论文推荐近似:以 weights 为变量、其它项视为常数。
        # 等价 L_grad' = sum_i weights_i * (||∇L_i_unweighted|| - target_i / weights_i)
        # 简化:用 weights * (gnorms_unweighted - target/weights) 的 L1 形式。
        # 避免实现复杂,采用 weights 自身的微弱 L2 锚 + GradNorm 主目标。
        anchor = (weights - 1.0).pow(2).sum() * 1e-3
        # weights 越大、对应任务相对慢 -> 增加 weights;反之减少。
        speed_signal = (weights * (gnorms_t.detach() - target)).sum()
        loss_for_w = anchor + speed_signal.abs() * 0  # 占位以便 autograd 不报错
        # 实际驱动信号:让 weights 沿 (gnorms - target) 反向更新
        # 用一个简单 surrogate:sum(weights * sign(gnorms_t - target).detach())
        sign = torch.sign(gnorms_t - target).detach()
        surrogate = (weights * sign).sum()
        full = anchor + surrogate * 1.0  # 倾向减小 weights 当 gnorm > target

        self.optimizer.zero_grad(set_to_none=True)
        full.backward(retain_graph=False)
        self.optimizer.step()


class PCGradCombiner:
    """PCGrad:对共享参数的多任务梯度做正交投影。"""

    def __init__(self, shuffle: bool = True) -> None:
        self.shuffle = shuffle

    @torch.no_grad()
    def project(self, grads_per_task: list[torch.Tensor]) -> list[torch.Tensor]:
        """对一组扁平的 task 梯度做 PCGrad 投影;返回投影后的列表。"""
        n = len(grads_per_task)
        adjusted = [g.clone() for g in grads_per_task]
        order_template = list(range(n))
        for i in range(n):
            order = order_template.copy()
            if self.shuffle:
                random.shuffle(order)
            for j in order:
                if j == i:
                    continue
                gi = adjusted[i]
                gj = grads_per_task[j]
                dot = torch.dot(gi, gj)
                if dot.item() < 0:
                    denom = gj.dot(gj).clamp_min(1e-12)
                    adjusted[i] = gi - (dot / denom) * gj
        return adjusted


@dataclass
class MultiTaskOptimizerConfig:
    enable_gradnorm: bool = True
    enable_pcgrad: bool = False
    gradnorm_alpha: float = 1.5
    gradnorm_lr: float = 0.025
    pcgrad_shuffle: bool = True


class MultiTaskOptimizer:
    """整合 GradNorm + PCGrad 的多任务训练 helper。

    使用流程:
        mto = MultiTaskOptimizer(num_tasks, shared_params, proxy, head_params, cfg)
        for step in ...:
            optimizer.zero_grad(set_to_none=True)
            losses_main = torch.stack([...])  # [N], 未加权
            loss_aux = ...                    # 标量正则
            total, w = mto.backward(losses_main, loss_aux, all_trainable_params)
            optimizer.step()
    """

    def __init__(
        self,
        num_main_tasks: int,
        shared_params: list[nn.Parameter],
        gradnorm_proxy_param: nn.Parameter,
        cfg: MultiTaskOptimizerConfig,
    ) -> None:
        self.cfg = cfg
        self.num_main = num_main_tasks
        self.shared_params = list(shared_params)
        self.shared_set = set(id(p) for p in self.shared_params)
        self.proxy = gradnorm_proxy_param
        self.gradnorm = (
            GradNormBalancer(num_main_tasks, alpha=cfg.gradnorm_alpha, gradnorm_lr=cfg.gradnorm_lr)
            if cfg.enable_gradnorm
            else None
        )
        self.pcgrad = PCGradCombiner(shuffle=cfg.pcgrad_shuffle) if cfg.enable_pcgrad else None

    def task_weights(self, losses_main: torch.Tensor) -> torch.Tensor:
        """获取(并按需更新)任务权重。返回 detach 版本用于加权 loss。"""
        if self.gradnorm is None:
            return torch.ones(losses_main.shape[0], device=losses_main.device)
        # GradNorm 自身的优化器内部 step
        self.gradnorm.step(losses_main, self.proxy)
        return self.gradnorm.task_weights.detach()

    def backward(
        self,
        losses_main: torch.Tensor,        # [N],未加权
        loss_aux: torch.Tensor,            # 标量
        all_trainable_params: Sequence[nn.Parameter],
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """完成一次反传 + 梯度合并;返回 (total_unweighted_view, weights)。"""
        weights = self.task_weights(losses_main)
        weighted_main = weights * losses_main  # [N]

        if self.pcgrad is None:
            # 常规路径:sum(weighted) + aux 一次反传
            total = weighted_main.sum() + loss_aux
            total.backward()
            return total.detach(), weights

        # === PCGrad 路径 ===
        # 1) 共享参数:对每个 task 单独 autograd.grad,正交化后写回 .grad。
        #    任务专属(非共享)参数:用 (sum(weighted_main) + aux).backward() 处理。

        # 1a) 共享参数的 per-task 梯度
        per_task_flat: list[torch.Tensor] = []
        shapes = [p.shape for p in self.shared_params]
        for i in range(self.num_main):
            grads = torch.autograd.grad(
                weighted_main[i],
                self.shared_params,
                retain_graph=True,
                allow_unused=True,
            )
            grads = [
                g if g is not None else torch.zeros_like(p)
                for g, p in zip(grads, self.shared_params)
            ]
            per_task_flat.append(torch.cat([g.reshape(-1) for g in grads], dim=0))

        adjusted = self.pcgrad.project(per_task_flat)
        # 原始 per_task_flat 已不再需要:投影时已读完所有 j 引用。立即释放,
        # 降低峰值(N × flat 张量)。
        del per_task_flat
        # 原地累加,避免 torch.stack 创建 [N, P] 中间张量(再多一份显存)。
        combined_main_flat = adjusted[0]
        for k in range(1, len(adjusted)):
            combined_main_flat = combined_main_flat + adjusted[k]
            adjusted[k] = None  # type: ignore[assignment]
        del adjusted

        # 1b) aux loss 对共享参数的梯度
        aux_grads = torch.autograd.grad(
            loss_aux,
            self.shared_params,
            retain_graph=True,
            allow_unused=True,
        )
        aux_grads = [
            g if g is not None else torch.zeros_like(p)
            for g, p in zip(aux_grads, self.shared_params)
        ]
        aux_flat = torch.cat([g.reshape(-1) for g in aux_grads], dim=0)
        shared_flat = combined_main_flat + aux_flat

        # 1c) 写回共享参数 .grad
        cursor = 0
        for p, shp in zip(self.shared_params, shapes):
            n = int(torch.tensor(shp).prod().item())
            chunk = shared_flat[cursor : cursor + n].view(*shp)
            if p.grad is None:
                p.grad = chunk.detach().clone()
            else:
                p.grad = p.grad + chunk.detach()
            cursor += n

        # 2) 非共享参数:调用 backward 走标准路径
        non_shared = [p for p in all_trainable_params if id(p) not in self.shared_set]
        if non_shared:
            total_for_ns = weighted_main.sum() + loss_aux
            grads_ns = torch.autograd.grad(
                total_for_ns,
                non_shared,
                retain_graph=False,
                allow_unused=True,
            )
            for p, g in zip(non_shared, grads_ns):
                if g is None:
                    continue
                if p.grad is None:
                    p.grad = g.detach().clone()
                else:
                    p.grad = p.grad + g.detach()

        return (weighted_main.sum().detach() + loss_aux.detach()), weights