File size: 7,431 Bytes
2b534de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import sys,os
cur_dir = os.path.dirname(os.path.abspath(__file__))
if __name__=='__main__': sys.path.append(os.path.abspath(os.path.join(cur_dir, '..')))

from confs import *
import json
import argparse, os, sys, glob
import cv2
import torch
import numpy as np
from MoE import *
from multiTask_model import *
from lora_layers import *
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from itertools import islice
from einops import rearrange
from torchvision.utils import make_grid
from my_py_lib.image_util import imgs_2_grid_A,img_paths_2_grid_A
import time
import copy
from pytorch_lightning import seed_everything
from torch import autocast
from contextlib import contextmanager, nullcontext
import torchvision
from ldm.models.diffusion.ddpm import LatentDiffusion
from ldm.models.diffusion.bank import Bank
from ldm.util import instantiate_from_config

from ldm.models.diffusion.ddim import DDIMSampler

from transformers import AutoFeatureExtractor

# import clip
from torchvision.transforms import Resize
from fnmatch import fnmatch


from PIL import Image
from torchvision.transforms import PILToTensor
#----------------------------------------------------------------------------


def get_moe():
    if 1:
        seed_everything(42)
        # torch.cuda.set_device(opt.device_ID)
        model :LatentDiffusion = instantiate_from_config(OmegaConf.load(f"LatentDiffusion.yaml").model,)
        if REFNET.ENABLE:
            assert model.model.diffusion_model_refNet.is_refNet

        device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        device = torch.device("cpu")
        model = model.to(device)
    if FOR_upcycle_ckpt_GEN_or_USE:
        del model.ptsM_Generator

    def average_module_weight(
        src_modules: list,
    ):
        """Average the weights of multiple modules"""
        if not src_modules:
            return None
        # Get the state dict of the first module as template
        avg_state_dict = {}
        first_state_dict = src_modules[0].state_dict()
        # Initialize with zeros
        for key in first_state_dict:
            avg_state_dict[key] = torch.zeros_like(first_state_dict[key])
        # Sum
        for module in src_modules:
            module_state_dict = module.state_dict()
            for key in avg_state_dict:
                avg_state_dict[key] += module_state_dict[key]
        # Average
        for key in avg_state_dict:
            avg_state_dict[key] /= len(src_modules)
        return avg_state_dict
    def recursive_average_module_weight(
        tgt_module: nn.Module,
        src_modules: list,
        cb,
    ):
        """
        Recursively find modules and replace with averaged weights based on callback
        """
        for name, child in tgt_module.named_children():
            if 1:    # Get corresponding modules from source models
                src_child_modules = []
                for src_module in src_modules:
                    src_child = getattr(src_module, name)
                    assert src_child is not None,name
                    src_child_modules.append(src_child)
            # assert not isinstance(child, TaskSpecific_MoE)
            if cb(child, name, tgt_module):
                print(f"[recursive_average_module_weight] {name=} child: {repr(child)[:50]} tgt_module: {repr(tgt_module)[:50]}")
                # Average & load
                avg_weights = average_module_weight(src_child_modules)
                child.load_state_dict(avg_weights)
            else:
                recursive_average_module_weight(child, src_child_modules, cb)
        return tgt_module
    
    def replace_module_with_TaskSpecific(
        tgt_module: nn.Module,# tgt module
        src_modules: list,
        cb,
        parent_name: str = "",
        depth :int = 0,
    ):
        for name, child in tgt_module.named_children():
            if 1:   # Get corresponding modules from source models
                src_child_modules = []
                for src_module in src_modules:
                    src_child = getattr(src_module, name)
                    assert src_child is not None,name
                    src_child_modules.append(src_child)
            assert not isinstance(child, TaskSpecific_MoE)
            full_name = f"{parent_name}.{name}"
            if cb(child, name, full_name, tgt_module):
                print(f"[replace_module_with_TaskSpecific] {name=} child: {repr(child)[:50]} tgt_module: {repr(tgt_module)[:50]}")
                setattr(tgt_module, name, TaskSpecific_MoE(src_child_modules,TASKS))
            else:
                if depth<=0:
                    replace_module_with_TaskSpecific(child, src_child_modules,cb,parent_name=full_name,depth=depth+1)
        return tgt_module
    
    if not FOR_upcycle_ckpt_GEN_or_USE:
        modelMOE :LatentDiffusion = model
        del model
        if 1:  # ensure distinct module instances per task (avoid shared identities)
            with open(PRETRAIN_JSON_PATH, 'r') as f: global_.moduleName_2_adaRank = json.load(f)
            print(f"loaded from {PRETRAIN_JSON_PATH=}")
            _src0 = copy.deepcopy(modelMOE.model.diffusion_model)
            _src1 = copy.deepcopy(modelMOE.model.diffusion_model)
            _src2 = copy.deepcopy(modelMOE.model.diffusion_model)
            _src3 = copy.deepcopy(modelMOE.model.diffusion_model)
            replace_modules_lossless(
                modelMOE.model.diffusion_model,
                [ _src0, _src1, _src2, _src3 ],
                [0,1,2,3],
                parent_name=".model.diffusion_model",
            )
            # Build-time dummy wrapping for task-specific heads so that ckpt keys match
            modelMOE.ID_proj_out = TaskSpecific_MoE([
                copy.deepcopy(modelMOE.ID_proj_out),
                copy.deepcopy(modelMOE.ID_proj_out),
                copy.deepcopy(modelMOE.ID_proj_out),
            ], [0,2,3])
            modelMOE.landmark_proj_out = TaskSpecific_MoE([
                copy.deepcopy(modelMOE.landmark_proj_out),
                copy.deepcopy(modelMOE.landmark_proj_out),
                copy.deepcopy(modelMOE.landmark_proj_out),
            ], [0,2,3])
            modelMOE.proj_out_source__head = TaskSpecific_MoE([
                copy.deepcopy(modelMOE.proj_out_source__head),
                copy.deepcopy(modelMOE.proj_out_source__head),
            ], [2,3])
            # Upcycle single refNet using three source refNets, and keep only one
            if REFNET.ENABLE:
                shared_ref = modelMOE.model.diffusion_model_refNet
                src0 = shared_ref
                src1 = copy.deepcopy(shared_ref)
                src2 = copy.deepcopy(shared_ref)
                src3 = copy.deepcopy(shared_ref)
                replace_modules_lossless(shared_ref, [src0, src1, src2, src3],[0,1,2,3], parent_name=".model.diffusion_model_refNet", for_refnet=True)
        # load from ./modelMOE.ckpt
        time.sleep(20*rank_)
        print(f"ckpt load over. m,u:")
    # Initialize bank here (after model structure is finalized)
    if REFNET.ENABLE :
        modelMOE.model.bank = Bank(reader=modelMOE.model.diffusion_model,writer=modelMOE.model.diffusion_model_refNet)
    if __name__=='__main__':
        for key in sorted( get_representative_moduleNames(modelMOE.state_dict().keys()) ):
            print(f"  - {key}")
    return modelMOE