vidfom's picture
Upload folder using huggingface_hub
31112ad verified
from typing import List, Optional
import torch
import torch.nn as nn
from shared.attention import pay_attention #, pay_sparse_attention
from .rope_3d import RotaryPositionalEmbedding
from .blocks import RMSNorm_FP32
def _run_attention(x_list, out_dtype, **attn_kwargs):
q, k, v = x_list
if out_dtype in (torch.float16, torch.bfloat16):
attn_dtype = out_dtype
else:
attn_dtype = torch.bfloat16
if q.dtype != attn_dtype:
q = q.to(attn_dtype)
k = k.to(attn_dtype)
v = v.to(attn_dtype)
x_list[:] = [q, k, v]
x = pay_attention(x_list, **attn_kwargs)
x_list[:] = []
if x.dtype != out_dtype:
x = x.to(out_dtype)
return x
def _run_sparse_attention(x_list, out_dtype, shape, bsa_params, **attn_kwargs):
q, k, v = x_list
if out_dtype in (torch.float16, torch.bfloat16):
attn_dtype = out_dtype
else:
attn_dtype = torch.bfloat16
if q.dtype != attn_dtype:
q = q.to(attn_dtype)
k = k.to(attn_dtype)
v = v.to(attn_dtype)
x_list[:] = [q, k, v]
x = pay_sparse_attention(x_list, shape=shape, bsa_params=bsa_params, **attn_kwargs)
x_list[:] = []
if x.dtype != out_dtype:
x = x.to(out_dtype)
return x
class Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
enable_flashattn3: bool = False,
enable_flashattn2: bool = False,
enable_xformers: bool = False,
enable_bsa: bool = False,
bsa_params: dict = None,
cp_split_hw: Optional[List[int]] = None
) -> None:
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.enable_flashattn3 = enable_flashattn3
self.enable_flashattn2 = enable_flashattn2
self.enable_xformers = enable_xformers
self.enable_bsa = enable_bsa
self.bsa_params = bsa_params
self.cp_split_hw = cp_split_hw
self.qkv = nn.Linear(dim, dim * 3, bias=True)
self.q_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
self.k_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
self.proj = nn.Linear(dim, dim)
self.rope_3d = RotaryPositionalEmbedding(
self.head_dim,
cp_split_hw=cp_split_hw
)
def _process_attn(self, q, k, v, shape, out_dtype):
"""
function wrapper to do attention with q, k, v
"""
if self.enable_bsa:
return _run_sparse_attention([q, k, v], out_dtype, shape, self.bsa_params)
return _run_attention([q, k, v], out_dtype)
def forward(self, x: torch.Tensor, shape=None, num_cond_latents=None, return_kv=False) -> torch.Tensor:
"""
"""
B, N, C = x.shape
out_dtype = x.dtype
qkv = self.qkv(x)
if qkv.dtype != out_dtype:
qkv = qkv.to(out_dtype)
qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.view(qkv_shape)
q, k, v = qkv.unbind(2)
q, k = self.q_norm(q), self.k_norm(k)
if return_kv:
k_cache, v_cache = k.clone(), v.clone()
q, k = self.rope_3d(q, k, shape)
# cond mode
if num_cond_latents is not None and num_cond_latents > 0:
num_cond_latents_thw = num_cond_latents * (N // shape[0])
# process the condition tokens
q_cond = q[:, :num_cond_latents_thw].contiguous()
k_cond = k[:, :num_cond_latents_thw].contiguous()
v_cond = v[:, :num_cond_latents_thw].contiguous()
x_cond = self._process_attn(q_cond, k_cond, v_cond, shape, out_dtype)
# process the noise tokens
q_noise = q[:, num_cond_latents_thw:].contiguous()
x_noise = self._process_attn(q_noise, k, v, shape, out_dtype)
# merge x_cond and x_noise
x = torch.cat([x_cond, x_noise], dim=1).contiguous()
else:
x = self._process_attn(q, k, v, shape, out_dtype)
x_output_shape = (B, N, C)
x = x.reshape(x_output_shape)
x = self.proj(x)
if return_kv:
return x, (k_cache, v_cache)
else:
return x
def forward_with_kv_cache(self, x: torch.Tensor, shape=None, num_cond_latents=None, kv_cache=None) -> torch.Tensor:
"""
"""
B, N, C = x.shape
out_dtype = x.dtype
qkv = self.qkv(x)
if qkv.dtype != out_dtype:
qkv = qkv.to(out_dtype)
qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.view(qkv_shape)
q, k, v = qkv.unbind(2)
q, k = self.q_norm(q), self.k_norm(k)
T, H, W = shape
k_cache, v_cache = kv_cache
if k_cache.shape[0] == 1 and B > 1:
k_cache = k_cache.repeat(B, 1, 1, 1)
v_cache = v_cache.repeat(B, 1, 1, 1)
if num_cond_latents is not None and num_cond_latents > 0:
k_full = torch.cat([k_cache, k], dim=1).contiguous()
v_full = torch.cat([v_cache, v], dim=1).contiguous()
q_padding = torch.cat([torch.empty_like(k_cache), q], dim=1).contiguous()
q_padding, k_full = self.rope_3d(q_padding, k_full, (T + num_cond_latents, H, W))
q = q_padding[:, -N:].contiguous()
else:
k_full = k
v_full = v
x = self._process_attn(q, k_full, v_full, shape, out_dtype)
x_output_shape = (B, N, C)
x = x.reshape(x_output_shape)
x = self.proj(x)
return x
class MultiHeadCrossAttention(nn.Module):
def __init__(
self,
dim,
num_heads,
enable_flashattn3=False,
enable_flashattn2=False,
enable_xformers=False,
):
super(MultiHeadCrossAttention, self).__init__()
assert dim % num_heads == 0, "d_model must be divisible by num_heads"
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q_linear = nn.Linear(dim, dim)
self.kv_linear = nn.Linear(dim, dim * 2)
self.proj = nn.Linear(dim, dim)
self.q_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
self.k_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
self.enable_flashattn3 = enable_flashattn3
self.enable_flashattn2 = enable_flashattn2
self.enable_xformers = enable_xformers
def _process_cross_attn(self, x, cond, kv_seqlen):
B, N, C = x.shape
assert C == self.dim and cond.shape[2] == self.dim
out_dtype = x.dtype
q = self.q_linear(x).view(B, N, self.num_heads, self.head_dim)
if q.dtype != out_dtype:
q = q.to(out_dtype)
kv = self.kv_linear(cond).view(B, -1, 2, self.num_heads, self.head_dim)
if kv.dtype != out_dtype:
kv = kv.to(out_dtype)
k, v = kv.unbind(2)
q, k = self.q_norm(q), self.k_norm(k)
k_lens = kv_seqlen
if k_lens is not None:
if isinstance(k_lens, torch.Tensor):
k_lens = k_lens.tolist() if B > 1 else k_lens.to(q.device)
elif isinstance(k_lens, list) and B == 1:
k_lens = torch.tensor(k_lens, device=q.device)
x = _run_attention([q, k, v], out_dtype, k_lens=k_lens, cross_attn=True)
x = x.view(B, N, C)
x = self.proj(x)
return x
def forward(self, x, cond, kv_seqlen, num_cond_latents=None, shape=None):
"""
x: [B, N, C]
cond: [B, M, C]
"""
if num_cond_latents is None or num_cond_latents == 0:
return self._process_cross_attn(x, cond, kv_seqlen)
else:
B, N, C = x.shape
if num_cond_latents is not None and num_cond_latents > 0:
assert shape is not None, "SHOULD pass in the shape"
num_cond_latents_thw = num_cond_latents * (N // shape[0])
x_noise = x[:, num_cond_latents_thw:] # [B, N_noise, C]
output_noise = self._process_cross_attn(x_noise, cond, kv_seqlen) # [B, N_noise, C]
output = torch.cat([
torch.zeros((B, num_cond_latents_thw, C), dtype=output_noise.dtype, device=output_noise.device),
output_noise
], dim=1).contiguous()
else:
raise NotImplementedError
return output