import os import torch import gc from ..utils import log, dict_to_device import numpy as np from accelerate import init_empty_weights from accelerate.utils import set_module_tensor_to_device import comfy.model_management as mm from comfy.utils import load_torch_file import folder_paths script_directory = os.path.dirname(os.path.abspath(__file__)) device = mm.get_torch_device() offload_device = mm.unet_offload_device() local_model_path = os.path.join(folder_paths.models_dir, "nlf", "nlf_l_multi_0.3.2.torchscript") from .motion4d import SMPL_VQVAE, VectorQuantizer, Encoder, Decoder from .mtv import prepare_motion_embeddings class DownloadAndLoadNLFModel: @classmethod def INPUT_TYPES(s): return { "required": { "url": ( [ "https://github.com/isarandi/nlf/releases/download/v0.3.2/nlf_l_multi_0.3.2.torchscript" ], ) }, } RETURN_TYPES = ("NLFMODEL",) RETURN_NAMES = ("nlf_model", ) FUNCTION = "loadmodel" CATEGORY = "WanVideoWrapper" def loadmodel(self, url): if not os.path.exists(local_model_path): log.info(f"Downloading NLF model to: {local_model_path}") import requests os.makedirs(os.path.dirname(local_model_path), exist_ok=True) response = requests.get(url) if response.status_code == 200: with open(local_model_path, "wb") as f: f.write(response.content) else: print("Failed to download file:", response.status_code) model = torch.jit.load(local_model_path).eval() return (model,) class LoadNLFModel: @classmethod def INPUT_TYPES(s): return { "required": { "path": ("STRING", {"default": local_model_path}), }, } RETURN_TYPES = ("NLFMODEL",) RETURN_NAMES = ("nlf_model", ) FUNCTION = "loadmodel" CATEGORY = "WanVideoWrapper" def loadmodel(self, path): model = torch.jit.load(path).eval() return model, class LoadVQVAE: @classmethod def INPUT_TYPES(s): return { "required": { "model_name": (folder_paths.get_filename_list("vae"), {"tooltip": "These models are loaded from 'ComfyUI/models/vae'"}), }, } RETURN_TYPES = ("VQVAE",) RETURN_NAMES = ("vqvae", ) FUNCTION = "loadmodel" CATEGORY = "WanVideoWrapper" def loadmodel(self, model_name): model_path = folder_paths.get_full_path("vae", model_name) vae_sd = load_torch_file(model_path, safe_load=True) # Get motion tokenizer motion_encoder = Encoder( in_channels=3, mid_channels=[128, 512], out_channels=3072, downsample_time=[2, 2], downsample_joint=[1, 1] ) motion_quant = VectorQuantizer(nb_code=8192, code_dim=3072) motion_decoder = Decoder( in_channels=3072, mid_channels=[512, 128], out_channels=3, upsample_rate=2.0, frame_upsample_rate=[2.0, 2.0], joint_upsample_rate=[1.0, 1.0] ) vqvae = SMPL_VQVAE(motion_encoder, motion_decoder, motion_quant).to(device) vqvae.load_state_dict(vae_sd, strict=True) return vqvae, class MTVCrafterEncodePoses: @classmethod def INPUT_TYPES(s): return { "required": { "vqvae": ("VQVAE", {"tooltip": "VQVAE model"}), "poses": ("NLFPRED", {"tooltip": "Input poses for the model"}), }, } RETURN_TYPES = ("MTVCRAFTERMOTION", "NLFPRED") RETURN_NAMES = ("mtvcrafter_motion", "pose_results") FUNCTION = "encode" CATEGORY = "WanVideoWrapper" def encode(self, vqvae, poses): # import pickle # with open(os.path.join(script_directory, "data", "sampled_data.pkl"), 'rb') as f: # data_list = pickle.load(f) # if not isinstance(data_list, list): # data_list = [data_list] # print(data_list) # smpl_poses = data_list[1]['pose'] global_mean = np.load(os.path.join(script_directory, "data", "mean.npy")) #global_mean.shape: (24, 3) global_std = np.load(os.path.join(script_directory, "data", "std.npy")) smpl_poses = [] for pose in poses['joints3d_nonparam'][0]: smpl_poses.append(pose[0].cpu().numpy()) smpl_poses = np.array(smpl_poses) norm_poses = torch.tensor((smpl_poses - global_mean) / global_std).unsqueeze(0) print(f"norm_poses shape: {norm_poses.shape}, dtype: {norm_poses.dtype}") vqvae.to(device) motion_tokens, vq_loss = vqvae(norm_poses.to(device), return_vq=True) recon_motion = vqvae(norm_poses.to(device))[0][0].to(dtype=torch.float32).cpu().detach() * global_std + global_mean vqvae.to(offload_device) poses_dict = { 'mtv_motion_tokens': motion_tokens, 'global_mean': global_mean, 'global_std': global_std } return poses_dict, recon_motion class NLFPredict: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("NLFMODEL",), "images": ("IMAGE", {"tooltip": "Input images for the model"}), }, } RETURN_TYPES = ("NLFPRED", ) RETURN_NAMES = ("pose_results",) FUNCTION = "predict" CATEGORY = "WanVideoWrapper" def predict(self, model, images): model.to(device) pred = model.detect_smpl_batched(images.permute(0, 3, 1, 2).to(device)) model.to(offload_device) pred = dict_to_device(pred, offload_device) pose_results = { 'joints3d_nonparam': [], } # Collect pose data for key in pose_results.keys(): if key in pred: pose_results[key].append(pred[key]) else: pose_results[key].append(None) return (pose_results,) class DrawNLFPoses: @classmethod def INPUT_TYPES(s): return {"required": { "poses": ("NLFPRED", {"tooltip": "Input poses for the model"}), "width": ("INT", {"default": 512}), "height": ("INT", {"default": 512}), }, } RETURN_TYPES = ("IMAGE", ) RETURN_NAMES = ("image",) FUNCTION = "predict" CATEGORY = "WanVideoWrapper" def predict(self, poses, width, height): from .draw_pose import get_control_conditions print(type(poses)) if isinstance(poses, dict): pose_input = poses['joints3d_nonparam'][0] if 'joints3d_nonparam' in poses else poses else: pose_input = poses control_conditions = get_control_conditions(pose_input, height, width) return (control_conditions,) NODE_CLASS_MAPPINGS = { "DownloadAndLoadNLFModel": DownloadAndLoadNLFModel, "NLFPredict": NLFPredict, "DrawNLFPoses": DrawNLFPoses, "LoadVQVAE": LoadVQVAE, "MTVCrafterEncodePoses": MTVCrafterEncodePoses } NODE_DISPLAY_NAME_MAPPINGS = { "DownloadAndLoadNLFModel": "(Download)Load NLF Model", "NLFPredict": "NLF Predict", "DrawNLFPoses": "Draw NLF Poses", "LoadVQVAE": "Load VQVAE", "MTVCrafterEncodePoses": "MTV Crafter Encode Poses" }