""" Based on https://github.com/buoyancy99/diffusion-forcing/blob/main/algorithms/diffusion_forcing/models/attention.py """ from typing import Optional from collections import namedtuple import torch from torch import nn from torch.nn import functional as F from einops import rearrange from .rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb import numpy as np class TemporalAxialAttention(nn.Module): def __init__( self, dim: int, heads: int, dim_head: int, reference_length: int, rotary_emb: RotaryEmbedding, is_causal: bool = True, is_temporal_independent: bool = False, use_domain_adapter = False ): super().__init__() self.inner_dim = dim_head * heads self.heads = heads self.head_dim = dim_head self.inner_dim = dim_head * heads self.to_qkv = nn.Linear(dim, self.inner_dim * 3, bias=False) self.use_domain_adapter = use_domain_adapter if self.use_domain_adapter: lora_rank = 8 self.lora_A = nn.Linear(dim, lora_rank, bias=False) self.lora_B = nn.Linear(lora_rank, self.inner_dim * 3, bias=False) self.to_out = nn.Linear(self.inner_dim, dim) self.rotary_emb = rotary_emb self.is_causal = is_causal self.is_temporal_independent = is_temporal_independent self.reference_length = reference_length def forward(self, x: torch.Tensor): B, T, H, W, D = x.shape # if T>=9: # try: # # x = torch.cat([x[:,:-1],x[:,16-T:17-T],x[:,-1:]], dim=1) # x = torch.cat([x[:,16-T:17-T],x], dim=1) # except: # import pdb;pdb.set_trace() # print("="*50) # print(x.shape) B, T, H, W, D = x.shape q, k, v = self.to_qkv(x).chunk(3, dim=-1) if self.use_domain_adapter: q_lora, k_lora, v_lora = self.lora_B(self.lora_A(x)).chunk(3, dim=-1) q = q+q_lora k = k+k_lora v = v+v_lora q = rearrange(q, "B T H W (h d) -> (B H W) h T d", h=self.heads) k = rearrange(k, "B T H W (h d) -> (B H W) h T d", h=self.heads) v = rearrange(v, "B T H W (h d) -> (B H W) h T d", h=self.heads) q = self.rotary_emb.rotate_queries_or_keys(q, self.rotary_emb.freqs) k = self.rotary_emb.rotate_queries_or_keys(k, self.rotary_emb.freqs) q, k, v = map(lambda t: t.contiguous(), (q, k, v)) if self.is_temporal_independent: attn_bias = torch.ones((T, T), dtype=q.dtype, device=q.device) attn_bias = attn_bias.masked_fill(attn_bias == 1, float('-inf')) attn_bias[range(T), range(T)] = 0 elif self.is_causal: attn_bias = torch.triu(torch.ones((T, T), dtype=q.dtype, device=q.device), diagonal=1) attn_bias = attn_bias.masked_fill(attn_bias == 1, float('-inf')) attn_bias[(T-self.reference_length):] = float('-inf') attn_bias[range(T), range(T)] = 0 else: attn_bias = None try: x = F.scaled_dot_product_attention(query=q, key=k, value=v, attn_mask=attn_bias) except: import pdb;pdb.set_trace() x = rearrange(x, "(B H W) h T d -> B T H W (h d)", B=B, H=H, W=W) x = x.to(q.dtype) # linear proj x = self.to_out(x) # if T>=10: # try: # # x = torch.cat([x[:,:-2],x[:,-1:]], dim=1) # x = x[:,1:] # except: # import pdb;pdb.set_trace() # print(x.shape) return x class SpatialAxialAttention(nn.Module): def __init__( self, dim: int, heads: int, dim_head: int, rotary_emb: RotaryEmbedding, use_domain_adapter = False ): super().__init__() self.inner_dim = dim_head * heads self.heads = heads self.head_dim = dim_head self.inner_dim = dim_head * heads self.to_qkv = nn.Linear(dim, self.inner_dim * 3, bias=False) self.use_domain_adapter = use_domain_adapter if self.use_domain_adapter: lora_rank = 8 self.lora_A = nn.Linear(dim, lora_rank, bias=False) self.lora_B = nn.Linear(lora_rank, self.inner_dim * 3, bias=False) self.to_out = nn.Linear(self.inner_dim, dim) self.rotary_emb = rotary_emb def forward(self, x: torch.Tensor): B, T, H, W, D = x.shape q, k, v = self.to_qkv(x).chunk(3, dim=-1) if self.use_domain_adapter: q_lora, k_lora, v_lora = self.lora_B(self.lora_A(x)).chunk(3, dim=-1) q = q+q_lora k = k+k_lora v = v+v_lora q = rearrange(q, "B T H W (h d) -> (B T) h H W d", h=self.heads) k = rearrange(k, "B T H W (h d) -> (B T) h H W d", h=self.heads) v = rearrange(v, "B T H W (h d) -> (B T) h H W d", h=self.heads) freqs = self.rotary_emb.get_axial_freqs(H, W) q = apply_rotary_emb(freqs, q) k = apply_rotary_emb(freqs, k) # prepare for attn q = rearrange(q, "(B T) h H W d -> (B T) h (H W) d", B=B, T=T, h=self.heads) k = rearrange(k, "(B T) h H W d -> (B T) h (H W) d", B=B, T=T, h=self.heads) v = rearrange(v, "(B T) h H W d -> (B T) h (H W) d", B=B, T=T, h=self.heads) x = F.scaled_dot_product_attention(query=q, key=k, value=v, is_causal=False) x = rearrange(x, "(B T) h (H W) d -> B T H W (h d)", B=B, H=H, W=W) x = x.to(q.dtype) # linear proj x = self.to_out(x) return x class MemTemporalAxialAttention(nn.Module): def __init__( self, dim: int, heads: int, dim_head: int, rotary_emb: RotaryEmbedding, is_causal: bool = True, ): super().__init__() self.inner_dim = dim_head * heads self.heads = heads self.head_dim = dim_head self.inner_dim = dim_head * heads self.to_qkv = nn.Linear(dim, self.inner_dim * 3, bias=False) self.to_out = nn.Linear(self.inner_dim, dim) self.rotary_emb = rotary_emb self.is_causal = is_causal self.reference_length = 3 def forward(self, x: torch.Tensor): B, T, H, W, D = x.shape q, k, v = self.to_qkv(x).chunk(3, dim=-1) q = rearrange(q, "B T H W (h d) -> (B H W) h T d", h=self.heads) k = rearrange(k, "B T H W (h d) -> (B H W) h T d", h=self.heads) v = rearrange(v, "B T H W (h d) -> (B H W) h T d", h=self.heads) # q = self.rotary_emb.rotate_queries_or_keys(q, self.rotary_emb.freqs) # k = self.rotary_emb.rotate_queries_or_keys(k, self.rotary_emb.freqs) q, k, v = map(lambda t: t.contiguous(), (q, k, v)) # if T == 21000: # # 手动计算缩放点积分数 # _, _, _, d_k = q.shape # scores = torch.einsum("b h n d, b h m d -> b h n m", q, k) / (d_k ** 0.5) # Shape: (B, T_q, T_k) # # 计算注意力图 (Attention Map) # attention_map = F.softmax(scores, dim=-1) # Shape: (B, T_q, T_k) # b_, h_, n_, m_ = attention_map.shape # attention_map = attention_map.reshape(1, int(np.sqrt(b_/1)), int(np.sqrt(b_/1)), h_, n_, m_) # attention_map = attention_map.mean(3) # attn_bias = torch.zeros((T, T), dtype=q.dtype, device=q.device) # T_origin = T - self.reference_length # attn_bias[:T_origin, T_origin:] = 1 # attn_bias[range(T), range(T)] = 1 # attention_map = attention_map * attn_bias # # print 注意力图 # import matplotlib.pyplot as plt # fig, axes = plt.subplots(21000, 21000, figsize=(9, 9)) # 调整figsize以适配图像大小 # # 遍历3*3维度 # for i in range(21000): # for j in range(21000): # # 取出第(i, j)个子图像 # img = attention_map[0, :, :, i, j].cpu().numpy() # axes[i, j].imshow(img, cmap='viridis') # 可以自定义cmap # axes[i, j].axis('off') # 隐藏坐标轴 # # 调整子图间距 # plt.tight_layout() # plt.savefig('attention_map.png') # import pdb; pdb.set_trace() # plt.close() attn_bias = torch.zeros((T, T), dtype=q.dtype, device=q.device) attn_bias = attn_bias.masked_fill(attn_bias == 0, float('-inf')) T_origin = T - self.reference_length attn_bias[:T_origin, T_origin:] = 0 attn_bias[range(T), range(T)] = 0 # if T==121000: # import pdb;pdb.set_trace() try: x = F.scaled_dot_product_attention(query=q, key=k, value=v, attn_mask=attn_bias) except: import pdb;pdb.set_trace() x = rearrange(x, "(B H W) h T d -> B T H W (h d)", B=B, H=H, W=W) x = x.to(q.dtype) # linear proj x = self.to_out(x) return x class MemFullAttention(nn.Module): def __init__( self, dim: int, heads: int, dim_head: int, reference_length: int, rotary_emb: RotaryEmbedding, is_causal: bool = True ): super().__init__() self.inner_dim = dim_head * heads self.heads = heads self.head_dim = dim_head self.inner_dim = dim_head * heads self.to_qkv = nn.Linear(dim, self.inner_dim * 3, bias=False) self.to_out = nn.Linear(self.inner_dim, dim) self.rotary_emb = rotary_emb self.is_causal = is_causal self.reference_length = reference_length self.store = None def forward(self, x: torch.Tensor, relative_embedding=False, extra_condition=None, state_embed_only_on_qk=False, reference_length=None): B, T, H, W, D = x.shape if state_embed_only_on_qk: q, k, _ = self.to_qkv(x+extra_condition).chunk(3, dim=-1) _, _, v = self.to_qkv(x).chunk(3, dim=-1) else: q, k, v = self.to_qkv(x).chunk(3, dim=-1) if relative_embedding: length = reference_length+1 n_frames = T // length x = x.reshape(B, n_frames, length, H, W, D) x_list = [] for i in range(n_frames): if i == n_frames-1: q_i = rearrange(q[:, i*length:], "B T H W (h d) -> B h (T H W) d", h=self.heads) k_i = rearrange(k[:, i*length+1:(i+1)*length], "B T H W (h d) -> B h (T H W) d", h=self.heads) v_i = rearrange(v[:, i*length+1:(i+1)*length], "B T H W (h d) -> B h (T H W) d", h=self.heads) else: q_i = rearrange(q[:, i*length:i*length+1], "B T H W (h d) -> B h (T H W) d", h=self.heads) k_i = rearrange(k[:, i*length+1:(i+1)*length], "B T H W (h d) -> B h (T H W) d", h=self.heads) v_i = rearrange(v[:, i*length+1:(i+1)*length], "B T H W (h d) -> B h (T H W) d", h=self.heads) q_i, k_i, v_i = map(lambda t: t.contiguous(), (q_i, k_i, v_i)) x_i = F.scaled_dot_product_attention(query=q_i, key=k_i, value=v_i) x_i = rearrange(x_i, "B h (T H W) d -> B T H W (h d)", B=B, H=H, W=W) x_i = x_i.to(q.dtype) x_list.append(x_i) x = torch.cat(x_list, dim=1) else: T_ = T - reference_length q = rearrange(q, "B T H W (h d) -> B h (T H W) d", h=self.heads) k = rearrange(k[:, T_:], "B T H W (h d) -> B h (T H W) d", h=self.heads) v = rearrange(v[:, T_:], "B T H W (h d) -> B h (T H W) d", h=self.heads) q, k, v = map(lambda t: t.contiguous(), (q, k, v)) x = F.scaled_dot_product_attention(query=q, key=k, value=v) x = rearrange(x, "B h (T H W) d -> B T H W (h d)", B=B, H=H, W=W) x = x.to(q.dtype) # linear proj x = self.to_out(x) return x