aliensmn's picture
Mirror from https://github.com/kijai/ComfyUI-WanVideoWrapper
cf812a0 verified
from einops import rearrange, repeat
import torch
import torch.nn as nn
from ..wanvideo.modules.attention import attention
def timestep_transform(
t,
shift=5.0,
num_timesteps=1000,
):
t = t / num_timesteps
# shift the timestep based on ratio
new_t = shift * t / (1 + (shift - 1) * t)
new_t = new_t * num_timesteps
return new_t
def add_noise(
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
"""
compatible with diffusers add_noise()
"""
timesteps = timesteps.float() / 1000
timesteps = timesteps.view(timesteps.shape + (1,) * (len(noise.shape)-1))
return (1 - timesteps) * original_samples + timesteps * noise
def normalize_and_scale(column, source_range, target_range, epsilon=1e-8):
source_min, source_max = source_range
new_min, new_max = target_range
normalized = (column - source_min) / (source_max - source_min + epsilon)
scaled = normalized * (new_max - new_min) + new_min
return scaled
def rotate_half(x):
x = rearrange(x, "... (d r) -> ... d r", r=2)
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, "... d r -> ... (d r)")
def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, mode='mean', attn_bias=None):
ref_k = ref_k.to(visual_q.dtype).to(visual_q.device)
scale = 1.0 / visual_q.shape[-1] ** 0.5
visual_q = visual_q * scale
visual_q = visual_q.transpose(1, 2)
ref_k = ref_k.transpose(1, 2)
attn = visual_q @ ref_k.transpose(-2, -1)
if attn_bias is not None:
attn = attn + attn_bias
x_ref_attn_map_source = attn.softmax(-1) # B, H, x_seqlens, ref_seqlens
x_ref_attn_maps = []
ref_target_masks = ref_target_masks.to(visual_q.dtype)
x_ref_attn_map_source = x_ref_attn_map_source.to(visual_q.dtype)
for class_idx, ref_target_mask in enumerate(ref_target_masks):
ref_target_mask = ref_target_mask[None, None, None, ...]
x_ref_attnmap = x_ref_attn_map_source * ref_target_mask
x_ref_attnmap = x_ref_attnmap.sum(-1) / ref_target_mask.sum() # B, H, x_seqlens, ref_seqlens --> B, H, x_seqlens
x_ref_attnmap = x_ref_attnmap.permute(0, 2, 1) # B, x_seqlens, H
if mode == 'mean':
x_ref_attnmap = x_ref_attnmap.mean(-1) # B, x_seqlens
elif mode == 'max':
x_ref_attnmap = x_ref_attnmap.max(-1) # B, x_seqlens
x_ref_attn_maps.append(x_ref_attnmap)
del attn, x_ref_attn_map_source
return torch.concat(x_ref_attn_maps, dim=0)
def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, split_num=2):
"""Args:
query (torch.tensor): B M H K
key (torch.tensor): B M H K
shape (tuple): (N_t, N_h, N_w)
ref_target_masks: [B, N_h * N_w]
"""
N_t, N_h, N_w = shape
x_seqlens = N_h * N_w
ref_k = ref_k[:, :x_seqlens]
_, seq_lens, heads, _ = visual_q.shape
class_num, _ = ref_target_masks.shape
x_ref_attn_maps = torch.zeros(class_num, seq_lens).to(visual_q.device).to(visual_q.dtype)
split_chunk = heads // split_num
for i in range(split_num):
x_ref_attn_maps_perhead = calculate_x_ref_attn_map(visual_q[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_k[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_target_masks)
x_ref_attn_maps += x_ref_attn_maps_perhead
return x_ref_attn_maps / split_num
class RotaryPositionalEmbedding1D(nn.Module):
def __init__(self,
head_dim,
):
super().__init__()
self.head_dim = head_dim
self.base = 10000
def precompute_freqs_cis_1d(self, pos_indices):
freqs = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2)[: (self.head_dim // 2)].float() / self.head_dim))
freqs = freqs.to(pos_indices.device)
freqs = torch.einsum("..., f -> ... f", pos_indices.float(), freqs)
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
return freqs
def forward(self, x, pos_indices):
"""1D RoPE.
Args:
query (torch.tensor): [B, head, seq, head_dim]
pos_indices (torch.tensor): [seq,]
Returns:
query with the same shape as input.
"""
freqs_cis = self.precompute_freqs_cis_1d(pos_indices)
x_ = x.float()
freqs_cis = freqs_cis.float().to(x.device)
cos, sin = freqs_cis.cos(), freqs_cis.sin()
cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d')
x_ = (x_ * cos) + (rotate_half(x_) * sin)
return x_.type_as(x)
class AudioProjModel(nn.Module):
def __init__(
self,
seq_len=5,
seq_len_vf=12,
blocks=12,
channels=768,
intermediate_dim=512,
output_dim=768,
context_tokens=32,
norm_output_audio=False,
):
super().__init__()
self.seq_len = seq_len
self.blocks = blocks
self.channels = channels
self.input_dim = seq_len * blocks * channels
self.input_dim_vf = seq_len_vf * blocks * channels
self.intermediate_dim = intermediate_dim
self.context_tokens = context_tokens
self.output_dim = output_dim
# define multiple linear layers
self.proj1 = nn.Linear(self.input_dim, intermediate_dim)
self.proj1_vf = nn.Linear(self.input_dim_vf, intermediate_dim)
self.proj2 = nn.Linear(intermediate_dim, intermediate_dim)
self.proj3 = nn.Linear(intermediate_dim, context_tokens * output_dim)
self.norm = nn.LayerNorm(output_dim) if norm_output_audio else nn.Identity()
def forward(self, audio_embeds, audio_embeds_vf):
video_length = audio_embeds.shape[1] + audio_embeds_vf.shape[1]
B, _, _, S, C = audio_embeds.shape
# process audio of first frame
audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
batch_size, window_size, blocks, channels = audio_embeds.shape
audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
# process audio of latter frame
audio_embeds_vf = rearrange(audio_embeds_vf, "bz f w b c -> (bz f) w b c")
batch_size_vf, window_size_vf, blocks_vf, channels_vf = audio_embeds_vf.shape
audio_embeds_vf = audio_embeds_vf.view(batch_size_vf, window_size_vf * blocks_vf * channels_vf)
# first projection
audio_embeds = torch.relu(self.proj1(audio_embeds))
audio_embeds_vf = torch.relu(self.proj1_vf(audio_embeds_vf))
audio_embeds = rearrange(audio_embeds, "(bz f) c -> bz f c", bz=B)
audio_embeds_vf = rearrange(audio_embeds_vf, "(bz f) c -> bz f c", bz=B)
audio_embeds_c = torch.concat([audio_embeds, audio_embeds_vf], dim=1)
batch_size_c, N_t, C_a = audio_embeds_c.shape
audio_embeds_c = audio_embeds_c.view(batch_size_c*N_t, C_a)
# second projection
audio_embeds_c = torch.relu(self.proj2(audio_embeds_c))
context_tokens = self.proj3(audio_embeds_c).reshape(batch_size_c*N_t, self.context_tokens, self.output_dim)
# normalization and reshape
context_tokens = self.norm(context_tokens)
context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length)
return context_tokens
#@torch.compiler.disable()
class SingleStreamAttention(nn.Module):
def __init__(
self,
dim: int,
encoder_hidden_states_dim: int,
num_heads: int,
qkv_bias: bool,
attention_mode: str = 'sdpa',
) -> None:
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.dim = dim
self.encoder_hidden_states_dim = encoder_hidden_states_dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.attention_mode = attention_mode
self.q_linear = nn.Linear(dim, dim, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.kv_linear = nn.Linear(encoder_hidden_states_dim, dim * 2, bias=qkv_bias)
def forward(self, x: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=None) -> torch.Tensor:
N_t, N_h, N_w = shape
expected_tokens = N_t * N_h * N_w
actual_tokens = x.shape[1]
x_extra = None
if actual_tokens != expected_tokens:
x_extra = x[:, -N_h * N_w:, :]
x = x[:, :-N_h * N_w, :]
N_t = N_t - 1
B = x.shape[0]
S = N_h * N_w
x = x.view(B * N_t, S, self.dim)
# get q for hidden_state
q = self.q_linear(x).view(B * N_t, S, self.num_heads, self.head_dim)
# get kv from encoder_hidden_states # shape: (B, N, num_heads, head_dim)
kv = self.kv_linear(encoder_hidden_states)
encoder_k, encoder_v = kv.view(B * N_t, encoder_hidden_states.shape[1], 2, self.num_heads, self.head_dim).unbind(2)
x = attention(q, encoder_k, encoder_v, attention_mode=self.attention_mode)
# linear transform
x = self.proj(x.reshape(B * N_t, S, self.dim))
x = x.view(B, N_t * S, self.dim)
if x_extra is not None:
x = torch.cat([x, torch.zeros_like(x_extra)], dim=1)
return x
class SingleStreamMultiAttention(SingleStreamAttention):
"""Multi-speaker rotary-position cross-attention.
This implementation generalises the original 2-speaker logic to an arbitrary
number of voices. Each speaker is allocated a contiguous *class_interval*
segment inside a shared *class_range* rotary bucket. The centre of each
bucket is applied to that speaker's KV tokens while queries are modulated
per-token according to which speaker dominates the pixel.
"""
def __init__(
self,
dim: int,
encoder_hidden_states_dim: int,
num_heads: int,
qkv_bias: bool,
class_range: int = 24,
class_interval: int = 4,
attention_mode: str = 'sdpa',
) -> None:
super().__init__(
dim=dim,
encoder_hidden_states_dim=encoder_hidden_states_dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
attention_mode=attention_mode,
)
# Rotary-embedding layout parameters
self.class_interval = class_interval
self.class_range = class_range
self.max_humans = self.class_range // self.class_interval
# Constant bucket used for background tokens
self.rope_bak = int(self.class_range // 2)
self.rope_1d = RotaryPositionalEmbedding1D(self.head_dim)
self.attention_mode = attention_mode
def forward(
self,
x: torch.Tensor,
encoder_hidden_states: torch.Tensor,
shape=None,
x_ref_attn_map=None,
human_num=None,
) -> torch.Tensor:
encoder_hidden_states = encoder_hidden_states.squeeze(0)
# Single-speaker fall-through
if human_num is None or human_num <= 1:
return super().forward(x, encoder_hidden_states, shape)
N_t, N_h, N_w = shape
x_extra = None
if x.shape[0] * N_t != encoder_hidden_states.shape[0]:
x_extra = x[:, -N_h * N_w:, :]
x = x[:, :-N_h * N_w, :]
N_t = N_t - 1
x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t)
# Query projection
B, N, C = x.shape
q = self.q_linear(x)
q = q.view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
if human_num == 2:
# Use `class_range` logic for exactly 2 speakers
rope_h1 = (0, self.class_interval)
rope_h2 = (self.class_range - self.class_interval, self.class_range)
rope_bak = int(self.class_range // 2)
# Normalize and scale attention maps for each speaker
max_values = x_ref_attn_map.max(1).values[:, None, None]
min_values = x_ref_attn_map.min(1).values[:, None, None]
max_min_values = torch.cat([max_values, min_values], dim=2)
human1_max_value, human1_min_value = max_min_values[0, :, 0].max(), max_min_values[0, :, 1].min()
human2_max_value, human2_min_value = max_min_values[1, :, 0].max(), max_min_values[1, :, 1].min()
human1 = normalize_and_scale(x_ref_attn_map[0], (human1_min_value, human1_max_value), rope_h1)
human2 = normalize_and_scale(x_ref_attn_map[1], (human2_min_value, human2_max_value), rope_h2)
back = torch.full((x_ref_attn_map.size(1),), rope_bak, dtype=human1.dtype, device=human1.device)
# Token-wise speaker dominance
max_indices = x_ref_attn_map.argmax(dim=0)
normalized_map = torch.stack([human1, human2, back], dim=1)
normalized_pos = normalized_map[torch.arange(x_ref_attn_map.size(1)), max_indices]
else:
# General case for more than 2 speakers
rope_ranges = [
(i * self.class_interval, (i + 1) * self.class_interval)
for i in range(human_num)
]
# Normalize each speaker's attention map into its own bucket
human_norm_list = []
for idx in range(human_num):
attn_map = x_ref_attn_map[idx]
att_min, att_max = attn_map.min(), attn_map.max()
human_norm = normalize_and_scale(
attn_map, (att_min, att_max), rope_ranges[idx]
)
human_norm_list.append(human_norm)
# Background constant bucket
back = torch.full(
(x_ref_attn_map.size(1),),
self.rope_bak,
dtype=x_ref_attn_map.dtype,
device=x_ref_attn_map.device,
)
# Token-wise speaker dominance
max_indices = x_ref_attn_map.argmax(dim=0)
normalized_map = torch.stack(human_norm_list + [back], dim=1)
normalized_pos = normalized_map[torch.arange(x_ref_attn_map.size(1)), max_indices]
# Apply rotary to Q
q = rearrange(q, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t)
q = self.rope_1d(q, normalized_pos)
q = rearrange(q, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t)
# Keys / Values
_, N_a, _ = encoder_hidden_states.shape
encoder_kv = self.kv_linear(encoder_hidden_states)
encoder_kv = encoder_kv.view(B, N_a, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
encoder_k, encoder_v = encoder_kv.unbind(0)
# Rotary for keys – assign centre of each speaker bucket to its context tokens
if human_num == 2:
per_frame = torch.zeros(N_a, dtype=encoder_k.dtype, device=encoder_k.device)
per_frame[: per_frame.size(0) // 2] = (rope_h1[0] + rope_h1[1]) / 2
per_frame[per_frame.size(0) // 2 :] = (rope_h2[0] + rope_h2[1]) / 2
encoder_pos = torch.cat([per_frame] * N_t, dim=0)
else:
tokens_per_human = N_a // human_num
encoder_pos_list = []
for i in range(human_num):
start, end = rope_ranges[i]
centre = (start + end) / 2
encoder_pos_list.append(
torch.full(
(tokens_per_human,), centre, dtype=encoder_k.dtype, device=encoder_k.device
)
)
encoder_pos = torch.cat(encoder_pos_list * N_t, dim=0)
encoder_k = rearrange(encoder_k, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t)
encoder_k = self.rope_1d(encoder_k, encoder_pos)
encoder_k = rearrange(encoder_k, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t)
# Final attention
q = rearrange(q, "B H M K -> B M H K")
encoder_k = rearrange(encoder_k, "B H M K -> B M H K")
encoder_v = rearrange(encoder_v, "B H M K -> B M H K")
x = attention(
q, encoder_k, encoder_v, attention_mode=self.attention_mode
)
# Linear projection
x = x.reshape(B, N, C)
x = self.proj(x)
# Restore original layout
x = rearrange(x, "(B N_t) S C -> B (N_t S) C", N_t=N_t)
if x_extra is not None:
x = torch.cat([x, torch.zeros_like(x_extra)], dim=1)
return x