| import torch
|
| from . import model_base
|
| from . import utils
|
| from . import latent_formats
|
|
|
| class ClipTarget:
|
| def __init__(self, tokenizer, clip):
|
| self.clip = clip
|
| self.tokenizer = tokenizer
|
| self.params = {}
|
|
|
| class BASE:
|
| unet_config = {}
|
| unet_extra_config = {
|
| "num_heads": -1,
|
| "num_head_channels": 64,
|
| }
|
|
|
| required_keys = {}
|
|
|
| clip_prefix = []
|
| clip_vision_prefix = None
|
| noise_aug_config = None
|
| sampling_settings = {}
|
| latent_format = latent_formats.LatentFormat
|
| vae_key_prefix = ["first_stage_model."]
|
| text_encoder_key_prefix = ["cond_stage_model."]
|
| supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
|
|
| manual_cast_dtype = None
|
|
|
| @classmethod
|
| def matches(s, unet_config, state_dict=None):
|
| for k in s.unet_config:
|
| if k not in unet_config or s.unet_config[k] != unet_config[k]:
|
| return False
|
| if state_dict is not None:
|
| for k in s.required_keys:
|
| if k not in state_dict:
|
| return False
|
| return True
|
|
|
| def model_type(self, state_dict, prefix=""):
|
| return model_base.ModelType.EPS
|
|
|
| def inpaint_model(self):
|
| return self.unet_config["in_channels"] > 4
|
|
|
| def __init__(self, unet_config):
|
| self.unet_config = unet_config.copy()
|
| self.sampling_settings = self.sampling_settings.copy()
|
| self.latent_format = self.latent_format()
|
| for x in self.unet_extra_config:
|
| self.unet_config[x] = self.unet_extra_config[x]
|
|
|
| def get_model(self, state_dict, prefix="", device=None):
|
| if self.noise_aug_config is not None:
|
| out = model_base.SD21UNCLIP(self, self.noise_aug_config, model_type=self.model_type(state_dict, prefix), device=device)
|
| else:
|
| out = model_base.BaseModel(self, model_type=self.model_type(state_dict, prefix), device=device)
|
| if self.inpaint_model():
|
| out.set_inpaint()
|
| return out
|
|
|
| def process_clip_state_dict(self, state_dict):
|
| state_dict = utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True)
|
| return state_dict
|
|
|
| def process_unet_state_dict(self, state_dict):
|
| return state_dict
|
|
|
| def process_vae_state_dict(self, state_dict):
|
| return state_dict
|
|
|
| def process_clip_state_dict_for_saving(self, state_dict):
|
| replace_prefix = {"": self.text_encoder_key_prefix[0]}
|
| return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
|
|
| def process_clip_vision_state_dict_for_saving(self, state_dict):
|
| replace_prefix = {}
|
| if self.clip_vision_prefix is not None:
|
| replace_prefix[""] = self.clip_vision_prefix
|
| return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
|
|
| def process_unet_state_dict_for_saving(self, state_dict):
|
| replace_prefix = {"": "model.diffusion_model."}
|
| return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
|
|
| def process_vae_state_dict_for_saving(self, state_dict):
|
| replace_prefix = {"": self.vae_key_prefix[0]}
|
| return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
|
|
| def set_inference_dtype(self, dtype, manual_cast_dtype):
|
| self.unet_config['dtype'] = dtype
|
| self.manual_cast_dtype = manual_cast_dtype
|
|
|