from typing import Optional import torch import torch.nn as nn from diffusers.models.attention import BasicTransformerBlock from diffusers.models.attention_processor import Attention, AttnProcessor from flash_attn.flash_attn_interface import flash_attn_func class InflatedAttentionProcessor(AttnProcessor): def __call__( self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None, num_views: int = 6, # 例如 CubeMap 有 6 个视角 *args, **kwargs, ) -> torch.Tensor: """ 实现 CubeDiff 论文中的 Inflated Attention: - 将输入 `B, N, C` 转换为 `B, F*N, C` - 在 `F*N` 维度上进行 Self-Attention """ residual = hidden_states # 1️⃣ 预处理 if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) BXF, N, C = hidden_states.shape # 原始注意力输入 # 2️⃣ **变换 `B, N, C → B, F*N, C`** F = num_views B=BXF//F # 3️⃣ **标准 Attention 计算** attention_mask = attn.prepare_attention_mask(attention_mask, hidden_states.shape[1], B) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) is_self_attn=False if encoder_hidden_states is None: hidden_states = hidden_states.view(B, F, N, C) hidden_states = hidden_states.reshape(B, F * N, C) encoder_hidden_states = hidden_states is_self_attn=True elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) query = attn.to_q(hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query,out_dim=4).permute(0,2,1,3) key = attn.head_to_batch_dim(key,out_dim=4).permute(0,2,1,3) value = attn.head_to_batch_dim(value,out_dim=4).permute(0,2,1,3) hidden_states = flash_attn_func(query, key, value, dropout_p=0.0, causal=False) B,L,H,D=hidden_states.shape hidden_states = hidden_states.view(B,L,H*D) # query = attn.head_to_batch_dim(query) # key = attn.head_to_batch_dim(key) # value = attn.head_to_batch_dim(value) # attention_probs = attn.get_attention_scores(query, key, attention_mask) # hidden_states = torch.bmm(attention_probs, value) # hidden_states = attn.batch_to_head_dim(hidden_states) # 4️⃣ **线性投影 & Dropout** hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) if is_self_attn: # 5️⃣ **还原形状 `B, F*N, C → B, N, C`** hidden_states = hidden_states.view(B, F, N, C) hidden_states = hidden_states.reshape(BXF, N, C) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states