ColabWan / models /longcat /modules /attention.py
1ripon1's picture
Upload folder using huggingface_hub
7344bef verified
Raw
History Blame Contribute Delete
9.05 kB
from typing import List, Optional
import torch
import torch.nn as nn
from shared.attention import pay_attention
from .rope_3d import RotaryPositionalEmbedding
from .blocks import RMSNorm_FP32, _take_tensor
def _run_attention(x_list, out_dtype, **attn_kwargs):
q, k, v = x_list
if out_dtype in (torch.float16, torch.bfloat16):
attn_dtype = out_dtype
else:
attn_dtype = torch.bfloat16
if q.dtype != attn_dtype:
q = q.to(attn_dtype)
k = k.to(attn_dtype)
v = v.to(attn_dtype)
x_list[:] = [q, k, v]
del q, k, v
attn_kwargs.setdefault("recycle_q", True)
x = pay_attention(x_list, **attn_kwargs)
if x.dtype != out_dtype:
x = x.to(out_dtype)
return x
def _run_sparse_attention(x_list, out_dtype, shape, bsa_params, **attn_kwargs):
raise NotImplementedError("LongCat sparse/BSA attention is not wired to WanGP shared attention.")
class Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
enable_flashattn3: bool = False,
enable_flashattn2: bool = False,
enable_xformers: bool = False,
enable_bsa: bool = False,
bsa_params: dict = None,
cp_split_hw: Optional[List[int]] = None
) -> None:
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.enable_flashattn3 = enable_flashattn3
self.enable_flashattn2 = enable_flashattn2
self.enable_xformers = enable_xformers
self.enable_bsa = enable_bsa
self.bsa_params = bsa_params
self.cp_split_hw = cp_split_hw
self.qkv = nn.Linear(dim, dim * 3, bias=True)
self.q_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
self.k_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
self.proj = nn.Linear(dim, dim)
self.rope_3d = RotaryPositionalEmbedding(
self.head_dim,
cp_split_hw=cp_split_hw
)
def _process_attn(self, q, k, v, shape, out_dtype):
"""
function wrapper to do attention with q, k, v
"""
if self.enable_bsa:
return _run_sparse_attention([q, k, v], out_dtype, shape, self.bsa_params)
return _run_attention([q, k, v], out_dtype)
def forward(self, x: torch.Tensor, shape=None, num_cond_latents=None, return_kv=False) -> torch.Tensor:
"""
"""
x = _take_tensor(x)
B, N, C = x.shape
out_dtype = x.dtype
qkv = self.qkv(x)
x = None
if qkv.dtype != out_dtype:
qkv = qkv.to(out_dtype)
qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.view(qkv_shape)
q, k, v = qkv.unbind(2)
q, k = self.q_norm(q), self.k_norm(k)
v = v.contiguous()
del qkv
if return_kv:
k_cache, v_cache = k.clone(), v.clone()
q, k = self.rope_3d(q, k, shape)
# cond mode
if num_cond_latents is not None and num_cond_latents > 0:
num_cond_latents_thw = num_cond_latents * (N // shape[0])
# process the condition tokens
q_cond = q[:, :num_cond_latents_thw].contiguous()
k_cond = k[:, :num_cond_latents_thw].contiguous()
v_cond = v[:, :num_cond_latents_thw].contiguous()
x_cond = self._process_attn(q_cond, k_cond, v_cond, shape, out_dtype)
# process the noise tokens
q_noise = q[:, num_cond_latents_thw:].contiguous()
x_noise = self._process_attn(q_noise, k, v, shape, out_dtype)
# merge x_cond and x_noise
x = x_cond.new_empty(B, N, self.num_heads, self.head_dim)
x[:, :num_cond_latents_thw].copy_(x_cond)
x[:, num_cond_latents_thw:].copy_(x_noise)
del x_cond, x_noise
else:
x = self._process_attn(q, k, v, shape, out_dtype)
q = k = v = None
x_output_shape = (B, N, C)
x = x.reshape(x_output_shape)
x = self.proj(x)
if return_kv:
return x, (k_cache, v_cache)
else:
return x
def forward_with_kv_cache(self, x: torch.Tensor, shape=None, num_cond_latents=None, kv_cache=None) -> torch.Tensor:
"""
"""
x = _take_tensor(x)
B, N, C = x.shape
out_dtype = x.dtype
qkv = self.qkv(x)
x = None
if qkv.dtype != out_dtype:
qkv = qkv.to(out_dtype)
qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.view(qkv_shape)
q, k, v = qkv.unbind(2)
q, k = self.q_norm(q), self.k_norm(k)
v = v.contiguous()
del qkv
T, H, W = shape
k_cache, v_cache = kv_cache
if k_cache.shape[0] == 1 and B > 1:
k_cache = k_cache.repeat(B, 1, 1, 1)
v_cache = v_cache.repeat(B, 1, 1, 1)
if num_cond_latents is not None and num_cond_latents > 0:
k_full = torch.cat([k_cache, k], dim=1).contiguous()
v_full = torch.cat([v_cache, v], dim=1).contiguous()
q_padding = torch.cat([torch.empty_like(k_cache), q], dim=1).contiguous()
q_padding, k_full = self.rope_3d(q_padding, k_full, (T + num_cond_latents, H, W))
q = q_padding[:, -N:].contiguous()
del q_padding
else:
k_full = k
v_full = v
x = self._process_attn(q, k_full, v_full, shape, out_dtype)
q = k = v = k_full = v_full = None
x_output_shape = (B, N, C)
x = x.reshape(x_output_shape)
x = self.proj(x)
return x
class MultiHeadCrossAttention(nn.Module):
def __init__(
self,
dim,
num_heads,
enable_flashattn3=False,
enable_flashattn2=False,
enable_xformers=False,
):
super(MultiHeadCrossAttention, self).__init__()
assert dim % num_heads == 0, "d_model must be divisible by num_heads"
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q_linear = nn.Linear(dim, dim)
self.kv_linear = nn.Linear(dim, dim * 2)
self.proj = nn.Linear(dim, dim)
self.q_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
self.k_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
self.enable_flashattn3 = enable_flashattn3
self.enable_flashattn2 = enable_flashattn2
self.enable_xformers = enable_xformers
def _process_cross_attn(self, x, cond, kv_seqlen):
x = _take_tensor(x)
cond = _take_tensor(cond)
B, N, C = x.shape
assert C == self.dim and cond.shape[2] == self.dim
out_dtype = x.dtype
q = self.q_linear(x).view(B, N, self.num_heads, self.head_dim)
x = None
if q.dtype != out_dtype:
q = q.to(out_dtype)
kv = self.kv_linear(cond).view(B, -1, 2, self.num_heads, self.head_dim)
cond = None
if kv.dtype != out_dtype:
kv = kv.to(out_dtype)
k, v = kv.unbind(2)
v = v.contiguous()
del kv
q, k = self.q_norm(q), self.k_norm(k)
k_lens = kv_seqlen
if k_lens is not None:
if isinstance(k_lens, torch.Tensor):
k_lens = k_lens.tolist() if B > 1 else k_lens.to(q.device)
elif isinstance(k_lens, list) and B == 1:
k_lens = torch.tensor(k_lens, device=q.device)
qkv_list = [q, k, v]
del q, k, v
x = _run_attention(qkv_list, out_dtype, k_lens=k_lens)
x = x.view(B, N, C)
x = self.proj(x)
return x
def forward_noise(self, x, cond, kv_seqlen, num_cond_latents=None, shape=None):
x = _take_tensor(x)
if num_cond_latents is None or num_cond_latents == 0:
x_list = [x]
x = None
return 0, self._process_cross_attn(x_list, cond, kv_seqlen)
assert shape is not None, "SHOULD pass in the shape"
B, N, C = x.shape
num_cond_latents_thw = num_cond_latents * (N // shape[0])
x_noise = x[:, num_cond_latents_thw:]
x = None
x_list = [x_noise]
x_noise = None
return num_cond_latents_thw, self._process_cross_attn(x_list, cond, kv_seqlen)
def forward(self, x, cond, kv_seqlen, num_cond_latents=None, shape=None):
"""
x: [B, N, C]
cond: [B, M, C]
"""
x = _take_tensor(x)
B, N, C = x.shape
x_list = [x]
x = None
cond_tokens, output_noise = self.forward_noise(x_list, cond, kv_seqlen, num_cond_latents=num_cond_latents, shape=shape)
if cond_tokens == 0:
return output_noise
output = output_noise.new_zeros(B, N, C)
output[:, cond_tokens:].copy_(output_noise)
return output