Spaces:
Configuration error
Configuration error
| #credit to Acly for this module | |
| #from https://github.com/Acly/comfyui-inpaint-nodes | |
| import torch | |
| import torch.nn.functional as F | |
| import comfy | |
| from comfy.model_base import BaseModel | |
| from comfy.model_patcher import ModelPatcher | |
| from comfy.model_management import cast_to_device | |
| from .log import log_node_warn, log_node_error, log_node_info | |
| # Inpaint | |
| if hasattr(comfy.lora, "calculate_weight"): | |
| original_calculate_weight = comfy.lora.calculate_weight | |
| else: | |
| original_calculate_weight = ModelPatcher.calculate_weight | |
| injected_model_patcher_calculate_weight = False | |
| class InpaintHead(torch.nn.Module): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.head = torch.nn.Parameter(torch.empty(size=(320, 5, 3, 3), device="cpu")) | |
| def __call__(self, x): | |
| x = F.pad(x, (1, 1, 1, 1), "replicate") | |
| return F.conv2d(x, weight=self.head) | |
| def calculate_weight_patched(patches, weight, key, intermediate_type=torch.float32): | |
| remaining = [] | |
| for p in patches: | |
| alpha = p[0] | |
| v = p[1] | |
| is_fooocus_patch = isinstance(v, tuple) and len(v) == 2 and v[0] == "fooocus" | |
| if not is_fooocus_patch: | |
| remaining.append(p) | |
| continue | |
| if alpha != 0.0: | |
| v = v[1] | |
| w1 = cast_to_device(v[0], weight.device, torch.float32) | |
| if w1.shape == weight.shape: | |
| w_min = cast_to_device(v[1], weight.device, torch.float32) | |
| w_max = cast_to_device(v[2], weight.device, torch.float32) | |
| w1 = (w1 / 255.0) * (w_max - w_min) + w_min | |
| weight += alpha * cast_to_device(w1, weight.device, weight.dtype) | |
| else: | |
| pass | |
| # log_node_warn(self.node_name, | |
| # f"Shape mismatch {key}, weight not merged ({w1.shape} != {weight.shape})" | |
| # ) | |
| if len(remaining) > 0: | |
| return original_calculate_weight(remaining, weight, key, intermediate_type) | |
| return weight | |
| def inject_patched_calculate_weight(): | |
| global injected_model_patcher_calculate_weight | |
| if not injected_model_patcher_calculate_weight: | |
| print( | |
| "[comfyui-inpaint-nodes] Injecting patched comfy.model_patcher.ModelPatcher.calculate_weight" | |
| ) | |
| if hasattr(comfy.lora, "calculate_weight"): | |
| comfy.lora.calculate_weight = calculate_weight_patched | |
| else: | |
| ModelPatcher.calculate_weight = calculate_weight_patched | |
| injected_model_patcher_calculate_weight = True | |
| class InpaintWorker: | |
| def __init__(self, node_name): | |
| self.node_name = node_name if node_name is not None else "" | |
| def load_fooocus_patch(self, lora: dict, to_load: dict): | |
| patch_dict = {} | |
| loaded_keys = set() | |
| for key in to_load.values(): | |
| if value := lora.get(key, None): | |
| patch_dict[key] = ("fooocus", value) | |
| loaded_keys.add(key) | |
| not_loaded = sum(1 for x in lora if x not in loaded_keys) | |
| if not_loaded > 0: | |
| log_node_info(self.node_name, | |
| f"{len(loaded_keys)} Lora keys loaded, {not_loaded} remaining keys not found in model." | |
| ) | |
| return patch_dict | |
| def patch(self, model, latent, patch): | |
| base_model: BaseModel = model.model | |
| latent_pixels = base_model.process_latent_in(latent["samples"]) | |
| noise_mask = latent["noise_mask"].round() | |
| latent_mask = F.max_pool2d(noise_mask, (8, 8)).round().to(latent_pixels) | |
| inpaint_head_model, inpaint_lora = patch | |
| feed = torch.cat([latent_mask, latent_pixels], dim=1) | |
| inpaint_head_model.to(device=feed.device, dtype=feed.dtype) | |
| inpaint_head_feature = inpaint_head_model(feed) | |
| def input_block_patch(h, transformer_options): | |
| if transformer_options["block"][1] == 0: | |
| h = h + inpaint_head_feature.to(h) | |
| return h | |
| lora_keys = comfy.lora.model_lora_keys_unet(model.model, {}) | |
| lora_keys.update({x: x for x in base_model.state_dict().keys()}) | |
| loaded_lora = self.load_fooocus_patch(inpaint_lora, lora_keys) | |
| m = model.clone() | |
| m.set_model_input_block_patch(input_block_patch) | |
| patched = m.add_patches(loaded_lora, 1.0) | |
| not_patched_count = sum(1 for x in loaded_lora if x not in patched) | |
| if not_patched_count > 0: | |
| log_node_error(self.node_name, f"Failed to patch {not_patched_count} keys") | |
| inject_patched_calculate_weight() | |
| return (m,) |