aliensmn's picture
Mirror from https://github.com/kijai/ComfyUI-WanVideoWrapper
cf812a0 verified
import torch
from ..utils import log
import comfy.model_management as mm
from comfy.utils import ProgressBar, load_torch_file
from tqdm import tqdm
import gc
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
import folder_paths
import json
import numpy as np
class WanVideoUni3C_ControlnetLoader:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": (folder_paths.get_filename_list("controlnet"), {"tooltip": "These models are loaded from the 'ComfyUI/models/controlnet' -folder",}),
"base_precision": (["fp32", "bf16", "fp16"], {"default": "fp16"}),
"quantization": (['disabled', 'fp8_e4m3fn', 'fp8_e5m2'], {"default": 'disabled', "tooltip": "optional quantization method"}),
"load_device": (["main_device", "offload_device"], {"default": "main_device", "tooltip": "Initial device to load the model to, NOT recommended with the larger models unless you have 48GB+ VRAM"}),
"attention_mode": ([
"sdpa",
"sageattn",
], {"default": "sdpa"}),
},
"optional": {
"compile_args": ("WANCOMPILEARGS", ),
#"block_swap_args": ("BLOCKSWAPARGS", ),
}
}
RETURN_TYPES = ("WANVIDEOCONTROLNET",)
RETURN_NAMES = ("controlnet", )
FUNCTION = "loadmodel"
CATEGORY = "WanVideoWrapper"
def loadmodel(self, model, base_precision, load_device, quantization, attention_mode, compile_args=None):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
transformer_load_device = device if load_device == "main_device" else offload_device
base_dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "fp8_e4m3fn_fast": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp16_fast": torch.float16, "fp32": torch.float32}[base_precision]
model_path = folder_paths.get_full_path_or_raise("controlnet", model)
sd = load_torch_file(model_path, device=transformer_load_device, safe_load=True)
if not "controlnet_patch_embedding.weight" in sd:
raise ValueError("Invalid ControlNet model")
in_channels = sd["controlnet_patch_embedding.weight"].shape[1]
ffn_dim = sd["controlnet_blocks.0.ffn.0.bias"].shape[0]
controlnet_cfg = {
"in_channels": in_channels,
"conv_out_dim": 5120,
"time_embed_dim": 5120,
"dim": 1024,
"ffn_dim": ffn_dim,
"num_heads": 16,
"num_layers": 20,
"add_channels": 7,
"mid_channels": 256,
"attention_mode": attention_mode,
"quantized": True if quantization != "disabled" else False,
"base_dtype": base_dtype
}
from .controlnet import WanControlNet
with init_empty_weights():
controlnet = WanControlNet(controlnet_cfg)
controlnet.eval()
if quantization == "disabled":
for k, v in sd.items():
if isinstance(v, torch.Tensor):
if v.dtype == torch.float8_e4m3fn:
quantization = "fp8_e4m3fn"
break
elif v.dtype == torch.float8_e5m2:
quantization = "fp8_e5m2"
break
if "fp8_e4m3fn" in quantization:
dtype = torch.float8_e4m3fn
elif quantization == "fp8_e5m2":
dtype = torch.float8_e5m2
else:
dtype = base_dtype
params_to_keep = {"norm", "head", "time_in", "vector_in", "controlnet_patch_embedding", "time_", "img_emb", "modulation", "text_embedding", "adapter", "proj_in"}
log.info("Using accelerate to load and assign controlnet model weights to device...")
param_count = sum(1 for _ in controlnet.named_parameters())
for name, param in tqdm(controlnet.named_parameters(),
desc=f"Loading transformer parameters to {transformer_load_device}",
total=param_count,
leave=True):
dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype
if "controlnet_patch_embedding" in name:
dtype_to_use = torch.float32
set_module_tensor_to_device(controlnet, name, device=transformer_load_device, dtype=dtype_to_use, value=sd[name])
del sd
if compile_args is not None:
torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"]
try:
if hasattr(torch, '_dynamo') and hasattr(torch._dynamo, 'config'):
torch._dynamo.config.recompile_limit = compile_args["dynamo_recompile_limit"]
except Exception as e:
log.warning(f"Could not set recompile_limit: {e}")
if compile_args["compile_transformer_blocks_only"]:
for i, block in enumerate(controlnet.controlnet_blocks):
controlnet.controlnet_blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"])
else:
controlnet = torch.compile(controlnet, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"])
if load_device == "offload_device" and controlnet.device != offload_device:
log.info(f"Moving controlnet model from {controlnet.device} to {offload_device}")
controlnet.to(offload_device)
gc.collect()
mm.soft_empty_cache()
return (controlnet,)
class WanVideoUni3C_embeds:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"controlnet": ("WANVIDEOCONTROLNET",),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent of the steps to apply the controlnet"}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent of the steps to apply the controlnet"}),
},
"optional": {
"render_latent": ("LATENT",),
"render_mask": ("MASK", {"tooltip": "NOT IMPLEMENTED!"}),
},
}
RETURN_TYPES = ("UNI3C_EMBEDS", )
RETURN_NAMES = ("uni3c_embeds",)
FUNCTION = "process"
CATEGORY = "WanVideoWrapper"
def process(self, controlnet, strength, start_percent, end_percent, render_latent=None, render_mask=None):
device = mm.get_torch_device()
latent_mask = latents = None
if render_latent is not None:
latents = render_latent["samples"]
# nframe = latents.shape[2] * 4
# height = latents.shape[3] * 8
# width = latents.shape[4] * 8
if render_mask is not None:
raise NotImplementedError("render_mask is not implemented at this time")
mask = torch.nn.functional.interpolate(
render_mask.unsqueeze(0).unsqueeze(0), # Add batch and channel dims [1,1,T,H,W]
size=(nframe, height, width),
mode='trilinear',
align_corners=False
).squeeze(0)
latent_mask = mask.unsqueeze(0).to(device)
log.info(f"latent mask shape {latent_mask.shape}")
# # load camera
# cam_info = json.load(open(f"{render_path}/cam_info.json"))
# w2cs = torch.tensor(np.array(cam_info["extrinsic"]), dtype=torch.float32, device=device)
# intrinsic = torch.tensor(np.array(cam_info["intrinsic"]), dtype=torch.float32, device=device)
# intrinsic[0, :] = intrinsic[0, :] / cam_info["width"] * width
# intrinsic[1, :] = intrinsic[1, :] / cam_info["height"] * height
# intrinsic = intrinsic[None].repeat(nframe, 1, 1)
# from .utils import build_cameras, set_initial_camera, traj_map
# focal_length = 1.0
# start_elevation = 5.0
# depth_avg = 0.5
# traj_type = "orbit"
# cam_traj, x_offset, y_offset, z_offset, d_theta, d_phi, d_r = traj_map(traj_type)
# focallength_px = focal_length * width
# K = torch.tensor([[focallength_px, 0, width / 2],
# [0, focallength_px, height / 2],
# [0, 0, 1]], dtype=torch.float32)
# K_inv = K.inverse()
# intrinsic = K[None].repeat(nframe, 1, 1)
# w2c_0, c2w_0 = set_initial_camera(start_elevation, depth_avg)
# w2cs, c2ws, intrinsic = build_cameras(cam_traj=cam_traj,
# w2c_0=w2c_0,
# c2w_0=c2w_0,
# intrinsic=intrinsic,
# nframe=nframe,
# focal_length=focal_length,
# d_theta=d_theta,
# d_phi=d_phi,
# d_r=d_r,
# radius=depth_avg,
# x_offset=x_offset,
# y_offset=y_offset,
# z_offset=z_offset)
# from .camera import get_camera_embedding
# camera_embedding = get_camera_embedding(intrinsic, w2cs, nframe, height, width, normalize=True)
#print("camera embedding shape", camera_embedding.shape)
uni3c_embeds = {
"controlnet": controlnet,
"controlnet_weight": strength,
"start": start_percent,
"end": end_percent,
"render_latent": latents,
"render_mask": latent_mask,
"camera_embedding": None
}
return (uni3c_embeds,)
NODE_CLASS_MAPPINGS = {
"WanVideoUni3C_ControlnetLoader": WanVideoUni3C_ControlnetLoader,
"WanVideoUni3C_embeds": WanVideoUni3C_embeds,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"WanVideoUni3C_ControlnetLoader": "WanVideo Uni3C Controlnet Loader",
"WanVideoUni3C_embeds": "WanVideo Uni3C Embeds",
}