| | """ |
| | This file is part of ComfyUI. |
| | Copyright (C) 2024 Comfy |
| | |
| | This program is free software: you can redistribute it and/or modify |
| | it under the terms of the GNU General Public License as published by |
| | the Free Software Foundation, either version 3 of the License, or |
| | (at your option) any later version. |
| | |
| | This program is distributed in the hope that it will be useful, |
| | but WITHOUT ANY WARRANTY; without even the implied warranty of |
| | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
| | GNU General Public License for more details. |
| | |
| | You should have received a copy of the GNU General Public License |
| | along with this program. If not, see <https://www.gnu.org/licenses/>. |
| | """ |
| |
|
| | import torch |
| | from . import model_base |
| | from . import utils |
| | from . import latent_formats |
| |
|
| | class ClipTarget: |
| | def __init__(self, tokenizer, clip): |
| | self.clip = clip |
| | self.tokenizer = tokenizer |
| | self.params = {} |
| |
|
| | class BASE: |
| | unet_config = {} |
| | unet_extra_config = { |
| | "num_heads": -1, |
| | "num_head_channels": 64, |
| | } |
| |
|
| | required_keys = {} |
| |
|
| | clip_prefix = [] |
| | clip_vision_prefix = None |
| | noise_aug_config = None |
| | sampling_settings = {} |
| | latent_format = latent_formats.LatentFormat |
| | vae_key_prefix = ["first_stage_model."] |
| | text_encoder_key_prefix = ["cond_stage_model."] |
| | supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] |
| |
|
| | memory_usage_factor = 2.0 |
| |
|
| | manual_cast_dtype = None |
| | custom_operations = None |
| | scaled_fp8 = None |
| | optimizations = {"fp8": False} |
| |
|
| | @classmethod |
| | def matches(s, unet_config, state_dict=None): |
| | for k in s.unet_config: |
| | if k not in unet_config or s.unet_config[k] != unet_config[k]: |
| | return False |
| | if state_dict is not None: |
| | for k in s.required_keys: |
| | if k not in state_dict: |
| | return False |
| | return True |
| |
|
| | def model_type(self, state_dict, prefix=""): |
| | return model_base.ModelType.EPS |
| |
|
| | def inpaint_model(self): |
| | return self.unet_config["in_channels"] > 4 |
| |
|
| | def __init__(self, unet_config): |
| | self.unet_config = unet_config.copy() |
| | self.sampling_settings = self.sampling_settings.copy() |
| | self.latent_format = self.latent_format() |
| | self.optimizations = self.optimizations.copy() |
| | for x in self.unet_extra_config: |
| | self.unet_config[x] = self.unet_extra_config[x] |
| |
|
| | def get_model(self, state_dict, prefix="", device=None): |
| | if self.noise_aug_config is not None: |
| | out = model_base.SD21UNCLIP(self, self.noise_aug_config, model_type=self.model_type(state_dict, prefix), device=device) |
| | else: |
| | out = model_base.BaseModel(self, model_type=self.model_type(state_dict, prefix), device=device) |
| | if self.inpaint_model(): |
| | out.set_inpaint() |
| | return out |
| |
|
| | def process_clip_state_dict(self, state_dict): |
| | state_dict = utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True) |
| | return state_dict |
| |
|
| | def process_unet_state_dict(self, state_dict): |
| | return state_dict |
| |
|
| | def process_vae_state_dict(self, state_dict): |
| | return state_dict |
| |
|
| | def process_clip_state_dict_for_saving(self, state_dict): |
| | replace_prefix = {"": self.text_encoder_key_prefix[0]} |
| | return utils.state_dict_prefix_replace(state_dict, replace_prefix) |
| |
|
| | def process_clip_vision_state_dict_for_saving(self, state_dict): |
| | replace_prefix = {} |
| | if self.clip_vision_prefix is not None: |
| | replace_prefix[""] = self.clip_vision_prefix |
| | return utils.state_dict_prefix_replace(state_dict, replace_prefix) |
| |
|
| | def process_unet_state_dict_for_saving(self, state_dict): |
| | replace_prefix = {"": "model.diffusion_model."} |
| | return utils.state_dict_prefix_replace(state_dict, replace_prefix) |
| |
|
| | def process_vae_state_dict_for_saving(self, state_dict): |
| | replace_prefix = {"": self.vae_key_prefix[0]} |
| | return utils.state_dict_prefix_replace(state_dict, replace_prefix) |
| |
|
| | def set_inference_dtype(self, dtype, manual_cast_dtype): |
| | self.unet_config['dtype'] = dtype |
| | self.manual_cast_dtype = manual_cast_dtype |
| |
|