aliensmn's picture
Mirror from https://github.com/kijai/ComfyUI-WanVideoWrapper
cf812a0 verified
import folder_paths
import math
import torch
import torch.nn.functional as F
import numpy as np
def get_sample_indices(original_fps,
total_frames,
target_fps,
num_sample,
fixed_start=None):
required_duration = num_sample / target_fps
required_origin_frames = int(np.ceil(required_duration * original_fps))
if required_duration > total_frames / original_fps:
raise ValueError("required_duration must be less than video length")
if not fixed_start is None and fixed_start >= 0:
start_frame = fixed_start
else:
max_start = total_frames - required_origin_frames
if max_start < 0:
raise ValueError("video length is too short")
start_frame = np.random.randint(0, max_start + 1)
start_time = start_frame / original_fps
end_time = start_time + required_duration
time_points = np.linspace(start_time, end_time, num_sample, endpoint=False)
frame_indices = np.round(np.array(time_points) * original_fps).astype(int)
frame_indices = np.clip(frame_indices, 0, total_frames - 1)
return frame_indices
def linear_interpolation(features, input_fps, output_fps, output_len=None):
"""
features: shape=[1, T, 512]
input_fps: fps for audio, f_a
output_fps: fps for video, f_m
output_len: video length
"""
features = features.transpose(1, 2) # [1, 512, T]
seq_len = features.shape[2] / float(input_fps) # T/f_a
if output_len is None:
output_len = int(seq_len * output_fps) # f_m*T/f_a
output_features = F.interpolate(
features, size=output_len, align_corners=True,
mode='linear') # [1, 512, output_len]
return output_features.transpose(1, 2) # [1, output_len, 512]
class WanVideoAddS2VEmbeds:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"embeds": ("WANVIDIMAGE_EMBEDS",),
"frame_window_size": ("INT", {"default": 80, "min": 1, "max": 100000, "step": 1, "tooltip": "Number of frames in a single window"}),
"audio_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.1, "tooltip": "Scale factor for audio embeddings"}),
"pose_start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percentage for pose embeddings"}),
"pose_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percentage for pose embeddings"})
},
"optional": {
"audio_encoder_output": ("AUDIO_ENCODER_OUTPUT",),
"ref_latent": ("LATENT",),
"pose_latent": ("LATENT",),
"vae": ("WANVAE",),
"enable_framepack": ("BOOLEAN", {"default": False, "tooltip": "Enable Framepack sampling loop, not compatible with context windows"})
}
}
RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", "INT",)
RETURN_NAMES = ("image_embeds", "audio_frame_count")
FUNCTION = "add"
CATEGORY = "WanVideoWrapper"
def add(self, embeds, frame_window_size, audio_encoder_output=None, audio_scale=1.0, ref_latent=None, pose_latent=None, vae=None, pose_start_percent=0.0, pose_end_percent=1.0, enable_framepack=False):
audio_frame_count=0
if audio_encoder_output is not None:
all_layers = audio_encoder_output["encoded_audio_all_layers"]
audio_feat = torch.stack(all_layers, dim=0).squeeze(1) # shape: [num_layers, T, 512]
print("audio_feat in", audio_feat.shape)
input_fps = 50 # determined by the model itself
output_fps = 30 # determined by the model itself
bucket_fps = 16 # target fps for the generation
if input_fps != output_fps:
audio_feat = linear_interpolation(audio_feat, input_fps=input_fps, output_fps=output_fps)
print("audio_feat after interpolation", audio_feat.shape)
audio_feat = audio_feat[:, :embeds["num_frames"] * output_fps // bucket_fps, :]
print("audio_feat after trim", audio_feat.shape)
self.video_rate = output_fps
audio_embed_bucket, num_repeat = self.get_audio_embed_bucket_fps(
audio_feat,
fps=bucket_fps,
batch_frames=frame_window_size
)
print("audio_embed_bucket", audio_embed_bucket.shape)
audio_embed_bucket = audio_embed_bucket.unsqueeze(0)
if len(audio_embed_bucket.shape) == 3:
audio_embed_bucket = audio_embed_bucket.permute(0, 2, 1)
elif len(audio_embed_bucket.shape) == 4:
audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1)
audio_frame_count = audio_embed_bucket.shape[-1]
print("audio_embed_bucket", audio_embed_bucket.shape)
new_entry = {
"audio_embed_bucket": audio_embed_bucket if audio_encoder_output is not None else None,
"num_repeat": num_repeat if audio_encoder_output is not None else None,
"ref_latent": ref_latent["samples"] if ref_latent is not None else None,
"pose_latent": pose_latent["samples"] if pose_latent is not None else None,
"audio_scale": audio_scale,
"vae": vae,
"pose_start_percent": pose_start_percent,
"pose_end_percent": pose_end_percent,
"enable_framepack": enable_framepack,
"frame_window_size": frame_window_size
}
updated = dict(embeds)
updated["audio_embeds"] = new_entry
return (updated, audio_frame_count)
def get_audio_embed_bucket_fps(self, audio_embed, fps=16, batch_frames=81, m=0):
num_layers, audio_frame_num, audio_dim = audio_embed.shape
if num_layers > 1:
return_all_layers = True
else:
return_all_layers = False
scale = self.video_rate / fps
min_batch_num = int(audio_frame_num / (batch_frames * scale)) + 1
bucket_num = min_batch_num * batch_frames
padd_audio_num = math.ceil(min_batch_num * batch_frames / fps * self.video_rate) - audio_frame_num
batch_idx = get_sample_indices(
original_fps=self.video_rate,
total_frames=audio_frame_num + padd_audio_num,
target_fps=fps,
num_sample=bucket_num,
fixed_start=0)
batch_audio_eb = []
audio_sample_stride = int(self.video_rate / fps)
for bi in batch_idx:
if bi < audio_frame_num:
chosen_idx = list(
range(bi - m * audio_sample_stride,
bi + (m + 1) * audio_sample_stride,
audio_sample_stride))
chosen_idx = [0 if c < 0 else c for c in chosen_idx]
chosen_idx = [
audio_frame_num - 1 if c >= audio_frame_num else c
for c in chosen_idx
]
if return_all_layers:
frame_audio_embed = audio_embed[:, chosen_idx].flatten(
start_dim=-2, end_dim=-1)
else:
frame_audio_embed = audio_embed[0][chosen_idx].flatten()
else:
frame_audio_embed = \
torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \
else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device)
batch_audio_eb.append(frame_audio_embed)
batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb],
dim=0)
return batch_audio_eb, min_batch_num
NODE_CLASS_MAPPINGS = {
"WanVideoAddS2VEmbeds": WanVideoAddS2VEmbeds,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"WanVideoAddS2VEmbeds": "WanVideo Add S2V Embeds",
}