Spaces:
Configuration error
Configuration error
| # This is for loading the CLIP (bert?) + mT5 encoder for HunYuanDiT | |
| import os | |
| import torch | |
| from transformers import AutoTokenizer, modeling_utils | |
| from transformers import T5Config, T5EncoderModel, BertConfig, BertModel | |
| from comfy import model_management | |
| import comfy.model_patcher | |
| import comfy.utils | |
| class mT5Model(torch.nn.Module): | |
| def __init__(self, textmodel_json_config=None, device="cpu", max_length=256, freeze=True, dtype=None): | |
| super().__init__() | |
| self.device = device | |
| self.dtype = dtype | |
| self.max_length = max_length | |
| if textmodel_json_config is None: | |
| textmodel_json_config = os.path.join( | |
| os.path.dirname(os.path.realpath(__file__)), | |
| f"config_mt5.json" | |
| ) | |
| config = T5Config.from_json_file(textmodel_json_config) | |
| with modeling_utils.no_init_weights(): | |
| self.transformer = T5EncoderModel(config) | |
| self.to(dtype) | |
| if freeze: | |
| self.freeze() | |
| def freeze(self): | |
| self.transformer = self.transformer.eval() | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def load_sd(self, sd): | |
| return self.transformer.load_state_dict(sd, strict=False) | |
| def to(self, *args, **kwargs): | |
| return self.transformer.to(*args, **kwargs) | |
| class hyCLIPModel(torch.nn.Module): | |
| def __init__(self, textmodel_json_config=None, device="cpu", max_length=77, freeze=True, dtype=None): | |
| super().__init__() | |
| self.device = device | |
| self.dtype = dtype | |
| self.max_length = max_length | |
| if textmodel_json_config is None: | |
| textmodel_json_config = os.path.join( | |
| os.path.dirname(os.path.realpath(__file__)), | |
| f"config_clip.json" | |
| ) | |
| config = BertConfig.from_json_file(textmodel_json_config) | |
| with modeling_utils.no_init_weights(): | |
| self.transformer = BertModel(config) | |
| self.to(dtype) | |
| if freeze: | |
| self.freeze() | |
| def freeze(self): | |
| self.transformer = self.transformer.eval() | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def load_sd(self, sd): | |
| return self.transformer.load_state_dict(sd, strict=False) | |
| def to(self, *args, **kwargs): | |
| return self.transformer.to(*args, **kwargs) | |
| class EXM_HyDiT_Tenc_Temp: | |
| def __init__(self, no_init=False, device="cpu", dtype=None, model_class="mT5", *kwargs): | |
| if no_init: | |
| return | |
| size = 8 if model_class == "mT5" else 2 | |
| if dtype == torch.float32: | |
| size *= 2 | |
| size *= (1024**3) | |
| if device == "auto": | |
| self.load_device = model_management.text_encoder_device() | |
| self.offload_device = model_management.text_encoder_offload_device() | |
| self.init_device = "cpu" | |
| elif device == "cpu": | |
| size = 0 # doesn't matter | |
| self.load_device = "cpu" | |
| self.offload_device = "cpu" | |
| self.init_device="cpu" | |
| elif device.startswith("cuda"): | |
| print("Direct CUDA device override!\nVRAM will not be freed by default.") | |
| size = 0 # not used | |
| self.load_device = device | |
| self.offload_device = device | |
| self.init_device = device | |
| else: | |
| self.load_device = model_management.get_torch_device() | |
| self.offload_device = "cpu" | |
| self.init_device="cpu" | |
| self.dtype = dtype | |
| self.device = self.load_device | |
| if model_class == "mT5": | |
| self.cond_stage_model = mT5Model( | |
| device = self.load_device, | |
| dtype = self.dtype, | |
| ) | |
| tokenizer_args = {"subfolder": "t2i/mt5"} # web | |
| tokenizer_path = os.path.join( # local | |
| os.path.dirname(os.path.realpath(__file__)), | |
| "mt5_tokenizer", | |
| ) | |
| else: | |
| self.cond_stage_model = hyCLIPModel( | |
| device = self.load_device, | |
| dtype = self.dtype, | |
| ) | |
| tokenizer_args = {"subfolder": "t2i/tokenizer",} # web | |
| tokenizer_path = os.path.join( # local | |
| os.path.dirname(os.path.realpath(__file__)), | |
| "tokenizer", | |
| ) | |
| # self.tokenizer = AutoTokenizer.from_pretrained( | |
| # "Tencent-Hunyuan/HunyuanDiT", | |
| # **tokenizer_args | |
| # ) | |
| self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) | |
| self.patcher = comfy.model_patcher.ModelPatcher( | |
| self.cond_stage_model, | |
| load_device = self.load_device, | |
| offload_device = self.offload_device, | |
| current_device = self.load_device, | |
| size = size, | |
| ) | |
| def clone(self): | |
| n = EXM_HyDiT_Tenc_Temp(no_init=True) | |
| n.patcher = self.patcher.clone() | |
| n.cond_stage_model = self.cond_stage_model | |
| n.tokenizer = self.tokenizer | |
| return n | |
| def load_sd(self, sd): | |
| return self.cond_stage_model.load_sd(sd) | |
| def get_sd(self): | |
| return self.cond_stage_model.state_dict() | |
| def load_model(self): | |
| if self.load_device != "cpu": | |
| model_management.load_model_gpu(self.patcher) | |
| return self.patcher | |
| def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): | |
| return self.patcher.add_patches(patches, strength_patch, strength_model) | |
| def get_key_patches(self): | |
| return self.patcher.get_key_patches() | |
| def load_clip(model_path, **kwargs): | |
| model = EXM_HyDiT_Tenc_Temp(model_class="clip", **kwargs) | |
| sd = comfy.utils.load_torch_file(model_path) | |
| prefix = "bert." | |
| state_dict = {} | |
| for key in sd: | |
| nkey = key | |
| if key.startswith(prefix): | |
| nkey = key[len(prefix):] | |
| state_dict[nkey] = sd[key] | |
| m, e = model.load_sd(state_dict) | |
| if len(m) > 0 or len(e) > 0: | |
| print(f"HYDiT: clip missing {len(m)} keys ({len(e)} extra)") | |
| return model | |
| def load_t5(model_path, **kwargs): | |
| model = EXM_HyDiT_Tenc_Temp(model_class="mT5", **kwargs) | |
| sd = comfy.utils.load_torch_file(model_path) | |
| m, e = model.load_sd(sd) | |
| if len(m) > 0 or len(e) > 0: | |
| print(f"HYDiT: mT5 missing {len(m)} keys ({len(e)} extra)") | |
| return model | |