UniBioTransfer / multiTask_model.py
scy639's picture
Upload folder using huggingface_hub
2b534de verified
from ldm.modules.attention import *
import global_
import torch
import torch.nn as nn
import torch.nn.functional as F
from my_py_lib.torch_util import custom_repr_v3
from confs import *
import cv2, numpy as np
from lmk_util.lmk_extractor import lmkAll_2_lmkMain, get_lmkMain_indices
from MoE import *
from lora_layers import *
import json
import copy
"""
Global knobs for shared experts and routing (no argparse per user preference)
"""
NUM_SHARED_FFN = 8
GATE_TOPK = 2
# Sparse MoE FFN for all FFN blocks (in addition to shared orig + task LoRA)
# default off to keep behavior unchanged; enable by setting EXTRA_MoE_enable to True
EXTRA_MoE_enable :bool = 1
EXTRA_MoE_num_ep = 8 # number of sparse MoE experts (narrow FFN)
EXTRA_MoE_inner_divisor = 64 # each expert intermediate dim = original FFN intermediate dim * this ratio
EXTRA_MoE_topK = 2 # sparse routing selects top-k experts (k fixed to 2)
EXTRA_MoE_add_noise :bool = 1 # add random noise to routing scores for exploration
EXTRA_MoE_noise_std = 0.1 # noise strength (Gaussian standard deviation)
EXTRA_MoE_en_auxLoss :bool = 0 # compute load-balancing auxiliary loss
EXTRA_MoE_aux_coef = 1e-2 # coefficient for auxiliary loss when adding to total loss
EXTRA_MoE_routing_mode = 'sparse' # 'sparse' | 'dense'
LMK_PICK_IDX = None
NUM_lmk_pick = len(LMK_PICK_IDX) if LMK_PICK_IDX is not None else len(get_lmkMain_indices(include_face_oval=True))
print(f"{NUM_lmk_pick=}")
IMAGE_SIZE_FOR_LMK_NORM = 512.0
def _log2(orig_modules, lora_modules):
"""Calculate and log parameter statistics for original and LoRA modules"""
# Calculate original module stats
orig_params = sum(p.numel() for p in orig_modules.parameters())
orig_size = sum(p.numel() * p.element_size() for p in orig_modules.parameters())
# Calculate LoRA stats (handle both single module and tuple/list)
if isinstance(lora_modules, (list, tuple)):
lora_params = sum(p.numel() for m in lora_modules for p in m.parameters())
lora_size = sum(p.numel() * p.element_size() for m in lora_modules for p in m.parameters())
# Try to get rank from lora modules
ranks = []
for m in lora_modules:
if hasattr(m, 'rank'):
ranks.append(m.rank)
if len(ranks) == 2:
rank_str = f" (rank_in={ranks[0]} rank_out={ranks[1]})"
elif len(ranks) == 1:
rank_str = f" (rank={ranks[0]})"
else:
rank_str = ""
else:
lora_params = sum(p.numel() for p in lora_modules.parameters())
lora_size = sum(p.numel() * p.element_size() for p in lora_modules.parameters())
# Try to get rank from lora module
if hasattr(lora_modules, 'rank'):
rank_str = f" (rank={lora_modules.rank})"
else:
rank_str = ""
msg1 = f"orig: {orig_params:,} params, {orig_size/1024/1024:.2f}MB"
msg2 = f"LoRA: {lora_params:,} params, {lora_size/1024/1024:.2f}MB{rank_str}"
for msg in [msg1, msg2]:
print(msg)
continue
with open(_verify_log_file, 'a') as f:
f.write(msg + '\n')
def _log1(msg: str):
"""Print message and append to log file"""
print(msg)
return
with open(_verify_log_file, 'a') as f:
f.write(msg + '\n')
def build_ffn_gate_input_common(x: torch.Tensor, token_pos_grid__cur, tasks: list):
"""Build gate input for FFN routing (reusable across FFN classes)."""
b, n, d = x.shape
token_feat = x # token
avg_feat = x.mean(dim=1, keepdim=True).expand(-1, n, -1) # avg(all tokens)
len_task = len(tasks) # task one-hot
task_1h = x.new_zeros(b, len_task)
task_1h[:, global_.task] = 1
task_1h = task_1h.unsqueeze(1).expand(-1, n, -1)
token_pos = token_pos_grid__cur # token-position from global_.token_pos_grid__cur
assert token_pos.shape[:2] == (b, n), (token_pos.shape, (b, n), )
rel_flat = x.new_zeros(b, n, 2*NUM_lmk_pick) # token-relative-position to lmks
lmk = global_.lmk_
if 1:
lmk = lmk.to(x.device).float()# TODO to check is it normed already?
if LMK_PICK_IDX is None:
assert NUM_lmk_pick==lmk.shape[1]
else:
lmk = lmk[:, LMK_PICK_IDX, :]
rel = token_pos.unsqueeze(2) - lmk.unsqueeze(1) # [b,n,L,2]
rel_flat = rel.reshape(b, n, -1)
gate_in = torch.cat([token_feat, avg_feat, task_1h, token_pos, rel_flat], dim=-1)
ctx = {'token_feat': token_feat, 'avg_feat': avg_feat, 'task_1h': task_1h, 'token_pos': token_pos, 'lmk': lmk, 'rel': rel, 'rel_flat': rel_flat}
return gate_in, ctx
def replace_modules_lossless(
module: nn.Module,
src_modules: list,
l_task: list,
parent_name: str = "",
depth :int = 0,
for_refnet: bool = False,
):
"""
Apply policy:
- FFN: shared-plus-task (lossless upcycle)
- CrossAttention linear projections (to_q, to_k, to_v, to_out.0): shared-plus-task
- Conv2d: keep task-specific or wrap with shared-plus-task if desired
- Norms: keep task-specific (LayerNorm/GroupNorm)
"""
if depth==0:
CONV2D_PARAM_STATS.clear()
# Skip modules with no parameters
if len(list(module.parameters())) == 0:
# print(f'[replace_modules_lossless] Skipping module with no parameters: {module}')
return module
if len(list(module.named_children()))==0:
print('\n!!!! len(list(module.named_children()))==0',module)
assert 0
for name, child in module.named_children():
full_name = f"{parent_name}.{name}" if parent_name else f".{name}"
src_child_modules = [getattr(src_module, name) for src_module in src_modules]
if len({id(s) for s in src_child_modules}) < len(src_child_modules):
raise Exception('Duplicate source modules detected!')
# if sources are the same instance(s), clone to ensure distinct expert modules
src_child_modules = [copy.deepcopy(src_child_modules[0]) for _ in src_child_modules]
if isinstance(child, FeedForward):
if 0:
setattr(module, name, TaskSpecific_MoE([s for s in src_child_modules], tasks=l_task))
else:
# FFN -> shared average + per-task LoRA
setattr(module, name, upCycle_module(src_child_modules, l_task, module_name=full_name))
continue
if isinstance(child, CrossAttention):
# replace linear projections
# if for_refnet:
if 0:
for proj_name in ["to_q", "to_k", "to_v"]:
src_proj_list = [getattr(s, proj_name) for s in src_child_modules]
setattr(child, proj_name, upCycle_module(src_proj_list, l_task, module_name=f"{full_name}.{proj_name}"))
if hasattr(child.to_out, "__getitem__"):
src_linear0 = [s.to_out[0] for s in src_child_modules]
child.to_out[0] = upCycle_module(src_linear0, l_task, module_name=f"{full_name}.to_out.0")
else:
for proj_name in ["to_q", "to_k", "to_v"]:
src_proj_list = [getattr(s, proj_name) for s in src_child_modules]
setattr(child, proj_name, TaskSpecific_MoE([s for s in src_proj_list], tasks=l_task) )
if hasattr(child.to_out, "__getitem__"):
src_linear0 = [s.to_out[0] for s in src_child_modules]
child.to_out[0] = TaskSpecific_MoE([s for s in src_linear0], tasks=l_task)
continue
if isinstance(child, nn.Conv2d):
num_params = sum(p.numel() for p in child.parameters())
CONV2D_PARAM_STATS.append((num_params, full_name))
# if num_params > CONV2D_PARAM_MOE_THRES and (not any(full_name.startswith(p) for p in FORCE_TASKSPEC_PREFIXES)):
if 1:
printC(f"shared+LoRA Conv2d",f"{full_name}")
setattr(module, name, upCycle_module(src_child_modules, l_task, module_name=full_name))
else:
setattr(module, name, TaskSpecific_MoE([s for s in src_child_modules], tasks=l_task))
continue
elif isinstance(child, (nn.LayerNorm, nn.GroupNorm)):
setattr(module, name, TaskSpecific_MoE([s for s in src_child_modules], tasks=l_task))
continue
elif isinstance(child, nn.Linear):
# default linear: task-specific
setattr(module, name, TaskSpecific_MoE([s for s in src_child_modules], tasks=l_task))
continue
else:
replace_modules_lossless(child, src_child_modules, l_task, parent_name=full_name, depth=depth+1, for_refnet=for_refnet)
if depth==0:
stats_sorted = sorted(CONV2D_PARAM_STATS, key=lambda x: x[0], reverse=True)
if gate_("[Conv2d param stats] count, name (sorted desc):"):
for cnt, n in stats_sorted:
print(f" {cnt:12d} {n}")
return module
def upCycle_module(l_modules, l_task, module_name: str = None):
assert len( set( [type(m) for m in l_modules] ) ) == 1
m0 = l_modules[0]
if isinstance(m0, FeedForward):
obj = FFN_Shared_Plus_TaskLoRA(l_modules, l_task, module_name=module_name)
elif isinstance(m0, nn.Linear):
obj = Linear_Shared_Plus_TaskLoRA(l_modules, l_task, module_name=module_name)
elif isinstance(m0, nn.Conv2d):
obj = Conv_Shared_Plus_TaskLoRA(l_modules, l_task, module_name=module_name)
else:
raise Exception(module_name,m0)
return TaskSpecific_MoE([s for s in l_modules], tasks=l_task)
if obj.dont_lora:
return TaskSpecific_MoE([s for s in l_modules], tasks=l_task)
return obj
class ResidualAdapterLinearOnly(nn.Module):
"""
Full-rank residual adapter returning the linear delta (orig - shared).
"""
def __init__(self, in_features: int, out_features: int, scaling: float = 1.0, use_bias_delta: bool = True):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.rank = min(in_features, out_features)
self.scaling = scaling
self.use_bias_delta = use_bias_delta
self.delta_weight = nn.Parameter(torch.zeros(out_features, in_features))
if use_bias_delta:
self.delta_bias = nn.Parameter(torch.zeros(out_features))
else:
self.register_parameter('delta_bias', None)
@torch.no_grad()
def init_from_diff(self, weight_diff: torch.Tensor, bias_diff: torch.Tensor = None):
self.delta_weight.copy_(weight_diff)
if (self.delta_bias is not None) and (bias_diff is not None):
self.delta_bias.copy_(bias_diff)
def forward(self, x: torch.Tensor) -> torch.Tensor:
update = x @ self.delta_weight.T
if self.delta_bias is not None:
update = update + self.delta_bias
return update * self.scaling
class ResidualAdapterConv2dOnly(nn.Module):
"""
Full-rank residual adapter for Conv2d, returning the convolutional delta (orig - shared).
"""
def __init__(self, in_channels: int, out_channels: int, kernel_size: tuple, stride: tuple, padding: tuple, dilation: tuple, groups: int = 1, scaling: float = 1.0, use_bias_delta: bool = True):
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
if isinstance(stride, int):
stride = (stride, stride)
if isinstance(padding, int):
padding = (padding, padding)
if isinstance(dilation, int):
dilation = (dilation, dilation)
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
kH, kW = kernel_size
self.rank = min(out_channels, in_channels * kH * kW)
self.scaling = scaling
self.use_bias_delta = use_bias_delta
self.delta_weight = nn.Parameter(torch.zeros(out_channels, in_channels // groups, kH, kW))
if use_bias_delta:
self.delta_bias = nn.Parameter(torch.zeros(out_channels))
else:
self.register_parameter('delta_bias', None)
@torch.no_grad()
def init_from_diff(self, weight_diff: torch.Tensor, bias_diff: torch.Tensor = None):
self.delta_weight.copy_(weight_diff)
if (self.delta_bias is not None) and (bias_diff is not None):
self.delta_bias.copy_(bias_diff)
def forward(self, x: torch.Tensor) -> torch.Tensor:
u = F.conv2d(x, self.delta_weight, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)
if self.delta_bias is not None:
u = u + self.delta_bias.view(1, -1, 1, 1)
return u * self.scaling
class Linear_TaskSpecific_Plus_Shared(nn.Module):
def __init__(self, l_proj: list, l_task: list):
super().__init__()
assert len(l_proj) >= 1
p0 = l_proj[0]
assert isinstance(p0, nn.Linear)
in_f, out_f = p0.in_features, p0.out_features
bias = p0.bias is not None
self.shared = nn.Linear(in_f, out_f, bias=bias)
self.shared = zero_module(self.shared)
self.tasks = l_task
self.task_proj = ModuleDict_W(l_proj, self.tasks)
def forward(self, x):
t = global_.task
return self.task_proj[t](x) + self.shared(x)
class Conv_TaskSpecific_Plus_Shared(nn.Module):
def __init__(self, l_conv: list, l_task: list):
super().__init__()
assert len(l_conv) >= 1
c0 = l_conv[0]
assert isinstance(c0, nn.Conv2d)
self.shared = nn.Conv2d(c0.in_channels, c0.out_channels, kernel_size=c0.kernel_size, stride=c0.stride, padding=c0.padding, dilation=c0.dilation, groups=c0.groups, bias=(c0.bias is not None), padding_mode=c0.padding_mode)
self.shared = zero_module(self.shared)
self.tasks = l_task
self.task_conv = ModuleDict_W(l_conv, self.tasks)
def forward(self, x):
t = global_.task
return self.task_conv[t](x) + self.shared(x)
def _average_state_dict(modules: list):
assert len(modules) > 0
sd0 = modules[0].state_dict()
avg = {k: torch.zeros_like(v) for k, v in sd0.items()}
for m in modules:
msd = m.state_dict()
for k in avg:
avg[k] += msd[k]
for k in avg:
avg[k] /= len(modules)
return avg
class FFN_Shared_Plus_TaskLoRA(nn.Module):
def __init__(self, l_ffn: list, l_task: list, module_name: str = None):
super().__init__()
self.module_name = module_name
# _log1(f"-------- {module_name} --------")
assert len(l_ffn) >= 1
self.tasks = l_task
self.num_tasks = len(l_task)
self.dont_lora = False
f0: FeedForward = l_ffn[0]
# build shared from f0 and load avg
self.shared_ffn: FeedForward = copy.deepcopy(f0)
if FOR_upcycle_ckpt_GEN_or_USE:
avg_sd = _average_state_dict(l_ffn)
self.shared_ffn.load_state_dict(avg_sd)
# freeze shared
for p in self.shared_ffn.parameters():
p.requires_grad = False
# discover inner layers
self.is_glu = isinstance(self.shared_ffn.net[0], GEGLU)
if self.is_glu:
in_linear: nn.Linear = self.shared_ffn.net[0].proj
else:
assert isinstance(self.shared_ffn.net[0], nn.Sequential)
in_linear: nn.Linear = self.shared_ffn.net[0][0]
out_linear: nn.Linear = self.shared_ffn.net[2]
self.in_features = in_linear.in_features
self.mid_features = in_linear.out_features
self.out_features = out_linear.out_features
if 1: # cal/read adaptive rank across tasks
if FOR_upcycle_ckpt_GEN_or_USE:
w_diff_in_list = []
w_diff_out_list = []
for f in l_ffn:
if self.is_glu:
tin: nn.Linear = f.net[0].proj
else:
tin: nn.Linear = f.net[0][0]
tout: nn.Linear = f.net[2]
w_diff_in_list.append(tin.weight.data - in_linear.weight.data)
w_diff_out_list.append(tout.weight.data - out_linear.weight.data)
if FORCE_SAME_RANK_ACROSS_TASKS:
rank_in = compute_adaptive_rank_for_linear_diffs(w_diff_in_list)
rank_out = compute_adaptive_rank_for_linear_diffs(w_diff_out_list)
global_.moduleName_2_adaRank[module_name] = [rank_in, rank_out]
else:
ranks_in = compute_adaptive_rank_for_linear_diffs(w_diff_in_list, per_task=True)
ranks_out = compute_adaptive_rank_for_linear_diffs(w_diff_out_list, per_task=True)
global_.moduleName_2_adaRank[module_name] = [ranks_in, ranks_out]
else:
r_info = global_.moduleName_2_adaRank[module_name]
if FORCE_SAME_RANK_ACROSS_TASKS: rank_in, rank_out = r_info
else: ranks_in, ranks_out = r_info
if 1:
# fallback decision: (1) tiny feature dims
min_dim_in = min(self.in_features, self.mid_features)
min_dim_out = min(self.mid_features, self.out_features)
if (min_dim_in < DONT_lora_if_dim_lt) or (min_dim_out < DONT_lora_if_dim_lt):
# print(f"[LoRA fallback][FFN] {module_name} {min_dim_in=} {min_dim_out=} {DONT_lora_if_dim_lt=}")
self.dont_lora = True; return
# per-task LoRA adapters
_l_in = []
_l_out = []
for idx, f in enumerate(l_ffn):
if self.is_glu:
tin: nn.Linear = f.net[0].proj
else:
tin: nn.Linear = f.net[0][0]
tout: nn.Linear = f.net[2]
if not FORCE_SAME_RANK_ACROSS_TASKS:
rank_in = ranks_in[idx]
rank_out = ranks_out[idx]
frac_in = float(rank_in) / min(self.in_features, self.mid_features)
frac_out = float(rank_out) / min(self.mid_features, self.out_features)
frac_avg = 0.5 * (frac_in + frac_out)
if frac_avg > DONT_lora_if_rankFrac_gt:
lora_in = ResidualAdapterLinearOnly(self.in_features, self.mid_features, scaling=1.0, use_bias_delta=True)
lora_out = ResidualAdapterLinearOnly(tout.in_features, tout.out_features, scaling=1.0, use_bias_delta=True)
else:
lora_in = LoRAAdapterLinearOnly(self.in_features, self.mid_features, rank=rank_in, dropout=0.0, scaling=1.0)
lora_out = LoRAAdapterLinearOnly(tout.in_features, tout.out_features, rank=rank_out, dropout=0.0, scaling=1.0)
# init from diffs
if FOR_upcycle_ckpt_GEN_or_USE:
with torch.no_grad():
w_diff_in = tin.weight.data - in_linear.weight.data
b_diff_in = (tin.bias.data - in_linear.bias.data) if tin.bias is not None else None
lora_in.init_from_diff(w_diff_in, b_diff_in)
w_diff_out = tout.weight.data - out_linear.weight.data
b_diff_out = (tout.bias.data - out_linear.bias.data) if tout.bias is not None else None
lora_out.init_from_diff(w_diff_out, b_diff_out)
_l_in.append(lora_in)
_l_out.append(lora_out)
self.task_lora_in = ModuleDict_W(_l_in, self.tasks)
self.task_lora_out = ModuleDict_W(_l_out, self.tasks)
# reuse dropout and activation behavior
self.dropout_p = self.shared_ffn.net[1].p if isinstance(self.shared_ffn.net[1], nn.Dropout) else 0.0
self.dropout = nn.Dropout(self.dropout_p) if self.dropout_p > 0 else nn.Identity()
# Sparse/Dense MoE experts (small inner dim) + gate
if EXTRA_MoE_enable:
small_inner = self.mid_features // EXTRA_MoE_inner_divisor
self.num_moe_expert = EXTRA_MoE_num_ep
gate_in_dim = self.in_features + self.in_features + len(self.tasks) + 2 + 2*NUM_lmk_pick
hidden = gate_in_dim // 8
self.moe_gate_mlp = nn.Sequential(
nn.Linear(gate_in_dim, hidden),
nn.SiLU(),
nn.Linear(hidden, self.num_moe_expert),
)
if EXTRA_MoE_routing_mode == 'dense':
self.moe_experts_batched = BatchedFeedForward(
dim=self.in_features, dim_out=self.out_features,
glu=self.is_glu, dropout=self.dropout_p,
inner_dim=small_inner, num_expert=self.num_moe_expert,
)
else:
mult = small_inner / self.in_features
experts = []
for _ in range(self.num_moe_expert):
expert = FeedForward(self.in_features, dim_out=self.out_features, mult=mult, glu=self.is_glu, dropout=self.dropout_p)
experts.append(expert)
self.moe_experts_list = nn.ModuleList(experts)
if 0: # log MoE gate and expert architecture to file only (no terminal output)
log_dir = Path("4debug/moe_ffn_struc"); log_dir.mkdir(exist_ok=True)
mod_name = self.module_name; log_path = log_dir / f"{mod_name}.txt"
gate_desc = f"GateMLP: Linear({gate_in_dim},{hidden})->SiLU->Linear({hidden},{self.num_moe_expert})"
if EXTRA_MoE_routing_mode == 'dense':
ep_desc = f"BatchedFeedForward(glu={self.is_glu}, num={self.num_moe_expert}, inner={small_inner}, in={self.in_features}, out={self.out_features})"
else:
ep_desc = f"FeedForwardList(glu={self.is_glu}, num={self.num_moe_expert}, inner≈{self.in_features*mult}, in={self.in_features}, out={self.out_features})"
with open(log_path, 'a') as f:
f.write(f"{mod_name} | routing={EXTRA_MoE_routing_mode} | {gate_desc} | {ep_desc}\n")
print(f"{log_path}")
if FOR_upcycle_ckpt_GEN_or_USE:
self.verify_approximation(orig_ffn_list=l_ffn)
def forward(self, x: torch.Tensor, token_pos_grid__cur=None):
t = global_.task
# in linear + LoRA
if self.is_glu:
base = self.shared_ffn.net[0].proj(x)
delta = self.task_lora_in[t](x)
z = base + delta
v, gate = z.chunk(2, dim=-1)
h = v * F.gelu(gate)
else:
base = self.shared_ffn.net[0][0](x)
delta = self.task_lora_in[t](x)
h = F.gelu(base + delta)
h = self.dropout(h)
# out linear + LoRA
y_base = self.shared_ffn.net[2](h)
y_delta = self.task_lora_out[t](h)
y = y_base + y_delta
if EXTRA_MoE_enable:
# gate input
gate_in, _ = build_ffn_gate_input_common(x, token_pos_grid__cur, self.tasks)
scores = self.moe_gate_mlp(gate_in).to(dtype=x.dtype) # b,n,k
if EXTRA_MoE_add_noise and self.training:
scores = scores + torch.randn_like(scores) * EXTRA_MoE_noise_std
v_topk, idx_topk = scores.topk(k=EXTRA_MoE_topK, dim=-1)
if EXTRA_MoE_routing_mode == 'dense':
raise Exception('not carefully checked yet')
else: # sparse: forward only the selected experts and aggregate by top-k weights
if 1: weights_topk = torch.softmax(v_topk, dim=-1) # b,n,topk
else: weights_topk = v_topk # b,n,topk. use top-k expert scores directly as weights
b, n, d = x.shape
dim_out = self.out_features
y_moe_flat = x.new_zeros(b*n, dim_out) # flattened tensor accumulating outputs from all experts (bs*N, D_out)
x_flat = x.reshape(b*n, d) # flatten input tensor (bs*N, D_in)
unique_experts = torch.unique(idx_topk) # set of expert IDs actually selected in this batch
for j in unique_experts.tolist(): # iterate only over active experts
mask_j = (idx_topk == j) # b,n,topk boolean mask indicating which tokens picked expert j
sel_token_mask = mask_j.any(dim=-1) # b,n boolean mask for tokens that selected expert j
if not sel_token_mask.any(): # skip if expert j was not selected by any token
continue
flat_pos = sel_token_mask.view(-1).nonzero(as_tuple=False).squeeze(1) # T_j flattened indices of tokens assigned to expert j
x_sel = x_flat.index_select(0, flat_pos) # T_j,d select those tokens from flattened input
# run expert only on selected tokens (n = T_j)
y_sel = self.moe_experts_list[j](x_sel.view(1, x_sel.shape[0], d)).squeeze(0) # T_j,dim_out expert j handles only its tokens
w_tok = (weights_topk * mask_j).sum(dim=-1).view(-1).index_select(0, flat_pos).unsqueeze(-1) # T_j,1 weights for each token assigned to expert j
y_moe_flat.index_add_(0, flat_pos, w_tok * y_sel) # add weighted expert output back into flattened tensor (in-place)
y = y + y_moe_flat.view(b, n, dim_out) # reshape aggregated MoE output and add back to backbone output
if EXTRA_MoE_en_auxLoss and self.training:
raise Exception('not carefully checked yet')
importance = torch.zeros(self.num_moe_expert, device=scores.device, dtype=weights_topk.dtype)
importance = importance.scatter_add(0, idx_topk.reshape(-1), weights_topk.reshape(-1))
load = torch.zeros(self.num_moe_expert, device=scores.device, dtype=weights_topk.dtype)
load = load.scatter_add(0, idx_topk.reshape(-1), torch.ones_like(weights_topk.reshape(-1)))
k = importance.shape[0]
target_imp = torch.full_like(importance, fill_value=importance.sum() / k)
target_load = torch.full_like(load, fill_value=load.sum() / k)
aux_imp = F.mse_loss(importance, target_imp)
aux_load = F.mse_loss(load, target_load)
aux = 0.5 * (aux_imp + aux_load) * EXTRA_MoE_aux_coef
global_.moe_aux_loss = aux # expose aux loss to the training loop for aggregation
return y
@torch.no_grad()
def verify_approximation(self, num_tokens: int = 16, batch_size: int = 2, orig_ffn_list: list = None):
if EXTRA_MoE_enable: return
device = next(self.shared_ffn.parameters()).device
dtype = next(self.shared_ffn.parameters()).dtype
x = torch.randn(batch_size, num_tokens, self.in_features, device=device, dtype=dtype)
old_task = getattr(global_, 'task', None)
for i,t in enumerate(self.tasks):
_log2(orig_ffn_list[i], [self.task_lora_in[t], self.task_lora_out[t]])
global_.task = t
y_lora = self.forward(x)
y_avg = self.shared_ffn(x)
assert orig_ffn_list is not None, "orig_ffn_list must be provided for verification"
y_orig = orig_ffn_list[i](x)
d_avg = torch.norm((y_avg - y_orig).float()).item()
d_lora = torch.norm((y_lora - y_orig).float()).item()
_log1(f"[FFN verify] task={t} rank_in={self.task_lora_in[t].rank} rank_out={self.task_lora_out[t].rank} L2(avg,orig)={d_avg:.6f} L2(lora,orig)={d_lora:.6f}")
global_.task = old_task
class Linear_Shared_Plus_TaskLoRA(nn.Module):
def __init__(self, l_proj: list, l_task: list, module_name: str = None):
super().__init__()
# _log1(f"-------- {module_name} --------")
assert len(l_proj) >= 1
self.dont_lora = False
p0: nn.Linear = l_proj[0]
# build shared from p0 and load avg
self.shared: nn.Linear = copy.deepcopy(p0)
if FOR_upcycle_ckpt_GEN_or_USE:
avg_sd = _average_state_dict(l_proj)
self.shared.load_state_dict(avg_sd)
for p in self.shared.parameters():
p.requires_grad = False
self.in_features = self.shared.in_features
self.out_features = self.shared.out_features
self.tasks = l_task
# cal/read adaptive rank across tasks
if 1:
if FOR_upcycle_ckpt_GEN_or_USE:
w_diff_list = []
for lin in l_proj:
w_diff_list.append(lin.weight.data - self.shared.weight.data)
if FORCE_SAME_RANK_ACROSS_TASKS:
rank_lin = compute_adaptive_rank_for_linear_diffs(w_diff_list)
global_.moduleName_2_adaRank[module_name] = rank_lin
else:
ranks_lin = compute_adaptive_rank_for_linear_diffs(w_diff_list, per_task=True)
global_.moduleName_2_adaRank[module_name] = ranks_lin
else:
r_info = global_.moduleName_2_adaRank[module_name]
if FORCE_SAME_RANK_ACROSS_TASKS: rank_lin = r_info
else: ranks_lin = r_info
if 1: # fallback decision for Linear
min_dim = min(self.in_features, self.out_features)
if min_dim < DONT_lora_if_dim_lt:
# print(f"[LoRA fallback][Linear] {module_name} {min_dim=} < {DONT_lora_if_dim_lt}")
self.dont_lora = True; return
_l = [] # per-task LoRA adapters
for idx, lin in enumerate(l_proj):
if not FORCE_SAME_RANK_ACROSS_TASKS:
rank_lin = ranks_lin[idx]
frac = float(rank_lin) / min(self.in_features, self.out_features)
if frac > DONT_lora_if_rankFrac_gt:
lora = ResidualAdapterLinearOnly(self.in_features, self.out_features, scaling=1.0, use_bias_delta=True)
else:
lora = LoRAAdapterLinearOnly(self.in_features, self.out_features, rank=rank_lin, dropout=0.0, scaling=1.0)
if FOR_upcycle_ckpt_GEN_or_USE:
with torch.no_grad():
w_diff = lin.weight.data - self.shared.weight.data
b_diff = (lin.bias.data - self.shared.bias.data) if (lin.bias is not None and self.shared.bias is not None) else None
lora.init_from_diff(w_diff, b_diff)
_l.append(lora)
self.task_lora = ModuleDict_W(_l, self.tasks)
if FOR_upcycle_ckpt_GEN_or_USE:
self.verify_approximation(orig_linear_list=l_proj)
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = self.shared(x)
y = y + self.task_lora[global_.task](x)
return y
@torch.no_grad()
def verify_approximation(self, batch_size: int = 2, in_dim_override: int = None, orig_linear_list: list = None):
device = next(self.shared.parameters()).device
dtype = next(self.shared.parameters()).dtype
d_in = self.in_features if in_dim_override is None else in_dim_override
x = torch.randn(batch_size, d_in, device=device, dtype=dtype)
old_task = getattr(global_, 'task', None)
for i,t in enumerate(self.tasks):
_log2(orig_linear_list[i], self.task_lora[t])
global_.task = t
y_lora = self.forward(x)
y_avg = self.shared(x)
assert orig_linear_list is not None, "orig_linear_list must be provided for verification"
y_orig = orig_linear_list[i](x)
d_avg = torch.norm((y_avg - y_orig).float()).item()
d_lora = torch.norm((y_lora - y_orig).float()).item()
_log1(f"[Linear verify] task={t} rank={self.task_lora[t].rank} L2(avg,orig)={d_avg:.6f} L2(lora,orig)={d_lora:.6f}")
global_.task = old_task
class Conv_Shared_Plus_TaskLoRA(nn.Module):
def __init__(self, l_conv: list, l_task: list, module_name: str = None):
super().__init__()
# _log1(f"-------- {module_name} --------")
assert len(l_conv) >= 1
self.dont_lora = False
c0: nn.Conv2d = l_conv[0]
# build shared conv
self.shared = nn.Conv2d(
c0.in_channels, c0.out_channels,
kernel_size=c0.kernel_size, stride=c0.stride,
padding=c0.padding, dilation=c0.dilation,
groups=c0.groups, bias=(c0.bias is not None),
padding_mode=c0.padding_mode,
)
if FOR_upcycle_ckpt_GEN_or_USE:
avg_sd = _average_state_dict(l_conv)
self.shared.load_state_dict(avg_sd)
for p in self.shared.parameters():
p.requires_grad = False
# per-task LoRA
self.tasks = l_task
_l = []
# cal/read adaptive rank across tasks
if 1:
if FOR_upcycle_ckpt_GEN_or_USE:
w_diff_list = []
for c in l_conv:
w_diff_list.append(c.weight.data - self.shared.weight.data)
if FORCE_SAME_RANK_ACROSS_TASKS:
rank_conv = compute_adaptive_rank_for_conv_diffs(w_diff_list)
global_.moduleName_2_adaRank[module_name] = rank_conv
else:
ranks_conv = compute_adaptive_rank_for_conv_diffs(w_diff_list, per_task=True)
global_.moduleName_2_adaRank[module_name] = ranks_conv
else:
r_info = global_.moduleName_2_adaRank[module_name]
if FORCE_SAME_RANK_ACROSS_TASKS: rank_conv = r_info
else: ranks_conv = r_info
if 1: # fallback decision for Conv
kH, kW = self.shared.kernel_size
min_dim = min(self.shared.out_channels, self.shared.in_channels * kH * kW )
if min_dim < DONT_lora_if_dim_lt:
# print(f"[LoRA fallback][Conv] {module_name} {min_dim=} {DONT_lora_if_dim_lt=} (in={self.shared.in_channels}, out={self.shared.out_channels}, k=({kH},{kW}))")
self.dont_lora = True; return
for idx, c in enumerate(l_conv):
if not FORCE_SAME_RANK_ACROSS_TASKS:
rank_conv = ranks_conv[idx]
frac = float(rank_conv) / min(self.shared.out_channels, self.shared.in_channels * kH * kW)
if frac > DONT_lora_if_rankFrac_gt:
lora = ResidualAdapterConv2dOnly(
in_channels=c.in_channels, out_channels=c.out_channels,
kernel_size=c.kernel_size, stride=c.stride,
padding=c.padding, dilation=c.dilation, groups=c.groups,
scaling=1.0, use_bias_delta=True,
)
else:
lora = LoRAAdapterConv2dOnly(
in_channels=c.in_channels, out_channels=c.out_channels,
kernel_size=c.kernel_size, stride=c.stride,
padding=c.padding, dilation=c.dilation, groups=c.groups,
rank=rank_conv, dropout=0.0, scaling=1.0,
)
if FOR_upcycle_ckpt_GEN_or_USE:
with torch.no_grad():
w_diff = c.weight.data - self.shared.weight.data
b_diff = (c.bias.data - self.shared.bias.data) if c.bias is not None and self.shared.bias is not None else None
lora.init_from_diff(w_diff, b_diff)
_l.append(lora)
self.task_lora = ModuleDict_W(_l, self.tasks)
if FOR_upcycle_ckpt_GEN_or_USE:
self.verify_approximation(orig_conv_list=l_conv)
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = self.shared(x)
y = y + self.task_lora[global_.task](x)
return y
@torch.no_grad()
def verify_approximation(self, spatial_hw=(32, 32), batch_size: int = 2, orig_conv_list: list = None):
device = next(self.shared.parameters()).device
dtype = next(self.shared.parameters()).dtype
H, W = spatial_hw
x = torch.randn(batch_size, self.shared.in_channels, H, W, device=device, dtype=dtype)
old_task = getattr(global_, 'task', None)
for i,t in enumerate(self.tasks):
_log2(orig_conv_list[i], self.task_lora[t])
global_.task = t
y_lora = self.forward(x)
y_avg = self.shared(x)
assert orig_conv_list is not None, "orig_conv_list must be provided for verification"
y_orig = orig_conv_list[i](x)
d_avg = torch.norm((y_avg - y_orig).float()).item()
d_lora = torch.norm((y_lora - y_orig).float()).item()
_log1(f"[Conv2d verify] task={t} rank={self.task_lora[t].rank} L2(avg,orig)={d_avg:.6f} L2(lora,orig)={d_lora:.6f}")
global_.task = old_task