| from typing import List, Optional |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from shared.attention import pay_attention |
| from .rope_3d import RotaryPositionalEmbedding |
| from .blocks import RMSNorm_FP32, _take_tensor |
|
|
|
|
| def _run_attention(x_list, out_dtype, **attn_kwargs): |
| q, k, v = x_list |
| if out_dtype in (torch.float16, torch.bfloat16): |
| attn_dtype = out_dtype |
| else: |
| attn_dtype = torch.bfloat16 |
| if q.dtype != attn_dtype: |
| q = q.to(attn_dtype) |
| k = k.to(attn_dtype) |
| v = v.to(attn_dtype) |
| x_list[:] = [q, k, v] |
| del q, k, v |
| attn_kwargs.setdefault("recycle_q", True) |
| x = pay_attention(x_list, **attn_kwargs) |
| if x.dtype != out_dtype: |
| x = x.to(out_dtype) |
| return x |
|
|
|
|
| def _run_sparse_attention(x_list, out_dtype, shape, bsa_params, **attn_kwargs): |
| raise NotImplementedError("LongCat sparse/BSA attention is not wired to WanGP shared attention.") |
|
|
|
|
| class Attention(nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| num_heads: int, |
| enable_flashattn3: bool = False, |
| enable_flashattn2: bool = False, |
| enable_xformers: bool = False, |
| enable_bsa: bool = False, |
| bsa_params: dict = None, |
| cp_split_hw: Optional[List[int]] = None |
| ) -> None: |
| super().__init__() |
| assert dim % num_heads == 0, "dim should be divisible by num_heads" |
| self.dim = dim |
| self.num_heads = num_heads |
| self.head_dim = dim // num_heads |
| self.scale = self.head_dim**-0.5 |
| self.enable_flashattn3 = enable_flashattn3 |
| self.enable_flashattn2 = enable_flashattn2 |
| self.enable_xformers = enable_xformers |
| self.enable_bsa = enable_bsa |
| self.bsa_params = bsa_params |
| self.cp_split_hw = cp_split_hw |
|
|
| self.qkv = nn.Linear(dim, dim * 3, bias=True) |
| self.q_norm = RMSNorm_FP32(self.head_dim, eps=1e-6) |
| self.k_norm = RMSNorm_FP32(self.head_dim, eps=1e-6) |
| self.proj = nn.Linear(dim, dim) |
|
|
| self.rope_3d = RotaryPositionalEmbedding( |
| self.head_dim, |
| cp_split_hw=cp_split_hw |
| ) |
|
|
| def _process_attn(self, q, k, v, shape, out_dtype): |
| """ |
| function wrapper to do attention with q, k, v |
| """ |
| if self.enable_bsa: |
| return _run_sparse_attention([q, k, v], out_dtype, shape, self.bsa_params) |
| return _run_attention([q, k, v], out_dtype) |
|
|
| def forward(self, x: torch.Tensor, shape=None, num_cond_latents=None, return_kv=False) -> torch.Tensor: |
| """ |
| """ |
| x = _take_tensor(x) |
| B, N, C = x.shape |
| out_dtype = x.dtype |
| qkv = self.qkv(x) |
| x = None |
| if qkv.dtype != out_dtype: |
| qkv = qkv.to(out_dtype) |
|
|
| qkv_shape = (B, N, 3, self.num_heads, self.head_dim) |
| qkv = qkv.view(qkv_shape) |
| q, k, v = qkv.unbind(2) |
| q, k = self.q_norm(q), self.k_norm(k) |
| v = v.contiguous() |
| del qkv |
|
|
| if return_kv: |
| k_cache, v_cache = k.clone(), v.clone() |
|
|
| q, k = self.rope_3d(q, k, shape) |
|
|
| |
| if num_cond_latents is not None and num_cond_latents > 0: |
| num_cond_latents_thw = num_cond_latents * (N // shape[0]) |
| |
| q_cond = q[:, :num_cond_latents_thw].contiguous() |
| k_cond = k[:, :num_cond_latents_thw].contiguous() |
| v_cond = v[:, :num_cond_latents_thw].contiguous() |
| x_cond = self._process_attn(q_cond, k_cond, v_cond, shape, out_dtype) |
| |
| q_noise = q[:, num_cond_latents_thw:].contiguous() |
| x_noise = self._process_attn(q_noise, k, v, shape, out_dtype) |
| |
| x = x_cond.new_empty(B, N, self.num_heads, self.head_dim) |
| x[:, :num_cond_latents_thw].copy_(x_cond) |
| x[:, num_cond_latents_thw:].copy_(x_noise) |
| del x_cond, x_noise |
| else: |
| x = self._process_attn(q, k, v, shape, out_dtype) |
| q = k = v = None |
|
|
| x_output_shape = (B, N, C) |
| x = x.reshape(x_output_shape) |
| x = self.proj(x) |
|
|
| if return_kv: |
| return x, (k_cache, v_cache) |
| else: |
| return x |
|
|
| def forward_with_kv_cache(self, x: torch.Tensor, shape=None, num_cond_latents=None, kv_cache=None) -> torch.Tensor: |
| """ |
| """ |
| x = _take_tensor(x) |
| B, N, C = x.shape |
| out_dtype = x.dtype |
| qkv = self.qkv(x) |
| x = None |
| if qkv.dtype != out_dtype: |
| qkv = qkv.to(out_dtype) |
| |
| qkv_shape = (B, N, 3, self.num_heads, self.head_dim) |
| qkv = qkv.view(qkv_shape) |
| q, k, v = qkv.unbind(2) |
| q, k = self.q_norm(q), self.k_norm(k) |
| v = v.contiguous() |
| del qkv |
|
|
| T, H, W = shape |
| k_cache, v_cache = kv_cache |
| if k_cache.shape[0] == 1 and B > 1: |
| k_cache = k_cache.repeat(B, 1, 1, 1) |
| v_cache = v_cache.repeat(B, 1, 1, 1) |
| |
| if num_cond_latents is not None and num_cond_latents > 0: |
| k_full = torch.cat([k_cache, k], dim=1).contiguous() |
| v_full = torch.cat([v_cache, v], dim=1).contiguous() |
| q_padding = torch.cat([torch.empty_like(k_cache), q], dim=1).contiguous() |
| q_padding, k_full = self.rope_3d(q_padding, k_full, (T + num_cond_latents, H, W)) |
| q = q_padding[:, -N:].contiguous() |
| del q_padding |
| else: |
| k_full = k |
| v_full = v |
| |
| x = self._process_attn(q, k_full, v_full, shape, out_dtype) |
| q = k = v = k_full = v_full = None |
| |
| x_output_shape = (B, N, C) |
| x = x.reshape(x_output_shape) |
| x = self.proj(x) |
|
|
| return x |
|
|
|
|
| class MultiHeadCrossAttention(nn.Module): |
| def __init__( |
| self, |
| dim, |
| num_heads, |
| enable_flashattn3=False, |
| enable_flashattn2=False, |
| enable_xformers=False, |
| ): |
| super(MultiHeadCrossAttention, self).__init__() |
| assert dim % num_heads == 0, "d_model must be divisible by num_heads" |
|
|
| self.dim = dim |
| self.num_heads = num_heads |
| self.head_dim = dim // num_heads |
|
|
| self.q_linear = nn.Linear(dim, dim) |
| self.kv_linear = nn.Linear(dim, dim * 2) |
| self.proj = nn.Linear(dim, dim) |
|
|
| self.q_norm = RMSNorm_FP32(self.head_dim, eps=1e-6) |
| self.k_norm = RMSNorm_FP32(self.head_dim, eps=1e-6) |
|
|
| self.enable_flashattn3 = enable_flashattn3 |
| self.enable_flashattn2 = enable_flashattn2 |
| self.enable_xformers = enable_xformers |
|
|
| def _process_cross_attn(self, x, cond, kv_seqlen): |
| x = _take_tensor(x) |
| cond = _take_tensor(cond) |
| B, N, C = x.shape |
| assert C == self.dim and cond.shape[2] == self.dim |
| out_dtype = x.dtype |
|
|
| q = self.q_linear(x).view(B, N, self.num_heads, self.head_dim) |
| x = None |
| if q.dtype != out_dtype: |
| q = q.to(out_dtype) |
| kv = self.kv_linear(cond).view(B, -1, 2, self.num_heads, self.head_dim) |
| cond = None |
| if kv.dtype != out_dtype: |
| kv = kv.to(out_dtype) |
| k, v = kv.unbind(2) |
| v = v.contiguous() |
| del kv |
|
|
| q, k = self.q_norm(q), self.k_norm(k) |
|
|
| k_lens = kv_seqlen |
| if k_lens is not None: |
| if isinstance(k_lens, torch.Tensor): |
| k_lens = k_lens.tolist() if B > 1 else k_lens.to(q.device) |
| elif isinstance(k_lens, list) and B == 1: |
| k_lens = torch.tensor(k_lens, device=q.device) |
| qkv_list = [q, k, v] |
| del q, k, v |
| x = _run_attention(qkv_list, out_dtype, k_lens=k_lens) |
| x = x.view(B, N, C) |
| x = self.proj(x) |
| return x |
|
|
| def forward_noise(self, x, cond, kv_seqlen, num_cond_latents=None, shape=None): |
| x = _take_tensor(x) |
| if num_cond_latents is None or num_cond_latents == 0: |
| x_list = [x] |
| x = None |
| return 0, self._process_cross_attn(x_list, cond, kv_seqlen) |
| assert shape is not None, "SHOULD pass in the shape" |
| B, N, C = x.shape |
| num_cond_latents_thw = num_cond_latents * (N // shape[0]) |
| x_noise = x[:, num_cond_latents_thw:] |
| x = None |
| x_list = [x_noise] |
| x_noise = None |
| return num_cond_latents_thw, self._process_cross_attn(x_list, cond, kv_seqlen) |
|
|
| def forward(self, x, cond, kv_seqlen, num_cond_latents=None, shape=None): |
| """ |
| x: [B, N, C] |
| cond: [B, M, C] |
| """ |
| x = _take_tensor(x) |
| B, N, C = x.shape |
| x_list = [x] |
| x = None |
| cond_tokens, output_noise = self.forward_noise(x_list, cond, kv_seqlen, num_cond_latents=num_cond_latents, shape=shape) |
| if cond_tokens == 0: |
| return output_noise |
| output = output_noise.new_zeros(B, N, C) |
| output[:, cond_tokens:].copy_(output_noise) |
| return output |
|
|