File size: 5,960 Bytes
2b534de | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 | from confs import *
import global_
import torch,copy
import torch.nn as nn
from ldm.modules.attention import FeedForward,CrossAttention
from ldm.modules.diffusionmodules.openaimodel import UNetModel,ResBlock,TimestepEmbedSequential
# import torch.nn.functional as F
# ---------------- Configs ----------------
CONV2D_PARAM_STATS = []
def average_module_weight(src_modules: list):
"""Average the weights of multiple modules (similar to init_model.py)."""
if not src_modules:
return None
avg_state_dict = {}
first_state_dict = src_modules[0].state_dict()
for key in first_state_dict:
avg_state_dict[key] = torch.zeros_like(first_state_dict[key])
for module in src_modules:
module_state_dict = module.state_dict()
for key in avg_state_dict:
avg_state_dict[key] += module_state_dict[key]
for key in avg_state_dict:
avg_state_dict[key] /= len(src_modules)
return avg_state_dict
class ModuleDict_W(nn.Module): # Wrapper of ModuleDict
def __init__(self, modules: list, keys: list):
super().__init__()
assert len(keys) == len(modules), f"{len(keys)=} {len(modules)=}"
self._keys = [int(k) for k in keys]
self._moduleDict = nn.ModuleDict({str(int(k)): m for k, m in zip(self._keys, modules)})
def __getitem__(self, k: int):
_k = str(int(k))
return self._moduleDict[_k]
def keys(self):
return list(self._keys)
def forward(self, *args, **kwargs):
cur_task = global_.task
assert cur_task in self._keys, f"Current task {cur_task} not in available tasks {self._keys}"
return self._moduleDict[str(int(cur_task))](*args, **kwargs)
def offload_unused_tasks(self, unused_tasks, method: str):
for i in unused_tasks:
_k = str(int(i))
if _k in self._moduleDict:
if method == 'del':
# self._moduleDict[_k] = None # should behave the same either way
del self._moduleDict[_k]
elif method == 'cpu':
self._moduleDict[_k].to('cpu')
else:
raise
class TaskSpecific_MoE(nn.Module):
def __init__(
self,
module:nn.Module,# or list of Module
tasks:tuple,
):
super().__init__()
self.cur_task = None
self.tasks = tasks
if isinstance(module, nn.Module):
modules = [copy.deepcopy(module) for _ in self.tasks]
elif isinstance(module, list):
assert len(module) == len(self.tasks), f"got {len(module)} and {len(self.tasks)}"
modules = module
else:
raise ValueError(f"got {type(module)}")
self.tasks_2_module = ModuleDict_W(modules, self.tasks)
def forward(self, *args, **kwargs) -> torch.Tensor:
# cur_task = self.cur_task
cur_task = global_.task
assert cur_task in self.tasks, f"Current task {cur_task} not in available tasks {self.tasks}"
return self.tasks_2_module[cur_task](*args, **kwargs)
def set_task(self, task):
assert 0, 'set_task is disabled for now; update to gg.task instead'
# assert task in self.tasks, f"Task {task} not in available tasks {self.tasks}"
self.cur_task = task
def is_task_specific_(name:str):
is_task_specific = (
('._moduleDict.' in name) or
('tasks_2_module' in name) or
('task_ffn' in name) or
('task_proj' in name) or
('task_conv' in name) or
('task_gate_mlps' in name) or
('task_lora' in name) or
('encoder_clip_' in name) or
('proj_out_source__' in name) or
('ID_proj_out' in name) or
('landmark_proj_out' in name) or
('learnable_vector' in name)
)
return is_task_specific
def tp_param_need_sync(name: str, p: torch.nn.Parameter):
if is_task_specific_(name):
return False, True
if 'first_stage_model' in name or 'face_ID_model' in name or 'encoder_clip_face.tokenizer' in name or 'encoder_clip_face.model' in name:
return False, False
if not p.requires_grad:
return False, False
return True, False
def offload_unused_tasks(parent: nn.Module, active_task: int, method: str, ):
unused_tasks = [_t for _t in TASKS if _t != active_task] # inactive tasks
for name, child in parent.named_children():
if hasattr(child, '__class__') and child.__class__.__name__ in [
'TaskSpecific_MoE',
'FFN_TaskSpecific_Plus_Shared',
'Linear_TaskSpecific_Plus_Shared',
'Conv_TaskSpecific_Plus_Shared',
'FFN_Shared_Plus_TaskLoRA',
'Linear_Shared_Plus_TaskLoRA',
'Conv_Shared_Plus_TaskLoRA',
]:
for attr_name in [ # normalize attribute handling to avoid repetition
'tasks_2_module',
'task_ffn', 'task_proj', 'task_conv',
'task_lora_in', 'task_lora_out', 'task_lora',
]:
if hasattr(child, attr_name):
ml = getattr(child, attr_name)
if isinstance(ml, nn.ModuleList):
for i in unused_tasks: # move or delete parameters for inactive tasks
if method == 'del':
ml[i] = None
elif method == 'cpu':
ml[i].to('cpu')
else: raise Exception
elif isinstance(ml, ModuleDict_W):
ml.offload_unused_tasks(unused_tasks,method)
# recurse(child)
else: offload_unused_tasks(child,active_task,method)
def offload_unused_tasks__LD(modelMOE, task_keep: int, method: str, ):
# Remove or offload inactive task-related parameters to save CUDA memory (method: del|cpu)
offload_unused_tasks(modelMOE, task_keep, method)
|