| | from typing import List, Optional |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| | from shared.attention import pay_attention |
| | from .rope_3d import RotaryPositionalEmbedding |
| | from .blocks import RMSNorm_FP32 |
| |
|
| |
|
| | def _run_attention(x_list, out_dtype, **attn_kwargs): |
| | q, k, v = x_list |
| | if out_dtype in (torch.float16, torch.bfloat16): |
| | attn_dtype = out_dtype |
| | else: |
| | attn_dtype = torch.bfloat16 |
| | if q.dtype != attn_dtype: |
| | q = q.to(attn_dtype) |
| | k = k.to(attn_dtype) |
| | v = v.to(attn_dtype) |
| | x_list[:] = [q, k, v] |
| | x = pay_attention(x_list, **attn_kwargs) |
| | x_list[:] = [] |
| | if x.dtype != out_dtype: |
| | x = x.to(out_dtype) |
| | return x |
| |
|
| |
|
| | def _run_sparse_attention(x_list, out_dtype, shape, bsa_params, **attn_kwargs): |
| | q, k, v = x_list |
| | if out_dtype in (torch.float16, torch.bfloat16): |
| | attn_dtype = out_dtype |
| | else: |
| | attn_dtype = torch.bfloat16 |
| | if q.dtype != attn_dtype: |
| | q = q.to(attn_dtype) |
| | k = k.to(attn_dtype) |
| | v = v.to(attn_dtype) |
| | x_list[:] = [q, k, v] |
| | x = pay_sparse_attention(x_list, shape=shape, bsa_params=bsa_params, **attn_kwargs) |
| | x_list[:] = [] |
| | if x.dtype != out_dtype: |
| | x = x.to(out_dtype) |
| | return x |
| |
|
| |
|
| | class Attention(nn.Module): |
| | def __init__( |
| | self, |
| | dim: int, |
| | num_heads: int, |
| | enable_flashattn3: bool = False, |
| | enable_flashattn2: bool = False, |
| | enable_xformers: bool = False, |
| | enable_bsa: bool = False, |
| | bsa_params: dict = None, |
| | cp_split_hw: Optional[List[int]] = None |
| | ) -> None: |
| | super().__init__() |
| | assert dim % num_heads == 0, "dim should be divisible by num_heads" |
| | self.dim = dim |
| | self.num_heads = num_heads |
| | self.head_dim = dim // num_heads |
| | self.scale = self.head_dim**-0.5 |
| | self.enable_flashattn3 = enable_flashattn3 |
| | self.enable_flashattn2 = enable_flashattn2 |
| | self.enable_xformers = enable_xformers |
| | self.enable_bsa = enable_bsa |
| | self.bsa_params = bsa_params |
| | self.cp_split_hw = cp_split_hw |
| |
|
| | self.qkv = nn.Linear(dim, dim * 3, bias=True) |
| | self.q_norm = RMSNorm_FP32(self.head_dim, eps=1e-6) |
| | self.k_norm = RMSNorm_FP32(self.head_dim, eps=1e-6) |
| | self.proj = nn.Linear(dim, dim) |
| |
|
| | self.rope_3d = RotaryPositionalEmbedding( |
| | self.head_dim, |
| | cp_split_hw=cp_split_hw |
| | ) |
| |
|
| | def _process_attn(self, q, k, v, shape, out_dtype): |
| | """ |
| | function wrapper to do attention with q, k, v |
| | """ |
| | if self.enable_bsa: |
| | return _run_sparse_attention([q, k, v], out_dtype, shape, self.bsa_params) |
| | return _run_attention([q, k, v], out_dtype) |
| |
|
| | def forward(self, x: torch.Tensor, shape=None, num_cond_latents=None, return_kv=False) -> torch.Tensor: |
| | """ |
| | """ |
| | B, N, C = x.shape |
| | out_dtype = x.dtype |
| | qkv = self.qkv(x) |
| | if qkv.dtype != out_dtype: |
| | qkv = qkv.to(out_dtype) |
| |
|
| | qkv_shape = (B, N, 3, self.num_heads, self.head_dim) |
| | qkv = qkv.view(qkv_shape) |
| | q, k, v = qkv.unbind(2) |
| | q, k = self.q_norm(q), self.k_norm(k) |
| |
|
| | if return_kv: |
| | k_cache, v_cache = k.clone(), v.clone() |
| |
|
| | q, k = self.rope_3d(q, k, shape) |
| |
|
| | |
| | if num_cond_latents is not None and num_cond_latents > 0: |
| | num_cond_latents_thw = num_cond_latents * (N // shape[0]) |
| | |
| | q_cond = q[:, :num_cond_latents_thw].contiguous() |
| | k_cond = k[:, :num_cond_latents_thw].contiguous() |
| | v_cond = v[:, :num_cond_latents_thw].contiguous() |
| | x_cond = self._process_attn(q_cond, k_cond, v_cond, shape, out_dtype) |
| | |
| | q_noise = q[:, num_cond_latents_thw:].contiguous() |
| | x_noise = self._process_attn(q_noise, k, v, shape, out_dtype) |
| | |
| | x = torch.cat([x_cond, x_noise], dim=1).contiguous() |
| | else: |
| | x = self._process_attn(q, k, v, shape, out_dtype) |
| |
|
| | x_output_shape = (B, N, C) |
| | x = x.reshape(x_output_shape) |
| | x = self.proj(x) |
| |
|
| | if return_kv: |
| | return x, (k_cache, v_cache) |
| | else: |
| | return x |
| |
|
| | def forward_with_kv_cache(self, x: torch.Tensor, shape=None, num_cond_latents=None, kv_cache=None) -> torch.Tensor: |
| | """ |
| | """ |
| | B, N, C = x.shape |
| | out_dtype = x.dtype |
| | qkv = self.qkv(x) |
| | if qkv.dtype != out_dtype: |
| | qkv = qkv.to(out_dtype) |
| | |
| | qkv_shape = (B, N, 3, self.num_heads, self.head_dim) |
| | qkv = qkv.view(qkv_shape) |
| | q, k, v = qkv.unbind(2) |
| | q, k = self.q_norm(q), self.k_norm(k) |
| |
|
| | T, H, W = shape |
| | k_cache, v_cache = kv_cache |
| | if k_cache.shape[0] == 1 and B > 1: |
| | k_cache = k_cache.repeat(B, 1, 1, 1) |
| | v_cache = v_cache.repeat(B, 1, 1, 1) |
| | |
| | if num_cond_latents is not None and num_cond_latents > 0: |
| | k_full = torch.cat([k_cache, k], dim=1).contiguous() |
| | v_full = torch.cat([v_cache, v], dim=1).contiguous() |
| | q_padding = torch.cat([torch.empty_like(k_cache), q], dim=1).contiguous() |
| | q_padding, k_full = self.rope_3d(q_padding, k_full, (T + num_cond_latents, H, W)) |
| | q = q_padding[:, -N:].contiguous() |
| | else: |
| | k_full = k |
| | v_full = v |
| | |
| | x = self._process_attn(q, k_full, v_full, shape, out_dtype) |
| | |
| | x_output_shape = (B, N, C) |
| | x = x.reshape(x_output_shape) |
| | x = self.proj(x) |
| |
|
| | return x |
| |
|
| |
|
| | class MultiHeadCrossAttention(nn.Module): |
| | def __init__( |
| | self, |
| | dim, |
| | num_heads, |
| | enable_flashattn3=False, |
| | enable_flashattn2=False, |
| | enable_xformers=False, |
| | ): |
| | super(MultiHeadCrossAttention, self).__init__() |
| | assert dim % num_heads == 0, "d_model must be divisible by num_heads" |
| |
|
| | self.dim = dim |
| | self.num_heads = num_heads |
| | self.head_dim = dim // num_heads |
| |
|
| | self.q_linear = nn.Linear(dim, dim) |
| | self.kv_linear = nn.Linear(dim, dim * 2) |
| | self.proj = nn.Linear(dim, dim) |
| |
|
| | self.q_norm = RMSNorm_FP32(self.head_dim, eps=1e-6) |
| | self.k_norm = RMSNorm_FP32(self.head_dim, eps=1e-6) |
| |
|
| | self.enable_flashattn3 = enable_flashattn3 |
| | self.enable_flashattn2 = enable_flashattn2 |
| | self.enable_xformers = enable_xformers |
| |
|
| | def _process_cross_attn(self, x, cond, kv_seqlen): |
| | B, N, C = x.shape |
| | assert C == self.dim and cond.shape[2] == self.dim |
| | out_dtype = x.dtype |
| |
|
| | q = self.q_linear(x).view(B, N, self.num_heads, self.head_dim) |
| | if q.dtype != out_dtype: |
| | q = q.to(out_dtype) |
| | kv = self.kv_linear(cond).view(B, -1, 2, self.num_heads, self.head_dim) |
| | if kv.dtype != out_dtype: |
| | kv = kv.to(out_dtype) |
| | k, v = kv.unbind(2) |
| |
|
| | q, k = self.q_norm(q), self.k_norm(k) |
| |
|
| | k_lens = kv_seqlen |
| | if k_lens is not None: |
| | if isinstance(k_lens, torch.Tensor): |
| | k_lens = k_lens.tolist() if B > 1 else k_lens.to(q.device) |
| | elif isinstance(k_lens, list) and B == 1: |
| | k_lens = torch.tensor(k_lens, device=q.device) |
| |
|
| | x = _run_attention([q, k, v], out_dtype, k_lens=k_lens, cross_attn=True) |
| | x = x.view(B, N, C) |
| | x = self.proj(x) |
| | return x |
| |
|
| | def forward(self, x, cond, kv_seqlen, num_cond_latents=None, shape=None): |
| | """ |
| | x: [B, N, C] |
| | cond: [B, M, C] |
| | """ |
| | if num_cond_latents is None or num_cond_latents == 0: |
| | return self._process_cross_attn(x, cond, kv_seqlen) |
| | else: |
| | B, N, C = x.shape |
| | if num_cond_latents is not None and num_cond_latents > 0: |
| | assert shape is not None, "SHOULD pass in the shape" |
| | num_cond_latents_thw = num_cond_latents * (N // shape[0]) |
| | x_noise = x[:, num_cond_latents_thw:] |
| | output_noise = self._process_cross_attn(x_noise, cond, kv_seqlen) |
| | output = torch.cat([ |
| | torch.zeros((B, num_cond_latents_thw, C), dtype=output_noise.dtype, device=output_noise.device), |
| | output_noise |
| | ], dim=1).contiguous() |
| | else: |
| | raise NotImplementedError |
| | |
| | return output |
| |
|