| import sys,os |
| cur_dir = os.path.dirname(os.path.abspath(__file__)) |
| if __name__=='__main__': sys.path.append(os.path.abspath(os.path.join(cur_dir, '..'))) |
|
|
| from confs import * |
| import json |
| import argparse, os, sys, glob |
| import cv2 |
| import torch |
| import numpy as np |
| from MoE import * |
| from multiTask_model import * |
| from lora_layers import * |
| from omegaconf import OmegaConf |
| from PIL import Image |
| from tqdm import tqdm, trange |
| from itertools import islice |
| from einops import rearrange |
| from torchvision.utils import make_grid |
| from my_py_lib.image_util import imgs_2_grid_A,img_paths_2_grid_A |
| import time |
| import copy |
| from pytorch_lightning import seed_everything |
| from torch import autocast |
| from contextlib import contextmanager, nullcontext |
| import torchvision |
| from ldm.models.diffusion.ddpm import LatentDiffusion |
| from ldm.models.diffusion.bank import Bank |
| from ldm.util import instantiate_from_config |
|
|
| from ldm.models.diffusion.ddim import DDIMSampler |
|
|
| from transformers import AutoFeatureExtractor |
|
|
| |
| from torchvision.transforms import Resize |
| from fnmatch import fnmatch |
|
|
|
|
| from PIL import Image |
| from torchvision.transforms import PILToTensor |
| |
|
|
|
|
| def get_moe(): |
| if 1: |
| seed_everything(42) |
| |
| model :LatentDiffusion = instantiate_from_config(OmegaConf.load(f"LatentDiffusion.yaml").model,) |
| if REFNET.ENABLE: |
| assert model.model.diffusion_model_refNet.is_refNet |
|
|
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
| device = torch.device("cpu") |
| model = model.to(device) |
| if FOR_upcycle_ckpt_GEN_or_USE: |
| del model.ptsM_Generator |
|
|
| def average_module_weight( |
| src_modules: list, |
| ): |
| """Average the weights of multiple modules""" |
| if not src_modules: |
| return None |
| |
| avg_state_dict = {} |
| first_state_dict = src_modules[0].state_dict() |
| |
| for key in first_state_dict: |
| avg_state_dict[key] = torch.zeros_like(first_state_dict[key]) |
| |
| for module in src_modules: |
| module_state_dict = module.state_dict() |
| for key in avg_state_dict: |
| avg_state_dict[key] += module_state_dict[key] |
| |
| for key in avg_state_dict: |
| avg_state_dict[key] /= len(src_modules) |
| return avg_state_dict |
| def recursive_average_module_weight( |
| tgt_module: nn.Module, |
| src_modules: list, |
| cb, |
| ): |
| """ |
| Recursively find modules and replace with averaged weights based on callback |
| """ |
| for name, child in tgt_module.named_children(): |
| if 1: |
| src_child_modules = [] |
| for src_module in src_modules: |
| src_child = getattr(src_module, name) |
| assert src_child is not None,name |
| src_child_modules.append(src_child) |
| |
| if cb(child, name, tgt_module): |
| print(f"[recursive_average_module_weight] {name=} child: {repr(child)[:50]} tgt_module: {repr(tgt_module)[:50]}") |
| |
| avg_weights = average_module_weight(src_child_modules) |
| child.load_state_dict(avg_weights) |
| else: |
| recursive_average_module_weight(child, src_child_modules, cb) |
| return tgt_module |
| |
| def replace_module_with_TaskSpecific( |
| tgt_module: nn.Module, |
| src_modules: list, |
| cb, |
| parent_name: str = "", |
| depth :int = 0, |
| ): |
| for name, child in tgt_module.named_children(): |
| if 1: |
| src_child_modules = [] |
| for src_module in src_modules: |
| src_child = getattr(src_module, name) |
| assert src_child is not None,name |
| src_child_modules.append(src_child) |
| assert not isinstance(child, TaskSpecific_MoE) |
| full_name = f"{parent_name}.{name}" |
| if cb(child, name, full_name, tgt_module): |
| print(f"[replace_module_with_TaskSpecific] {name=} child: {repr(child)[:50]} tgt_module: {repr(tgt_module)[:50]}") |
| setattr(tgt_module, name, TaskSpecific_MoE(src_child_modules,TASKS)) |
| else: |
| if depth<=0: |
| replace_module_with_TaskSpecific(child, src_child_modules,cb,parent_name=full_name,depth=depth+1) |
| return tgt_module |
| |
| if not FOR_upcycle_ckpt_GEN_or_USE: |
| modelMOE :LatentDiffusion = model |
| del model |
| if 1: |
| with open(PRETRAIN_JSON_PATH, 'r') as f: global_.moduleName_2_adaRank = json.load(f) |
| print(f"loaded from {PRETRAIN_JSON_PATH=}") |
| _src0 = copy.deepcopy(modelMOE.model.diffusion_model) |
| _src1 = copy.deepcopy(modelMOE.model.diffusion_model) |
| _src2 = copy.deepcopy(modelMOE.model.diffusion_model) |
| _src3 = copy.deepcopy(modelMOE.model.diffusion_model) |
| replace_modules_lossless( |
| modelMOE.model.diffusion_model, |
| [ _src0, _src1, _src2, _src3 ], |
| [0,1,2,3], |
| parent_name=".model.diffusion_model", |
| ) |
| |
| modelMOE.ID_proj_out = TaskSpecific_MoE([ |
| copy.deepcopy(modelMOE.ID_proj_out), |
| copy.deepcopy(modelMOE.ID_proj_out), |
| copy.deepcopy(modelMOE.ID_proj_out), |
| ], [0,2,3]) |
| modelMOE.landmark_proj_out = TaskSpecific_MoE([ |
| copy.deepcopy(modelMOE.landmark_proj_out), |
| copy.deepcopy(modelMOE.landmark_proj_out), |
| copy.deepcopy(modelMOE.landmark_proj_out), |
| ], [0,2,3]) |
| modelMOE.proj_out_source__head = TaskSpecific_MoE([ |
| copy.deepcopy(modelMOE.proj_out_source__head), |
| copy.deepcopy(modelMOE.proj_out_source__head), |
| ], [2,3]) |
| |
| if REFNET.ENABLE: |
| shared_ref = modelMOE.model.diffusion_model_refNet |
| src0 = shared_ref |
| src1 = copy.deepcopy(shared_ref) |
| src2 = copy.deepcopy(shared_ref) |
| src3 = copy.deepcopy(shared_ref) |
| replace_modules_lossless(shared_ref, [src0, src1, src2, src3],[0,1,2,3], parent_name=".model.diffusion_model_refNet", for_refnet=True) |
| |
| time.sleep(20*rank_) |
| print(f"ckpt load over. m,u:") |
| |
| if REFNET.ENABLE : |
| modelMOE.model.bank = Bank(reader=modelMOE.model.diffusion_model,writer=modelMOE.model.diffusion_model_refNet) |
| if __name__=='__main__': |
| for key in sorted( get_representative_moduleNames(modelMOE.state_dict().keys()) ): |
| print(f" - {key}") |
| return modelMOE |
|
|
|
|