diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..bbcb7a3b482d765050eab4ac90f93d4bf67f1ca9 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,35 +1,16 @@ -*.7z filter=lfs diff=lfs merge=lfs -text -*.arrow filter=lfs diff=lfs merge=lfs -text -*.bin filter=lfs diff=lfs merge=lfs -text -*.bz2 filter=lfs diff=lfs merge=lfs -text -*.ckpt filter=lfs diff=lfs merge=lfs -text -*.ftz filter=lfs diff=lfs merge=lfs -text -*.gz filter=lfs diff=lfs merge=lfs -text -*.h5 filter=lfs diff=lfs merge=lfs -text -*.joblib filter=lfs diff=lfs merge=lfs -text -*.lfs.* filter=lfs diff=lfs merge=lfs -text -*.mlmodel filter=lfs diff=lfs merge=lfs -text -*.model filter=lfs diff=lfs merge=lfs -text -*.msgpack filter=lfs diff=lfs merge=lfs -text -*.npy filter=lfs diff=lfs merge=lfs -text -*.npz filter=lfs diff=lfs merge=lfs -text -*.onnx filter=lfs diff=lfs merge=lfs -text -*.ot filter=lfs diff=lfs merge=lfs -text -*.parquet filter=lfs diff=lfs merge=lfs -text -*.pb filter=lfs diff=lfs merge=lfs -text -*.pickle filter=lfs diff=lfs merge=lfs -text -*.pkl filter=lfs diff=lfs merge=lfs -text -*.pt filter=lfs diff=lfs merge=lfs -text -*.pth filter=lfs diff=lfs merge=lfs -text -*.rar filter=lfs diff=lfs merge=lfs -text -*.safetensors filter=lfs diff=lfs merge=lfs -text -saved_model/**/* filter=lfs diff=lfs merge=lfs -text -*.tar.* filter=lfs diff=lfs merge=lfs -text -*.tar filter=lfs diff=lfs merge=lfs -text -*.tflite filter=lfs diff=lfs merge=lfs -text -*.tgz filter=lfs diff=lfs merge=lfs -text -*.wasm filter=lfs diff=lfs merge=lfs -text -*.xz filter=lfs diff=lfs merge=lfs -text -*.zip filter=lfs diff=lfs merge=lfs -text -*.zst filter=lfs diff=lfs merge=lfs -text -*tfevents* filter=lfs diff=lfs merge=lfs -text +# Auto detect text files and perform LF normalization +* text=auto +MTV/data/mean.npy filter=lfs diff=lfs merge=lfs -text +MTV/data/std.npy filter=lfs diff=lfs merge=lfs -text +configs/T5_tokenizer/spiece.model filter=lfs diff=lfs merge=lfs -text +configs/T5_tokenizer/tokenizer.json filter=lfs diff=lfs merge=lfs -text +example_workflows/example_inputs/MTV_crafter_example_pose.mp4 filter=lfs diff=lfs merge=lfs -text +example_workflows/example_inputs/env.png filter=lfs diff=lfs merge=lfs -text +example_workflows/example_inputs/human.png filter=lfs diff=lfs merge=lfs -text +example_workflows/example_inputs/jeep.mp4 filter=lfs diff=lfs merge=lfs -text +example_workflows/example_inputs/wolf_interpolated.mp4 filter=lfs diff=lfs merge=lfs -text +example_workflows/example_inputs/woman.jpg filter=lfs diff=lfs merge=lfs -text +example_workflows/example_inputs/woman.wav filter=lfs diff=lfs merge=lfs -text +fantasyportrait/models/face_det.onnx filter=lfs diff=lfs merge=lfs -text +fantasyportrait/models/face_landmark.onnx filter=lfs diff=lfs merge=lfs -text +multitalk/encoded_silence.safetensors filter=lfs diff=lfs merge=lfs -text diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 0000000000000000000000000000000000000000..2cbf13eac5b7a649c039eea4d9eb842bd59dabe8 --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1 @@ +github: [kijai] diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 0000000000000000000000000000000000000000..181cf4bf1f2a0c88a503d56592574b1f6902b1d2 --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,25 @@ +name: Publish to Comfy registry +on: + workflow_dispatch: + push: + branches: + - main + paths: + - "pyproject.toml" + +permissions: + issues: write + +jobs: + publish-node: + name: Publish Custom Node to registry + runs-on: ubuntu-latest + if: ${{ github.repository_owner == 'kijai' }} + steps: + - name: Check out code + uses: actions/checkout@v4 + - name: Publish Custom Node + uses: Comfy-Org/publish-node-action@v1 + with: + ## Add your own personal access token to your Github Repository secrets and reference it here. + personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..d32eaccb39749ed6d9e0b7bd75ec02149d6b59bb --- /dev/null +++ b/.gitignore @@ -0,0 +1,13 @@ +output/ +*__pycache__/ +samples*/ +runs/ +checkpoints/ +master_ip +logs/ +*.DS_Store +.idea +tools/ +.vscode/ +convert_* +*.pt \ No newline at end of file diff --git a/ATI/motion.py b/ATI/motion.py new file mode 100644 index 0000000000000000000000000000000000000000..3d615e65c63c28c2a11696148e0fd53e3fb5067d --- /dev/null +++ b/ATI/motion.py @@ -0,0 +1,42 @@ +# Copyright (c) 2024-2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Optional, Tuple, Union +import numpy as np +import torch + +def process_tracks(tracks_np: np.ndarray, frame_size: Tuple[int, int], quant_multi: int = 8, **kwargs): + # tracks: shape [t, h, w, 3] => samples align with 24 fps, model trained with 16 fps. + # frame_size: tuple (W, H) + + tracks = torch.from_numpy(tracks_np).float() + + if tracks.shape[1] == 121: + tracks = torch.permute(tracks, (1, 0, 2, 3)) + + tracks, visibles = tracks[..., :2], tracks[..., 2:3] + short_edge = min(*frame_size) + + tracks = tracks - torch.tensor([*frame_size]).type_as(tracks) / 2 + tracks = tracks / short_edge * 2 + + visibles = visibles * 2 - 1 + + trange = torch.linspace(-1, 1, tracks.shape[0]).view(-1, 1, 1, 1).expand(*visibles.shape) + + out_ = torch.cat([trange, tracks, visibles], dim=-1).view(121, -1, 4) + out_0 = out_[:1] + out_l = out_[1:] # 121 => 120 | 1 + out_l = torch.repeat_interleave(out_l, 2, dim=0)[1::3] # 120 => 240 => 80 + return torch.cat([out_0, out_l], dim=0) diff --git a/ATI/motion_patch.py b/ATI/motion_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..1adf1e5eae11805d1f86e388e53c31b3c1e0973a --- /dev/null +++ b/ATI/motion_patch.py @@ -0,0 +1,142 @@ +# Copyright (c) 2024-2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple, Union +import torch + + +# Refer to https://github.com/Angtian/VoGE/blob/main/VoGE/Utils.py +def ind_sel(target: torch.Tensor, ind: torch.Tensor, dim: int = 1): + """ + :param target: [... (can be k or 1), n > M, ...] + :param ind: [... (k), M] + :param dim: dim to apply index on + :return: sel_target [... (k), M, ...] + """ + assert ( + len(ind.shape) > dim + ), "Index must have the target dim, but get dim: %d, ind shape: %s" % (dim, str(ind.shape)) + + target = target.expand( + *tuple( + [ind.shape[k] if target.shape[k] == 1 else -1 for k in range(dim)] + + [ + -1, + ] + * (len(target.shape) - dim) + ) + ) + + ind_pad = ind + + if len(target.shape) > dim + 1: + for _ in range(len(target.shape) - (dim + 1)): + ind_pad = ind_pad.unsqueeze(-1) + ind_pad = ind_pad.expand(*(-1,) * (dim + 1), *target.shape[(dim + 1) : :]) + + return torch.gather(target, dim=dim, index=ind_pad) + + +def merge_final(vert_attr: torch.Tensor, weight: torch.Tensor, vert_assign: torch.Tensor): + """ + + :param vert_attr: [n, d] or [b, n, d] color or feature of each vertex + :param weight: [b(optional), w, h, M] weight of selected vertices + :param vert_assign: [b(optional), w, h, M] selective index + :return: + """ + target_dim = len(vert_assign.shape) - 1 + if len(vert_attr.shape) == 2: + assert vert_attr.shape[0] > vert_assign.max() + # [n, d] ind: [b(optional), w, h, M]-> [b(optional), w, h, M, d] + # sel_attr = ind_sel( + # vert_attr[(None,) * target_dim], vert_assign.type(torch.long), dim=target_dim + # ) + new_shape = [1] * target_dim + list(vert_attr.shape) + tensor = vert_attr.reshape(new_shape) + sel_attr = ind_sel(tensor, vert_assign.type(torch.long), dim=target_dim) + else: + assert vert_attr.shape[1] > vert_assign.max() + #sel_attr = ind_sel( + # vert_attr[:, *(None,) * (target_dim - 1)], vert_assign.type(torch.long), dim=target_dim + #) + new_shape = [vert_attr.shape[0]] + [1] * (target_dim - 1) + list(vert_attr.shape[1:]) + tensor = vert_attr.reshape(new_shape) + sel_attr = ind_sel(tensor, vert_assign.type(torch.long), dim=target_dim) + + # [b(optional), w, h, M] + final_attr = torch.sum(sel_attr * weight.unsqueeze(-1), dim=-2) + return final_attr + + +def patch_motion( + tracks: torch.FloatTensor, # (B, T, N, 4) + vid: torch.FloatTensor, # (C, T, H, W) + temperature: float = 220.0, + vae_divide: tuple = (4, 16), + topk: int = 2, +): + with torch.no_grad(): + _, T, H, W = vid.shape + N = tracks.shape[2] + _, tracks, visible = torch.split( + tracks, [1, 2, 1], dim=-1 + ) # (B, T, N, 2) | (B, T, N, 1) + tracks_n = tracks / torch.tensor([W / min(H, W), H / min(H, W)], device=tracks.device) + tracks_n = tracks_n.clamp(-1, 1) + visible = visible.clamp(0, 1) + + xx = torch.linspace(-W / min(H, W), W / min(H, W), W) + yy = torch.linspace(-H / min(H, W), H / min(H, W), H) + + grid = torch.stack(torch.meshgrid(yy, xx, indexing="ij")[::-1], dim=-1).to( + tracks.device + ) + + tracks_pad = tracks[:, 1:] + visible_pad = visible[:, 1:] + + visible_align = visible_pad.view(T - 1, 4, *visible_pad.shape[2:]).sum(1) + tracks_align = (tracks_pad * visible_pad).view(T - 1, 4, *tracks_pad.shape[2:]).sum( + 1 + ) / (visible_align + 1e-5) + dist_ = ( + (tracks_align[:, None, None] - grid[None, :, :, None]).pow(2).sum(-1) + ) # T, H, W, N + weight = torch.exp(-dist_ * temperature) * visible_align.clamp(0, 1).view( + T - 1, 1, 1, N + ) + vert_weight, vert_index = torch.topk( + weight, k=min(topk, weight.shape[-1]), dim=-1 + ) + + grid_mode = "bilinear" + point_feature = torch.nn.functional.grid_sample( + vid[vae_divide[0]:].permute(1, 0, 2, 3)[:1], + tracks_n[:, :1].type(vid.dtype), + mode=grid_mode, + padding_mode="zeros", + align_corners=False, + ) + point_feature = point_feature.squeeze(0).squeeze(1).permute(1, 0) # N, C=16 + + out_feature = merge_final(point_feature, vert_weight, vert_index).permute(3, 0, 1, 2) # T - 1, H, W, C => C, T - 1, H, W + out_weight = vert_weight.sum(-1) # T - 1, H, W + + # out feature -> already soft weighted + mix_feature = out_feature + vid[vae_divide[0]:, 1:] * (1 - out_weight.clamp(0, 1)) + + out_feature_full = torch.cat([vid[vae_divide[0]:, :1], mix_feature], dim=1) # C, T, H, W + out_mask_full = torch.cat([torch.ones_like(out_weight[:1]), out_weight], dim=0) # T, H, W + return torch.cat([out_mask_full[None].expand(vae_divide[0], -1, -1, -1), out_feature_full], dim=0) diff --git a/ATI/nodes.py b/ATI/nodes.py new file mode 100644 index 0000000000000000000000000000000000000000..cf25eff760c11e2e5a94e49af0b9050da5faba20 --- /dev/null +++ b/ATI/nodes.py @@ -0,0 +1,329 @@ +import json +from .motion import process_tracks +import numpy as np +from typing import List, Tuple +import torch +FIXED_LENGTH = 121 +def pad_pts(tr): + """Convert list of {x,y} to (FIXED_LENGTH,1,3) array, padding/truncating.""" + pts = np.array([[p['x'], p['y'], 1] for p in tr], dtype=np.float32) + n = pts.shape[0] + if n < FIXED_LENGTH: + pad = np.zeros((FIXED_LENGTH - n, 3), dtype=np.float32) + pts = np.vstack((pts, pad)) + else: + pts = pts[:FIXED_LENGTH] + return pts.reshape(FIXED_LENGTH, 1, 3) + +def age_to_bgr(ratio: float) -> Tuple[int,int,int]: + """ + Map ratio∈[0,1] through: 0→blue, 1/3→green, 2/3→yellow, 1→red. + Returns (B,G,R) for OpenCV. + """ + if ratio <= 1/3: + # blue→green + t = ratio / (1/3) + b = int(255 * (1 - t)) + g = int(255 * t) + r = 0 + elif ratio <= 2/3: + # green→yellow + t = (ratio - 1/3) / (1/3) + b = 0 + g = 255 + r = int(255 * t) + else: + # yellow→red + t = (ratio - 2/3) / (1/3) + b = 0 + g = int(255 * (1 - t)) + r = 255 + return (r, g, b) + +def paint_point_track( + frames: np.ndarray, + point_tracks: np.ndarray, + visibles: np.ndarray, + min_radius: int = 1, + max_radius: int = 6, + max_retain: int = 50 +) -> np.ndarray: + """ + Draws every past point of each track on each frame, with radius and color + interpolated by the point's age (old→small to new→large). + + Args: + frames: [F, H, W, 3] uint8 RGB + point_tracks:[N, F, 2] float32 – (x,y) in pixel coords + visibles: [N, F] bool – visibility mask + min_radius: radius for the very first point (oldest) + max_radius: radius for the current point (newest) + + Returns: + video: [F, H, W, 3] uint8 RGB + """ + import cv2 + num_points, num_frames = point_tracks.shape[:2] + H, W = frames.shape[1:3] + + video = frames.copy() + + for t in range(num_frames): + # start from the original frame + frame = video[t].copy() + + for i in range(num_points): + # draw every past step τ = 0..t + for τ in range(t + 1): + if not visibles[i, τ]: + continue + + if t - τ > max_retain: + continue + + # sub-pixel offset + clamp + x, y = point_tracks[i, τ] + 0.5 + xi = int(np.clip(x, 0, W - 1)) + yi = int(np.clip(y, 0, H - 1)) + + # age‐ratio in [0,1] + if num_frames > 1: + ratio = 1 - float(t - τ) / max_retain + else: + ratio = 1.0 + + # interpolated radius + radius = int(round(min_radius + (max_radius - min_radius) * ratio)) + + # OpenCV draws in BGR order: + color_rgb = age_to_bgr(ratio) + + # filled circle + cv2.circle(frame, (xi, yi), radius, color_rgb, thickness=-1) + + video[t] = frame + + return video + +def parse_json_tracks(tracks): + tracks_data = [] + try: + # If tracks is a string, try to parse it as JSON + if isinstance(tracks, str): + parsed = json.loads(tracks.replace("'", '"')) + tracks_data.extend(parsed) + else: + # If tracks is a list of strings, parse each one + for track_str in tracks: + parsed = json.loads(track_str.replace("'", '"')) + tracks_data.append(parsed) + + # Check if we have a single track (dict with x,y) or a list of tracks + if tracks_data and isinstance(tracks_data[0], dict) and 'x' in tracks_data[0]: + # Single track detected, wrap it in a list + tracks_data = [tracks_data] + elif tracks_data and isinstance(tracks_data[0], list) and tracks_data[0] and isinstance(tracks_data[0][0], dict) and 'x' in tracks_data[0][0]: + # Already a list of tracks, nothing to do + pass + else: + # Unexpected format + print(f"Warning: Unexpected track format: {type(tracks_data[0])}") + + except json.JSONDecodeError as e: + print(f"Error parsing tracks JSON: {e}") + tracks_data = [] + + return tracks_data + +class WanVideoATITracks: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "model": ("WANVIDEOMODEL", ), + "tracks": ("STRING",), + "width": ("INT", {"default": 832, "min": 64, "max": 2048, "step": 8, "tooltip": "Width of the image to encode"}), + "height": ("INT", {"default": 480, "min": 64, "max": 29048, "step": 8, "tooltip": "Height of the image to encode"}), + "temperature": ("FLOAT", {"default": 220.0, "min": 0.0, "max": 1000.0, "step": 0.1}), + "topk": ("INT", {"default": 2, "min": 1, "max": 10, "step": 1}), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent of the steps to apply ATI"}), + "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent of the steps to apply ATI"}), + }, + } + + RETURN_TYPES = ("WANVIDEOMODEL",) + RETURN_NAMES = ("model",) + FUNCTION = "patchmodel" + CATEGORY = "WanVideoWrapper" + + def patchmodel(self, model, tracks, width, height, temperature, topk, start_percent, end_percent): + tracks_data = parse_json_tracks(tracks) + arrs = [] + for track in tracks_data: + pts = pad_pts(track) + arrs.append(pts) + + tracks_np = np.stack(arrs, axis=0) + + processed_tracks = process_tracks(tracks_np, (width, height)) + + patcher = model.clone() + patcher.model_options["transformer_options"]["ati_tracks"] = processed_tracks.unsqueeze(0) + patcher.model_options["transformer_options"]["ati_temperature"] = temperature + patcher.model_options["transformer_options"]["ati_topk"] = topk + patcher.model_options["transformer_options"]["ati_start_percent"] = start_percent + patcher.model_options["transformer_options"]["ati_end_percent"] = end_percent + + return (patcher,) + +class WanVideoATITracksVisualize: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "images": ("IMAGE",), + "tracks": ("STRING",), + "min_radius": ("INT", {"default": 1, "min": 0, "max": 100, "step": 1, "tooltip": "radius for the very first point (oldest)"}), + "max_radius": ("INT", {"default": 6, "min": 0, "max": 100, "step": 1, "tooltip": "radius for the current point (newest)"}), + "max_retain": ("INT", {"default": 50, "min": 0, "max": 100, "step": 1, "tooltip": "Maximum number of points to retain"}), + }, + } + + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ("images",) + FUNCTION = "patchmodel" + CATEGORY = "WanVideoWrapper" + + def patchmodel(self, images, tracks, min_radius, max_radius, max_retain): + tracks_data = parse_json_tracks(tracks) + arrs = [] + for track in tracks_data: + pts = pad_pts(track) + arrs.append(pts) + + tracks_np = np.stack(arrs, axis=0) + track = np.repeat(tracks_np, 2, axis=1)[:, ::3] + points = track[:, :, 0, :2].astype(np.float32) + visibles = track[:, :, 0, 2].astype(np.float32) + + if images.shape[0] < points.shape[1]: + repeat_count = (points.shape[1] + images.shape[0] - 1) // images.shape[0] + images = images.repeat(repeat_count, 1, 1, 1) + images = images[:points.shape[1]] + elif images.shape[0] > points.shape[1]: + images = images[:points.shape[1]] + + video_viz = paint_point_track(images.cpu().numpy(), points, visibles, min_radius, max_radius, max_retain) + video_viz = torch.from_numpy(video_viz).float() + + return (video_viz,) + +from comfy import utils +import types +from .motion_patch import patch_motion + +class WanConcatCondPatch: + def __init__(self, tracks, temperature, topk): + self.tracks = tracks + self.temperature = temperature + self.topk = topk + + def __get__(self, obj, objtype=None): + # Create bound method with stored parameters + def wrapped_concat_cond(self_module, *args, **kwargs): + return modified_concat_cond(self_module, self.tracks, self.temperature, self.topk, *args, **kwargs) + return types.MethodType(wrapped_concat_cond, obj) + +def modified_concat_cond(self, tracks, temperature, topk, **kwargs): + noise = kwargs.get("noise", None) + extra_channels = self.diffusion_model.patch_embedding.weight.shape[1] - noise.shape[1] + if extra_channels == 0: + return None + + image = kwargs.get("concat_latent_image", None) + device = kwargs["device"] + + if image is None: + shape_image = list(noise.shape) + shape_image[1] = extra_channels + image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device) + else: + image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + for i in range(0, image.shape[1], 16): + image[:, i: i + 16] = self.process_latent_in(image[:, i: i + 16]) + image = utils.resize_to_batch_size(image, noise.shape[0]) + + if not self.image_to_video or extra_channels == image.shape[1]: + return image + + if image.shape[1] > (extra_channels - 4): + image = image[:, :(extra_channels - 4)] + + mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) + if mask is None: + mask = torch.zeros_like(noise)[:, :4] + else: + if mask.shape[1] != 4: + mask = torch.mean(mask, dim=1, keepdim=True) + mask = 1.0 - mask + mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + if mask.shape[-3] < noise.shape[-3]: + mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0) + if mask.shape[1] == 1: + mask = mask.repeat(1, 4, 1, 1, 1) + mask = utils.resize_to_batch_size(mask, noise.shape[0]) + + image_cond = torch.cat((mask, image), dim=1) + image_cond_ati = patch_motion(tracks.to(image_cond.device, image_cond.dtype), image_cond[0], + temperature=temperature, topk=topk) + + return image_cond_ati.unsqueeze(0) + +class WanVideoATI_comfy: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "model": ("MODEL", ), + "width": ("INT", {"default": 832, "min": 64, "max": 2048, "step": 8, "tooltip": "Width of the image to encode"}), + "height": ("INT", {"default": 480, "min": 64, "max": 29048, "step": 8, "tooltip": "Height of the image to encode"}), + "tracks": ("STRING",), + "temperature": ("FLOAT", {"default": 220.0, "min": 0.0, "max": 1000.0, "step": 0.1}), + "topk": ("INT", {"default": 2, "min": 1, "max": 10, "step": 1}), + }, + } + + RETURN_TYPES = ("MODEL",) + RETURN_NAMES = ("model", ) + FUNCTION = "patchcond" + CATEGORY = "WanVideoWrapper" + + def patchcond(self, model, tracks, width, height, temperature, topk): + + tracks_data = parse_json_tracks(tracks) + arrs = [] + for track in tracks_data: + pts = pad_pts(track) + arrs.append(pts) + + tracks_np = np.stack(arrs, axis=0) + + processed_tracks = process_tracks(tracks_np, (width, height)) + + model_clone = model.clone() + model_clone.add_object_patch( + "concat_cond", + WanConcatCondPatch( + processed_tracks.unsqueeze(0), temperature, topk + ).__get__(model.model, model.model.__class__) + ) + + return (model_clone,) + +NODE_CLASS_MAPPINGS = { + "WanVideoATITracks": WanVideoATITracks, + "WanVideoATITracksVisualize": WanVideoATITracksVisualize, + "WanVideoATI_comfy": WanVideoATI_comfy, + } +NODE_DISPLAY_NAME_MAPPINGS = { + "WanVideoATITracks": "WanVideo ATI Tracks", + "WanVideoATITracksVisualize": "WanVideo ATI Tracks Visualize", + "WanVideoATI_comfy": "WanVideo ATI Comfy", + } diff --git a/HuMo/audio_proj.py b/HuMo/audio_proj.py new file mode 100644 index 0000000000000000000000000000000000000000..8483a3fbc704a9083de056d4e099fe7f5065381b --- /dev/null +++ b/HuMo/audio_proj.py @@ -0,0 +1,87 @@ +import torch +from einops import rearrange +from torch import nn +from einops import rearrange + +class WanRMSNorm(nn.Module): + + def __init__(self, dim, eps=1e-5): + super().__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + r""" + Args: + x(Tensor): Shape [B, L, C] + """ + return self._norm(x.float()).type_as(x) * self.weight + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) + + +class DummyAdapterLayer(nn.Module): + def __init__(self, layer): + super().__init__() + self.layer = layer + + def forward(self, *args, **kwargs): + return self.layer(*args, **kwargs) + + +class AudioProjModel(nn.Module): + def __init__( + self, + seq_len=5, + blocks=13, # add a new parameter blocks + channels=768, # add a new parameter channels + intermediate_dim=512, + output_dim=1536, + context_tokens=16, + ): + super().__init__() + + self.seq_len = seq_len + self.blocks = blocks + self.channels = channels + self.input_dim = seq_len * blocks * channels # update input_dim to be the product of blocks and channels. + self.intermediate_dim = intermediate_dim + self.context_tokens = context_tokens + self.output_dim = output_dim + + # define multiple linear layers + self.audio_proj_glob_1 = DummyAdapterLayer(nn.Linear(self.input_dim, intermediate_dim)) + self.audio_proj_glob_2 = DummyAdapterLayer(nn.Linear(intermediate_dim, intermediate_dim)) + self.audio_proj_glob_3 = DummyAdapterLayer(nn.Linear(intermediate_dim, context_tokens * output_dim)) + + self.audio_proj_glob_norm = DummyAdapterLayer(nn.LayerNorm(output_dim)) + + self.initialize_weights() + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + def forward(self, audio_embeds): + video_length = audio_embeds.shape[1] + audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c") + batch_size, window_size, blocks, channels = audio_embeds.shape + audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels) + + audio_embeds = torch.relu(self.audio_proj_glob_1(audio_embeds)) + audio_embeds = torch.relu(self.audio_proj_glob_2(audio_embeds)) + + context_tokens = self.audio_proj_glob_3(audio_embeds).reshape(batch_size, self.context_tokens, self.output_dim) + + context_tokens = self.audio_proj_glob_norm(context_tokens) + context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length) + + return context_tokens \ No newline at end of file diff --git a/HuMo/nodes.py b/HuMo/nodes.py new file mode 100644 index 0000000000000000000000000000000000000000..3a98ba7d15dbc0775053e47172fb616f009e8c53 --- /dev/null +++ b/HuMo/nodes.py @@ -0,0 +1,287 @@ +import folder_paths +import torch +import torch.nn.functional as F +import os +import json +import torchaudio + +from comfy.utils import load_torch_file, common_upscale +import comfy.model_management as mm + +from accelerate import init_empty_weights +from ..utils import set_module_tensor_to_device, log +from ..nodes import WanVideoEncodeLatentBatch + +script_directory = os.path.dirname(os.path.abspath(__file__)) +device = mm.get_torch_device() +offload_device = mm.unet_offload_device() + +def linear_interpolation_fps(features, input_fps, output_fps, output_len=None): + features = features.transpose(1, 2) # [1, C, T] + seq_len = features.shape[2] / float(input_fps) + if output_len is None: + output_len = int(seq_len * output_fps) + output_features = F.interpolate(features, size=output_len, align_corners=True, mode='linear') + return output_features.transpose(1, 2) + +def get_audio_emb_window(audio_emb, frame_num, frame0_idx, audio_shift=2): + zero_audio_embed = torch.zeros((audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device) + zero_audio_embed_3 = torch.zeros((3, audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device) + iter_ = 1 + (frame_num - 1) // 4 + audio_emb_wind = [] + for lt_i in range(iter_): + if lt_i == 0: + st = frame0_idx + lt_i - 2 + ed = frame0_idx + lt_i + 3 + wind_feat = torch.stack([ + audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed + for i in range(st, ed) + ], dim=0) + wind_feat = torch.cat((zero_audio_embed_3, wind_feat), dim=0) + else: + st = frame0_idx + 1 + 4 * (lt_i - 1) - audio_shift + ed = frame0_idx + 1 + 4 * lt_i + audio_shift + wind_feat = torch.stack([ + audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed + for i in range(st, ed) + ], dim=0) + audio_emb_wind.append(wind_feat) + audio_emb_wind = torch.stack(audio_emb_wind, dim=0) + + return audio_emb_wind, ed - audio_shift + +class WhisperModelLoader: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": (folder_paths.get_filename_list("audio_encoders"), {"tooltip": "These models are loaded from the 'ComfyUI/models/audio_encoders' folder",}), + "base_precision": (["fp32", "bf16", "fp16"], {"default": "fp16"}), + "load_device": (["main_device", "offload_device"], {"default": "main_device", "tooltip": "Initial device to load the model to, NOT recommended with the larger models unless you have 48GB+ VRAM"}), + }, + } + + RETURN_TYPES = ("WHISPERMODEL",) + RETURN_NAMES = ("whisper_model", ) + FUNCTION = "loadmodel" + CATEGORY = "WanVideoWrapper" + + def loadmodel(self, model, base_precision, load_device): + from transformers import WhisperConfig, WhisperModel, WhisperFeatureExtractor + + base_dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "fp8_e4m3fn_fast": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp16_fast": torch.float16, "fp32": torch.float32}[base_precision] + + if load_device == "offload_device": + transformer_load_device = offload_device + else: + transformer_load_device = device + + config_path = os.path.join(script_directory, "whisper_config.json") + whisper_config = WhisperConfig(**json.load(open(config_path))) + + with init_empty_weights(): + whisper = WhisperModel(whisper_config).eval() + whisper.decoder = None # we only need the encoder + + feature_extractor_config = { + "chunk_length": 30, + "feature_extractor_type": "WhisperFeatureExtractor", + "feature_size": 128, + "hop_length": 160, + "n_fft": 400, + "n_samples": 480000, + "nb_max_frames": 3000, + "padding_side": "right", + "padding_value": 0.0, + "processor_class": "WhisperProcessor", + "return_attention_mask": False, + "sampling_rate": 16000 + } + + feature_extractor = WhisperFeatureExtractor(**feature_extractor_config) + + model_path = folder_paths.get_full_path_or_raise("audio_encoders", model) + sd = load_torch_file(model_path, device=transformer_load_device, safe_load=True) + + for name, param in whisper.named_parameters(): + key = "model." + name + value=sd[key] + set_module_tensor_to_device(whisper, name, device=offload_device, dtype=base_dtype, value=value) + + whisper_model = { + "feature_extractor": feature_extractor, + "model": whisper, + "dtype": base_dtype, + } + + return (whisper_model,) + +class HuMoEmbeds: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "num_frames": ("INT", {"default": 81, "min": -1, "max": 10000, "step": 1, "tooltip": "The total frame count to generate."}), + "width": ("INT", {"default": 832, "min": 64, "max": 4096, "step": 16}), + "height": ("INT", {"default": 480, "min": 64, "max": 4096, "step": 16}), + "audio_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "tooltip": "Strength of the audio conditioning"}), + "audio_cfg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "tooltip": "When not 1.0, an extra model pass without audio conditioning is done: slower inference but more motion is allowed"}), + "audio_start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The percent of the video to start applying audio conditioning"}), + "audio_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The percent of the video to stop applying audio conditioning"}) + }, + "optional" : { + "whisper_model": ("WHISPERMODEL",), + "vae": ("WANVAE", ), + "reference_images": ("IMAGE", {"tooltip": "reference images for the humo model"}), + "audio": ("AUDIO",), + "tiled_vae": ("BOOLEAN", {"default": False, "tooltip": "Use tiled VAE encoding for reduced memory use"}), + } + } + + RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", ) + RETURN_NAMES = ("image_embeds", ) + FUNCTION = "process" + CATEGORY = "WanVideoWrapper" + + def process(self, num_frames, width, height, audio_scale, audio_cfg_scale, audio_start_percent, audio_end_percent, whisper_model=None, vae=None, reference_images=None, audio=None, tiled_vae=False): + if reference_images is not None and vae is None: + raise ValueError("VAE is required when reference images are provided") + if whisper_model is None and audio is not None: + raise ValueError("Whisper model is required when audio is provided") + model = whisper_model["model"] + feature_extractor = whisper_model["feature_extractor"] + dtype = whisper_model["dtype"] + + sampling_rate = 16000 + + if audio is not None: + audio_input = audio["waveform"][0] + sample_rate = audio["sample_rate"] + + if sample_rate != sampling_rate: + audio_input = torchaudio.functional.resample(audio_input, sample_rate, sampling_rate) + if audio_input.shape[1] == 2: + audio_input = audio_input.mean(dim=0, keepdim=False) + else: + audio_input = audio_input[0] + + model.to(device) + audio_len = len(audio_input) // 640 + + # feature extraction + audio_features = [] + window = 750*640 + for i in range(0, len(audio_input), window): + audio_feature = feature_extractor(audio_input[i:i+window], sampling_rate=sampling_rate, return_tensors="pt").input_features + audio_features.append(audio_feature) + audio_features = torch.cat(audio_features, dim=-1).to(device, dtype) + + # preprocess + window = 3000 + audio_prompts = [] + for i in range(0, audio_features.shape[-1], window): + audio_prompt = model.encoder(audio_features[:,:,i:i+window], output_hidden_states=True).hidden_states + audio_prompt = torch.stack(audio_prompt, dim=2) + audio_prompts.append(audio_prompt) + + model.to(offload_device) + + audio_prompts = torch.cat(audio_prompts, dim=1) + audio_prompts = audio_prompts[:,:audio_len*2] + + feat0 = linear_interpolation_fps(audio_prompts[:, :, 0: 8].mean(dim=2), 50, 25) + feat1 = linear_interpolation_fps(audio_prompts[:, :, 8: 16].mean(dim=2), 50, 25) + feat2 = linear_interpolation_fps(audio_prompts[:, :, 16: 24].mean(dim=2), 50, 25) + feat3 = linear_interpolation_fps(audio_prompts[:, :, 24: 32].mean(dim=2), 50, 25) + feat4 = linear_interpolation_fps(audio_prompts[:, :, 32], 50, 25) + audio_emb = torch.stack([feat0, feat1, feat2, feat3, feat4], dim=2)[0] # [T, 5, 1280] + else: + audio_emb = torch.zeros(num_frames, 5, 1280, device=device) + audio_len = num_frames + + pixel_frame_num = num_frames if num_frames != -1 else audio_len + pixel_frame_num = 4 * ((pixel_frame_num - 1) // 4) + 1 + latent_frame_num = (pixel_frame_num - 1) // 4 + 1 + + log.info(f"HuMo set to generate {pixel_frame_num} frames") + + #audio_emb, _ = get_audio_emb_window(audio_emb, pixel_frame_num, frame0_idx=0) + + num_refs = 0 + if reference_images is not None: + if reference_images.shape[1] != height or reference_images.shape[2] != width: + reference_images_in = common_upscale(reference_images.movedim(-1, 1), width, height, "lanczos", "disabled").movedim(1, -1) + else: + reference_images_in = reference_images + samples, = WanVideoEncodeLatentBatch.encode(self, vae, reference_images_in, tiled_vae, None, None, None, None) + samples = samples["samples"].transpose(0, 2).squeeze(0) + num_refs = samples.shape[1] + + vae.to(device) + zero_frames = torch.zeros(1, 3, pixel_frame_num + 4*num_refs, height, width, device=device, dtype=vae.dtype) + zero_latents = vae.encode(zero_frames, device=device, tiled=tiled_vae)[0].to(offload_device) + + vae.to(offload_device) + mm.soft_empty_cache() + + target_shape = (16, latent_frame_num + num_refs, height // 8, width // 8) + + mask = torch.ones(4, target_shape[1], target_shape[2], target_shape[3], device=offload_device, dtype=vae.dtype) + if reference_images is not None: + mask[:,:-num_refs] = 0 + image_cond = torch.cat([zero_latents[:, :(target_shape[1]-num_refs)], samples], dim=1) + #zero_audio_pad = torch.zeros(num_refs, *audio_emb.shape[1:]).to(audio_emb.device) + #audio_emb = torch.cat([audio_emb, zero_audio_pad], dim=0) + else: + image_cond = zero_latents + mask = torch.zeros_like(mask) + image_cond = torch.cat([mask, image_cond], dim=0) + image_cond_neg = torch.cat([mask, zero_latents], dim=0) + + embeds = { + "humo_audio_emb": audio_emb, + "humo_audio_emb_neg": torch.zeros_like(audio_emb, dtype=audio_emb.dtype, device=audio_emb.device), + "humo_image_cond": image_cond, + "humo_image_cond_neg": image_cond_neg, + "humo_reference_count": num_refs, + "target_shape": target_shape, + "num_frames": pixel_frame_num, + "humo_audio_scale": audio_scale, + "humo_audio_cfg_scale": audio_cfg_scale, + "humo_start_percent": audio_start_percent, + "humo_end_percent": audio_end_percent, + } + + return (embeds, ) + +class WanVideoCombineEmbeds: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "embeds_1": ("WANVIDIMAGE_EMBEDS",), + "embeds_2": ("WANVIDIMAGE_EMBEDS",), + } + } + + RETURN_TYPES = ("WANVIDIMAGE_EMBEDS",) + RETURN_NAMES = ("image_embeds",) + FUNCTION = "add" + CATEGORY = "WanVideoWrapper" + EXPERIMENTAL = True + + def add(self, embeds_1, embeds_2): + # Combine the two sets of embeds + combined = {**embeds_1, **embeds_2} + return (combined,) + + +NODE_CLASS_MAPPINGS = { + "WhisperModelLoader": WhisperModelLoader, + "HuMoEmbeds": HuMoEmbeds, + "WanVideoCombineEmbeds": WanVideoCombineEmbeds, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "WhisperModelLoader": "Whisper Model Loader", + "HuMoEmbeds": "HuMo Embeds", + "WanVideoCombineEmbeds": "WanVideo Combine Embeds", +} diff --git a/HuMo/whisper_config.json b/HuMo/whisper_config.json new file mode 100644 index 0000000000000000000000000000000000000000..14c6c8cf48b64ebb1cb8b637e2b0fab3a9774972 --- /dev/null +++ b/HuMo/whisper_config.json @@ -0,0 +1,50 @@ +{ + "_name_or_path": "openai/whisper-large-v3", + "activation_dropout": 0.0, + "activation_function": "gelu", + "apply_spec_augment": false, + "architectures": [ + "WhisperForConditionalGeneration" + ], + "attention_dropout": 0.0, + "begin_suppress_tokens": [ + 220, + 50257 + ], + "bos_token_id": 50257, + "classifier_proj_size": 256, + "d_model": 1280, + "decoder_attention_heads": 20, + "decoder_ffn_dim": 5120, + "decoder_layerdrop": 0.0, + "decoder_layers": 32, + "decoder_start_token_id": 50258, + "dropout": 0.0, + "encoder_attention_heads": 20, + "encoder_ffn_dim": 5120, + "encoder_layerdrop": 0.0, + "encoder_layers": 32, + "eos_token_id": 50257, + "init_std": 0.02, + "is_encoder_decoder": true, + "mask_feature_length": 10, + "mask_feature_min_masks": 0, + "mask_feature_prob": 0.0, + "mask_time_length": 10, + "mask_time_min_masks": 2, + "mask_time_prob": 0.05, + "max_length": 448, + "max_source_positions": 1500, + "max_target_positions": 448, + "median_filter_width": 7, + "model_type": "whisper", + "num_hidden_layers": 32, + "num_mel_bins": 128, + "pad_token_id": 50256, + "scale_embedding": false, + "torch_dtype": "float16", + "transformers_version": "4.36.0.dev0", + "use_cache": true, + "use_weighted_layer_sum": false, + "vocab_size": 51866 +} diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/MTV/data/mean.npy b/MTV/data/mean.npy new file mode 100644 index 0000000000000000000000000000000000000000..997feeb77c920e6c82ffc8081d3453e413fdafd2 --- /dev/null +++ b/MTV/data/mean.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3ababeaabf5ac096ce7c7714ada14aa1de8355c0016de25695be611d51285141 +size 416 diff --git a/MTV/data/std.npy b/MTV/data/std.npy new file mode 100644 index 0000000000000000000000000000000000000000..2110569fe9427ce05d2feb35781d9a6c1dedd7e2 --- /dev/null +++ b/MTV/data/std.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:650e46902a0878e6947be401e4e1995e54a8fd407f2be3ded0dda62bda99a9b3 +size 416 diff --git a/MTV/draw_pose.py b/MTV/draw_pose.py new file mode 100644 index 0000000000000000000000000000000000000000..eb5c606ef9edc74176969b14d5709f0de536ae85 --- /dev/null +++ b/MTV/draw_pose.py @@ -0,0 +1,142 @@ +import cv2 +import math +import torch +import numpy as np +from PIL import Image +from torchvision import transforms + + +def intrinsic_matrix_from_field_of_view(imshape, fov_degrees:float =55 ): # nlf default fov_degrees 55 + imshape = np.array(imshape) + fov_radians = fov_degrees * np.array(np.pi / 180) + larger_side = np.max(imshape) + focal_length = larger_side / (np.tan(fov_radians / 2) * 2) + # intrinsic_matrix 3*3 + return np.array([ + [focal_length, 0, imshape[1] / 2], + [0, focal_length, imshape[0] / 2], + [0, 0, 1], + ]) + + +def p3d_to_p2d(point_3d, height, width): # point3d n*1024*3 + camera_matrix = intrinsic_matrix_from_field_of_view((height,width)) + camera_matrix = np.expand_dims(camera_matrix, axis=0) + camera_matrix = np.expand_dims(camera_matrix, axis=0) # 1*1*3*3 + point_3d = np.expand_dims(point_3d,axis=-1) # n*1024*3*1 + point_2d = (camera_matrix@point_3d).squeeze(-1) + point_2d[:,:,:2] = point_2d[:,:,:2]/point_2d[:,:,2:3] + return point_2d[:,:,:] # n*1024*2 + + +def get_pose_images(smpl_data, offset): + pose_images = [] + for data in smpl_data: + if isinstance(data, np.ndarray): + joints3d = data + else: + joints3d = data.numpy() + canvas = np.zeros(shape=(offset[0], offset[1], 3), dtype=np.uint8) + joints3d = p3d_to_p2d(joints3d, offset[0], offset[1]) + canvas = draw_3d_points(canvas, joints3d[0], stickwidth=int(offset[1]/350)) + pose_images.append(Image.fromarray(canvas)) + return pose_images + + +def get_control_conditions(poses, h, w): + video_transforms = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) + control_images = [] + for idx, pose in enumerate(poses): + canvas = np.zeros(shape=(h, w, 3), dtype=np.uint8) + try: + joints3d = p3d_to_p2d(pose, h, w) + canvas = draw_3d_points( + canvas, + joints3d[0], + stickwidth=int(h / 350), + ) + resized_canvas = cv2.resize(canvas, (w, h)) + # Image.fromarray(resized_canvas).save(f'tmp/{idx}_pose.jpg') + control_images.append(resized_canvas) + except Exception as e: + print("wrong:", e) + control_images.append(Image.fromarray(canvas)) + control_pixel_values = np.array(control_images) + control_pixel_values = torch.from_numpy(control_pixel_values).contiguous() / 255. + print("control_pixel_values.shape", control_pixel_values.shape) + #control_pixel_values = video_transforms(control_pixel_values) + return control_pixel_values + + +def draw_3d_points(canvas, points, stickwidth=2, r=2, draw_line=True): + colors = [ + [255, 0, 0], # 0 + [0, 255, 0], # 1 + [0, 0, 255], # 2 + [255, 0, 255], # 3 + [255, 255, 0], # 4 + [85, 255, 0], # 5 + [0, 75, 255], # 6 + [0, 255, 85], # 7 + [0, 255, 170], # 8 + [170, 0, 255], # 9 + [85, 0, 255], # 10 + [0, 85, 255], # 11 + [0, 255, 255], # 12 + [85, 0, 255], # 13 + [170, 0, 255], # 14 + [255, 0, 255], # 15 + [255, 0, 170], # 16 + [255, 0, 85], # 17 + ] + connetions = [ + [15,12],[12, 16],[16, 18],[18, 20],[20, 22], + [12,17],[17,19],[19,21], + [21,23],[12,9],[9,6], + [6,3],[3,0],[0,1], + [1,4],[4,7],[7,10],[0,2],[2,5],[5,8],[8,11] + ] + connection_colors = [ + [255, 0, 0], # 0 + [0, 255, 0], # 1 + [0, 0, 255], # 2 + [255, 255, 0], # 3 + [255, 0, 255], # 4 + [0, 255, 0], # 5 + [0, 85, 255], # 6 + [255, 175, 0], # 7 + [0, 0, 255], # 8 + [255, 85, 0], # 9 + [0, 255, 85], # 10 + [255, 0, 255], # 11 + [255, 0, 0], # 12 + [0, 175, 255], # 13 + [255, 255, 0], # 14 + [0, 0, 255], # 15 + [0, 255, 0], # 16 + ] + + # draw point + for i in range(len(points)): + x,y = points[i][0:2] + x,y = int(x),int(y) + if i==13 or i == 14: + continue + cv2.circle(canvas, (x, y), r, colors[i%17], thickness=-1) + + # draw line + if draw_line: + for i in range(len(connetions)): + point1_idx,point2_idx = connetions[i][0:2] + point1 = points[point1_idx] + point2 = points[point2_idx] + Y = [point2[0],point1[0]] + X = [point2[1],point1[1]] + mX = int(np.mean(X)) + mY = int(np.mean(Y)) + length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) + polygon = cv2.ellipse2Poly((mY, mX), (int(length / 2), stickwidth), int(angle), 0, 360, 1) + cv2.fillConvexPoly(canvas, polygon, connection_colors[i%17]) + + return canvas diff --git a/MTV/motion4d/__init__.py b/MTV/motion4d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..87ac3b3320cc21a738716220771f69fdb78fe34e --- /dev/null +++ b/MTV/motion4d/__init__.py @@ -0,0 +1 @@ +from .vqvae import SMPL_VQVAE, VectorQuantizer, Encoder, Decoder \ No newline at end of file diff --git a/MTV/motion4d/vqvae.py b/MTV/motion4d/vqvae.py new file mode 100644 index 0000000000000000000000000000000000000000..30535e007f70214d8bf4287b162c8515ba45453f --- /dev/null +++ b/MTV/motion4d/vqvae.py @@ -0,0 +1,329 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + + +class Encoder(nn.Module): + def __init__( + self, + in_channels=3, + mid_channels=[128, 512], + out_channels=3072, + downsample_time=[1, 1], + downsample_joint=[1, 1], + num_attention_heads=8, + attention_head_dim=64, + dim=3072, + ): + super(Encoder, self).__init__() + + self.conv_in = nn.Conv2d(in_channels, mid_channels[0], kernel_size=3, stride=1, padding=1) + self.resnet1 = nn.ModuleList([ResBlock(mid_channels[0], mid_channels[0]) for _ in range(3)]) + self.downsample1 = Downsample(mid_channels[0], mid_channels[0], downsample_time[0], downsample_joint[0]) + self.resnet2 = ResBlock(mid_channels[0], mid_channels[1]) + self.resnet3 = nn.ModuleList([ResBlock(mid_channels[1], mid_channels[1]) for _ in range(3)]) + self.downsample2 = Downsample(mid_channels[1], mid_channels[1], downsample_time[1], downsample_joint[1]) + self.conv_out = nn.Conv2d(mid_channels[-1], out_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + x = self.conv_in(x) + for resnet in self.resnet1: + x = resnet(x) + x = self.downsample1(x) + + x = self.resnet2(x) + for resnet in self.resnet3: + x = resnet(x) + x = self.downsample2(x) + + x = self.conv_out(x) + + return x + + + +class VectorQuantizer(nn.Module): + def __init__(self, nb_code, code_dim): + super().__init__() + self.nb_code = nb_code + self.code_dim = code_dim + self.mu = 0.99 + self.reset_codebook() + self.reset_count = 0 + self.usage = torch.zeros((self.nb_code, 1)) + + def reset_codebook(self): + self.init = False + self.code_sum = None + self.code_count = None + self.register_buffer('codebook', torch.zeros(self.nb_code, self.code_dim).cuda()) + + def _tile(self, x): + nb_code_x, code_dim = x.shape + if nb_code_x < self.nb_code: + n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x + std = 0.01 / np.sqrt(code_dim) + out = x.repeat(n_repeats, 1) + out = out + torch.randn_like(out) * std + else: + out = x + return out + + def preprocess(self, x): + # [bs, c, f, j] -> [bs * f * j, c] + x = x.permute(0, 2, 3, 1).contiguous() + x = x.view(-1, x.shape[-1]) + return x + + def quantize(self, x): + # [bs * f * j, dim=3072] + # Calculate latent code x_l + k_w = self.codebook.t() + distance = torch.sum(x ** 2, dim=-1, keepdim=True) - 2 * torch.matmul(x, k_w) + torch.sum(k_w ** 2, dim=0, keepdim=True) + _, code_idx = torch.min(distance, dim=-1) + return code_idx + + def dequantize(self, code_idx): + x = F.embedding(code_idx, self.codebook) # indexing: [bs * f * j, 32] + return x + + def forward(self, x, return_vq=False): + bs, c, f, j = x.shape # SMPL data frames: [bs, 3072, f, j] + + # Preprocess + x = self.preprocess(x) + # return x.view(bs, f*j, c).contiguous(), None + assert x.shape[-1] == self.code_dim + + # quantize and dequantize through bottleneck + code_idx = self.quantize(x) + x_d = self.dequantize(code_idx) + + # Loss + commit_loss = F.mse_loss(x, x_d.detach()) + + # Passthrough + x_d = x + (x_d - x).detach() + + if return_vq: + return x_d.view(bs, f*j, c).contiguous(), commit_loss + # return (x_d, x_d.view(bs, f, j, c).permute(0, 3, 1, 2).contiguous()), commit_loss, perplexity + + # Postprocess + x_d = x_d.view(bs, f, j, c).permute(0, 3, 1, 2).contiguous() + + return x_d, commit_loss + + + + +class Decoder(nn.Module): + def __init__( + self, + in_channels=3072, + mid_channels=[512, 128], + out_channels=3, + upsample_rate=None, + frame_upsample_rate=[1.0, 1.0], + joint_upsample_rate=[1.0, 1.0], + dim=128, + attention_head_dim=64, + num_attention_heads=8, + ): + super(Decoder, self).__init__() + + self.conv_in = nn.Conv2d(in_channels, mid_channels[0], kernel_size=3, stride=1, padding=1) + self.resnet1 = nn.ModuleList([ResBlock(mid_channels[0], mid_channels[0]) for _ in range(3)]) + self.upsample1 = Upsample(mid_channels[0], mid_channels[0], frame_upsample_rate=frame_upsample_rate[0], joint_upsample_rate=joint_upsample_rate[0]) + self.resnet2 = ResBlock(mid_channels[0], mid_channels[1]) + self.resnet3 = nn.ModuleList([ResBlock(mid_channels[1], mid_channels[1]) for _ in range(3)]) + self.upsample2 = Upsample(mid_channels[1], mid_channels[1], frame_upsample_rate=frame_upsample_rate[1], joint_upsample_rate=joint_upsample_rate[1]) + self.conv_out = nn.Conv2d(mid_channels[-1], out_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + x = self.conv_in(x) + for resnet in self.resnet1: + x = resnet(x) + x = self.upsample1(x) + + x = self.resnet2(x) + for resnet in self.resnet3: + x = resnet(x) + x = self.upsample2(x) + + x = self.conv_out(x) + + return x + + +class Upsample(nn.Module): + def __init__( + self, + in_channels, + out_channels, + upsample_rate=None, + frame_upsample_rate=None, + joint_upsample_rate=None, + ): + super(Upsample, self).__init__() + + self.upsampler = nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.upsample_rate = upsample_rate + self.frame_upsample_rate = frame_upsample_rate + self.joint_upsample_rate = joint_upsample_rate + self.upsample_rate = upsample_rate + + def forward(self, inputs): + if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1: + # split first frame + x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:] + + if self.upsample_rate is not None: + # import pdb; pdb.set_trace() + x_first = F.interpolate(x_first, scale_factor=self.upsample_rate) + x_rest = F.interpolate(x_rest, scale_factor=self.upsample_rate) + else: + # import pdb; pdb.set_trace() + # x_first = F.interpolate(x_first, scale_factor=(self.frame_upsample_rate, self.joint_upsample_rate), mode="bilinear", align_corners=True) + x_rest = F.interpolate(x_rest, scale_factor=(self.frame_upsample_rate, self.joint_upsample_rate), mode="bilinear", align_corners=True) + x_first = x_first[:, :, None, :] + inputs = torch.cat([x_first, x_rest], dim=2) + elif inputs.shape[2] > 1: + if self.upsample_rate is not None: + inputs = F.interpolate(inputs, scale_factor=self.upsample_rate) + else: + inputs = F.interpolate(inputs, scale_factor=(self.frame_upsample_rate, self.joint_upsample_rate), mode="bilinear", align_corners=True) + else: + inputs = inputs.squeeze(2) + if self.upsample_rate is not None: + inputs = F.interpolate(inputs, scale_factor=self.upsample_rate) + else: + inputs = F.interpolate(inputs, scale_factor=(self.frame_upsample_rate, self.joint_upsample_rate), mode="linear", align_corners=True) + inputs = inputs[:, :, None, :, :] + + b, c, t, j = inputs.shape + inputs = inputs.permute(0, 2, 1, 3).reshape(b * t, c, j) + inputs = self.upsampler(inputs) + inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3) + + return inputs + + +class Downsample(nn.Module): + def __init__( + self, + in_channels, + out_channels, + frame_downsample_rate, + joint_downsample_rate + ): + super(Downsample, self).__init__() + + self.frame_downsample_rate = frame_downsample_rate + self.joint_downsample_rate = joint_downsample_rate + self.joint_downsample = nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=self.joint_downsample_rate, padding=1) + + def forward(self, x): + # (batch_size, channels, frames, joints) -> (batch_size * joints, channels, frames) + if self.frame_downsample_rate > 1: + batch_size, channels, frames, joints = x.shape + x = x.permute(0, 3, 1, 2).reshape(batch_size * joints, channels, frames) + if x.shape[-1] % 2 == 1: + x_first, x_rest = x[..., 0], x[..., 1:] + if x_rest.shape[-1] > 0: + # (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2) + x_rest = F.avg_pool1d(x_rest, kernel_size=self.frame_downsample_rate, stride=self.frame_downsample_rate) + + x = torch.cat([x_first[..., None], x_rest], dim=-1) + # (batch_size * joints, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, joints) + x = x.reshape(batch_size, joints, channels, x.shape[-1]).permute(0, 2, 3, 1) + else: + # (batch_size * joints, channels, frames) -> (batch_size * joints, channels, frames // 2) + x = F.avg_pool1d(x, kernel_size=2, stride=2) + # (batch_size * joints, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width) + x = x.reshape(batch_size, joints, channels, x.shape[-1]).permute(0, 2, 3, 1) + + # Pad the tensor + # pad = (0, 1) + # x = F.pad(x, pad, mode="constant", value=0) + batch_size, channels, frames, joints = x.shape + # (batch_size, channels, frames, joints) -> (batch_size * frames, channels, joints) + x = x.permute(0, 2, 1, 3).reshape(batch_size * frames, channels, joints) + x = self.joint_downsample(x) + # (batch_size * frames, channels, joints) -> (batch_size, channels, frames, joints) + x = x.reshape(batch_size, frames, x.shape[1], x.shape[2]).permute(0, 2, 1, 3) + return x + + + +class ResBlock(nn.Module): + def __init__(self, + in_channels, + out_channels, + group_num=32, + max_channels=512): + super(ResBlock, self).__init__() + skip = max(1, max_channels // out_channels - 1) + self.block = nn.Sequential( + nn.GroupNorm(group_num, in_channels, eps=1e-06, affine=True), + nn.SiLU(), + nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=skip, dilation=skip), + nn.GroupNorm(group_num, out_channels, eps=1e-06, affine=True), + nn.SiLU(), + nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0), + ) + self.conv_short = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) if in_channels != out_channels else nn.Identity() + + def forward(self, x): + hidden_states = self.block(x) + if hidden_states.shape != x.shape: + x = self.conv_short(x) + x = x + hidden_states + return x + + + +class SMPL_VQVAE(nn.Module): + def __init__(self, encoder, decoder, vq): + super(SMPL_VQVAE, self).__init__() + + self.encoder = encoder + self.decoder = decoder + self.vq = vq + + def to(self, device): + self.encoder = self.encoder.to(device) + self.decoder = self.decoder.to(device) + self.vq = self.vq.to(device) + self.device = device + return self + + def encdec_slice_frames(self, x, frame_batch_size, encdec, return_vq): + num_frames = x.shape[2] + remaining_frames = num_frames % frame_batch_size + x_output = [] + + for i in range(num_frames // frame_batch_size): + remaining_frames = num_frames % frame_batch_size + start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames) + end_frame = frame_batch_size * (i + 1) + remaining_frames + x_intermediate = x[:, :, start_frame:end_frame] + x_intermediate = encdec(x_intermediate) + x_output.append(x_intermediate) + if encdec == self.encoder and self.vq is not None: + x_output, loss = self.vq(torch.cat(x_output, dim=2), return_vq=return_vq) + return x_output, loss + else: + return torch.cat(x_output, dim=2), None, None + + def forward(self, x, return_vq=False): + x = x.permute(0, 3, 1, 2) + x, loss = self.encdec_slice_frames(x, frame_batch_size=8, encdec=self.encoder, return_vq=return_vq) + + if return_vq: + return x, loss + x, _, _ = self.encdec_slice_frames(x, frame_batch_size=2, encdec=self.decoder, return_vq=return_vq) + x = x.permute(0, 2, 3, 1) + + return x, loss diff --git a/MTV/mtv.py b/MTV/mtv.py new file mode 100644 index 0000000000000000000000000000000000000000..915773d096b43f78aa5b250c0b64a198d0b4dc92 --- /dev/null +++ b/MTV/mtv.py @@ -0,0 +1,193 @@ +import torch +import numpy as np +from typing import Union, Tuple + + +def get_1d_rotary_pos_embed( + dim: int, + pos: Union[np.ndarray, int], + theta: float = 10000.0, + use_real=False, + linear_factor=1.0, + ntk_factor=1.0, + repeat_interleave_real=True, + freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux) +): + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end + index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 + data type. + + Args: + dim (`int`): Dimension of the frequency tensor. + pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar + theta (`float`, *optional*, defaults to 10000.0): + Scaling factor for frequency computation. Defaults to 10000.0. + use_real (`bool`, *optional*): + If True, return real part and imaginary part separately. Otherwise, return complex numbers. + linear_factor (`float`, *optional*, defaults to 1.0): + Scaling factor for the context extrapolation. Defaults to 1.0. + ntk_factor (`float`, *optional*, defaults to 1.0): + Scaling factor for the NTK-Aware RoPE. Defaults to 1.0. + repeat_interleave_real (`bool`, *optional*, defaults to `True`): + If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`. + Otherwise, they are concateanted with themselves. + freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`): + the dtype of the frequency tensor. + Returns: + `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] + """ + assert dim % 2 == 0 + + if isinstance(pos, int): + pos = torch.arange(pos) + if isinstance(pos, np.ndarray): + pos = torch.from_numpy(pos) # type: ignore # [S] + + theta = theta * ntk_factor + freqs = ( + 1.0 + / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim)) + / linear_factor + ) # [D/2] + freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] + if use_real and repeat_interleave_real: + freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D] + return freqs_cos, freqs_sin + elif use_real: + freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D] + freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D] + return freqs_cos, freqs_sin + else: + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] + return freqs_cis + + +def get_3d_rotary_pos_embed( + embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + RoPE for video tokens with 3D structure. + + Args: + embed_dim: (`int`): + The embedding dimension size, corresponding to hidden_size_head. + crops_coords (`Tuple[int]`): + The top-left and bottom-right coordinates of the crop. + grid_size (`Tuple[int]`): + The grid size of the spatial positional embedding (height, width). + temporal_size (`int`): + The size of the temporal dimension. + theta (`float`): + Scaling factor for frequency computation. + + Returns: + `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`. + """ + if use_real is not True: + raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed") + start, stop = crops_coords + grid_size_h, grid_size_w = grid_size + grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32) + grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32) + grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32) + + # Compute dimensions for each axis + dim_t = embed_dim // 4 + dim_h = embed_dim // 8 * 3 + dim_w = embed_dim // 8 * 3 + + # Temporal frequencies + freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True) + # Spatial frequencies for height and width + freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True) + freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True) + + # BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor + def combine_time_height_width(freqs_t, freqs_h, freqs_w): + freqs_t = freqs_t[:, None, None, :].expand( + -1, grid_size_h, grid_size_w, -1 + ) # temporal_size, grid_size_h, grid_size_w, dim_t + freqs_h = freqs_h[None, :, None, :].expand( + temporal_size, -1, grid_size_w, -1 + ) # temporal_size, grid_size_h, grid_size_2, dim_h + freqs_w = freqs_w[None, None, :, :].expand( + temporal_size, grid_size_h, -1, -1 + ) # temporal_size, grid_size_h, grid_size_2, dim_w + + freqs = torch.cat( + [freqs_t, freqs_h, freqs_w], dim=-1 + ) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w) + freqs = freqs.view( + temporal_size * grid_size_h * grid_size_w, -1 + ) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w) + return freqs + + t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t + h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h + w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w + cos = combine_time_height_width(t_cos, h_cos, w_cos) + sin = combine_time_height_width(t_sin, h_sin, w_sin) + return cos, sin + + +def get_3d_motion_spatial_embed( + embed_dim: int, num_joints: int, joints_mean: np.ndarray, joints_std: np.ndarray, theta: float = 10000.0 +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + assert embed_dim % 2 == 0 and embed_dim % 3 == 0 + + def create_rope_pe(dim, pos, freqs_dtype=torch.float32): + if isinstance(pos, np.ndarray): + pos = torch.from_numpy(pos) + freqs = ( + 1.0 + / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim)) + ) # [D/2] + freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] + freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D] + return freqs_cos, freqs_sin + + pos_x = joints_mean[:, 0] + pos_y = joints_mean[:, 1] + pos_z = joints_mean[:, 2] + + normalized_pos_x = (pos_x - pos_x.mean()) + normalized_pos_y = (pos_y - pos_y.mean()) + normalized_pos_z = (pos_z - pos_z.mean()) + + freqs_cos_x, freqs_sin_x = create_rope_pe(embed_dim // 3, normalized_pos_x) + freqs_cos_y, freqs_sin_y = create_rope_pe(embed_dim // 3, normalized_pos_y) + freqs_cos_z, freqs_sin_z = create_rope_pe(embed_dim // 3, normalized_pos_z) + + freqs_cos = torch.cat([freqs_cos_x, freqs_cos_y, freqs_cos_z], dim=-1) + freqs_sin = torch.cat([freqs_sin_x, freqs_sin_y, freqs_sin_z], dim=-1) + + return freqs_cos, freqs_sin + +def prepare_motion_embeddings(num_frames, num_joints, joints_mean, joints_std, theta=10000, device='cuda'): + time_embed = get_1d_rotary_pos_embed(44, num_frames, theta, use_real=True) + time_embed_cos = time_embed[0][:, None, :].expand(-1, num_joints, -1).reshape(num_frames*num_joints, -1) + time_embed_sin = time_embed[1][:, None, :].expand(-1, num_joints, -1).reshape(num_frames*num_joints, -1) + spatial_motion_embed = get_3d_motion_spatial_embed(84, num_joints, joints_mean, joints_std, theta) + spatial_embed_cos = spatial_motion_embed[0][None, :, :].expand(num_frames, -1, -1).reshape(num_frames*num_joints, -1) + spatial_embed_sin = spatial_motion_embed[1][None, :, :].expand(num_frames, -1, -1).reshape(num_frames*num_joints, -1) + motion_embed_cos = torch.cat([time_embed_cos, spatial_embed_cos], dim=-1).to(device=device) + motion_embed_sin = torch.cat([time_embed_sin, spatial_embed_sin], dim=-1).to(device=device) + return motion_embed_cos, motion_embed_sin + +def apply_rotary_emb(x, freqs_cis): + cos, sin = freqs_cis # [S, D] + cos = cos[None, None] + sin = sin[None, None] + cos, sin = cos.to(x.device), sin.to(x.device) + + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + return out \ No newline at end of file diff --git a/MTV/nlf.py b/MTV/nlf.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MTV/nodes.py b/MTV/nodes.py new file mode 100644 index 0000000000000000000000000000000000000000..bb6ae4ca78c20da29748ca614ebc4f01ed1922e5 --- /dev/null +++ b/MTV/nodes.py @@ -0,0 +1,242 @@ +import os +import torch +import gc +from ..utils import log, dict_to_device +import numpy as np +from accelerate import init_empty_weights +from accelerate.utils import set_module_tensor_to_device + +import comfy.model_management as mm +from comfy.utils import load_torch_file +import folder_paths + +script_directory = os.path.dirname(os.path.abspath(__file__)) +device = mm.get_torch_device() +offload_device = mm.unet_offload_device() + +local_model_path = os.path.join(folder_paths.models_dir, "nlf", "nlf_l_multi_0.3.2.torchscript") + +from .motion4d import SMPL_VQVAE, VectorQuantizer, Encoder, Decoder +from .mtv import prepare_motion_embeddings + +class DownloadAndLoadNLFModel: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "url": ( + [ + "https://github.com/isarandi/nlf/releases/download/v0.3.2/nlf_l_multi_0.3.2.torchscript" + ], + ) + }, + } + + RETURN_TYPES = ("NLFMODEL",) + RETURN_NAMES = ("nlf_model", ) + FUNCTION = "loadmodel" + CATEGORY = "WanVideoWrapper" + + def loadmodel(self, url): + + if not os.path.exists(local_model_path): + log.info(f"Downloading NLF model to: {local_model_path}") + import requests + os.makedirs(os.path.dirname(local_model_path), exist_ok=True) + response = requests.get(url) + if response.status_code == 200: + with open(local_model_path, "wb") as f: + f.write(response.content) + else: + print("Failed to download file:", response.status_code) + + model = torch.jit.load(local_model_path).eval() + + return (model,) + +class LoadNLFModel: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "path": ("STRING", {"default": local_model_path}), + }, + } + + RETURN_TYPES = ("NLFMODEL",) + RETURN_NAMES = ("nlf_model", ) + FUNCTION = "loadmodel" + CATEGORY = "WanVideoWrapper" + + def loadmodel(self, path): + model = torch.jit.load(path).eval() + + return model, + +class LoadVQVAE: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model_name": (folder_paths.get_filename_list("vae"), {"tooltip": "These models are loaded from 'ComfyUI/models/vae'"}), + }, + } + + RETURN_TYPES = ("VQVAE",) + RETURN_NAMES = ("vqvae", ) + FUNCTION = "loadmodel" + CATEGORY = "WanVideoWrapper" + + def loadmodel(self, model_name): + model_path = folder_paths.get_full_path("vae", model_name) + vae_sd = load_torch_file(model_path, safe_load=True) + + # Get motion tokenizer + motion_encoder = Encoder( + in_channels=3, + mid_channels=[128, 512], + out_channels=3072, + downsample_time=[2, 2], + downsample_joint=[1, 1] + ) + motion_quant = VectorQuantizer(nb_code=8192, code_dim=3072) + motion_decoder = Decoder( + in_channels=3072, + mid_channels=[512, 128], + out_channels=3, + upsample_rate=2.0, + frame_upsample_rate=[2.0, 2.0], + joint_upsample_rate=[1.0, 1.0] + ) + + vqvae = SMPL_VQVAE(motion_encoder, motion_decoder, motion_quant).to(device) + vqvae.load_state_dict(vae_sd, strict=True) + + return vqvae, + +class MTVCrafterEncodePoses: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "vqvae": ("VQVAE", {"tooltip": "VQVAE model"}), + "poses": ("NLFPRED", {"tooltip": "Input poses for the model"}), + }, + } + + RETURN_TYPES = ("MTVCRAFTERMOTION", "NLFPRED") + RETURN_NAMES = ("mtvcrafter_motion", "pose_results") + FUNCTION = "encode" + CATEGORY = "WanVideoWrapper" + + def encode(self, vqvae, poses): + + # import pickle + # with open(os.path.join(script_directory, "data", "sampled_data.pkl"), 'rb') as f: + # data_list = pickle.load(f) + # if not isinstance(data_list, list): + # data_list = [data_list] + # print(data_list) + + # smpl_poses = data_list[1]['pose'] + + global_mean = np.load(os.path.join(script_directory, "data", "mean.npy")) #global_mean.shape: (24, 3) + global_std = np.load(os.path.join(script_directory, "data", "std.npy")) + + smpl_poses = [] + for pose in poses['joints3d_nonparam'][0]: + smpl_poses.append(pose[0].cpu().numpy()) + smpl_poses = np.array(smpl_poses) + + norm_poses = torch.tensor((smpl_poses - global_mean) / global_std).unsqueeze(0) + print(f"norm_poses shape: {norm_poses.shape}, dtype: {norm_poses.dtype}") + + vqvae.to(device) + motion_tokens, vq_loss = vqvae(norm_poses.to(device), return_vq=True) + + recon_motion = vqvae(norm_poses.to(device))[0][0].to(dtype=torch.float32).cpu().detach() * global_std + global_mean + vqvae.to(offload_device) + + poses_dict = { + 'mtv_motion_tokens': motion_tokens, + 'global_mean': global_mean, + 'global_std': global_std + } + + return poses_dict, recon_motion + + +class NLFPredict: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "model": ("NLFMODEL",), + "images": ("IMAGE", {"tooltip": "Input images for the model"}), + }, + } + + RETURN_TYPES = ("NLFPRED", ) + RETURN_NAMES = ("pose_results",) + FUNCTION = "predict" + CATEGORY = "WanVideoWrapper" + + def predict(self, model, images): + + model.to(device) + pred = model.detect_smpl_batched(images.permute(0, 3, 1, 2).to(device)) + model.to(offload_device) + + pred = dict_to_device(pred, offload_device) + + pose_results = { + 'joints3d_nonparam': [], + } + # Collect pose data + for key in pose_results.keys(): + if key in pred: + pose_results[key].append(pred[key]) + else: + pose_results[key].append(None) + + return (pose_results,) + +class DrawNLFPoses: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "poses": ("NLFPRED", {"tooltip": "Input poses for the model"}), + "width": ("INT", {"default": 512}), + "height": ("INT", {"default": 512}), + }, + } + + RETURN_TYPES = ("IMAGE", ) + RETURN_NAMES = ("image",) + FUNCTION = "predict" + CATEGORY = "WanVideoWrapper" + + def predict(self, poses, width, height): + from .draw_pose import get_control_conditions + print(type(poses)) + if isinstance(poses, dict): + pose_input = poses['joints3d_nonparam'][0] if 'joints3d_nonparam' in poses else poses + else: + pose_input = poses + control_conditions = get_control_conditions(pose_input, height, width) + + return (control_conditions,) + +NODE_CLASS_MAPPINGS = { + "DownloadAndLoadNLFModel": DownloadAndLoadNLFModel, + "NLFPredict": NLFPredict, + "DrawNLFPoses": DrawNLFPoses, + "LoadVQVAE": LoadVQVAE, + "MTVCrafterEncodePoses": MTVCrafterEncodePoses + } +NODE_DISPLAY_NAME_MAPPINGS = { + "DownloadAndLoadNLFModel": "(Download)Load NLF Model", + "NLFPredict": "NLF Predict", + "DrawNLFPoses": "Draw NLF Poses", + "LoadVQVAE": "Load VQVAE", + "MTVCrafterEncodePoses": "MTV Crafter Encode Poses" +} diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f0186083efb82f7233f4c3d6cbdf5a4126f3b189 --- /dev/null +++ b/__init__.py @@ -0,0 +1,113 @@ +try: + from .utils import check_duplicate_nodes, log + duplicate_dirs = check_duplicate_nodes() + if duplicate_dirs: + warning_msg = f"WARNING: Found {len(duplicate_dirs)} other WanVideoWrapper directories:\n" + for dir_path in duplicate_dirs: + warning_msg += f" - {dir_path}\n" + log.warning(warning_msg + "Please remove duplicates to avoid possible conflicts.") +except: + pass + +from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS +from .recammaster.nodes import NODE_CLASS_MAPPINGS as RECAM_MASTER_NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS as RECAM_MASTER_NODE_DISPLAY_NAME_MAPPINGS +from .skyreels.nodes import NODE_CLASS_MAPPINGS as SKYREELS_NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS as SKYREELS_NODE_DISPLAY_NAME_MAPPINGS +from .fantasytalking.nodes import NODE_CLASS_MAPPINGS as FANTASYTALKING_NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS as FANTASYTALKING_NODE_DISPLAY_NAME_MAPPINGS +from .nodes_sampler import NODE_CLASS_MAPPINGS as SAMPLER_NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS as SAMPLER_NODE_DISPLAY_NAME_MAPPINGS +from .fun_camera.nodes import NODE_CLASS_MAPPINGS as FUN_CAMERA_NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS as FUN_CAMERA_NODE_DISPLAY_NAME_MAPPINGS +from .uni3c.nodes import NODE_CLASS_MAPPINGS as UNI3C_NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS as UNI3C_NODE_DISPLAY_NAME_MAPPINGS +from .controlnet.nodes import NODE_CLASS_MAPPINGS as CONTROLNET_NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS as CONTROLNET_NODE_DISPLAY_NAME_MAPPINGS +from .ATI.nodes import NODE_CLASS_MAPPINGS as ATI_NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS as ATI_NODE_DISPLAY_NAME_MAPPINGS +from .multitalk.nodes import NODE_CLASS_MAPPINGS as MULTITALK_NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS as MULTITALK_NODE_DISPLAY_NAME_MAPPINGS +from .nodes_model_loading import NODE_CLASS_MAPPINGS as MODEL_LOADING_NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS as MODEL_LOADING_NODE_DISPLAY_NAME_MAPPINGS +from .nodes_utility import NODE_CLASS_MAPPINGS as UTILITY_NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS as UTILITY_NODE_DISPLAY_NAME_MAPPINGS +from .cache_methods.nodes_cache import NODE_CLASS_MAPPINGS as NODE_CACHE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS as NODE_CACHE_DISPLAY_NAME_MAPPINGS +from .nodes_deprecated import NODE_CLASS_MAPPINGS as DEPRECATED_NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS as DEPRECATED_NODE_DISPLAY_NAME_MAPPINGS +from .s2v.nodes import NODE_CLASS_MAPPINGS as S2V_NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS as S2V_NODE_DISPLAY_NAME_MAPPINGS + +try: + from .qwen.qwen import NODE_CLASS_MAPPINGS as QWEN_NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS as QWEN_NODE_DISPLAY_NAME_MAPPINGS +except Exception as e: + log.warning(f"WanVideoWrapper WARNING: Qwen nodes not available due to error in importing them: {e}") + QWEN_NODE_CLASS_MAPPINGS = {} + QWEN_NODE_DISPLAY_NAME_MAPPINGS = {} + + +try: + from .fantasyportrait.nodes import NODE_CLASS_MAPPINGS as FANTASYPORTRAIT_NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS as FANTASYPORTRAIT_NODE_DISPLAY_NAME_MAPPINGS +except Exception as e: + log.warning(f"WanVideoWrapper WARNING: FantasyPortrait nodes not available due to error in importing them: {e}") + FANTASYPORTRAIT_NODE_CLASS_MAPPINGS = {} + FANTASYPORTRAIT_NODE_DISPLAY_NAME_MAPPINGS = {} + +try: + from .unianimate.nodes import NODE_CLASS_MAPPINGS as UNIANIMATE_NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS as UNIANIMATE_NODE_DISPLAY_NAME_MAPPINGS +except Exception as e: + log.warning(f"WanVideoWrapper WARNING: UniAnimate nodes not available due to error in importing them: {e}") + UNIANIMATE_NODE_CLASS_MAPPINGS = {} + UNIANIMATE_NODE_DISPLAY_NAME_MAPPINGS = {} + +try: + from .MTV.nodes import NODE_CLASS_MAPPINGS as MTV_NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS as MTV_NODE_DISPLAY_NAME_MAPPINGS +except Exception as e: + log.warning(f"WanVideoWrapper WARNING: MTV nodes not available due to error in importing them: {e}") + MTV_NODE_CLASS_MAPPINGS = {} + MTV_NODE_DISPLAY_NAME_MAPPINGS = {} + +try: + from .HuMo.nodes import NODE_CLASS_MAPPINGS as HUMO_NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS as HUMO_NODE_DISPLAY_NAME_MAPPINGS +except Exception as e: + log.warning(f"WanVideoWrapper WARNING: HuMo nodes not available due to error in importing them: {e}") + HUMO_NODE_CLASS_MAPPINGS = {} + HUMO_NODE_DISPLAY_NAME_MAPPINGS = {} + +try: + from .lynx.nodes import NODE_CLASS_MAPPINGS as LYNX_NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS as LYNX_NODE_DISPLAY_NAME_MAPPINGS +except Exception as e: + log.warning(f"WanVideoWrapper WARNING: Lynx nodes not available due to error in importing them: {e}") + LYNX_NODE_CLASS_MAPPINGS = {} + LYNX_NODE_DISPLAY_NAME_MAPPINGS = {} + +NODE_CLASS_MAPPINGS.update(RECAM_MASTER_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(UNIANIMATE_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(SKYREELS_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(FANTASYTALKING_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(FANTASYPORTRAIT_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(FUN_CAMERA_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(UNI3C_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(CONTROLNET_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(ATI_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(MULTITALK_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(MODEL_LOADING_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(UTILITY_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(NODE_CACHE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(DEPRECATED_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(QWEN_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(MTV_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(S2V_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(HUMO_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(SAMPLER_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(LYNX_NODE_CLASS_MAPPINGS) + +NODE_DISPLAY_NAME_MAPPINGS.update(RECAM_MASTER_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(UNIANIMATE_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(SKYREELS_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(FANTASYTALKING_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(FANTASYPORTRAIT_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(FUN_CAMERA_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(UNI3C_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(CONTROLNET_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(ATI_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(MULTITALK_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(MODEL_LOADING_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(UTILITY_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(NODE_CACHE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(DEPRECATED_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(QWEN_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(MTV_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(S2V_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(HUMO_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(SAMPLER_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(LYNX_NODE_DISPLAY_NAME_MAPPINGS) + +__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] \ No newline at end of file diff --git a/cache_methods/cache_methods.py b/cache_methods/cache_methods.py new file mode 100644 index 0000000000000000000000000000000000000000..1b7bb12b2cbc3c27df886cf2cf6932266c818255 --- /dev/null +++ b/cache_methods/cache_methods.py @@ -0,0 +1,158 @@ +from ..utils import log +import torch + +def set_transformer_cache_method(transformer, timesteps, cache_args=None): + transformer.cache_device = cache_args["cache_device"] + if cache_args["cache_type"] == "TeaCache": + log.info(f"TeaCache: Using cache device: {transformer.cache_device}") + transformer.teacache_state.clear_all() + transformer.enable_teacache = True + transformer.rel_l1_thresh = cache_args["rel_l1_thresh"] + transformer.teacache_start_step = cache_args["start_step"] + transformer.teacache_end_step = len(timesteps)-1 if cache_args["end_step"] == -1 else cache_args["end_step"] + transformer.teacache_use_coefficients = cache_args["use_coefficients"] + transformer.teacache_mode = cache_args["mode"] + elif cache_args["cache_type"] == "MagCache": + log.info(f"MagCache: Using cache device: {transformer.cache_device}") + transformer.magcache_state.clear_all() + transformer.enable_magcache = True + transformer.magcache_start_step = cache_args["start_step"] + transformer.magcache_end_step = len(timesteps)-1 if cache_args["end_step"] == -1 else cache_args["end_step"] + transformer.magcache_thresh = cache_args["magcache_thresh"] + transformer.magcache_K = cache_args["magcache_K"] + elif cache_args["cache_type"] == "EasyCache": + log.info(f"EasyCache: Using cache device: {transformer.cache_device}") + transformer.easycache_state.clear_all() + transformer.enable_easycache = True + transformer.easycache_start_step = cache_args["start_step"] + transformer.easycache_end_step = len(timesteps)-1 if cache_args["end_step"] == -1 else cache_args["end_step"] + transformer.easycache_thresh = cache_args["easycache_thresh"] + return transformer + +class TeaCacheState: + def __init__(self, cache_device='cpu'): + self.cache_device = cache_device + self.states = {} + self._next_pred_id = 0 + + def new_prediction(self, cache_device='cpu'): + """Create new prediction state and return its ID""" + self.cache_device = cache_device + pred_id = self._next_pred_id + self._next_pred_id += 1 + self.states[pred_id] = { + 'previous_residual': None, + 'accumulated_rel_l1_distance': 0, + 'previous_modulated_input': None, + 'skipped_steps': [], + } + return pred_id + + def update(self, pred_id, **kwargs): + """Update state for specific prediction""" + if pred_id not in self.states: + return None + for key, value in kwargs.items(): + self.states[pred_id][key] = value + + def get(self, pred_id): + return self.states.get(pred_id, {}) + + def clear_all(self): + self.states = {} + self._next_pred_id = 0 + +class MagCacheState: + def __init__(self, cache_device='cpu'): + self.cache_device = cache_device + self.states = {} + self._next_pred_id = 0 + + def new_prediction(self, cache_device='cpu'): + """Create new prediction state and return its ID""" + self.cache_device = cache_device + pred_id = self._next_pred_id + self._next_pred_id += 1 + self.states[pred_id] = { + 'residual_cache': None, + 'accumulated_ratio': 1.0, + 'accumulated_steps': 0, + 'accumulated_err': 0, + 'skipped_steps': [], + } + return pred_id + + def update(self, pred_id, **kwargs): + """Update state for specific prediction""" + if pred_id not in self.states: + return None + for key, value in kwargs.items(): + self.states[pred_id][key] = value + + def get(self, pred_id): + return self.states.get(pred_id, {}) + + def clear_all(self): + self.states = {} + self._next_pred_id = 0 + +class EasyCacheState: + def __init__(self, cache_device='cpu'): + self.cache_device = cache_device + self.states = {} + self._next_pred_id = 0 + + def new_prediction(self, cache_device='cpu'): + """Create a new prediction state and return its ID.""" + self.cache_device = cache_device + pred_id = self._next_pred_id + self._next_pred_id += 1 + self.states[pred_id] = { + 'previous_raw_input': None, + 'previous_raw_output': None, + 'cache': None, + 'accumulated_error': 0.0, + 'skipped_steps': [], + } + return pred_id + + def update(self, pred_id, **kwargs): + """Update state for a specific prediction.""" + if pred_id not in self.states: + return None + for key, value in kwargs.items(): + self.states[pred_id][key] = value + + def get(self, pred_id): + return self.states.get(pred_id, {}) + + def clear_all(self): + self.states = {} + self._next_pred_id = 0 + +def relative_l1_distance(last_tensor, current_tensor): + l1_distance = torch.abs(last_tensor.to(current_tensor.device) - current_tensor).mean() + norm = torch.abs(last_tensor).mean() + relative_l1_distance = l1_distance / norm + return relative_l1_distance.to(torch.float32).to(current_tensor.device) + +def cache_report(transformer, cache_args): + cache_type = cache_args["cache_type"] + states = ( + transformer.teacache_state.states if cache_type == "TeaCache" else + transformer.magcache_state.states if cache_type == "MagCache" else + transformer.easycache_state.states if cache_type == "EasyCache" else + None + ) + state_names = { + 0: "conditional", + 1: "unconditional" + } + for pred_id, state in states.items(): + name = state_names.get(pred_id, f"prediction_{pred_id}") + if 'skipped_steps' in state: + log.info(f"{cache_type} skipped: {len(state['skipped_steps'])} {name} steps: {state['skipped_steps']}") + transformer.teacache_state.clear_all() + transformer.magcache_state.clear_all() + transformer.easycache_state.clear_all() + del states \ No newline at end of file diff --git a/cache_methods/nodes_cache.py b/cache_methods/nodes_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..fb194befd5f2f191ec69949757a083206a8f2e19 --- /dev/null +++ b/cache_methods/nodes_cache.py @@ -0,0 +1,140 @@ +from comfy import model_management as mm + +class WanVideoTeaCache: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "rel_l1_thresh": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.001, + "tooltip": "Higher values will make TeaCache more aggressive, faster, but may cause artifacts. Good value range for 1.3B: 0.05 - 0.08, for other models 0.15-0.30"}), + "start_step": ("INT", {"default": 1, "min": 0, "max": 9999, "step": 1, "tooltip": "Start percentage of the steps to apply TeaCache"}), + "end_step": ("INT", {"default": -1, "min": -1, "max": 9999, "step": 1, "tooltip": "End steps to apply TeaCache"}), + "cache_device": (["main_device", "offload_device"], {"default": "offload_device", "tooltip": "Device to cache to"}), + "use_coefficients": ("BOOLEAN", {"default": True, "tooltip": "Use calculated coefficients for more accuracy. When enabled therel_l1_thresh should be about 10 times higher than without"}), + }, + "optional": { + "mode": (["e", "e0"], {"default": "e", "tooltip": "Choice between using e (time embeds, default) or e0 (modulated time embeds)"}), + }, + } + RETURN_TYPES = ("CACHEARGS",) + RETURN_NAMES = ("cache_args",) + FUNCTION = "process" + CATEGORY = "WanVideoWrapper" + DESCRIPTION = """ +Patch WanVideo model to use TeaCache. Speeds up inference by caching the output and +applying it instead of doing the step. Best results are achieved by choosing the +appropriate coefficients for the model. Early steps should never be skipped, with too +aggressive values this can happen and the motion suffers. Starting later can help with that too. +When NOT using coefficients, the threshold value should be +about 10 times smaller than the value used with coefficients. + +Official recommended values https://github.com/ali-vilab/TeaCache/tree/main/TeaCache4Wan2.1: + + +
++-------------------+--------+---------+--------+ +| Model | Low | Medium | High | ++-------------------+--------+---------+--------+ +| Wan2.1 t2v 1.3B | 0.05 | 0.07 | 0.08 | +| Wan2.1 t2v 14B | 0.14 | 0.15 | 0.20 | +| Wan2.1 i2v 480P | 0.13 | 0.19 | 0.26 | +| Wan2.1 i2v 720P | 0.18 | 0.20 | 0.30 | ++-------------------+--------+---------+--------+ ++""" + + def process(self, rel_l1_thresh, start_step, end_step, cache_device, use_coefficients, mode="e"): + if cache_device == "main_device": + cache_device = mm.get_torch_device() + else: + cache_device = mm.unet_offload_device() + cache_args = { + "cache_type": "TeaCache", + "rel_l1_thresh": rel_l1_thresh, + "start_step": start_step, + "end_step": end_step, + "cache_device": cache_device, + "use_coefficients": use_coefficients, + "mode": mode, + } + return (cache_args,) + +class WanVideoMagCache: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "magcache_thresh": ("FLOAT", {"default": 0.02, "min": 0.0, "max": 0.3, "step": 0.001, "tooltip": "How strongly to cache the output of diffusion model. This value must be non-negative."}), + "magcache_K": ("INT", {"default": 4, "min": 0, "max": 6, "step": 1, "tooltip": "The maxium skip steps of MagCache."}), + "start_step": ("INT", {"default": 1, "min": 0, "max": 9999, "step": 1, "tooltip": "Step to start applying MagCache"}), + "end_step": ("INT", {"default": -1, "min": -1, "max": 9999, "step": 1, "tooltip": "Step to end applying MagCache"}), + "cache_device": (["main_device", "offload_device"], {"default": "offload_device", "tooltip": "Device to cache to"}), + }, + } + RETURN_TYPES = ("CACHEARGS",) + RETURN_NAMES = ("cache_args",) + FUNCTION = "setargs" + CATEGORY = "WanVideoWrapper" + EXPERIMENTAL = True + DESCRIPTION = "MagCache for WanVideoWrapper, source https://github.com/Zehong-Ma/MagCache" + + def setargs(self, magcache_thresh, magcache_K, start_step, end_step, cache_device): + if cache_device == "main_device": + cache_device = mm.get_torch_device() + else: + cache_device = mm.unet_offload_device() + + cache_args = { + "cache_type": "MagCache", + "magcache_thresh": magcache_thresh, + "magcache_K": magcache_K, + "start_step": start_step, + "end_step": end_step, + "cache_device": cache_device, + } + return (cache_args,) + +class WanVideoEasyCache: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "easycache_thresh": ("FLOAT", {"default": 0.015, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "How strongly to cache the output of diffusion model. This value must be non-negative."}), + "start_step": ("INT", {"default": 10, "min": 0, "max": 9999, "step": 1, "tooltip": "Step to start applying EasyCache"}), + "end_step": ("INT", {"default": -1, "min": -1, "max": 9999, "step": 1, "tooltip": "Step to end applying EasyCache"}), + "cache_device": (["main_device", "offload_device"], {"default": "offload_device", "tooltip": "Device to cache to"}), + }, + } + RETURN_TYPES = ("CACHEARGS",) + RETURN_NAMES = ("cache_args",) + FUNCTION = "setargs" + CATEGORY = "WanVideoWrapper" + EXPERIMENTAL = True + DESCRIPTION = "EasyCache for WanVideoWrapper, source https://github.com/H-EmbodVis/EasyCache" + + def setargs(self, easycache_thresh, start_step, end_step, cache_device): + if cache_device == "main_device": + cache_device = mm.get_torch_device() + else: + cache_device = mm.unet_offload_device() + + cache_args = { + "cache_type": "EasyCache", + "easycache_thresh": easycache_thresh, + "start_step": start_step, + "end_step": end_step, + "cache_device": cache_device, + } + return (cache_args,) + + +NODE_CLASS_MAPPINGS = { + "WanVideoTeaCache": WanVideoTeaCache, + "WanVideoMagCache": WanVideoMagCache, + "WanVideoEasyCache": WanVideoEasyCache, + } +NODE_DISPLAY_NAME_MAPPINGS = { + "WanVideoTeaCache": "WanVideo TeaCache", + "WanVideoMagCache": "WanVideo MagCache", + "WanVideoEasyCache": "WanVideo EasyCache" + } \ No newline at end of file diff --git a/configs/T5_tokenizer/special_tokens_map.json b/configs/T5_tokenizer/special_tokens_map.json new file mode 100644 index 0000000000000000000000000000000000000000..14855e7052ffbb595057dfd791d293c1c940db2c --- /dev/null +++ b/configs/T5_tokenizer/special_tokens_map.json @@ -0,0 +1,308 @@ +{ + "additional_special_tokens": [ + "