File size: 7,431 Bytes
2b534de | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 | import sys,os
cur_dir = os.path.dirname(os.path.abspath(__file__))
if __name__=='__main__': sys.path.append(os.path.abspath(os.path.join(cur_dir, '..')))
from confs import *
import json
import argparse, os, sys, glob
import cv2
import torch
import numpy as np
from MoE import *
from multiTask_model import *
from lora_layers import *
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from itertools import islice
from einops import rearrange
from torchvision.utils import make_grid
from my_py_lib.image_util import imgs_2_grid_A,img_paths_2_grid_A
import time
import copy
from pytorch_lightning import seed_everything
from torch import autocast
from contextlib import contextmanager, nullcontext
import torchvision
from ldm.models.diffusion.ddpm import LatentDiffusion
from ldm.models.diffusion.bank import Bank
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from transformers import AutoFeatureExtractor
# import clip
from torchvision.transforms import Resize
from fnmatch import fnmatch
from PIL import Image
from torchvision.transforms import PILToTensor
#----------------------------------------------------------------------------
def get_moe():
if 1:
seed_everything(42)
# torch.cuda.set_device(opt.device_ID)
model :LatentDiffusion = instantiate_from_config(OmegaConf.load(f"LatentDiffusion.yaml").model,)
if REFNET.ENABLE:
assert model.model.diffusion_model_refNet.is_refNet
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device("cpu")
model = model.to(device)
if FOR_upcycle_ckpt_GEN_or_USE:
del model.ptsM_Generator
def average_module_weight(
src_modules: list,
):
"""Average the weights of multiple modules"""
if not src_modules:
return None
# Get the state dict of the first module as template
avg_state_dict = {}
first_state_dict = src_modules[0].state_dict()
# Initialize with zeros
for key in first_state_dict:
avg_state_dict[key] = torch.zeros_like(first_state_dict[key])
# Sum
for module in src_modules:
module_state_dict = module.state_dict()
for key in avg_state_dict:
avg_state_dict[key] += module_state_dict[key]
# Average
for key in avg_state_dict:
avg_state_dict[key] /= len(src_modules)
return avg_state_dict
def recursive_average_module_weight(
tgt_module: nn.Module,
src_modules: list,
cb,
):
"""
Recursively find modules and replace with averaged weights based on callback
"""
for name, child in tgt_module.named_children():
if 1: # Get corresponding modules from source models
src_child_modules = []
for src_module in src_modules:
src_child = getattr(src_module, name)
assert src_child is not None,name
src_child_modules.append(src_child)
# assert not isinstance(child, TaskSpecific_MoE)
if cb(child, name, tgt_module):
print(f"[recursive_average_module_weight] {name=} child: {repr(child)[:50]} tgt_module: {repr(tgt_module)[:50]}")
# Average & load
avg_weights = average_module_weight(src_child_modules)
child.load_state_dict(avg_weights)
else:
recursive_average_module_weight(child, src_child_modules, cb)
return tgt_module
def replace_module_with_TaskSpecific(
tgt_module: nn.Module,# tgt module
src_modules: list,
cb,
parent_name: str = "",
depth :int = 0,
):
for name, child in tgt_module.named_children():
if 1: # Get corresponding modules from source models
src_child_modules = []
for src_module in src_modules:
src_child = getattr(src_module, name)
assert src_child is not None,name
src_child_modules.append(src_child)
assert not isinstance(child, TaskSpecific_MoE)
full_name = f"{parent_name}.{name}"
if cb(child, name, full_name, tgt_module):
print(f"[replace_module_with_TaskSpecific] {name=} child: {repr(child)[:50]} tgt_module: {repr(tgt_module)[:50]}")
setattr(tgt_module, name, TaskSpecific_MoE(src_child_modules,TASKS))
else:
if depth<=0:
replace_module_with_TaskSpecific(child, src_child_modules,cb,parent_name=full_name,depth=depth+1)
return tgt_module
if not FOR_upcycle_ckpt_GEN_or_USE:
modelMOE :LatentDiffusion = model
del model
if 1: # ensure distinct module instances per task (avoid shared identities)
with open(PRETRAIN_JSON_PATH, 'r') as f: global_.moduleName_2_adaRank = json.load(f)
print(f"loaded from {PRETRAIN_JSON_PATH=}")
_src0 = copy.deepcopy(modelMOE.model.diffusion_model)
_src1 = copy.deepcopy(modelMOE.model.diffusion_model)
_src2 = copy.deepcopy(modelMOE.model.diffusion_model)
_src3 = copy.deepcopy(modelMOE.model.diffusion_model)
replace_modules_lossless(
modelMOE.model.diffusion_model,
[ _src0, _src1, _src2, _src3 ],
[0,1,2,3],
parent_name=".model.diffusion_model",
)
# Build-time dummy wrapping for task-specific heads so that ckpt keys match
modelMOE.ID_proj_out = TaskSpecific_MoE([
copy.deepcopy(modelMOE.ID_proj_out),
copy.deepcopy(modelMOE.ID_proj_out),
copy.deepcopy(modelMOE.ID_proj_out),
], [0,2,3])
modelMOE.landmark_proj_out = TaskSpecific_MoE([
copy.deepcopy(modelMOE.landmark_proj_out),
copy.deepcopy(modelMOE.landmark_proj_out),
copy.deepcopy(modelMOE.landmark_proj_out),
], [0,2,3])
modelMOE.proj_out_source__head = TaskSpecific_MoE([
copy.deepcopy(modelMOE.proj_out_source__head),
copy.deepcopy(modelMOE.proj_out_source__head),
], [2,3])
# Upcycle single refNet using three source refNets, and keep only one
if REFNET.ENABLE:
shared_ref = modelMOE.model.diffusion_model_refNet
src0 = shared_ref
src1 = copy.deepcopy(shared_ref)
src2 = copy.deepcopy(shared_ref)
src3 = copy.deepcopy(shared_ref)
replace_modules_lossless(shared_ref, [src0, src1, src2, src3],[0,1,2,3], parent_name=".model.diffusion_model_refNet", for_refnet=True)
# load from ./modelMOE.ckpt
time.sleep(20*rank_)
print(f"ckpt load over. m,u:")
# Initialize bank here (after model structure is finalized)
if REFNET.ENABLE :
modelMOE.model.bank = Bank(reader=modelMOE.model.diffusion_model,writer=modelMOE.model.diffusion_model_refNet)
if __name__=='__main__':
for key in sorted( get_representative_moduleNames(modelMOE.state_dict().keys()) ):
print(f" - {key}")
return modelMOE
|