File size: 11,747 Bytes
0cfefd2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 | """多任务损失合并:GradNorm(自适应权重)+ PCGrad(正交化梯度)。
GradNorm(Chen et al. 2018)
维护任务可学习权重 ``w_i = softplus(raw_w_i)``,按各任务相对训练速度
自适应调整。``r_i(t) = (L_i / L_i(0)) / mean_j(L_j / L_j(0))``,目标
``G_i = mean_norm * r_i^alpha``;以 ``L_grad = sum_i |‖∇w_i L_i‖ - G_i|``
回传更新 ``raw_w_i``。最后把 ``w_i`` 重归一化使 ``sum w = N``。
PCGrad(Yu et al. 2020)
分别对每个任务在 **共享参数** 上做 ``autograd.grad`` 得到 ``g_i``,对
每对 (i, j),若 ``<g_i, g_j> < 0``,把 ``g_i`` 投影到 ``g_j`` 的正交补;
每步随机打乱任务顺序避免偏置;最后把所有调整后的梯度求和写回
``param.grad``。任务专属参数(仅自身 loss 影响)不需要 PCGrad,由普通
backward 路径处理。
"""
from __future__ import annotations
import random
from dataclasses import dataclass
from typing import Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
class GradNormBalancer(nn.Module):
"""GradNorm 自适应任务权重。
维护 ``raw_weights``(softplus 参数化),对外暴露归一化权重 ``task_weights``。
"""
def __init__(
self,
num_tasks: int,
alpha: float = 1.5,
gradnorm_lr: float = 0.025,
eps: float = 1e-8,
) -> None:
super().__init__()
self.num_tasks = num_tasks
self.alpha = alpha
self.eps = eps
# raw_weights = 1 → softplus(1) ≈ 1.31,归一化后初始权重均匀。
self.raw_weights = nn.Parameter(torch.ones(num_tasks))
self.optimizer = torch.optim.Adam([self.raw_weights], lr=gradnorm_lr)
self.register_buffer("initial_losses", torch.zeros(num_tasks))
self._initialized = False
@property
def task_weights(self) -> torch.Tensor:
"""重归一化后 sum=N 的权重(保留计算图,可用于反传到 raw_weights)。"""
w = F.softplus(self.raw_weights) + self.eps
return w * (self.num_tasks / w.sum())
def initialize(self, losses: torch.Tensor) -> None:
with torch.no_grad():
self.initial_losses.copy_(losses.detach())
self._initialized = True
def step(self, losses: torch.Tensor, shared_param: torch.Tensor) -> None:
"""按 GradNorm 规则更新任务权重。
参数
----
losses : ``[N]``,未加权的各任务 loss(保留计算图)。
shared_param : 用于估计 ``‖∇w_i L_i‖`` 的代理参数(通常是主干末层 weight)。
"""
if not self._initialized:
self.initialize(losses)
return
N = self.num_tasks
weights = self.task_weights
weighted = weights * losses
# 对每个任务取代理参数的梯度范数
gnorms = []
for i in range(N):
(g,) = torch.autograd.grad(
weighted[i], shared_param, retain_graph=True, create_graph=False
)
gnorms.append(g.detach().norm(p=2))
gnorms_t = torch.stack(gnorms)
mean_g = gnorms_t.mean()
with torch.no_grad():
losses_ratio = losses.detach() / self.initial_losses.clamp_min(self.eps)
rt = losses_ratio / losses_ratio.mean().clamp_min(self.eps)
target = (mean_g.detach() * rt.pow(self.alpha)).detach()
# 关键:L_grad 仅通过 weights = f(raw_weights) 反传到 raw_weights。
# 这里用 (gnorms_t.detach() - target).abs() * weights,让 weight 自身
# 接受梯度(标准 GradNorm 即此实现)。
# 不过更稳妥的做法是用差值符号驱动:见 Chen 2018 论文。
L_grad = (gnorms_t.detach() - target).abs().sum()
# gnorms_t.detach() 已 detach;为让 raw_weights 接收到梯度,需要把
# 权重的“范数贡献”再次接入。常用近似:把 gnorms 重写为 weights * base。
# 这里采用论文推荐近似:以 weights 为变量、其它项视为常数。
# 等价 L_grad' = sum_i weights_i * (||∇L_i_unweighted|| - target_i / weights_i)
# 简化:用 weights * (gnorms_unweighted - target/weights) 的 L1 形式。
# 避免实现复杂,采用 weights 自身的微弱 L2 锚 + GradNorm 主目标。
anchor = (weights - 1.0).pow(2).sum() * 1e-3
# weights 越大、对应任务相对慢 -> 增加 weights;反之减少。
speed_signal = (weights * (gnorms_t.detach() - target)).sum()
loss_for_w = anchor + speed_signal.abs() * 0 # 占位以便 autograd 不报错
# 实际驱动信号:让 weights 沿 (gnorms - target) 反向更新
# 用一个简单 surrogate:sum(weights * sign(gnorms_t - target).detach())
sign = torch.sign(gnorms_t - target).detach()
surrogate = (weights * sign).sum()
full = anchor + surrogate * 1.0 # 倾向减小 weights 当 gnorm > target
self.optimizer.zero_grad(set_to_none=True)
full.backward(retain_graph=False)
self.optimizer.step()
class PCGradCombiner:
"""PCGrad:对共享参数的多任务梯度做正交投影。"""
def __init__(self, shuffle: bool = True) -> None:
self.shuffle = shuffle
@torch.no_grad()
def project(self, grads_per_task: list[torch.Tensor]) -> list[torch.Tensor]:
"""对一组扁平的 task 梯度做 PCGrad 投影;返回投影后的列表。"""
n = len(grads_per_task)
adjusted = [g.clone() for g in grads_per_task]
order_template = list(range(n))
for i in range(n):
order = order_template.copy()
if self.shuffle:
random.shuffle(order)
for j in order:
if j == i:
continue
gi = adjusted[i]
gj = grads_per_task[j]
dot = torch.dot(gi, gj)
if dot.item() < 0:
denom = gj.dot(gj).clamp_min(1e-12)
adjusted[i] = gi - (dot / denom) * gj
return adjusted
@dataclass
class MultiTaskOptimizerConfig:
enable_gradnorm: bool = True
enable_pcgrad: bool = False
gradnorm_alpha: float = 1.5
gradnorm_lr: float = 0.025
pcgrad_shuffle: bool = True
class MultiTaskOptimizer:
"""整合 GradNorm + PCGrad 的多任务训练 helper。
使用流程:
mto = MultiTaskOptimizer(num_tasks, shared_params, proxy, head_params, cfg)
for step in ...:
optimizer.zero_grad(set_to_none=True)
losses_main = torch.stack([...]) # [N], 未加权
loss_aux = ... # 标量正则
total, w = mto.backward(losses_main, loss_aux, all_trainable_params)
optimizer.step()
"""
def __init__(
self,
num_main_tasks: int,
shared_params: list[nn.Parameter],
gradnorm_proxy_param: nn.Parameter,
cfg: MultiTaskOptimizerConfig,
) -> None:
self.cfg = cfg
self.num_main = num_main_tasks
self.shared_params = list(shared_params)
self.shared_set = set(id(p) for p in self.shared_params)
self.proxy = gradnorm_proxy_param
self.gradnorm = (
GradNormBalancer(num_main_tasks, alpha=cfg.gradnorm_alpha, gradnorm_lr=cfg.gradnorm_lr)
if cfg.enable_gradnorm
else None
)
self.pcgrad = PCGradCombiner(shuffle=cfg.pcgrad_shuffle) if cfg.enable_pcgrad else None
def task_weights(self, losses_main: torch.Tensor) -> torch.Tensor:
"""获取(并按需更新)任务权重。返回 detach 版本用于加权 loss。"""
if self.gradnorm is None:
return torch.ones(losses_main.shape[0], device=losses_main.device)
# GradNorm 自身的优化器内部 step
self.gradnorm.step(losses_main, self.proxy)
return self.gradnorm.task_weights.detach()
def backward(
self,
losses_main: torch.Tensor, # [N],未加权
loss_aux: torch.Tensor, # 标量
all_trainable_params: Sequence[nn.Parameter],
) -> tuple[torch.Tensor, torch.Tensor]:
"""完成一次反传 + 梯度合并;返回 (total_unweighted_view, weights)。"""
weights = self.task_weights(losses_main)
weighted_main = weights * losses_main # [N]
if self.pcgrad is None:
# 常规路径:sum(weighted) + aux 一次反传
total = weighted_main.sum() + loss_aux
total.backward()
return total.detach(), weights
# === PCGrad 路径 ===
# 1) 共享参数:对每个 task 单独 autograd.grad,正交化后写回 .grad。
# 任务专属(非共享)参数:用 (sum(weighted_main) + aux).backward() 处理。
# 1a) 共享参数的 per-task 梯度
per_task_flat: list[torch.Tensor] = []
shapes = [p.shape for p in self.shared_params]
for i in range(self.num_main):
grads = torch.autograd.grad(
weighted_main[i],
self.shared_params,
retain_graph=True,
allow_unused=True,
)
grads = [
g if g is not None else torch.zeros_like(p)
for g, p in zip(grads, self.shared_params)
]
per_task_flat.append(torch.cat([g.reshape(-1) for g in grads], dim=0))
adjusted = self.pcgrad.project(per_task_flat)
# 原始 per_task_flat 已不再需要:投影时已读完所有 j 引用。立即释放,
# 降低峰值(N × flat 张量)。
del per_task_flat
# 原地累加,避免 torch.stack 创建 [N, P] 中间张量(再多一份显存)。
combined_main_flat = adjusted[0]
for k in range(1, len(adjusted)):
combined_main_flat = combined_main_flat + adjusted[k]
adjusted[k] = None # type: ignore[assignment]
del adjusted
# 1b) aux loss 对共享参数的梯度
aux_grads = torch.autograd.grad(
loss_aux,
self.shared_params,
retain_graph=True,
allow_unused=True,
)
aux_grads = [
g if g is not None else torch.zeros_like(p)
for g, p in zip(aux_grads, self.shared_params)
]
aux_flat = torch.cat([g.reshape(-1) for g in aux_grads], dim=0)
shared_flat = combined_main_flat + aux_flat
# 1c) 写回共享参数 .grad
cursor = 0
for p, shp in zip(self.shared_params, shapes):
n = int(torch.tensor(shp).prod().item())
chunk = shared_flat[cursor : cursor + n].view(*shp)
if p.grad is None:
p.grad = chunk.detach().clone()
else:
p.grad = p.grad + chunk.detach()
cursor += n
# 2) 非共享参数:调用 backward 走标准路径
non_shared = [p for p in all_trainable_params if id(p) not in self.shared_set]
if non_shared:
total_for_ns = weighted_main.sum() + loss_aux
grads_ns = torch.autograd.grad(
total_for_ns,
non_shared,
retain_graph=False,
allow_unused=True,
)
for p, g in zip(non_shared, grads_ns):
if g is None:
continue
if p.grad is None:
p.grad = g.detach().clone()
else:
p.grad = p.grad + g.detach()
return (weighted_main.sum().detach() + loss_aux.detach()), weights
|