|
|
|
|
|
import torch |
|
|
from ..utils import log |
|
|
import comfy.model_management as mm |
|
|
from comfy.utils import load_torch_file |
|
|
from tqdm import tqdm |
|
|
import gc |
|
|
|
|
|
from accelerate import init_empty_weights |
|
|
from accelerate.utils import set_module_tensor_to_device |
|
|
import folder_paths |
|
|
|
|
|
class WanVideoControlnetLoader: |
|
|
@classmethod |
|
|
def INPUT_TYPES(s): |
|
|
return { |
|
|
"required": { |
|
|
"model": (folder_paths.get_filename_list("controlnet"), {"tooltip": "These models are loaded from the 'ComfyUI/models/controlnet' -folder",}), |
|
|
|
|
|
"base_precision": (["fp32", "bf16", "fp16"], {"default": "bf16"}), |
|
|
"quantization": (['disabled', 'fp8_e4m3fn', 'fp8_e4m3fn_fast', 'fp8_e5m2', 'fp8_e4m3fn_fast_no_ffn'], {"default": 'disabled', "tooltip": "optional quantization method"}), |
|
|
"load_device": (["main_device", "offload_device"], {"default": "main_device", "tooltip": "Initial device to load the model to, NOT recommended with the larger models unless you have 48GB+ VRAM"}), |
|
|
}, |
|
|
} |
|
|
|
|
|
RETURN_TYPES = ("WANVIDEOCONTROLNET",) |
|
|
RETURN_NAMES = ("controlnet", ) |
|
|
FUNCTION = "loadmodel" |
|
|
CATEGORY = "WanVideoWrapper" |
|
|
DESCRIPTION = "Loads ControlNet model from 'https://huggingface.co/collections/TheDenk/wan21-controlnets-68302b430411dafc0d74d2fc'" |
|
|
|
|
|
def loadmodel(self, model, base_precision, load_device, quantization): |
|
|
|
|
|
device = mm.get_torch_device() |
|
|
offload_device = mm.unet_offload_device() |
|
|
|
|
|
transformer_load_device = device if load_device == "main_device" else offload_device |
|
|
|
|
|
base_dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "fp8_e4m3fn_fast": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp16_fast": torch.float16, "fp32": torch.float32}[base_precision] |
|
|
|
|
|
model_path = folder_paths.get_full_path_or_raise("controlnet", model) |
|
|
|
|
|
sd = load_torch_file(model_path, device=transformer_load_device, safe_load=True) |
|
|
|
|
|
num_layers = 8 if "blocks.7.scale_shift_table" in sd else 6 |
|
|
out_proj_dim = sd["controlnet_blocks.0.bias"].shape[0] |
|
|
downscale_coef = 16 if out_proj_dim == 3072 else 8 |
|
|
vae_channels = 48 if out_proj_dim == 3072 else 16 |
|
|
|
|
|
if not "control_encoder.0.0.weight" in sd: |
|
|
raise ValueError("Invalid ControlNet model") |
|
|
|
|
|
controlnet_cfg = { |
|
|
"added_kv_proj_dim": None, |
|
|
"attention_head_dim": 128, |
|
|
"cross_attn_norm": None, |
|
|
"downscale_coef": downscale_coef, |
|
|
"eps": 1e-06, |
|
|
"ffn_dim": 8960, |
|
|
"freq_dim": 256, |
|
|
"image_dim": None, |
|
|
"in_channels": 3, |
|
|
"num_attention_heads": 12, |
|
|
"num_layers": num_layers, |
|
|
"out_proj_dim": out_proj_dim, |
|
|
"patch_size": [ |
|
|
1, |
|
|
2, |
|
|
2 |
|
|
], |
|
|
"qk_norm": "rms_norm_across_heads", |
|
|
"rope_max_seq_len": 1024, |
|
|
"text_dim": 4096, |
|
|
"vae_channels": vae_channels |
|
|
} |
|
|
print(f"Loading WanControlnet with config: {controlnet_cfg}") |
|
|
|
|
|
from .wan_controlnet import WanControlnet |
|
|
|
|
|
with init_empty_weights(): |
|
|
controlnet = WanControlnet(**controlnet_cfg) |
|
|
controlnet.eval() |
|
|
|
|
|
if quantization == "disabled": |
|
|
for k, v in sd.items(): |
|
|
if isinstance(v, torch.Tensor): |
|
|
if v.dtype == torch.float8_e4m3fn: |
|
|
quantization = "fp8_e4m3fn" |
|
|
break |
|
|
elif v.dtype == torch.float8_e5m2: |
|
|
quantization = "fp8_e5m2" |
|
|
break |
|
|
|
|
|
if "fp8_e4m3fn" in quantization: |
|
|
dtype = torch.float8_e4m3fn |
|
|
elif quantization == "fp8_e5m2": |
|
|
dtype = torch.float8_e5m2 |
|
|
else: |
|
|
dtype = base_dtype |
|
|
params_to_keep = {"norm", "head", "time_in", "vector_in", "controlnet_patch_embedding", "time_", "img_emb", "modulation", "text_embedding", "adapter"} |
|
|
|
|
|
log.info("Using accelerate to load and assign controlnet model weights to device...") |
|
|
param_count = sum(1 for _ in controlnet.named_parameters()) |
|
|
for name, param in tqdm(controlnet.named_parameters(), |
|
|
desc=f"Loading transformer parameters to {transformer_load_device}", |
|
|
total=param_count, |
|
|
leave=True): |
|
|
dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype |
|
|
if "controlnet_patch_embedding" in name: |
|
|
dtype_to_use = torch.float32 |
|
|
set_module_tensor_to_device(controlnet, name, device=transformer_load_device, dtype=dtype_to_use, value=sd[name]) |
|
|
|
|
|
del sd |
|
|
|
|
|
if load_device == "offload_device" and controlnet.device != offload_device: |
|
|
log.info(f"Moving controlnet model from {controlnet.device} to {offload_device}") |
|
|
controlnet.to(offload_device) |
|
|
gc.collect() |
|
|
mm.soft_empty_cache() |
|
|
|
|
|
return (controlnet,) |
|
|
|
|
|
class WanVideoControlnetApply: |
|
|
@classmethod |
|
|
def INPUT_TYPES(s): |
|
|
return { |
|
|
"required": { |
|
|
"model": ("WANVIDEOMODEL", ), |
|
|
"controlnet": ("WANVIDEOCONTROLNET", ), |
|
|
"control_images": ("IMAGE", ), |
|
|
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.0001, "tooltip": "controlnet strength"}), |
|
|
"control_stride": ("INT", {"default": 3, "min": 1, "max": 8, "step": 1, "tooltip": "controlnet stride"}), |
|
|
"control_start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent of the steps to apply controlnet"}), |
|
|
"control_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent of the steps to apply controlnet"}), |
|
|
} |
|
|
} |
|
|
|
|
|
RETURN_TYPES = ("WANVIDEOMODEL",) |
|
|
RETURN_NAMES = ("model", ) |
|
|
FUNCTION = "loadmodel" |
|
|
CATEGORY = "WanVideoWrapper" |
|
|
|
|
|
def loadmodel(self, model, controlnet, control_images, strength, control_stride, control_start_percent, control_end_percent): |
|
|
|
|
|
patcher = model.clone() |
|
|
if 'transformer_options' not in patcher.model_options: |
|
|
patcher.model_options['transformer_options'] = {} |
|
|
|
|
|
control_input = control_images.permute(3, 0, 1, 2).unsqueeze(0).contiguous() |
|
|
control_input = control_input * 2.0 - 1.0 |
|
|
|
|
|
controlnet = { |
|
|
"controlnet": controlnet, |
|
|
"control_latents": control_input, |
|
|
"controlnet_strength": strength, |
|
|
"control_stride": control_stride, |
|
|
"controlnet_start": control_start_percent, |
|
|
"controlnet_end": control_end_percent |
|
|
} |
|
|
patcher.model_options["transformer_options"]["controlnet"] = controlnet |
|
|
|
|
|
return (patcher,) |
|
|
|
|
|
NODE_CLASS_MAPPINGS = { |
|
|
"WanVideoControlnetLoader": WanVideoControlnetLoader, |
|
|
"WanVideoControlnet": WanVideoControlnetApply, |
|
|
} |
|
|
NODE_DISPLAY_NAME_MAPPINGS = { |
|
|
"WanVideoControlnetLoader": "WanVideo Controlnet Loader", |
|
|
"WanVideoControlnet": "WanVideo Controlnet Apply", |
|
|
} |
|
|
|
|
|
|