|
|
import folder_paths |
|
|
import math |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
|
|
|
def get_sample_indices(original_fps, |
|
|
total_frames, |
|
|
target_fps, |
|
|
num_sample, |
|
|
fixed_start=None): |
|
|
required_duration = num_sample / target_fps |
|
|
required_origin_frames = int(np.ceil(required_duration * original_fps)) |
|
|
if required_duration > total_frames / original_fps: |
|
|
raise ValueError("required_duration must be less than video length") |
|
|
|
|
|
if not fixed_start is None and fixed_start >= 0: |
|
|
start_frame = fixed_start |
|
|
else: |
|
|
max_start = total_frames - required_origin_frames |
|
|
if max_start < 0: |
|
|
raise ValueError("video length is too short") |
|
|
start_frame = np.random.randint(0, max_start + 1) |
|
|
start_time = start_frame / original_fps |
|
|
|
|
|
end_time = start_time + required_duration |
|
|
time_points = np.linspace(start_time, end_time, num_sample, endpoint=False) |
|
|
|
|
|
frame_indices = np.round(np.array(time_points) * original_fps).astype(int) |
|
|
frame_indices = np.clip(frame_indices, 0, total_frames - 1) |
|
|
return frame_indices |
|
|
|
|
|
def linear_interpolation(features, input_fps, output_fps, output_len=None): |
|
|
""" |
|
|
features: shape=[1, T, 512] |
|
|
input_fps: fps for audio, f_a |
|
|
output_fps: fps for video, f_m |
|
|
output_len: video length |
|
|
""" |
|
|
features = features.transpose(1, 2) |
|
|
seq_len = features.shape[2] / float(input_fps) |
|
|
if output_len is None: |
|
|
output_len = int(seq_len * output_fps) |
|
|
output_features = F.interpolate( |
|
|
features, size=output_len, align_corners=True, |
|
|
mode='linear') |
|
|
return output_features.transpose(1, 2) |
|
|
|
|
|
class WanVideoAddS2VEmbeds: |
|
|
@classmethod |
|
|
def INPUT_TYPES(s): |
|
|
return {"required": { |
|
|
"embeds": ("WANVIDIMAGE_EMBEDS",), |
|
|
"frame_window_size": ("INT", {"default": 80, "min": 1, "max": 100000, "step": 1, "tooltip": "Number of frames in a single window"}), |
|
|
"audio_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.1, "tooltip": "Scale factor for audio embeddings"}), |
|
|
"pose_start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percentage for pose embeddings"}), |
|
|
"pose_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percentage for pose embeddings"}) |
|
|
}, |
|
|
"optional": { |
|
|
"audio_encoder_output": ("AUDIO_ENCODER_OUTPUT",), |
|
|
"ref_latent": ("LATENT",), |
|
|
"pose_latent": ("LATENT",), |
|
|
"vae": ("WANVAE",), |
|
|
"enable_framepack": ("BOOLEAN", {"default": False, "tooltip": "Enable Framepack sampling loop, not compatible with context windows"}) |
|
|
} |
|
|
} |
|
|
|
|
|
RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", "INT",) |
|
|
RETURN_NAMES = ("image_embeds", "audio_frame_count") |
|
|
FUNCTION = "add" |
|
|
CATEGORY = "WanVideoWrapper" |
|
|
|
|
|
def add(self, embeds, frame_window_size, audio_encoder_output=None, audio_scale=1.0, ref_latent=None, pose_latent=None, vae=None, pose_start_percent=0.0, pose_end_percent=1.0, enable_framepack=False): |
|
|
audio_frame_count=0 |
|
|
if audio_encoder_output is not None: |
|
|
all_layers = audio_encoder_output["encoded_audio_all_layers"] |
|
|
audio_feat = torch.stack(all_layers, dim=0).squeeze(1) |
|
|
|
|
|
print("audio_feat in", audio_feat.shape) |
|
|
input_fps = 50 |
|
|
output_fps = 30 |
|
|
bucket_fps = 16 |
|
|
|
|
|
if input_fps != output_fps: |
|
|
audio_feat = linear_interpolation(audio_feat, input_fps=input_fps, output_fps=output_fps) |
|
|
print("audio_feat after interpolation", audio_feat.shape) |
|
|
|
|
|
audio_feat = audio_feat[:, :embeds["num_frames"] * output_fps // bucket_fps, :] |
|
|
print("audio_feat after trim", audio_feat.shape) |
|
|
|
|
|
self.video_rate = output_fps |
|
|
|
|
|
audio_embed_bucket, num_repeat = self.get_audio_embed_bucket_fps( |
|
|
audio_feat, |
|
|
fps=bucket_fps, |
|
|
batch_frames=frame_window_size |
|
|
) |
|
|
print("audio_embed_bucket", audio_embed_bucket.shape) |
|
|
|
|
|
audio_embed_bucket = audio_embed_bucket.unsqueeze(0) |
|
|
if len(audio_embed_bucket.shape) == 3: |
|
|
audio_embed_bucket = audio_embed_bucket.permute(0, 2, 1) |
|
|
elif len(audio_embed_bucket.shape) == 4: |
|
|
audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1) |
|
|
|
|
|
audio_frame_count = audio_embed_bucket.shape[-1] |
|
|
|
|
|
print("audio_embed_bucket", audio_embed_bucket.shape) |
|
|
|
|
|
new_entry = { |
|
|
"audio_embed_bucket": audio_embed_bucket if audio_encoder_output is not None else None, |
|
|
"num_repeat": num_repeat if audio_encoder_output is not None else None, |
|
|
"ref_latent": ref_latent["samples"] if ref_latent is not None else None, |
|
|
"pose_latent": pose_latent["samples"] if pose_latent is not None else None, |
|
|
"audio_scale": audio_scale, |
|
|
"vae": vae, |
|
|
"pose_start_percent": pose_start_percent, |
|
|
"pose_end_percent": pose_end_percent, |
|
|
"enable_framepack": enable_framepack, |
|
|
"frame_window_size": frame_window_size |
|
|
} |
|
|
updated = dict(embeds) |
|
|
updated["audio_embeds"] = new_entry |
|
|
return (updated, audio_frame_count) |
|
|
|
|
|
def get_audio_embed_bucket_fps(self, audio_embed, fps=16, batch_frames=81, m=0): |
|
|
num_layers, audio_frame_num, audio_dim = audio_embed.shape |
|
|
|
|
|
if num_layers > 1: |
|
|
return_all_layers = True |
|
|
else: |
|
|
return_all_layers = False |
|
|
|
|
|
scale = self.video_rate / fps |
|
|
|
|
|
min_batch_num = int(audio_frame_num / (batch_frames * scale)) + 1 |
|
|
|
|
|
bucket_num = min_batch_num * batch_frames |
|
|
padd_audio_num = math.ceil(min_batch_num * batch_frames / fps * self.video_rate) - audio_frame_num |
|
|
batch_idx = get_sample_indices( |
|
|
original_fps=self.video_rate, |
|
|
total_frames=audio_frame_num + padd_audio_num, |
|
|
target_fps=fps, |
|
|
num_sample=bucket_num, |
|
|
fixed_start=0) |
|
|
batch_audio_eb = [] |
|
|
audio_sample_stride = int(self.video_rate / fps) |
|
|
for bi in batch_idx: |
|
|
if bi < audio_frame_num: |
|
|
|
|
|
chosen_idx = list( |
|
|
range(bi - m * audio_sample_stride, |
|
|
bi + (m + 1) * audio_sample_stride, |
|
|
audio_sample_stride)) |
|
|
chosen_idx = [0 if c < 0 else c for c in chosen_idx] |
|
|
chosen_idx = [ |
|
|
audio_frame_num - 1 if c >= audio_frame_num else c |
|
|
for c in chosen_idx |
|
|
] |
|
|
|
|
|
if return_all_layers: |
|
|
frame_audio_embed = audio_embed[:, chosen_idx].flatten( |
|
|
start_dim=-2, end_dim=-1) |
|
|
else: |
|
|
frame_audio_embed = audio_embed[0][chosen_idx].flatten() |
|
|
else: |
|
|
frame_audio_embed = \ |
|
|
torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \ |
|
|
else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device) |
|
|
batch_audio_eb.append(frame_audio_embed) |
|
|
batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], |
|
|
dim=0) |
|
|
|
|
|
return batch_audio_eb, min_batch_num |
|
|
|
|
|
|
|
|
|
|
|
NODE_CLASS_MAPPINGS = { |
|
|
"WanVideoAddS2VEmbeds": WanVideoAddS2VEmbeds, |
|
|
} |
|
|
|
|
|
NODE_DISPLAY_NAME_MAPPINGS = { |
|
|
"WanVideoAddS2VEmbeds": "WanVideo Add S2V Embeds", |
|
|
} |