WJAD / src /wjad /train /multitask.py
fuzirui's picture
Sync WJAD codebase
0cfefd2 verified
"""多任务损失合并:GradNorm(自适应权重)+ PCGrad(正交化梯度)。
GradNorm(Chen et al. 2018)
维护任务可学习权重 ``w_i = softplus(raw_w_i)``,按各任务相对训练速度
自适应调整。``r_i(t) = (L_i / L_i(0)) / mean_j(L_j / L_j(0))``,目标
``G_i = mean_norm * r_i^alpha``;以 ``L_grad = sum_i |‖∇w_i L_i‖ - G_i|``
回传更新 ``raw_w_i``。最后把 ``w_i`` 重归一化使 ``sum w = N``。
PCGrad(Yu et al. 2020)
分别对每个任务在 **共享参数** 上做 ``autograd.grad`` 得到 ``g_i``,对
每对 (i, j),若 ``<g_i, g_j> < 0``,把 ``g_i`` 投影到 ``g_j`` 的正交补;
每步随机打乱任务顺序避免偏置;最后把所有调整后的梯度求和写回
``param.grad``。任务专属参数(仅自身 loss 影响)不需要 PCGrad,由普通
backward 路径处理。
"""
from __future__ import annotations
import random
from dataclasses import dataclass
from typing import Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
class GradNormBalancer(nn.Module):
"""GradNorm 自适应任务权重。
维护 ``raw_weights``(softplus 参数化),对外暴露归一化权重 ``task_weights``。
"""
def __init__(
self,
num_tasks: int,
alpha: float = 1.5,
gradnorm_lr: float = 0.025,
eps: float = 1e-8,
) -> None:
super().__init__()
self.num_tasks = num_tasks
self.alpha = alpha
self.eps = eps
# raw_weights = 1 → softplus(1) ≈ 1.31,归一化后初始权重均匀。
self.raw_weights = nn.Parameter(torch.ones(num_tasks))
self.optimizer = torch.optim.Adam([self.raw_weights], lr=gradnorm_lr)
self.register_buffer("initial_losses", torch.zeros(num_tasks))
self._initialized = False
@property
def task_weights(self) -> torch.Tensor:
"""重归一化后 sum=N 的权重(保留计算图,可用于反传到 raw_weights)。"""
w = F.softplus(self.raw_weights) + self.eps
return w * (self.num_tasks / w.sum())
def initialize(self, losses: torch.Tensor) -> None:
with torch.no_grad():
self.initial_losses.copy_(losses.detach())
self._initialized = True
def step(self, losses: torch.Tensor, shared_param: torch.Tensor) -> None:
"""按 GradNorm 规则更新任务权重。
参数
----
losses : ``[N]``,未加权的各任务 loss(保留计算图)。
shared_param : 用于估计 ``‖∇w_i L_i‖`` 的代理参数(通常是主干末层 weight)。
"""
if not self._initialized:
self.initialize(losses)
return
N = self.num_tasks
weights = self.task_weights
weighted = weights * losses
# 对每个任务取代理参数的梯度范数
gnorms = []
for i in range(N):
(g,) = torch.autograd.grad(
weighted[i], shared_param, retain_graph=True, create_graph=False
)
gnorms.append(g.detach().norm(p=2))
gnorms_t = torch.stack(gnorms)
mean_g = gnorms_t.mean()
with torch.no_grad():
losses_ratio = losses.detach() / self.initial_losses.clamp_min(self.eps)
rt = losses_ratio / losses_ratio.mean().clamp_min(self.eps)
target = (mean_g.detach() * rt.pow(self.alpha)).detach()
# 关键:L_grad 仅通过 weights = f(raw_weights) 反传到 raw_weights。
# 这里用 (gnorms_t.detach() - target).abs() * weights,让 weight 自身
# 接受梯度(标准 GradNorm 即此实现)。
# 不过更稳妥的做法是用差值符号驱动:见 Chen 2018 论文。
L_grad = (gnorms_t.detach() - target).abs().sum()
# gnorms_t.detach() 已 detach;为让 raw_weights 接收到梯度,需要把
# 权重的“范数贡献”再次接入。常用近似:把 gnorms 重写为 weights * base。
# 这里采用论文推荐近似:以 weights 为变量、其它项视为常数。
# 等价 L_grad' = sum_i weights_i * (||∇L_i_unweighted|| - target_i / weights_i)
# 简化:用 weights * (gnorms_unweighted - target/weights) 的 L1 形式。
# 避免实现复杂,采用 weights 自身的微弱 L2 锚 + GradNorm 主目标。
anchor = (weights - 1.0).pow(2).sum() * 1e-3
# weights 越大、对应任务相对慢 -> 增加 weights;反之减少。
speed_signal = (weights * (gnorms_t.detach() - target)).sum()
loss_for_w = anchor + speed_signal.abs() * 0 # 占位以便 autograd 不报错
# 实际驱动信号:让 weights 沿 (gnorms - target) 反向更新
# 用一个简单 surrogate:sum(weights * sign(gnorms_t - target).detach())
sign = torch.sign(gnorms_t - target).detach()
surrogate = (weights * sign).sum()
full = anchor + surrogate * 1.0 # 倾向减小 weights 当 gnorm > target
self.optimizer.zero_grad(set_to_none=True)
full.backward(retain_graph=False)
self.optimizer.step()
class PCGradCombiner:
"""PCGrad:对共享参数的多任务梯度做正交投影。"""
def __init__(self, shuffle: bool = True) -> None:
self.shuffle = shuffle
@torch.no_grad()
def project(self, grads_per_task: list[torch.Tensor]) -> list[torch.Tensor]:
"""对一组扁平的 task 梯度做 PCGrad 投影;返回投影后的列表。"""
n = len(grads_per_task)
adjusted = [g.clone() for g in grads_per_task]
order_template = list(range(n))
for i in range(n):
order = order_template.copy()
if self.shuffle:
random.shuffle(order)
for j in order:
if j == i:
continue
gi = adjusted[i]
gj = grads_per_task[j]
dot = torch.dot(gi, gj)
if dot.item() < 0:
denom = gj.dot(gj).clamp_min(1e-12)
adjusted[i] = gi - (dot / denom) * gj
return adjusted
@dataclass
class MultiTaskOptimizerConfig:
enable_gradnorm: bool = True
enable_pcgrad: bool = False
gradnorm_alpha: float = 1.5
gradnorm_lr: float = 0.025
pcgrad_shuffle: bool = True
class MultiTaskOptimizer:
"""整合 GradNorm + PCGrad 的多任务训练 helper。
使用流程:
mto = MultiTaskOptimizer(num_tasks, shared_params, proxy, head_params, cfg)
for step in ...:
optimizer.zero_grad(set_to_none=True)
losses_main = torch.stack([...]) # [N], 未加权
loss_aux = ... # 标量正则
total, w = mto.backward(losses_main, loss_aux, all_trainable_params)
optimizer.step()
"""
def __init__(
self,
num_main_tasks: int,
shared_params: list[nn.Parameter],
gradnorm_proxy_param: nn.Parameter,
cfg: MultiTaskOptimizerConfig,
) -> None:
self.cfg = cfg
self.num_main = num_main_tasks
self.shared_params = list(shared_params)
self.shared_set = set(id(p) for p in self.shared_params)
self.proxy = gradnorm_proxy_param
self.gradnorm = (
GradNormBalancer(num_main_tasks, alpha=cfg.gradnorm_alpha, gradnorm_lr=cfg.gradnorm_lr)
if cfg.enable_gradnorm
else None
)
self.pcgrad = PCGradCombiner(shuffle=cfg.pcgrad_shuffle) if cfg.enable_pcgrad else None
def task_weights(self, losses_main: torch.Tensor) -> torch.Tensor:
"""获取(并按需更新)任务权重。返回 detach 版本用于加权 loss。"""
if self.gradnorm is None:
return torch.ones(losses_main.shape[0], device=losses_main.device)
# GradNorm 自身的优化器内部 step
self.gradnorm.step(losses_main, self.proxy)
return self.gradnorm.task_weights.detach()
def backward(
self,
losses_main: torch.Tensor, # [N],未加权
loss_aux: torch.Tensor, # 标量
all_trainable_params: Sequence[nn.Parameter],
) -> tuple[torch.Tensor, torch.Tensor]:
"""完成一次反传 + 梯度合并;返回 (total_unweighted_view, weights)。"""
weights = self.task_weights(losses_main)
weighted_main = weights * losses_main # [N]
if self.pcgrad is None:
# 常规路径:sum(weighted) + aux 一次反传
total = weighted_main.sum() + loss_aux
total.backward()
return total.detach(), weights
# === PCGrad 路径 ===
# 1) 共享参数:对每个 task 单独 autograd.grad,正交化后写回 .grad。
# 任务专属(非共享)参数:用 (sum(weighted_main) + aux).backward() 处理。
# 1a) 共享参数的 per-task 梯度
per_task_flat: list[torch.Tensor] = []
shapes = [p.shape for p in self.shared_params]
for i in range(self.num_main):
grads = torch.autograd.grad(
weighted_main[i],
self.shared_params,
retain_graph=True,
allow_unused=True,
)
grads = [
g if g is not None else torch.zeros_like(p)
for g, p in zip(grads, self.shared_params)
]
per_task_flat.append(torch.cat([g.reshape(-1) for g in grads], dim=0))
adjusted = self.pcgrad.project(per_task_flat)
# 原始 per_task_flat 已不再需要:投影时已读完所有 j 引用。立即释放,
# 降低峰值(N × flat 张量)。
del per_task_flat
# 原地累加,避免 torch.stack 创建 [N, P] 中间张量(再多一份显存)。
combined_main_flat = adjusted[0]
for k in range(1, len(adjusted)):
combined_main_flat = combined_main_flat + adjusted[k]
adjusted[k] = None # type: ignore[assignment]
del adjusted
# 1b) aux loss 对共享参数的梯度
aux_grads = torch.autograd.grad(
loss_aux,
self.shared_params,
retain_graph=True,
allow_unused=True,
)
aux_grads = [
g if g is not None else torch.zeros_like(p)
for g, p in zip(aux_grads, self.shared_params)
]
aux_flat = torch.cat([g.reshape(-1) for g in aux_grads], dim=0)
shared_flat = combined_main_flat + aux_flat
# 1c) 写回共享参数 .grad
cursor = 0
for p, shp in zip(self.shared_params, shapes):
n = int(torch.tensor(shp).prod().item())
chunk = shared_flat[cursor : cursor + n].view(*shp)
if p.grad is None:
p.grad = chunk.detach().clone()
else:
p.grad = p.grad + chunk.detach()
cursor += n
# 2) 非共享参数:调用 backward 走标准路径
non_shared = [p for p in all_trainable_params if id(p) not in self.shared_set]
if non_shared:
total_for_ns = weighted_main.sum() + loss_aux
grads_ns = torch.autograd.grad(
total_for_ns,
non_shared,
retain_graph=False,
allow_unused=True,
)
for p, g in zip(non_shared, grads_ns):
if g is None:
continue
if p.grad is None:
p.grad = g.detach().clone()
else:
p.grad = p.grad + g.detach()
return (weighted_main.sum().detach() + loss_aux.detach()), weights