| import os |
| import torch |
| import gc |
| from ..utils import log, dict_to_device |
| import numpy as np |
| from accelerate import init_empty_weights |
| from accelerate.utils import set_module_tensor_to_device |
|
|
| import comfy.model_management as mm |
| from comfy.utils import load_torch_file |
| import folder_paths |
|
|
| script_directory = os.path.dirname(os.path.abspath(__file__)) |
| device = mm.get_torch_device() |
| offload_device = mm.unet_offload_device() |
|
|
| local_model_path = os.path.join(folder_paths.models_dir, "nlf", "nlf_l_multi_0.3.2.torchscript") |
|
|
| from .motion4d import SMPL_VQVAE, VectorQuantizer, Encoder, Decoder |
| from .mtv import prepare_motion_embeddings |
|
|
| class DownloadAndLoadNLFModel: |
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "url": ( |
| [ |
| "https://github.com/isarandi/nlf/releases/download/v0.3.2/nlf_l_multi_0.3.2.torchscript" |
| ], |
| ) |
| }, |
| } |
|
|
| RETURN_TYPES = ("NLFMODEL",) |
| RETURN_NAMES = ("nlf_model", ) |
| FUNCTION = "loadmodel" |
| CATEGORY = "WanVideoWrapper" |
|
|
| def loadmodel(self, url): |
| |
| if not os.path.exists(local_model_path): |
| log.info(f"Downloading NLF model to: {local_model_path}") |
| import requests |
| os.makedirs(os.path.dirname(local_model_path), exist_ok=True) |
| response = requests.get(url) |
| if response.status_code == 200: |
| with open(local_model_path, "wb") as f: |
| f.write(response.content) |
| else: |
| print("Failed to download file:", response.status_code) |
|
|
| model = torch.jit.load(local_model_path).eval() |
|
|
| return (model,) |
|
|
| class LoadNLFModel: |
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "path": ("STRING", {"default": local_model_path}), |
| }, |
| } |
|
|
| RETURN_TYPES = ("NLFMODEL",) |
| RETURN_NAMES = ("nlf_model", ) |
| FUNCTION = "loadmodel" |
| CATEGORY = "WanVideoWrapper" |
|
|
| def loadmodel(self, path): |
| model = torch.jit.load(path).eval() |
|
|
| return model, |
|
|
| class LoadVQVAE: |
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "model_name": (folder_paths.get_filename_list("vae"), {"tooltip": "These models are loaded from 'ComfyUI/models/vae'"}), |
| }, |
| } |
|
|
| RETURN_TYPES = ("VQVAE",) |
| RETURN_NAMES = ("vqvae", ) |
| FUNCTION = "loadmodel" |
| CATEGORY = "WanVideoWrapper" |
|
|
| def loadmodel(self, model_name): |
| model_path = folder_paths.get_full_path("vae", model_name) |
| vae_sd = load_torch_file(model_path, safe_load=True) |
|
|
| |
| motion_encoder = Encoder( |
| in_channels=3, |
| mid_channels=[128, 512], |
| out_channels=3072, |
| downsample_time=[2, 2], |
| downsample_joint=[1, 1] |
| ) |
| motion_quant = VectorQuantizer(nb_code=8192, code_dim=3072) |
| motion_decoder = Decoder( |
| in_channels=3072, |
| mid_channels=[512, 128], |
| out_channels=3, |
| upsample_rate=2.0, |
| frame_upsample_rate=[2.0, 2.0], |
| joint_upsample_rate=[1.0, 1.0] |
| ) |
| |
| vqvae = SMPL_VQVAE(motion_encoder, motion_decoder, motion_quant).to(device) |
| vqvae.load_state_dict(vae_sd, strict=True) |
|
|
| return vqvae, |
|
|
| class MTVCrafterEncodePoses: |
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "vqvae": ("VQVAE", {"tooltip": "VQVAE model"}), |
| "poses": ("NLFPRED", {"tooltip": "Input poses for the model"}), |
| }, |
| } |
|
|
| RETURN_TYPES = ("MTVCRAFTERMOTION", "NLFPRED") |
| RETURN_NAMES = ("mtvcrafter_motion", "pose_results") |
| FUNCTION = "encode" |
| CATEGORY = "WanVideoWrapper" |
|
|
| def encode(self, vqvae, poses): |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| global_mean = np.load(os.path.join(script_directory, "data", "mean.npy")) |
| global_std = np.load(os.path.join(script_directory, "data", "std.npy")) |
|
|
| smpl_poses = [] |
| for pose in poses['joints3d_nonparam'][0]: |
| smpl_poses.append(pose[0].cpu().numpy()) |
| smpl_poses = np.array(smpl_poses) |
|
|
| norm_poses = torch.tensor((smpl_poses - global_mean) / global_std).unsqueeze(0) |
| print(f"norm_poses shape: {norm_poses.shape}, dtype: {norm_poses.dtype}") |
|
|
| vqvae.to(device) |
| motion_tokens, vq_loss = vqvae(norm_poses.to(device), return_vq=True) |
| |
| recon_motion = vqvae(norm_poses.to(device))[0][0].to(dtype=torch.float32).cpu().detach() * global_std + global_mean |
| vqvae.to(offload_device) |
|
|
| poses_dict = { |
| 'mtv_motion_tokens': motion_tokens, |
| 'global_mean': global_mean, |
| 'global_std': global_std |
| } |
| |
| return poses_dict, recon_motion |
|
|
|
|
| class NLFPredict: |
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": { |
| "model": ("NLFMODEL",), |
| "images": ("IMAGE", {"tooltip": "Input images for the model"}), |
| }, |
| } |
|
|
| RETURN_TYPES = ("NLFPRED", ) |
| RETURN_NAMES = ("pose_results",) |
| FUNCTION = "predict" |
| CATEGORY = "WanVideoWrapper" |
|
|
| def predict(self, model, images): |
| |
| model.to(device) |
| pred = model.detect_smpl_batched(images.permute(0, 3, 1, 2).to(device)) |
| model.to(offload_device) |
|
|
| pred = dict_to_device(pred, offload_device) |
|
|
| pose_results = { |
| 'joints3d_nonparam': [], |
| } |
| |
| for key in pose_results.keys(): |
| if key in pred: |
| pose_results[key].append(pred[key]) |
| else: |
| pose_results[key].append(None) |
| |
| return (pose_results,) |
|
|
| class DrawNLFPoses: |
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": { |
| "poses": ("NLFPRED", {"tooltip": "Input poses for the model"}), |
| "width": ("INT", {"default": 512}), |
| "height": ("INT", {"default": 512}), |
| }, |
| } |
|
|
| RETURN_TYPES = ("IMAGE", ) |
| RETURN_NAMES = ("image",) |
| FUNCTION = "predict" |
| CATEGORY = "WanVideoWrapper" |
|
|
| def predict(self, poses, width, height): |
| from .draw_pose import get_control_conditions |
| print(type(poses)) |
| if isinstance(poses, dict): |
| pose_input = poses['joints3d_nonparam'][0] if 'joints3d_nonparam' in poses else poses |
| else: |
| pose_input = poses |
| control_conditions = get_control_conditions(pose_input, height, width) |
|
|
| return (control_conditions,) |
|
|
| NODE_CLASS_MAPPINGS = { |
| "DownloadAndLoadNLFModel": DownloadAndLoadNLFModel, |
| "NLFPredict": NLFPredict, |
| "DrawNLFPoses": DrawNLFPoses, |
| "LoadVQVAE": LoadVQVAE, |
| "MTVCrafterEncodePoses": MTVCrafterEncodePoses |
| } |
| NODE_DISPLAY_NAME_MAPPINGS = { |
| "DownloadAndLoadNLFModel": "(Download)Load NLF Model", |
| "NLFPredict": "NLF Predict", |
| "DrawNLFPoses": "Draw NLF Poses", |
| "LoadVQVAE": "Load VQVAE", |
| "MTVCrafterEncodePoses": "MTV Crafter Encode Poses" |
| } |
|
|