| """多任务损失合并:GradNorm(自适应权重)+ PCGrad(正交化梯度)。 |
| |
| GradNorm(Chen et al. 2018) |
| 维护任务可学习权重 ``w_i = softplus(raw_w_i)``,按各任务相对训练速度 |
| 自适应调整。``r_i(t) = (L_i / L_i(0)) / mean_j(L_j / L_j(0))``,目标 |
| ``G_i = mean_norm * r_i^alpha``;以 ``L_grad = sum_i |‖∇w_i L_i‖ - G_i|`` |
| 回传更新 ``raw_w_i``。最后把 ``w_i`` 重归一化使 ``sum w = N``。 |
| |
| PCGrad(Yu et al. 2020) |
| 分别对每个任务在 **共享参数** 上做 ``autograd.grad`` 得到 ``g_i``,对 |
| 每对 (i, j),若 ``<g_i, g_j> < 0``,把 ``g_i`` 投影到 ``g_j`` 的正交补; |
| 每步随机打乱任务顺序避免偏置;最后把所有调整后的梯度求和写回 |
| ``param.grad``。任务专属参数(仅自身 loss 影响)不需要 PCGrad,由普通 |
| backward 路径处理。 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import random |
| from dataclasses import dataclass |
| from typing import Sequence |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class GradNormBalancer(nn.Module): |
| """GradNorm 自适应任务权重。 |
| |
| 维护 ``raw_weights``(softplus 参数化),对外暴露归一化权重 ``task_weights``。 |
| """ |
|
|
| def __init__( |
| self, |
| num_tasks: int, |
| alpha: float = 1.5, |
| gradnorm_lr: float = 0.025, |
| eps: float = 1e-8, |
| ) -> None: |
| super().__init__() |
| self.num_tasks = num_tasks |
| self.alpha = alpha |
| self.eps = eps |
| |
| self.raw_weights = nn.Parameter(torch.ones(num_tasks)) |
| self.optimizer = torch.optim.Adam([self.raw_weights], lr=gradnorm_lr) |
| self.register_buffer("initial_losses", torch.zeros(num_tasks)) |
| self._initialized = False |
|
|
| @property |
| def task_weights(self) -> torch.Tensor: |
| """重归一化后 sum=N 的权重(保留计算图,可用于反传到 raw_weights)。""" |
| w = F.softplus(self.raw_weights) + self.eps |
| return w * (self.num_tasks / w.sum()) |
|
|
| def initialize(self, losses: torch.Tensor) -> None: |
| with torch.no_grad(): |
| self.initial_losses.copy_(losses.detach()) |
| self._initialized = True |
|
|
| def step(self, losses: torch.Tensor, shared_param: torch.Tensor) -> None: |
| """按 GradNorm 规则更新任务权重。 |
| |
| 参数 |
| ---- |
| losses : ``[N]``,未加权的各任务 loss(保留计算图)。 |
| shared_param : 用于估计 ``‖∇w_i L_i‖`` 的代理参数(通常是主干末层 weight)。 |
| """ |
| if not self._initialized: |
| self.initialize(losses) |
| return |
| N = self.num_tasks |
| weights = self.task_weights |
| weighted = weights * losses |
| |
| gnorms = [] |
| for i in range(N): |
| (g,) = torch.autograd.grad( |
| weighted[i], shared_param, retain_graph=True, create_graph=False |
| ) |
| gnorms.append(g.detach().norm(p=2)) |
| gnorms_t = torch.stack(gnorms) |
| mean_g = gnorms_t.mean() |
|
|
| with torch.no_grad(): |
| losses_ratio = losses.detach() / self.initial_losses.clamp_min(self.eps) |
| rt = losses_ratio / losses_ratio.mean().clamp_min(self.eps) |
| target = (mean_g.detach() * rt.pow(self.alpha)).detach() |
|
|
| |
| |
| |
| |
| L_grad = (gnorms_t.detach() - target).abs().sum() |
| |
| |
| |
| |
| |
| |
| anchor = (weights - 1.0).pow(2).sum() * 1e-3 |
| |
| speed_signal = (weights * (gnorms_t.detach() - target)).sum() |
| loss_for_w = anchor + speed_signal.abs() * 0 |
| |
| |
| sign = torch.sign(gnorms_t - target).detach() |
| surrogate = (weights * sign).sum() |
| full = anchor + surrogate * 1.0 |
|
|
| self.optimizer.zero_grad(set_to_none=True) |
| full.backward(retain_graph=False) |
| self.optimizer.step() |
|
|
|
|
| class PCGradCombiner: |
| """PCGrad:对共享参数的多任务梯度做正交投影。""" |
|
|
| def __init__(self, shuffle: bool = True) -> None: |
| self.shuffle = shuffle |
|
|
| @torch.no_grad() |
| def project(self, grads_per_task: list[torch.Tensor]) -> list[torch.Tensor]: |
| """对一组扁平的 task 梯度做 PCGrad 投影;返回投影后的列表。""" |
| n = len(grads_per_task) |
| adjusted = [g.clone() for g in grads_per_task] |
| order_template = list(range(n)) |
| for i in range(n): |
| order = order_template.copy() |
| if self.shuffle: |
| random.shuffle(order) |
| for j in order: |
| if j == i: |
| continue |
| gi = adjusted[i] |
| gj = grads_per_task[j] |
| dot = torch.dot(gi, gj) |
| if dot.item() < 0: |
| denom = gj.dot(gj).clamp_min(1e-12) |
| adjusted[i] = gi - (dot / denom) * gj |
| return adjusted |
|
|
|
|
| @dataclass |
| class MultiTaskOptimizerConfig: |
| enable_gradnorm: bool = True |
| enable_pcgrad: bool = False |
| gradnorm_alpha: float = 1.5 |
| gradnorm_lr: float = 0.025 |
| pcgrad_shuffle: bool = True |
|
|
|
|
| class MultiTaskOptimizer: |
| """整合 GradNorm + PCGrad 的多任务训练 helper。 |
| |
| 使用流程: |
| mto = MultiTaskOptimizer(num_tasks, shared_params, proxy, head_params, cfg) |
| for step in ...: |
| optimizer.zero_grad(set_to_none=True) |
| losses_main = torch.stack([...]) # [N], 未加权 |
| loss_aux = ... # 标量正则 |
| total, w = mto.backward(losses_main, loss_aux, all_trainable_params) |
| optimizer.step() |
| """ |
|
|
| def __init__( |
| self, |
| num_main_tasks: int, |
| shared_params: list[nn.Parameter], |
| gradnorm_proxy_param: nn.Parameter, |
| cfg: MultiTaskOptimizerConfig, |
| ) -> None: |
| self.cfg = cfg |
| self.num_main = num_main_tasks |
| self.shared_params = list(shared_params) |
| self.shared_set = set(id(p) for p in self.shared_params) |
| self.proxy = gradnorm_proxy_param |
| self.gradnorm = ( |
| GradNormBalancer(num_main_tasks, alpha=cfg.gradnorm_alpha, gradnorm_lr=cfg.gradnorm_lr) |
| if cfg.enable_gradnorm |
| else None |
| ) |
| self.pcgrad = PCGradCombiner(shuffle=cfg.pcgrad_shuffle) if cfg.enable_pcgrad else None |
|
|
| def task_weights(self, losses_main: torch.Tensor) -> torch.Tensor: |
| """获取(并按需更新)任务权重。返回 detach 版本用于加权 loss。""" |
| if self.gradnorm is None: |
| return torch.ones(losses_main.shape[0], device=losses_main.device) |
| |
| self.gradnorm.step(losses_main, self.proxy) |
| return self.gradnorm.task_weights.detach() |
|
|
| def backward( |
| self, |
| losses_main: torch.Tensor, |
| loss_aux: torch.Tensor, |
| all_trainable_params: Sequence[nn.Parameter], |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """完成一次反传 + 梯度合并;返回 (total_unweighted_view, weights)。""" |
| weights = self.task_weights(losses_main) |
| weighted_main = weights * losses_main |
|
|
| if self.pcgrad is None: |
| |
| total = weighted_main.sum() + loss_aux |
| total.backward() |
| return total.detach(), weights |
|
|
| |
| |
| |
|
|
| |
| per_task_flat: list[torch.Tensor] = [] |
| shapes = [p.shape for p in self.shared_params] |
| for i in range(self.num_main): |
| grads = torch.autograd.grad( |
| weighted_main[i], |
| self.shared_params, |
| retain_graph=True, |
| allow_unused=True, |
| ) |
| grads = [ |
| g if g is not None else torch.zeros_like(p) |
| for g, p in zip(grads, self.shared_params) |
| ] |
| per_task_flat.append(torch.cat([g.reshape(-1) for g in grads], dim=0)) |
|
|
| adjusted = self.pcgrad.project(per_task_flat) |
| |
| |
| del per_task_flat |
| |
| combined_main_flat = adjusted[0] |
| for k in range(1, len(adjusted)): |
| combined_main_flat = combined_main_flat + adjusted[k] |
| adjusted[k] = None |
| del adjusted |
|
|
| |
| aux_grads = torch.autograd.grad( |
| loss_aux, |
| self.shared_params, |
| retain_graph=True, |
| allow_unused=True, |
| ) |
| aux_grads = [ |
| g if g is not None else torch.zeros_like(p) |
| for g, p in zip(aux_grads, self.shared_params) |
| ] |
| aux_flat = torch.cat([g.reshape(-1) for g in aux_grads], dim=0) |
| shared_flat = combined_main_flat + aux_flat |
|
|
| |
| cursor = 0 |
| for p, shp in zip(self.shared_params, shapes): |
| n = int(torch.tensor(shp).prod().item()) |
| chunk = shared_flat[cursor : cursor + n].view(*shp) |
| if p.grad is None: |
| p.grad = chunk.detach().clone() |
| else: |
| p.grad = p.grad + chunk.detach() |
| cursor += n |
|
|
| |
| non_shared = [p for p in all_trainable_params if id(p) not in self.shared_set] |
| if non_shared: |
| total_for_ns = weighted_main.sum() + loss_aux |
| grads_ns = torch.autograd.grad( |
| total_for_ns, |
| non_shared, |
| retain_graph=False, |
| allow_unused=True, |
| ) |
| for p, g in zip(non_shared, grads_ns): |
| if g is None: |
| continue |
| if p.grad is None: |
| p.grad = g.detach().clone() |
| else: |
| p.grad = p.grad + g.detach() |
|
|
| return (weighted_main.sum().detach() + loss_aux.detach()), weights |
|
|