UniBioTransfer / MoE.py
scy639's picture
Upload folder using huggingface_hub
2b534de verified
from confs import *
import global_
import torch,copy
import torch.nn as nn
from ldm.modules.attention import FeedForward,CrossAttention
from ldm.modules.diffusionmodules.openaimodel import UNetModel,ResBlock,TimestepEmbedSequential
# import torch.nn.functional as F
# ---------------- Configs ----------------
CONV2D_PARAM_STATS = []
def average_module_weight(src_modules: list):
"""Average the weights of multiple modules (similar to init_model.py)."""
if not src_modules:
return None
avg_state_dict = {}
first_state_dict = src_modules[0].state_dict()
for key in first_state_dict:
avg_state_dict[key] = torch.zeros_like(first_state_dict[key])
for module in src_modules:
module_state_dict = module.state_dict()
for key in avg_state_dict:
avg_state_dict[key] += module_state_dict[key]
for key in avg_state_dict:
avg_state_dict[key] /= len(src_modules)
return avg_state_dict
class ModuleDict_W(nn.Module): # Wrapper of ModuleDict
def __init__(self, modules: list, keys: list):
super().__init__()
assert len(keys) == len(modules), f"{len(keys)=} {len(modules)=}"
self._keys = [int(k) for k in keys]
self._moduleDict = nn.ModuleDict({str(int(k)): m for k, m in zip(self._keys, modules)})
def __getitem__(self, k: int):
_k = str(int(k))
return self._moduleDict[_k]
def keys(self):
return list(self._keys)
def forward(self, *args, **kwargs):
cur_task = global_.task
assert cur_task in self._keys, f"Current task {cur_task} not in available tasks {self._keys}"
return self._moduleDict[str(int(cur_task))](*args, **kwargs)
def offload_unused_tasks(self, unused_tasks, method: str):
for i in unused_tasks:
_k = str(int(i))
if _k in self._moduleDict:
if method == 'del':
# self._moduleDict[_k] = None # should behave the same either way
del self._moduleDict[_k]
elif method == 'cpu':
self._moduleDict[_k].to('cpu')
else:
raise
class TaskSpecific_MoE(nn.Module):
def __init__(
self,
module:nn.Module,# or list of Module
tasks:tuple,
):
super().__init__()
self.cur_task = None
self.tasks = tasks
if isinstance(module, nn.Module):
modules = [copy.deepcopy(module) for _ in self.tasks]
elif isinstance(module, list):
assert len(module) == len(self.tasks), f"got {len(module)} and {len(self.tasks)}"
modules = module
else:
raise ValueError(f"got {type(module)}")
self.tasks_2_module = ModuleDict_W(modules, self.tasks)
def forward(self, *args, **kwargs) -> torch.Tensor:
# cur_task = self.cur_task
cur_task = global_.task
assert cur_task in self.tasks, f"Current task {cur_task} not in available tasks {self.tasks}"
return self.tasks_2_module[cur_task](*args, **kwargs)
def set_task(self, task):
assert 0, 'set_task is disabled for now; update to gg.task instead'
# assert task in self.tasks, f"Task {task} not in available tasks {self.tasks}"
self.cur_task = task
def is_task_specific_(name:str):
is_task_specific = (
('._moduleDict.' in name) or
('tasks_2_module' in name) or
('task_ffn' in name) or
('task_proj' in name) or
('task_conv' in name) or
('task_gate_mlps' in name) or
('task_lora' in name) or
('encoder_clip_' in name) or
('proj_out_source__' in name) or
('ID_proj_out' in name) or
('landmark_proj_out' in name) or
('learnable_vector' in name)
)
return is_task_specific
def tp_param_need_sync(name: str, p: torch.nn.Parameter):
if is_task_specific_(name):
return False, True
if 'first_stage_model' in name or 'face_ID_model' in name or 'encoder_clip_face.tokenizer' in name or 'encoder_clip_face.model' in name:
return False, False
if not p.requires_grad:
return False, False
return True, False
def offload_unused_tasks(parent: nn.Module, active_task: int, method: str, ):
unused_tasks = [_t for _t in TASKS if _t != active_task] # inactive tasks
for name, child in parent.named_children():
if hasattr(child, '__class__') and child.__class__.__name__ in [
'TaskSpecific_MoE',
'FFN_TaskSpecific_Plus_Shared',
'Linear_TaskSpecific_Plus_Shared',
'Conv_TaskSpecific_Plus_Shared',
'FFN_Shared_Plus_TaskLoRA',
'Linear_Shared_Plus_TaskLoRA',
'Conv_Shared_Plus_TaskLoRA',
]:
for attr_name in [ # normalize attribute handling to avoid repetition
'tasks_2_module',
'task_ffn', 'task_proj', 'task_conv',
'task_lora_in', 'task_lora_out', 'task_lora',
]:
if hasattr(child, attr_name):
ml = getattr(child, attr_name)
if isinstance(ml, nn.ModuleList):
for i in unused_tasks: # move or delete parameters for inactive tasks
if method == 'del':
ml[i] = None
elif method == 'cpu':
ml[i].to('cpu')
else: raise Exception
elif isinstance(ml, ModuleDict_W):
ml.offload_unused_tasks(unused_tasks,method)
# recurse(child)
else: offload_unused_tasks(child,active_task,method)
def offload_unused_tasks__LD(modelMOE, task_keep: int, method: str, ):
# Remove or offload inactive task-related parameters to save CUDA memory (method: del|cpu)
offload_unused_tasks(modelMOE, task_keep, method)