| from confs import * |
| import global_ |
| import torch,copy |
| import torch.nn as nn |
| from ldm.modules.attention import FeedForward,CrossAttention |
| from ldm.modules.diffusionmodules.openaimodel import UNetModel,ResBlock,TimestepEmbedSequential |
| |
|
|
| |
| CONV2D_PARAM_STATS = [] |
|
|
| def average_module_weight(src_modules: list): |
| """Average the weights of multiple modules (similar to init_model.py).""" |
| if not src_modules: |
| return None |
| avg_state_dict = {} |
| first_state_dict = src_modules[0].state_dict() |
| for key in first_state_dict: |
| avg_state_dict[key] = torch.zeros_like(first_state_dict[key]) |
| for module in src_modules: |
| module_state_dict = module.state_dict() |
| for key in avg_state_dict: |
| avg_state_dict[key] += module_state_dict[key] |
| for key in avg_state_dict: |
| avg_state_dict[key] /= len(src_modules) |
| return avg_state_dict |
|
|
| class ModuleDict_W(nn.Module): |
| def __init__(self, modules: list, keys: list): |
| super().__init__() |
| assert len(keys) == len(modules), f"{len(keys)=} {len(modules)=}" |
| self._keys = [int(k) for k in keys] |
| self._moduleDict = nn.ModuleDict({str(int(k)): m for k, m in zip(self._keys, modules)}) |
| def __getitem__(self, k: int): |
| _k = str(int(k)) |
| return self._moduleDict[_k] |
| def keys(self): |
| return list(self._keys) |
| def forward(self, *args, **kwargs): |
| cur_task = global_.task |
| assert cur_task in self._keys, f"Current task {cur_task} not in available tasks {self._keys}" |
| return self._moduleDict[str(int(cur_task))](*args, **kwargs) |
| def offload_unused_tasks(self, unused_tasks, method: str): |
| for i in unused_tasks: |
| _k = str(int(i)) |
| if _k in self._moduleDict: |
| if method == 'del': |
| |
| del self._moduleDict[_k] |
| elif method == 'cpu': |
| self._moduleDict[_k].to('cpu') |
| else: |
| raise |
|
|
| class TaskSpecific_MoE(nn.Module): |
| def __init__( |
| self, |
| module:nn.Module, |
| tasks:tuple, |
| ): |
| super().__init__() |
| self.cur_task = None |
| self.tasks = tasks |
| if isinstance(module, nn.Module): |
| modules = [copy.deepcopy(module) for _ in self.tasks] |
| elif isinstance(module, list): |
| assert len(module) == len(self.tasks), f"got {len(module)} and {len(self.tasks)}" |
| modules = module |
| else: |
| raise ValueError(f"got {type(module)}") |
| self.tasks_2_module = ModuleDict_W(modules, self.tasks) |
| |
| def forward(self, *args, **kwargs) -> torch.Tensor: |
| |
| cur_task = global_.task |
| assert cur_task in self.tasks, f"Current task {cur_task} not in available tasks {self.tasks}" |
| return self.tasks_2_module[cur_task](*args, **kwargs) |
|
|
| def set_task(self, task): |
| assert 0, 'set_task is disabled for now; update to gg.task instead' |
| |
| self.cur_task = task |
|
|
| def is_task_specific_(name:str): |
| is_task_specific = ( |
| ('._moduleDict.' in name) or |
| ('tasks_2_module' in name) or |
| ('task_ffn' in name) or |
| ('task_proj' in name) or |
| ('task_conv' in name) or |
| ('task_gate_mlps' in name) or |
| ('task_lora' in name) or |
| |
| ('encoder_clip_' in name) or |
| ('proj_out_source__' in name) or |
| ('ID_proj_out' in name) or |
| ('landmark_proj_out' in name) or |
| ('learnable_vector' in name) |
| ) |
| return is_task_specific |
| def tp_param_need_sync(name: str, p: torch.nn.Parameter): |
| if is_task_specific_(name): |
| return False, True |
| if 'first_stage_model' in name or 'face_ID_model' in name or 'encoder_clip_face.tokenizer' in name or 'encoder_clip_face.model' in name: |
| return False, False |
| if not p.requires_grad: |
| return False, False |
| return True, False |
| def offload_unused_tasks(parent: nn.Module, active_task: int, method: str, ): |
| unused_tasks = [_t for _t in TASKS if _t != active_task] |
| for name, child in parent.named_children(): |
| if hasattr(child, '__class__') and child.__class__.__name__ in [ |
| 'TaskSpecific_MoE', |
| 'FFN_TaskSpecific_Plus_Shared', |
| 'Linear_TaskSpecific_Plus_Shared', |
| 'Conv_TaskSpecific_Plus_Shared', |
| 'FFN_Shared_Plus_TaskLoRA', |
| 'Linear_Shared_Plus_TaskLoRA', |
| 'Conv_Shared_Plus_TaskLoRA', |
| ]: |
| for attr_name in [ |
| 'tasks_2_module', |
| 'task_ffn', 'task_proj', 'task_conv', |
| 'task_lora_in', 'task_lora_out', 'task_lora', |
| ]: |
| if hasattr(child, attr_name): |
| ml = getattr(child, attr_name) |
| if isinstance(ml, nn.ModuleList): |
| for i in unused_tasks: |
| if method == 'del': |
| ml[i] = None |
| elif method == 'cpu': |
| ml[i].to('cpu') |
| else: raise Exception |
| elif isinstance(ml, ModuleDict_W): |
| ml.offload_unused_tasks(unused_tasks,method) |
| |
| else: offload_unused_tasks(child,active_task,method) |
| def offload_unused_tasks__LD(modelMOE, task_keep: int, method: str, ): |
| |
| offload_unused_tasks(modelMOE, task_keep, method) |
|
|