| from ldm.modules.attention import * |
| import global_ |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from my_py_lib.torch_util import custom_repr_v3 |
| from confs import * |
| import cv2, numpy as np |
| from lmk_util.lmk_extractor import lmkAll_2_lmkMain, get_lmkMain_indices |
| from MoE import * |
| from lora_layers import * |
| import json |
| import copy |
|
|
|
|
|
|
| """ |
| Global knobs for shared experts and routing (no argparse per user preference) |
| """ |
| NUM_SHARED_FFN = 8 |
| GATE_TOPK = 2 |
|
|
| |
| |
| EXTRA_MoE_enable :bool = 1 |
| EXTRA_MoE_num_ep = 8 |
| EXTRA_MoE_inner_divisor = 64 |
| EXTRA_MoE_topK = 2 |
| EXTRA_MoE_add_noise :bool = 1 |
| EXTRA_MoE_noise_std = 0.1 |
| EXTRA_MoE_en_auxLoss :bool = 0 |
| EXTRA_MoE_aux_coef = 1e-2 |
| EXTRA_MoE_routing_mode = 'sparse' |
| LMK_PICK_IDX = None |
| NUM_lmk_pick = len(LMK_PICK_IDX) if LMK_PICK_IDX is not None else len(get_lmkMain_indices(include_face_oval=True)) |
| print(f"{NUM_lmk_pick=}") |
| IMAGE_SIZE_FOR_LMK_NORM = 512.0 |
|
|
| def _log2(orig_modules, lora_modules): |
| """Calculate and log parameter statistics for original and LoRA modules""" |
| |
| orig_params = sum(p.numel() for p in orig_modules.parameters()) |
| orig_size = sum(p.numel() * p.element_size() for p in orig_modules.parameters()) |
| |
| if isinstance(lora_modules, (list, tuple)): |
| lora_params = sum(p.numel() for m in lora_modules for p in m.parameters()) |
| lora_size = sum(p.numel() * p.element_size() for m in lora_modules for p in m.parameters()) |
| |
| ranks = [] |
| for m in lora_modules: |
| if hasattr(m, 'rank'): |
| ranks.append(m.rank) |
| if len(ranks) == 2: |
| rank_str = f" (rank_in={ranks[0]} rank_out={ranks[1]})" |
| elif len(ranks) == 1: |
| rank_str = f" (rank={ranks[0]})" |
| else: |
| rank_str = "" |
| else: |
| lora_params = sum(p.numel() for p in lora_modules.parameters()) |
| lora_size = sum(p.numel() * p.element_size() for p in lora_modules.parameters()) |
| |
| if hasattr(lora_modules, 'rank'): |
| rank_str = f" (rank={lora_modules.rank})" |
| else: |
| rank_str = "" |
| msg1 = f"orig: {orig_params:,} params, {orig_size/1024/1024:.2f}MB" |
| msg2 = f"LoRA: {lora_params:,} params, {lora_size/1024/1024:.2f}MB{rank_str}" |
| for msg in [msg1, msg2]: |
| print(msg) |
| continue |
| with open(_verify_log_file, 'a') as f: |
| f.write(msg + '\n') |
| def _log1(msg: str): |
| """Print message and append to log file""" |
| print(msg) |
| return |
| with open(_verify_log_file, 'a') as f: |
| f.write(msg + '\n') |
|
|
| def build_ffn_gate_input_common(x: torch.Tensor, token_pos_grid__cur, tasks: list): |
| """Build gate input for FFN routing (reusable across FFN classes).""" |
| b, n, d = x.shape |
| token_feat = x |
| avg_feat = x.mean(dim=1, keepdim=True).expand(-1, n, -1) |
| len_task = len(tasks) |
| task_1h = x.new_zeros(b, len_task) |
| task_1h[:, global_.task] = 1 |
| task_1h = task_1h.unsqueeze(1).expand(-1, n, -1) |
| token_pos = token_pos_grid__cur |
| assert token_pos.shape[:2] == (b, n), (token_pos.shape, (b, n), ) |
| rel_flat = x.new_zeros(b, n, 2*NUM_lmk_pick) |
| lmk = global_.lmk_ |
| if 1: |
| lmk = lmk.to(x.device).float() |
| if LMK_PICK_IDX is None: |
| assert NUM_lmk_pick==lmk.shape[1] |
| else: |
| lmk = lmk[:, LMK_PICK_IDX, :] |
| rel = token_pos.unsqueeze(2) - lmk.unsqueeze(1) |
| rel_flat = rel.reshape(b, n, -1) |
| gate_in = torch.cat([token_feat, avg_feat, task_1h, token_pos, rel_flat], dim=-1) |
| ctx = {'token_feat': token_feat, 'avg_feat': avg_feat, 'task_1h': task_1h, 'token_pos': token_pos, 'lmk': lmk, 'rel': rel, 'rel_flat': rel_flat} |
| return gate_in, ctx |
|
|
| def replace_modules_lossless( |
| module: nn.Module, |
| src_modules: list, |
| l_task: list, |
| parent_name: str = "", |
| depth :int = 0, |
| for_refnet: bool = False, |
| ): |
| """ |
| Apply policy: |
| - FFN: shared-plus-task (lossless upcycle) |
| - CrossAttention linear projections (to_q, to_k, to_v, to_out.0): shared-plus-task |
| - Conv2d: keep task-specific or wrap with shared-plus-task if desired |
| - Norms: keep task-specific (LayerNorm/GroupNorm) |
| """ |
| if depth==0: |
| CONV2D_PARAM_STATS.clear() |
| |
| if len(list(module.parameters())) == 0: |
| |
| return module |
| if len(list(module.named_children()))==0: |
| print('\n!!!! len(list(module.named_children()))==0',module) |
| assert 0 |
| for name, child in module.named_children(): |
| full_name = f"{parent_name}.{name}" if parent_name else f".{name}" |
| src_child_modules = [getattr(src_module, name) for src_module in src_modules] |
| if len({id(s) for s in src_child_modules}) < len(src_child_modules): |
| raise Exception('Duplicate source modules detected!') |
| |
| src_child_modules = [copy.deepcopy(src_child_modules[0]) for _ in src_child_modules] |
|
|
| if isinstance(child, FeedForward): |
| if 0: |
| setattr(module, name, TaskSpecific_MoE([s for s in src_child_modules], tasks=l_task)) |
| else: |
| |
| setattr(module, name, upCycle_module(src_child_modules, l_task, module_name=full_name)) |
| continue |
|
|
| if isinstance(child, CrossAttention): |
| |
| |
| if 0: |
| for proj_name in ["to_q", "to_k", "to_v"]: |
| src_proj_list = [getattr(s, proj_name) for s in src_child_modules] |
| setattr(child, proj_name, upCycle_module(src_proj_list, l_task, module_name=f"{full_name}.{proj_name}")) |
| if hasattr(child.to_out, "__getitem__"): |
| src_linear0 = [s.to_out[0] for s in src_child_modules] |
| child.to_out[0] = upCycle_module(src_linear0, l_task, module_name=f"{full_name}.to_out.0") |
| else: |
| for proj_name in ["to_q", "to_k", "to_v"]: |
| src_proj_list = [getattr(s, proj_name) for s in src_child_modules] |
| setattr(child, proj_name, TaskSpecific_MoE([s for s in src_proj_list], tasks=l_task) ) |
| if hasattr(child.to_out, "__getitem__"): |
| src_linear0 = [s.to_out[0] for s in src_child_modules] |
| child.to_out[0] = TaskSpecific_MoE([s for s in src_linear0], tasks=l_task) |
| continue |
|
|
| if isinstance(child, nn.Conv2d): |
| num_params = sum(p.numel() for p in child.parameters()) |
| CONV2D_PARAM_STATS.append((num_params, full_name)) |
| |
| if 1: |
| printC(f"shared+LoRA Conv2d",f"{full_name}") |
| setattr(module, name, upCycle_module(src_child_modules, l_task, module_name=full_name)) |
| else: |
| setattr(module, name, TaskSpecific_MoE([s for s in src_child_modules], tasks=l_task)) |
| continue |
| elif isinstance(child, (nn.LayerNorm, nn.GroupNorm)): |
| setattr(module, name, TaskSpecific_MoE([s for s in src_child_modules], tasks=l_task)) |
| continue |
| elif isinstance(child, nn.Linear): |
| |
| setattr(module, name, TaskSpecific_MoE([s for s in src_child_modules], tasks=l_task)) |
| continue |
| else: |
| replace_modules_lossless(child, src_child_modules, l_task, parent_name=full_name, depth=depth+1, for_refnet=for_refnet) |
|
|
| if depth==0: |
| stats_sorted = sorted(CONV2D_PARAM_STATS, key=lambda x: x[0], reverse=True) |
| if gate_("[Conv2d param stats] count, name (sorted desc):"): |
| for cnt, n in stats_sorted: |
| print(f" {cnt:12d} {n}") |
| return module |
|
|
| def upCycle_module(l_modules, l_task, module_name: str = None): |
| assert len( set( [type(m) for m in l_modules] ) ) == 1 |
| m0 = l_modules[0] |
| if isinstance(m0, FeedForward): |
| obj = FFN_Shared_Plus_TaskLoRA(l_modules, l_task, module_name=module_name) |
| elif isinstance(m0, nn.Linear): |
| obj = Linear_Shared_Plus_TaskLoRA(l_modules, l_task, module_name=module_name) |
| elif isinstance(m0, nn.Conv2d): |
| obj = Conv_Shared_Plus_TaskLoRA(l_modules, l_task, module_name=module_name) |
| else: |
| raise Exception(module_name,m0) |
| return TaskSpecific_MoE([s for s in l_modules], tasks=l_task) |
| if obj.dont_lora: |
| return TaskSpecific_MoE([s for s in l_modules], tasks=l_task) |
| return obj |
|
|
|
|
|
|
|
|
| class ResidualAdapterLinearOnly(nn.Module): |
| """ |
| Full-rank residual adapter returning the linear delta (orig - shared). |
| """ |
| def __init__(self, in_features: int, out_features: int, scaling: float = 1.0, use_bias_delta: bool = True): |
| super().__init__() |
| self.in_features = in_features |
| self.out_features = out_features |
| self.rank = min(in_features, out_features) |
| self.scaling = scaling |
| self.use_bias_delta = use_bias_delta |
| self.delta_weight = nn.Parameter(torch.zeros(out_features, in_features)) |
| if use_bias_delta: |
| self.delta_bias = nn.Parameter(torch.zeros(out_features)) |
| else: |
| self.register_parameter('delta_bias', None) |
| @torch.no_grad() |
| def init_from_diff(self, weight_diff: torch.Tensor, bias_diff: torch.Tensor = None): |
| self.delta_weight.copy_(weight_diff) |
| if (self.delta_bias is not None) and (bias_diff is not None): |
| self.delta_bias.copy_(bias_diff) |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| update = x @ self.delta_weight.T |
| if self.delta_bias is not None: |
| update = update + self.delta_bias |
| return update * self.scaling |
|
|
| class ResidualAdapterConv2dOnly(nn.Module): |
| """ |
| Full-rank residual adapter for Conv2d, returning the convolutional delta (orig - shared). |
| """ |
| def __init__(self, in_channels: int, out_channels: int, kernel_size: tuple, stride: tuple, padding: tuple, dilation: tuple, groups: int = 1, scaling: float = 1.0, use_bias_delta: bool = True): |
| super().__init__() |
| if isinstance(kernel_size, int): |
| kernel_size = (kernel_size, kernel_size) |
| if isinstance(stride, int): |
| stride = (stride, stride) |
| if isinstance(padding, int): |
| padding = (padding, padding) |
| if isinstance(dilation, int): |
| dilation = (dilation, dilation) |
| self.in_channels = in_channels |
| self.out_channels = out_channels |
| self.kernel_size = kernel_size |
| self.stride = stride |
| self.padding = padding |
| self.dilation = dilation |
| self.groups = groups |
| kH, kW = kernel_size |
| self.rank = min(out_channels, in_channels * kH * kW) |
| self.scaling = scaling |
| self.use_bias_delta = use_bias_delta |
| self.delta_weight = nn.Parameter(torch.zeros(out_channels, in_channels // groups, kH, kW)) |
| if use_bias_delta: |
| self.delta_bias = nn.Parameter(torch.zeros(out_channels)) |
| else: |
| self.register_parameter('delta_bias', None) |
| @torch.no_grad() |
| def init_from_diff(self, weight_diff: torch.Tensor, bias_diff: torch.Tensor = None): |
| self.delta_weight.copy_(weight_diff) |
| if (self.delta_bias is not None) and (bias_diff is not None): |
| self.delta_bias.copy_(bias_diff) |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| u = F.conv2d(x, self.delta_weight, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) |
| if self.delta_bias is not None: |
| u = u + self.delta_bias.view(1, -1, 1, 1) |
| return u * self.scaling |
|
|
|
|
|
|
| class Linear_TaskSpecific_Plus_Shared(nn.Module): |
| def __init__(self, l_proj: list, l_task: list): |
| super().__init__() |
| assert len(l_proj) >= 1 |
| p0 = l_proj[0] |
| assert isinstance(p0, nn.Linear) |
| in_f, out_f = p0.in_features, p0.out_features |
| bias = p0.bias is not None |
| self.shared = nn.Linear(in_f, out_f, bias=bias) |
| self.shared = zero_module(self.shared) |
| self.tasks = l_task |
| self.task_proj = ModuleDict_W(l_proj, self.tasks) |
|
|
| def forward(self, x): |
| t = global_.task |
| return self.task_proj[t](x) + self.shared(x) |
|
|
|
|
| class Conv_TaskSpecific_Plus_Shared(nn.Module): |
| def __init__(self, l_conv: list, l_task: list): |
| super().__init__() |
| assert len(l_conv) >= 1 |
| c0 = l_conv[0] |
| assert isinstance(c0, nn.Conv2d) |
| self.shared = nn.Conv2d(c0.in_channels, c0.out_channels, kernel_size=c0.kernel_size, stride=c0.stride, padding=c0.padding, dilation=c0.dilation, groups=c0.groups, bias=(c0.bias is not None), padding_mode=c0.padding_mode) |
| self.shared = zero_module(self.shared) |
| self.tasks = l_task |
| self.task_conv = ModuleDict_W(l_conv, self.tasks) |
|
|
| def forward(self, x): |
| t = global_.task |
| return self.task_conv[t](x) + self.shared(x) |
|
|
|
|
|
|
|
|
| def _average_state_dict(modules: list): |
| assert len(modules) > 0 |
| sd0 = modules[0].state_dict() |
| avg = {k: torch.zeros_like(v) for k, v in sd0.items()} |
| for m in modules: |
| msd = m.state_dict() |
| for k in avg: |
| avg[k] += msd[k] |
| for k in avg: |
| avg[k] /= len(modules) |
| return avg |
|
|
|
|
| class FFN_Shared_Plus_TaskLoRA(nn.Module): |
| def __init__(self, l_ffn: list, l_task: list, module_name: str = None): |
| super().__init__() |
| self.module_name = module_name |
| |
| assert len(l_ffn) >= 1 |
| self.tasks = l_task |
| self.num_tasks = len(l_task) |
| self.dont_lora = False |
| f0: FeedForward = l_ffn[0] |
| |
| self.shared_ffn: FeedForward = copy.deepcopy(f0) |
| if FOR_upcycle_ckpt_GEN_or_USE: |
| avg_sd = _average_state_dict(l_ffn) |
| self.shared_ffn.load_state_dict(avg_sd) |
| |
| for p in self.shared_ffn.parameters(): |
| p.requires_grad = False |
| |
| self.is_glu = isinstance(self.shared_ffn.net[0], GEGLU) |
| if self.is_glu: |
| in_linear: nn.Linear = self.shared_ffn.net[0].proj |
| else: |
| assert isinstance(self.shared_ffn.net[0], nn.Sequential) |
| in_linear: nn.Linear = self.shared_ffn.net[0][0] |
| out_linear: nn.Linear = self.shared_ffn.net[2] |
| self.in_features = in_linear.in_features |
| self.mid_features = in_linear.out_features |
| self.out_features = out_linear.out_features |
| if 1: |
| if FOR_upcycle_ckpt_GEN_or_USE: |
| w_diff_in_list = [] |
| w_diff_out_list = [] |
| for f in l_ffn: |
| if self.is_glu: |
| tin: nn.Linear = f.net[0].proj |
| else: |
| tin: nn.Linear = f.net[0][0] |
| tout: nn.Linear = f.net[2] |
| w_diff_in_list.append(tin.weight.data - in_linear.weight.data) |
| w_diff_out_list.append(tout.weight.data - out_linear.weight.data) |
| if FORCE_SAME_RANK_ACROSS_TASKS: |
| rank_in = compute_adaptive_rank_for_linear_diffs(w_diff_in_list) |
| rank_out = compute_adaptive_rank_for_linear_diffs(w_diff_out_list) |
| global_.moduleName_2_adaRank[module_name] = [rank_in, rank_out] |
| else: |
| ranks_in = compute_adaptive_rank_for_linear_diffs(w_diff_in_list, per_task=True) |
| ranks_out = compute_adaptive_rank_for_linear_diffs(w_diff_out_list, per_task=True) |
| global_.moduleName_2_adaRank[module_name] = [ranks_in, ranks_out] |
| else: |
| r_info = global_.moduleName_2_adaRank[module_name] |
| if FORCE_SAME_RANK_ACROSS_TASKS: rank_in, rank_out = r_info |
| else: ranks_in, ranks_out = r_info |
| if 1: |
| |
| min_dim_in = min(self.in_features, self.mid_features) |
| min_dim_out = min(self.mid_features, self.out_features) |
| if (min_dim_in < DONT_lora_if_dim_lt) or (min_dim_out < DONT_lora_if_dim_lt): |
| |
| self.dont_lora = True; return |
| |
| _l_in = [] |
| _l_out = [] |
| for idx, f in enumerate(l_ffn): |
| if self.is_glu: |
| tin: nn.Linear = f.net[0].proj |
| else: |
| tin: nn.Linear = f.net[0][0] |
| tout: nn.Linear = f.net[2] |
| if not FORCE_SAME_RANK_ACROSS_TASKS: |
| rank_in = ranks_in[idx] |
| rank_out = ranks_out[idx] |
| frac_in = float(rank_in) / min(self.in_features, self.mid_features) |
| frac_out = float(rank_out) / min(self.mid_features, self.out_features) |
| frac_avg = 0.5 * (frac_in + frac_out) |
| if frac_avg > DONT_lora_if_rankFrac_gt: |
| lora_in = ResidualAdapterLinearOnly(self.in_features, self.mid_features, scaling=1.0, use_bias_delta=True) |
| lora_out = ResidualAdapterLinearOnly(tout.in_features, tout.out_features, scaling=1.0, use_bias_delta=True) |
| else: |
| lora_in = LoRAAdapterLinearOnly(self.in_features, self.mid_features, rank=rank_in, dropout=0.0, scaling=1.0) |
| lora_out = LoRAAdapterLinearOnly(tout.in_features, tout.out_features, rank=rank_out, dropout=0.0, scaling=1.0) |
| |
| if FOR_upcycle_ckpt_GEN_or_USE: |
| with torch.no_grad(): |
| w_diff_in = tin.weight.data - in_linear.weight.data |
| b_diff_in = (tin.bias.data - in_linear.bias.data) if tin.bias is not None else None |
| lora_in.init_from_diff(w_diff_in, b_diff_in) |
| w_diff_out = tout.weight.data - out_linear.weight.data |
| b_diff_out = (tout.bias.data - out_linear.bias.data) if tout.bias is not None else None |
| lora_out.init_from_diff(w_diff_out, b_diff_out) |
| _l_in.append(lora_in) |
| _l_out.append(lora_out) |
| self.task_lora_in = ModuleDict_W(_l_in, self.tasks) |
| self.task_lora_out = ModuleDict_W(_l_out, self.tasks) |
| |
| self.dropout_p = self.shared_ffn.net[1].p if isinstance(self.shared_ffn.net[1], nn.Dropout) else 0.0 |
| self.dropout = nn.Dropout(self.dropout_p) if self.dropout_p > 0 else nn.Identity() |
|
|
| |
| if EXTRA_MoE_enable: |
| small_inner = self.mid_features // EXTRA_MoE_inner_divisor |
| self.num_moe_expert = EXTRA_MoE_num_ep |
| gate_in_dim = self.in_features + self.in_features + len(self.tasks) + 2 + 2*NUM_lmk_pick |
| hidden = gate_in_dim // 8 |
| self.moe_gate_mlp = nn.Sequential( |
| nn.Linear(gate_in_dim, hidden), |
| nn.SiLU(), |
| nn.Linear(hidden, self.num_moe_expert), |
| ) |
|
|
| if EXTRA_MoE_routing_mode == 'dense': |
| self.moe_experts_batched = BatchedFeedForward( |
| dim=self.in_features, dim_out=self.out_features, |
| glu=self.is_glu, dropout=self.dropout_p, |
| inner_dim=small_inner, num_expert=self.num_moe_expert, |
| ) |
| else: |
| mult = small_inner / self.in_features |
| experts = [] |
| for _ in range(self.num_moe_expert): |
| expert = FeedForward(self.in_features, dim_out=self.out_features, mult=mult, glu=self.is_glu, dropout=self.dropout_p) |
| experts.append(expert) |
| self.moe_experts_list = nn.ModuleList(experts) |
|
|
| if 0: |
| log_dir = Path("4debug/moe_ffn_struc"); log_dir.mkdir(exist_ok=True) |
| mod_name = self.module_name; log_path = log_dir / f"{mod_name}.txt" |
| gate_desc = f"GateMLP: Linear({gate_in_dim},{hidden})->SiLU->Linear({hidden},{self.num_moe_expert})" |
| if EXTRA_MoE_routing_mode == 'dense': |
| ep_desc = f"BatchedFeedForward(glu={self.is_glu}, num={self.num_moe_expert}, inner={small_inner}, in={self.in_features}, out={self.out_features})" |
| else: |
| ep_desc = f"FeedForwardList(glu={self.is_glu}, num={self.num_moe_expert}, inner≈{self.in_features*mult}, in={self.in_features}, out={self.out_features})" |
| with open(log_path, 'a') as f: |
| f.write(f"{mod_name} | routing={EXTRA_MoE_routing_mode} | {gate_desc} | {ep_desc}\n") |
| print(f"{log_path}") |
| if FOR_upcycle_ckpt_GEN_or_USE: |
| self.verify_approximation(orig_ffn_list=l_ffn) |
|
|
| def forward(self, x: torch.Tensor, token_pos_grid__cur=None): |
| t = global_.task |
| |
| if self.is_glu: |
| base = self.shared_ffn.net[0].proj(x) |
| delta = self.task_lora_in[t](x) |
| z = base + delta |
| v, gate = z.chunk(2, dim=-1) |
| h = v * F.gelu(gate) |
| else: |
| base = self.shared_ffn.net[0][0](x) |
| delta = self.task_lora_in[t](x) |
| h = F.gelu(base + delta) |
| h = self.dropout(h) |
| |
| y_base = self.shared_ffn.net[2](h) |
| y_delta = self.task_lora_out[t](h) |
| y = y_base + y_delta |
| if EXTRA_MoE_enable: |
| |
| gate_in, _ = build_ffn_gate_input_common(x, token_pos_grid__cur, self.tasks) |
| scores = self.moe_gate_mlp(gate_in).to(dtype=x.dtype) |
| if EXTRA_MoE_add_noise and self.training: |
| scores = scores + torch.randn_like(scores) * EXTRA_MoE_noise_std |
| v_topk, idx_topk = scores.topk(k=EXTRA_MoE_topK, dim=-1) |
|
|
| if EXTRA_MoE_routing_mode == 'dense': |
| raise Exception('not carefully checked yet') |
| else: |
| if 1: weights_topk = torch.softmax(v_topk, dim=-1) |
| else: weights_topk = v_topk |
| b, n, d = x.shape |
| dim_out = self.out_features |
| y_moe_flat = x.new_zeros(b*n, dim_out) |
| x_flat = x.reshape(b*n, d) |
| unique_experts = torch.unique(idx_topk) |
| for j in unique_experts.tolist(): |
| mask_j = (idx_topk == j) |
| sel_token_mask = mask_j.any(dim=-1) |
| if not sel_token_mask.any(): |
| continue |
| flat_pos = sel_token_mask.view(-1).nonzero(as_tuple=False).squeeze(1) |
| x_sel = x_flat.index_select(0, flat_pos) |
| |
| y_sel = self.moe_experts_list[j](x_sel.view(1, x_sel.shape[0], d)).squeeze(0) |
| w_tok = (weights_topk * mask_j).sum(dim=-1).view(-1).index_select(0, flat_pos).unsqueeze(-1) |
| y_moe_flat.index_add_(0, flat_pos, w_tok * y_sel) |
| y = y + y_moe_flat.view(b, n, dim_out) |
| if EXTRA_MoE_en_auxLoss and self.training: |
| raise Exception('not carefully checked yet') |
| importance = torch.zeros(self.num_moe_expert, device=scores.device, dtype=weights_topk.dtype) |
| importance = importance.scatter_add(0, idx_topk.reshape(-1), weights_topk.reshape(-1)) |
| load = torch.zeros(self.num_moe_expert, device=scores.device, dtype=weights_topk.dtype) |
| load = load.scatter_add(0, idx_topk.reshape(-1), torch.ones_like(weights_topk.reshape(-1))) |
| k = importance.shape[0] |
| target_imp = torch.full_like(importance, fill_value=importance.sum() / k) |
| target_load = torch.full_like(load, fill_value=load.sum() / k) |
| aux_imp = F.mse_loss(importance, target_imp) |
| aux_load = F.mse_loss(load, target_load) |
| aux = 0.5 * (aux_imp + aux_load) * EXTRA_MoE_aux_coef |
| global_.moe_aux_loss = aux |
| return y |
|
|
| @torch.no_grad() |
| def verify_approximation(self, num_tokens: int = 16, batch_size: int = 2, orig_ffn_list: list = None): |
| if EXTRA_MoE_enable: return |
| device = next(self.shared_ffn.parameters()).device |
| dtype = next(self.shared_ffn.parameters()).dtype |
| x = torch.randn(batch_size, num_tokens, self.in_features, device=device, dtype=dtype) |
| old_task = getattr(global_, 'task', None) |
| for i,t in enumerate(self.tasks): |
| _log2(orig_ffn_list[i], [self.task_lora_in[t], self.task_lora_out[t]]) |
| global_.task = t |
| y_lora = self.forward(x) |
| y_avg = self.shared_ffn(x) |
| assert orig_ffn_list is not None, "orig_ffn_list must be provided for verification" |
| y_orig = orig_ffn_list[i](x) |
| d_avg = torch.norm((y_avg - y_orig).float()).item() |
| d_lora = torch.norm((y_lora - y_orig).float()).item() |
| _log1(f"[FFN verify] task={t} rank_in={self.task_lora_in[t].rank} rank_out={self.task_lora_out[t].rank} L2(avg,orig)={d_avg:.6f} L2(lora,orig)={d_lora:.6f}") |
| global_.task = old_task |
|
|
|
|
| class Linear_Shared_Plus_TaskLoRA(nn.Module): |
| def __init__(self, l_proj: list, l_task: list, module_name: str = None): |
| super().__init__() |
| |
| assert len(l_proj) >= 1 |
| self.dont_lora = False |
| p0: nn.Linear = l_proj[0] |
| |
| self.shared: nn.Linear = copy.deepcopy(p0) |
| if FOR_upcycle_ckpt_GEN_or_USE: |
| avg_sd = _average_state_dict(l_proj) |
| self.shared.load_state_dict(avg_sd) |
| for p in self.shared.parameters(): |
| p.requires_grad = False |
| self.in_features = self.shared.in_features |
| self.out_features = self.shared.out_features |
| self.tasks = l_task |
| |
| if 1: |
| if FOR_upcycle_ckpt_GEN_or_USE: |
| w_diff_list = [] |
| for lin in l_proj: |
| w_diff_list.append(lin.weight.data - self.shared.weight.data) |
| if FORCE_SAME_RANK_ACROSS_TASKS: |
| rank_lin = compute_adaptive_rank_for_linear_diffs(w_diff_list) |
| global_.moduleName_2_adaRank[module_name] = rank_lin |
| else: |
| ranks_lin = compute_adaptive_rank_for_linear_diffs(w_diff_list, per_task=True) |
| global_.moduleName_2_adaRank[module_name] = ranks_lin |
| else: |
| r_info = global_.moduleName_2_adaRank[module_name] |
| if FORCE_SAME_RANK_ACROSS_TASKS: rank_lin = r_info |
| else: ranks_lin = r_info |
| if 1: |
| min_dim = min(self.in_features, self.out_features) |
| if min_dim < DONT_lora_if_dim_lt: |
| |
| self.dont_lora = True; return |
| _l = [] |
| for idx, lin in enumerate(l_proj): |
| if not FORCE_SAME_RANK_ACROSS_TASKS: |
| rank_lin = ranks_lin[idx] |
| frac = float(rank_lin) / min(self.in_features, self.out_features) |
| if frac > DONT_lora_if_rankFrac_gt: |
| lora = ResidualAdapterLinearOnly(self.in_features, self.out_features, scaling=1.0, use_bias_delta=True) |
| else: |
| lora = LoRAAdapterLinearOnly(self.in_features, self.out_features, rank=rank_lin, dropout=0.0, scaling=1.0) |
| if FOR_upcycle_ckpt_GEN_or_USE: |
| with torch.no_grad(): |
| w_diff = lin.weight.data - self.shared.weight.data |
| b_diff = (lin.bias.data - self.shared.bias.data) if (lin.bias is not None and self.shared.bias is not None) else None |
| lora.init_from_diff(w_diff, b_diff) |
| _l.append(lora) |
| self.task_lora = ModuleDict_W(_l, self.tasks) |
| if FOR_upcycle_ckpt_GEN_or_USE: |
| self.verify_approximation(orig_linear_list=l_proj) |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| y = self.shared(x) |
| y = y + self.task_lora[global_.task](x) |
| return y |
| @torch.no_grad() |
| def verify_approximation(self, batch_size: int = 2, in_dim_override: int = None, orig_linear_list: list = None): |
| device = next(self.shared.parameters()).device |
| dtype = next(self.shared.parameters()).dtype |
| d_in = self.in_features if in_dim_override is None else in_dim_override |
| x = torch.randn(batch_size, d_in, device=device, dtype=dtype) |
| old_task = getattr(global_, 'task', None) |
| for i,t in enumerate(self.tasks): |
| _log2(orig_linear_list[i], self.task_lora[t]) |
| global_.task = t |
| y_lora = self.forward(x) |
| y_avg = self.shared(x) |
| assert orig_linear_list is not None, "orig_linear_list must be provided for verification" |
| y_orig = orig_linear_list[i](x) |
| d_avg = torch.norm((y_avg - y_orig).float()).item() |
| d_lora = torch.norm((y_lora - y_orig).float()).item() |
| _log1(f"[Linear verify] task={t} rank={self.task_lora[t].rank} L2(avg,orig)={d_avg:.6f} L2(lora,orig)={d_lora:.6f}") |
| global_.task = old_task |
|
|
| class Conv_Shared_Plus_TaskLoRA(nn.Module): |
| def __init__(self, l_conv: list, l_task: list, module_name: str = None): |
| super().__init__() |
| |
| assert len(l_conv) >= 1 |
| self.dont_lora = False |
| c0: nn.Conv2d = l_conv[0] |
| |
| self.shared = nn.Conv2d( |
| c0.in_channels, c0.out_channels, |
| kernel_size=c0.kernel_size, stride=c0.stride, |
| padding=c0.padding, dilation=c0.dilation, |
| groups=c0.groups, bias=(c0.bias is not None), |
| padding_mode=c0.padding_mode, |
| ) |
| if FOR_upcycle_ckpt_GEN_or_USE: |
| avg_sd = _average_state_dict(l_conv) |
| self.shared.load_state_dict(avg_sd) |
| for p in self.shared.parameters(): |
| p.requires_grad = False |
| |
| self.tasks = l_task |
| _l = [] |
| |
| if 1: |
| if FOR_upcycle_ckpt_GEN_or_USE: |
| w_diff_list = [] |
| for c in l_conv: |
| w_diff_list.append(c.weight.data - self.shared.weight.data) |
| if FORCE_SAME_RANK_ACROSS_TASKS: |
| rank_conv = compute_adaptive_rank_for_conv_diffs(w_diff_list) |
| global_.moduleName_2_adaRank[module_name] = rank_conv |
| else: |
| ranks_conv = compute_adaptive_rank_for_conv_diffs(w_diff_list, per_task=True) |
| global_.moduleName_2_adaRank[module_name] = ranks_conv |
| else: |
| r_info = global_.moduleName_2_adaRank[module_name] |
| if FORCE_SAME_RANK_ACROSS_TASKS: rank_conv = r_info |
| else: ranks_conv = r_info |
| if 1: |
| kH, kW = self.shared.kernel_size |
| min_dim = min(self.shared.out_channels, self.shared.in_channels * kH * kW ) |
| if min_dim < DONT_lora_if_dim_lt: |
| |
| self.dont_lora = True; return |
| for idx, c in enumerate(l_conv): |
| if not FORCE_SAME_RANK_ACROSS_TASKS: |
| rank_conv = ranks_conv[idx] |
| frac = float(rank_conv) / min(self.shared.out_channels, self.shared.in_channels * kH * kW) |
| if frac > DONT_lora_if_rankFrac_gt: |
| lora = ResidualAdapterConv2dOnly( |
| in_channels=c.in_channels, out_channels=c.out_channels, |
| kernel_size=c.kernel_size, stride=c.stride, |
| padding=c.padding, dilation=c.dilation, groups=c.groups, |
| scaling=1.0, use_bias_delta=True, |
| ) |
| else: |
| lora = LoRAAdapterConv2dOnly( |
| in_channels=c.in_channels, out_channels=c.out_channels, |
| kernel_size=c.kernel_size, stride=c.stride, |
| padding=c.padding, dilation=c.dilation, groups=c.groups, |
| rank=rank_conv, dropout=0.0, scaling=1.0, |
| ) |
| if FOR_upcycle_ckpt_GEN_or_USE: |
| with torch.no_grad(): |
| w_diff = c.weight.data - self.shared.weight.data |
| b_diff = (c.bias.data - self.shared.bias.data) if c.bias is not None and self.shared.bias is not None else None |
| lora.init_from_diff(w_diff, b_diff) |
| _l.append(lora) |
| self.task_lora = ModuleDict_W(_l, self.tasks) |
|
|
| if FOR_upcycle_ckpt_GEN_or_USE: |
| self.verify_approximation(orig_conv_list=l_conv) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| y = self.shared(x) |
| y = y + self.task_lora[global_.task](x) |
| return y |
|
|
| @torch.no_grad() |
| def verify_approximation(self, spatial_hw=(32, 32), batch_size: int = 2, orig_conv_list: list = None): |
| device = next(self.shared.parameters()).device |
| dtype = next(self.shared.parameters()).dtype |
| H, W = spatial_hw |
| x = torch.randn(batch_size, self.shared.in_channels, H, W, device=device, dtype=dtype) |
| old_task = getattr(global_, 'task', None) |
| for i,t in enumerate(self.tasks): |
| _log2(orig_conv_list[i], self.task_lora[t]) |
| global_.task = t |
| y_lora = self.forward(x) |
| y_avg = self.shared(x) |
| assert orig_conv_list is not None, "orig_conv_list must be provided for verification" |
| y_orig = orig_conv_list[i](x) |
| d_avg = torch.norm((y_avg - y_orig).float()).item() |
| d_lora = torch.norm((y_lora - y_orig).float()).item() |
| _log1(f"[Conv2d verify] task={t} rank={self.task_lora[t].rank} L2(avg,orig)={d_avg:.6f} L2(lora,orig)={d_lora:.6f}") |
| global_.task = old_task |
|
|