File size: 5,960 Bytes
2b534de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from confs import *
import global_
import torch,copy
import torch.nn as nn
from ldm.modules.attention import FeedForward,CrossAttention
from ldm.modules.diffusionmodules.openaimodel import UNetModel,ResBlock,TimestepEmbedSequential
# import torch.nn.functional as F

# ---------------- Configs ----------------
CONV2D_PARAM_STATS = []

def average_module_weight(src_modules: list):
    """Average the weights of multiple modules (similar to init_model.py)."""
    if not src_modules:
        return None
    avg_state_dict = {}
    first_state_dict = src_modules[0].state_dict()
    for key in first_state_dict:
        avg_state_dict[key] = torch.zeros_like(first_state_dict[key])
    for module in src_modules:
        module_state_dict = module.state_dict()
        for key in avg_state_dict:
            avg_state_dict[key] += module_state_dict[key]
    for key in avg_state_dict:
        avg_state_dict[key] /= len(src_modules)
    return avg_state_dict

class ModuleDict_W(nn.Module): # Wrapper of ModuleDict
    def __init__(self, modules: list, keys: list):
        super().__init__()
        assert len(keys) == len(modules), f"{len(keys)=} {len(modules)=}"
        self._keys = [int(k) for k in keys]
        self._moduleDict = nn.ModuleDict({str(int(k)): m for k, m in zip(self._keys, modules)})
    def __getitem__(self, k: int):
        _k = str(int(k))
        return self._moduleDict[_k]
    def keys(self):
        return list(self._keys)
    def forward(self, *args, **kwargs):
        cur_task = global_.task
        assert cur_task in self._keys, f"Current task {cur_task} not in available tasks {self._keys}"
        return self._moduleDict[str(int(cur_task))](*args, **kwargs)
    def offload_unused_tasks(self, unused_tasks, method: str):
        for i in unused_tasks:
            _k = str(int(i))
            if _k in self._moduleDict:
                if method == 'del':
                    # self._moduleDict[_k] = None # should behave the same either way
                    del self._moduleDict[_k]
                elif method == 'cpu':
                    self._moduleDict[_k].to('cpu')
                else:
                    raise

class TaskSpecific_MoE(nn.Module):
    def __init__(
        self,
        module:nn.Module,# or list of Module
        tasks:tuple,
    ):
        super().__init__()
        self.cur_task = None
        self.tasks = tasks
        if isinstance(module, nn.Module):
            modules = [copy.deepcopy(module) for _ in self.tasks]
        elif isinstance(module, list):
            assert len(module) == len(self.tasks), f"got {len(module)} and {len(self.tasks)}"
            modules = module
        else:
            raise ValueError(f"got {type(module)}")
        self.tasks_2_module = ModuleDict_W(modules, self.tasks)
        
    def forward(self, *args, **kwargs) -> torch.Tensor:
        # cur_task = self.cur_task
        cur_task = global_.task
        assert cur_task in self.tasks, f"Current task {cur_task} not in available tasks {self.tasks}"
        return self.tasks_2_module[cur_task](*args, **kwargs)

    def set_task(self, task):
        assert 0, 'set_task is disabled for now; update to gg.task instead'
        # assert task in self.tasks, f"Task {task} not in available tasks {self.tasks}"
        self.cur_task = task

def is_task_specific_(name:str):
    is_task_specific = (
        ('._moduleDict.' in name) or
        ('tasks_2_module' in name) or
        ('task_ffn' in name) or
        ('task_proj' in name) or
        ('task_conv' in name) or
        ('task_gate_mlps' in name) or
        ('task_lora' in name) or
        
        ('encoder_clip_' in name) or
        ('proj_out_source__' in name) or
        ('ID_proj_out' in name) or
        ('landmark_proj_out' in name) or
        ('learnable_vector' in name)
    )
    return is_task_specific
def tp_param_need_sync(name: str, p: torch.nn.Parameter):
    if is_task_specific_(name):
        return False, True
    if 'first_stage_model' in name or 'face_ID_model' in name or 'encoder_clip_face.tokenizer' in name or 'encoder_clip_face.model' in name:
        return False, False
    if not p.requires_grad:
        return False, False
    return True, False
def offload_unused_tasks(parent: nn.Module, active_task: int, method: str, ):
    unused_tasks = [_t for _t in TASKS if _t != active_task] # inactive tasks
    for name, child in parent.named_children():
        if hasattr(child, '__class__') and child.__class__.__name__ in [
            'TaskSpecific_MoE',
            'FFN_TaskSpecific_Plus_Shared',
            'Linear_TaskSpecific_Plus_Shared',
            'Conv_TaskSpecific_Plus_Shared',
            'FFN_Shared_Plus_TaskLoRA',
            'Linear_Shared_Plus_TaskLoRA',
            'Conv_Shared_Plus_TaskLoRA',
        ]:
            for attr_name in [ # normalize attribute handling to avoid repetition
                'tasks_2_module',
                'task_ffn', 'task_proj', 'task_conv',
                'task_lora_in', 'task_lora_out', 'task_lora',
            ]:
                if hasattr(child, attr_name):
                    ml = getattr(child, attr_name)
                    if isinstance(ml, nn.ModuleList):
                        for i in unused_tasks:  # move or delete parameters for inactive tasks
                            if method == 'del':
                                ml[i] = None
                            elif method == 'cpu':
                                ml[i].to('cpu')
                            else:  raise Exception
                    elif isinstance(ml, ModuleDict_W):
                        ml.offload_unused_tasks(unused_tasks,method)
            # recurse(child)
        else:  offload_unused_tasks(child,active_task,method)
def offload_unused_tasks__LD(modelMOE, task_keep: int, method: str, ):
    # Remove or offload inactive task-related parameters to save CUDA memory (method: del|cpu)
    offload_unused_tasks(modelMOE, task_keep, method)