| import folder_paths |
| import math |
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
|
|
| def get_sample_indices(original_fps, |
| total_frames, |
| target_fps, |
| num_sample, |
| fixed_start=None): |
| required_duration = num_sample / target_fps |
| required_origin_frames = int(np.ceil(required_duration * original_fps)) |
| if required_duration > total_frames / original_fps: |
| raise ValueError("required_duration must be less than video length") |
|
|
| if not fixed_start is None and fixed_start >= 0: |
| start_frame = fixed_start |
| else: |
| max_start = total_frames - required_origin_frames |
| if max_start < 0: |
| raise ValueError("video length is too short") |
| start_frame = np.random.randint(0, max_start + 1) |
| start_time = start_frame / original_fps |
|
|
| end_time = start_time + required_duration |
| time_points = np.linspace(start_time, end_time, num_sample, endpoint=False) |
|
|
| frame_indices = np.round(np.array(time_points) * original_fps).astype(int) |
| frame_indices = np.clip(frame_indices, 0, total_frames - 1) |
| return frame_indices |
|
|
| def linear_interpolation(features, input_fps, output_fps, output_len=None): |
| """ |
| features: shape=[1, T, 512] |
| input_fps: fps for audio, f_a |
| output_fps: fps for video, f_m |
| output_len: video length |
| """ |
| features = features.transpose(1, 2) |
| seq_len = features.shape[2] / float(input_fps) |
| if output_len is None: |
| output_len = int(seq_len * output_fps) |
| output_features = F.interpolate( |
| features, size=output_len, align_corners=True, |
| mode='linear') |
| return output_features.transpose(1, 2) |
|
|
| class WanVideoAddS2VEmbeds: |
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": { |
| "embeds": ("WANVIDIMAGE_EMBEDS",), |
| "frame_window_size": ("INT", {"default": 80, "min": 1, "max": 100000, "step": 1, "tooltip": "Number of frames in a single window"}), |
| "audio_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.1, "tooltip": "Scale factor for audio embeddings"}), |
| "pose_start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percentage for pose embeddings"}), |
| "pose_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percentage for pose embeddings"}) |
| }, |
| "optional": { |
| "audio_encoder_output": ("AUDIO_ENCODER_OUTPUT",), |
| "ref_latent": ("LATENT",), |
| "pose_latent": ("LATENT",), |
| "vae": ("WANVAE",), |
| "enable_framepack": ("BOOLEAN", {"default": False, "tooltip": "Enable Framepack sampling loop, not compatible with context windows"}) |
| } |
| } |
|
|
| RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", "INT",) |
| RETURN_NAMES = ("image_embeds", "audio_frame_count") |
| FUNCTION = "add" |
| CATEGORY = "WanVideoWrapper" |
|
|
| def add(self, embeds, frame_window_size, audio_encoder_output=None, audio_scale=1.0, ref_latent=None, pose_latent=None, vae=None, pose_start_percent=0.0, pose_end_percent=1.0, enable_framepack=False): |
| audio_frame_count=0 |
| if audio_encoder_output is not None: |
| all_layers = audio_encoder_output["encoded_audio_all_layers"] |
| audio_feat = torch.stack(all_layers, dim=0).squeeze(1) |
|
|
| print("audio_feat in", audio_feat.shape) |
| input_fps = 50 |
| output_fps = 30 |
| bucket_fps = 16 |
|
|
| if input_fps != output_fps: |
| audio_feat = linear_interpolation(audio_feat, input_fps=input_fps, output_fps=output_fps) |
| print("audio_feat after interpolation", audio_feat.shape) |
|
|
| audio_feat = audio_feat[:, :embeds["num_frames"] * output_fps // bucket_fps, :] |
| print("audio_feat after trim", audio_feat.shape) |
|
|
| self.video_rate = output_fps |
|
|
| audio_embed_bucket, num_repeat = self.get_audio_embed_bucket_fps( |
| audio_feat, |
| fps=bucket_fps, |
| batch_frames=frame_window_size |
| ) |
| print("audio_embed_bucket", audio_embed_bucket.shape) |
|
|
| audio_embed_bucket = audio_embed_bucket.unsqueeze(0) |
| if len(audio_embed_bucket.shape) == 3: |
| audio_embed_bucket = audio_embed_bucket.permute(0, 2, 1) |
| elif len(audio_embed_bucket.shape) == 4: |
| audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1) |
|
|
| audio_frame_count = audio_embed_bucket.shape[-1] |
|
|
| print("audio_embed_bucket", audio_embed_bucket.shape) |
|
|
| new_entry = { |
| "audio_embed_bucket": audio_embed_bucket if audio_encoder_output is not None else None, |
| "num_repeat": num_repeat if audio_encoder_output is not None else None, |
| "ref_latent": ref_latent["samples"] if ref_latent is not None else None, |
| "pose_latent": pose_latent["samples"] if pose_latent is not None else None, |
| "audio_scale": audio_scale, |
| "vae": vae, |
| "pose_start_percent": pose_start_percent, |
| "pose_end_percent": pose_end_percent, |
| "enable_framepack": enable_framepack, |
| "frame_window_size": frame_window_size |
| } |
| updated = dict(embeds) |
| updated["audio_embeds"] = new_entry |
| return (updated, audio_frame_count) |
| |
| def get_audio_embed_bucket_fps(self, audio_embed, fps=16, batch_frames=81, m=0): |
| num_layers, audio_frame_num, audio_dim = audio_embed.shape |
|
|
| if num_layers > 1: |
| return_all_layers = True |
| else: |
| return_all_layers = False |
|
|
| scale = self.video_rate / fps |
|
|
| min_batch_num = int(audio_frame_num / (batch_frames * scale)) + 1 |
|
|
| bucket_num = min_batch_num * batch_frames |
| padd_audio_num = math.ceil(min_batch_num * batch_frames / fps * self.video_rate) - audio_frame_num |
| batch_idx = get_sample_indices( |
| original_fps=self.video_rate, |
| total_frames=audio_frame_num + padd_audio_num, |
| target_fps=fps, |
| num_sample=bucket_num, |
| fixed_start=0) |
| batch_audio_eb = [] |
| audio_sample_stride = int(self.video_rate / fps) |
| for bi in batch_idx: |
| if bi < audio_frame_num: |
|
|
| chosen_idx = list( |
| range(bi - m * audio_sample_stride, |
| bi + (m + 1) * audio_sample_stride, |
| audio_sample_stride)) |
| chosen_idx = [0 if c < 0 else c for c in chosen_idx] |
| chosen_idx = [ |
| audio_frame_num - 1 if c >= audio_frame_num else c |
| for c in chosen_idx |
| ] |
|
|
| if return_all_layers: |
| frame_audio_embed = audio_embed[:, chosen_idx].flatten( |
| start_dim=-2, end_dim=-1) |
| else: |
| frame_audio_embed = audio_embed[0][chosen_idx].flatten() |
| else: |
| frame_audio_embed = \ |
| torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \ |
| else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device) |
| batch_audio_eb.append(frame_audio_embed) |
| batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], |
| dim=0) |
|
|
| return batch_audio_eb, min_batch_num |
|
|
|
|
| |
| NODE_CLASS_MAPPINGS = { |
| "WanVideoAddS2VEmbeds": WanVideoAddS2VEmbeds, |
| } |
|
|
| NODE_DISPLAY_NAME_MAPPINGS = { |
| "WanVideoAddS2VEmbeds": "WanVideo Add S2V Embeds", |
| } |