|
|
| import torch |
| from ..utils import log |
| import comfy.model_management as mm |
| from comfy.utils import ProgressBar, load_torch_file |
| from tqdm import tqdm |
| import gc |
|
|
| from accelerate import init_empty_weights |
| from accelerate.utils import set_module_tensor_to_device |
| import folder_paths |
|
|
| import json |
| import numpy as np |
|
|
| class WanVideoUni3C_ControlnetLoader: |
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "model": (folder_paths.get_filename_list("controlnet"), {"tooltip": "These models are loaded from the 'ComfyUI/models/controlnet' -folder",}), |
|
|
| "base_precision": (["fp32", "bf16", "fp16"], {"default": "fp16"}), |
| "quantization": (['disabled', 'fp8_e4m3fn', 'fp8_e5m2'], {"default": 'disabled', "tooltip": "optional quantization method"}), |
| "load_device": (["main_device", "offload_device"], {"default": "main_device", "tooltip": "Initial device to load the model to, NOT recommended with the larger models unless you have 48GB+ VRAM"}), |
| "attention_mode": ([ |
| "sdpa", |
| "sageattn", |
| ], {"default": "sdpa"}), |
| }, |
| "optional": { |
| "compile_args": ("WANCOMPILEARGS", ), |
| |
| } |
| } |
|
|
| RETURN_TYPES = ("WANVIDEOCONTROLNET",) |
| RETURN_NAMES = ("controlnet", ) |
| FUNCTION = "loadmodel" |
| CATEGORY = "WanVideoWrapper" |
|
|
| def loadmodel(self, model, base_precision, load_device, quantization, attention_mode, compile_args=None): |
|
|
| device = mm.get_torch_device() |
| offload_device = mm.unet_offload_device() |
|
|
| transformer_load_device = device if load_device == "main_device" else offload_device |
| |
| base_dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "fp8_e4m3fn_fast": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp16_fast": torch.float16, "fp32": torch.float32}[base_precision] |
| |
|
|
| model_path = folder_paths.get_full_path_or_raise("controlnet", model) |
| |
| sd = load_torch_file(model_path, device=transformer_load_device, safe_load=True) |
|
|
| if not "controlnet_patch_embedding.weight" in sd: |
| raise ValueError("Invalid ControlNet model") |
| |
| in_channels = sd["controlnet_patch_embedding.weight"].shape[1] |
| ffn_dim = sd["controlnet_blocks.0.ffn.0.bias"].shape[0] |
|
|
| controlnet_cfg = { |
| "in_channels": in_channels, |
| "conv_out_dim": 5120, |
| "time_embed_dim": 5120, |
| "dim": 1024, |
| "ffn_dim": ffn_dim, |
| "num_heads": 16, |
| "num_layers": 20, |
| "add_channels": 7, |
| "mid_channels": 256, |
| "attention_mode": attention_mode, |
| "quantized": True if quantization != "disabled" else False, |
| "base_dtype": base_dtype |
| } |
|
|
| from .controlnet import WanControlNet |
|
|
| with init_empty_weights(): |
| controlnet = WanControlNet(controlnet_cfg) |
| controlnet.eval() |
| |
| if quantization == "disabled": |
| for k, v in sd.items(): |
| if isinstance(v, torch.Tensor): |
| if v.dtype == torch.float8_e4m3fn: |
| quantization = "fp8_e4m3fn" |
| break |
| elif v.dtype == torch.float8_e5m2: |
| quantization = "fp8_e5m2" |
| break |
|
|
| if "fp8_e4m3fn" in quantization: |
| dtype = torch.float8_e4m3fn |
| elif quantization == "fp8_e5m2": |
| dtype = torch.float8_e5m2 |
| else: |
| dtype = base_dtype |
| params_to_keep = {"norm", "head", "time_in", "vector_in", "controlnet_patch_embedding", "time_", "img_emb", "modulation", "text_embedding", "adapter", "proj_in"} |
| |
| log.info("Using accelerate to load and assign controlnet model weights to device...") |
| param_count = sum(1 for _ in controlnet.named_parameters()) |
| for name, param in tqdm(controlnet.named_parameters(), |
| desc=f"Loading transformer parameters to {transformer_load_device}", |
| total=param_count, |
| leave=True): |
| dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype |
| if "controlnet_patch_embedding" in name: |
| dtype_to_use = torch.float32 |
| set_module_tensor_to_device(controlnet, name, device=transformer_load_device, dtype=dtype_to_use, value=sd[name]) |
| |
| del sd |
|
|
| if compile_args is not None: |
| torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"] |
| try: |
| if hasattr(torch, '_dynamo') and hasattr(torch._dynamo, 'config'): |
| torch._dynamo.config.recompile_limit = compile_args["dynamo_recompile_limit"] |
| except Exception as e: |
| log.warning(f"Could not set recompile_limit: {e}") |
| if compile_args["compile_transformer_blocks_only"]: |
| for i, block in enumerate(controlnet.controlnet_blocks): |
| controlnet.controlnet_blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"]) |
| else: |
| controlnet = torch.compile(controlnet, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"]) |
| |
|
|
| if load_device == "offload_device" and controlnet.device != offload_device: |
| log.info(f"Moving controlnet model from {controlnet.device} to {offload_device}") |
| controlnet.to(offload_device) |
| gc.collect() |
| mm.soft_empty_cache() |
|
|
| return (controlnet,) |
|
|
| class WanVideoUni3C_embeds: |
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": { |
| "controlnet": ("WANVIDEOCONTROLNET",), |
| "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}), |
| "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent of the steps to apply the controlnet"}), |
| "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent of the steps to apply the controlnet"}), |
| }, |
| "optional": { |
| "render_latent": ("LATENT",), |
| "render_mask": ("MASK", {"tooltip": "NOT IMPLEMENTED!"}), |
| }, |
| } |
|
|
| RETURN_TYPES = ("UNI3C_EMBEDS", ) |
| RETURN_NAMES = ("uni3c_embeds",) |
| FUNCTION = "process" |
| CATEGORY = "WanVideoWrapper" |
|
|
| def process(self, controlnet, strength, start_percent, end_percent, render_latent=None, render_mask=None): |
|
|
| device = mm.get_torch_device() |
|
|
| latent_mask = latents = None |
| if render_latent is not None: |
| latents = render_latent["samples"] |
| |
| |
| |
| |
| if render_mask is not None: |
| raise NotImplementedError("render_mask is not implemented at this time") |
| mask = torch.nn.functional.interpolate( |
| render_mask.unsqueeze(0).unsqueeze(0), |
| size=(nframe, height, width), |
| mode='trilinear', |
| align_corners=False |
| ).squeeze(0) |
| latent_mask = mask.unsqueeze(0).to(device) |
| log.info(f"latent mask shape {latent_mask.shape}") |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| uni3c_embeds = { |
| "controlnet": controlnet, |
| "controlnet_weight": strength, |
| "start": start_percent, |
| "end": end_percent, |
| "render_latent": latents, |
| "render_mask": latent_mask, |
| "camera_embedding": None |
| } |
| |
| return (uni3c_embeds,) |
| |
| NODE_CLASS_MAPPINGS = { |
| "WanVideoUni3C_ControlnetLoader": WanVideoUni3C_ControlnetLoader, |
| "WanVideoUni3C_embeds": WanVideoUni3C_embeds, |
| } |
| NODE_DISPLAY_NAME_MAPPINGS = { |
| "WanVideoUni3C_ControlnetLoader": "WanVideo Uni3C Controlnet Loader", |
| "WanVideoUni3C_embeds": "WanVideo Uni3C Embeds", |
| } |
|
|
| |