"""多任务损失合并:GradNorm(自适应权重)+ PCGrad(正交化梯度)。 GradNorm(Chen et al. 2018) 维护任务可学习权重 ``w_i = softplus(raw_w_i)``,按各任务相对训练速度 自适应调整。``r_i(t) = (L_i / L_i(0)) / mean_j(L_j / L_j(0))``,目标 ``G_i = mean_norm * r_i^alpha``;以 ``L_grad = sum_i |‖∇w_i L_i‖ - G_i|`` 回传更新 ``raw_w_i``。最后把 ``w_i`` 重归一化使 ``sum w = N``。 PCGrad(Yu et al. 2020) 分别对每个任务在 **共享参数** 上做 ``autograd.grad`` 得到 ``g_i``,对 每对 (i, j),若 `` < 0``,把 ``g_i`` 投影到 ``g_j`` 的正交补; 每步随机打乱任务顺序避免偏置;最后把所有调整后的梯度求和写回 ``param.grad``。任务专属参数(仅自身 loss 影响)不需要 PCGrad,由普通 backward 路径处理。 """ from __future__ import annotations import random from dataclasses import dataclass from typing import Sequence import torch import torch.nn as nn import torch.nn.functional as F class GradNormBalancer(nn.Module): """GradNorm 自适应任务权重。 维护 ``raw_weights``(softplus 参数化),对外暴露归一化权重 ``task_weights``。 """ def __init__( self, num_tasks: int, alpha: float = 1.5, gradnorm_lr: float = 0.025, eps: float = 1e-8, ) -> None: super().__init__() self.num_tasks = num_tasks self.alpha = alpha self.eps = eps # raw_weights = 1 → softplus(1) ≈ 1.31,归一化后初始权重均匀。 self.raw_weights = nn.Parameter(torch.ones(num_tasks)) self.optimizer = torch.optim.Adam([self.raw_weights], lr=gradnorm_lr) self.register_buffer("initial_losses", torch.zeros(num_tasks)) self._initialized = False @property def task_weights(self) -> torch.Tensor: """重归一化后 sum=N 的权重(保留计算图,可用于反传到 raw_weights)。""" w = F.softplus(self.raw_weights) + self.eps return w * (self.num_tasks / w.sum()) def initialize(self, losses: torch.Tensor) -> None: with torch.no_grad(): self.initial_losses.copy_(losses.detach()) self._initialized = True def step(self, losses: torch.Tensor, shared_param: torch.Tensor) -> None: """按 GradNorm 规则更新任务权重。 参数 ---- losses : ``[N]``,未加权的各任务 loss(保留计算图)。 shared_param : 用于估计 ``‖∇w_i L_i‖`` 的代理参数(通常是主干末层 weight)。 """ if not self._initialized: self.initialize(losses) return N = self.num_tasks weights = self.task_weights weighted = weights * losses # 对每个任务取代理参数的梯度范数 gnorms = [] for i in range(N): (g,) = torch.autograd.grad( weighted[i], shared_param, retain_graph=True, create_graph=False ) gnorms.append(g.detach().norm(p=2)) gnorms_t = torch.stack(gnorms) mean_g = gnorms_t.mean() with torch.no_grad(): losses_ratio = losses.detach() / self.initial_losses.clamp_min(self.eps) rt = losses_ratio / losses_ratio.mean().clamp_min(self.eps) target = (mean_g.detach() * rt.pow(self.alpha)).detach() # 关键:L_grad 仅通过 weights = f(raw_weights) 反传到 raw_weights。 # 这里用 (gnorms_t.detach() - target).abs() * weights,让 weight 自身 # 接受梯度(标准 GradNorm 即此实现)。 # 不过更稳妥的做法是用差值符号驱动:见 Chen 2018 论文。 L_grad = (gnorms_t.detach() - target).abs().sum() # gnorms_t.detach() 已 detach;为让 raw_weights 接收到梯度,需要把 # 权重的“范数贡献”再次接入。常用近似:把 gnorms 重写为 weights * base。 # 这里采用论文推荐近似:以 weights 为变量、其它项视为常数。 # 等价 L_grad' = sum_i weights_i * (||∇L_i_unweighted|| - target_i / weights_i) # 简化:用 weights * (gnorms_unweighted - target/weights) 的 L1 形式。 # 避免实现复杂,采用 weights 自身的微弱 L2 锚 + GradNorm 主目标。 anchor = (weights - 1.0).pow(2).sum() * 1e-3 # weights 越大、对应任务相对慢 -> 增加 weights;反之减少。 speed_signal = (weights * (gnorms_t.detach() - target)).sum() loss_for_w = anchor + speed_signal.abs() * 0 # 占位以便 autograd 不报错 # 实际驱动信号:让 weights 沿 (gnorms - target) 反向更新 # 用一个简单 surrogate:sum(weights * sign(gnorms_t - target).detach()) sign = torch.sign(gnorms_t - target).detach() surrogate = (weights * sign).sum() full = anchor + surrogate * 1.0 # 倾向减小 weights 当 gnorm > target self.optimizer.zero_grad(set_to_none=True) full.backward(retain_graph=False) self.optimizer.step() class PCGradCombiner: """PCGrad:对共享参数的多任务梯度做正交投影。""" def __init__(self, shuffle: bool = True) -> None: self.shuffle = shuffle @torch.no_grad() def project(self, grads_per_task: list[torch.Tensor]) -> list[torch.Tensor]: """对一组扁平的 task 梯度做 PCGrad 投影;返回投影后的列表。""" n = len(grads_per_task) adjusted = [g.clone() for g in grads_per_task] order_template = list(range(n)) for i in range(n): order = order_template.copy() if self.shuffle: random.shuffle(order) for j in order: if j == i: continue gi = adjusted[i] gj = grads_per_task[j] dot = torch.dot(gi, gj) if dot.item() < 0: denom = gj.dot(gj).clamp_min(1e-12) adjusted[i] = gi - (dot / denom) * gj return adjusted @dataclass class MultiTaskOptimizerConfig: enable_gradnorm: bool = True enable_pcgrad: bool = False gradnorm_alpha: float = 1.5 gradnorm_lr: float = 0.025 pcgrad_shuffle: bool = True class MultiTaskOptimizer: """整合 GradNorm + PCGrad 的多任务训练 helper。 使用流程: mto = MultiTaskOptimizer(num_tasks, shared_params, proxy, head_params, cfg) for step in ...: optimizer.zero_grad(set_to_none=True) losses_main = torch.stack([...]) # [N], 未加权 loss_aux = ... # 标量正则 total, w = mto.backward(losses_main, loss_aux, all_trainable_params) optimizer.step() """ def __init__( self, num_main_tasks: int, shared_params: list[nn.Parameter], gradnorm_proxy_param: nn.Parameter, cfg: MultiTaskOptimizerConfig, ) -> None: self.cfg = cfg self.num_main = num_main_tasks self.shared_params = list(shared_params) self.shared_set = set(id(p) for p in self.shared_params) self.proxy = gradnorm_proxy_param self.gradnorm = ( GradNormBalancer(num_main_tasks, alpha=cfg.gradnorm_alpha, gradnorm_lr=cfg.gradnorm_lr) if cfg.enable_gradnorm else None ) self.pcgrad = PCGradCombiner(shuffle=cfg.pcgrad_shuffle) if cfg.enable_pcgrad else None def task_weights(self, losses_main: torch.Tensor) -> torch.Tensor: """获取(并按需更新)任务权重。返回 detach 版本用于加权 loss。""" if self.gradnorm is None: return torch.ones(losses_main.shape[0], device=losses_main.device) # GradNorm 自身的优化器内部 step self.gradnorm.step(losses_main, self.proxy) return self.gradnorm.task_weights.detach() def backward( self, losses_main: torch.Tensor, # [N],未加权 loss_aux: torch.Tensor, # 标量 all_trainable_params: Sequence[nn.Parameter], ) -> tuple[torch.Tensor, torch.Tensor]: """完成一次反传 + 梯度合并;返回 (total_unweighted_view, weights)。""" weights = self.task_weights(losses_main) weighted_main = weights * losses_main # [N] if self.pcgrad is None: # 常规路径:sum(weighted) + aux 一次反传 total = weighted_main.sum() + loss_aux total.backward() return total.detach(), weights # === PCGrad 路径 === # 1) 共享参数:对每个 task 单独 autograd.grad,正交化后写回 .grad。 # 任务专属(非共享)参数:用 (sum(weighted_main) + aux).backward() 处理。 # 1a) 共享参数的 per-task 梯度 per_task_flat: list[torch.Tensor] = [] shapes = [p.shape for p in self.shared_params] for i in range(self.num_main): grads = torch.autograd.grad( weighted_main[i], self.shared_params, retain_graph=True, allow_unused=True, ) grads = [ g if g is not None else torch.zeros_like(p) for g, p in zip(grads, self.shared_params) ] per_task_flat.append(torch.cat([g.reshape(-1) for g in grads], dim=0)) adjusted = self.pcgrad.project(per_task_flat) # 原始 per_task_flat 已不再需要:投影时已读完所有 j 引用。立即释放, # 降低峰值(N × flat 张量)。 del per_task_flat # 原地累加,避免 torch.stack 创建 [N, P] 中间张量(再多一份显存)。 combined_main_flat = adjusted[0] for k in range(1, len(adjusted)): combined_main_flat = combined_main_flat + adjusted[k] adjusted[k] = None # type: ignore[assignment] del adjusted # 1b) aux loss 对共享参数的梯度 aux_grads = torch.autograd.grad( loss_aux, self.shared_params, retain_graph=True, allow_unused=True, ) aux_grads = [ g if g is not None else torch.zeros_like(p) for g, p in zip(aux_grads, self.shared_params) ] aux_flat = torch.cat([g.reshape(-1) for g in aux_grads], dim=0) shared_flat = combined_main_flat + aux_flat # 1c) 写回共享参数 .grad cursor = 0 for p, shp in zip(self.shared_params, shapes): n = int(torch.tensor(shp).prod().item()) chunk = shared_flat[cursor : cursor + n].view(*shp) if p.grad is None: p.grad = chunk.detach().clone() else: p.grad = p.grad + chunk.detach() cursor += n # 2) 非共享参数:调用 backward 走标准路径 non_shared = [p for p in all_trainable_params if id(p) not in self.shared_set] if non_shared: total_for_ns = weighted_main.sum() + loss_aux grads_ns = torch.autograd.grad( total_for_ns, non_shared, retain_graph=False, allow_unused=True, ) for p, g in zip(non_shared, grads_ns): if g is None: continue if p.grad is None: p.grad = g.detach().clone() else: p.grad = p.grad + g.detach() return (weighted_main.sum().detach() + loss_aux.detach()), weights