File size: 3,979 Bytes
a521a3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from typing import Optional
import torch
import torch.nn as nn
from diffusers.models.attention import BasicTransformerBlock
from diffusers.models.attention_processor import Attention, AttnProcessor
from flash_attn.flash_attn_interface import flash_attn_func

class InflatedAttentionProcessor(AttnProcessor):
    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.Tensor,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        temb: Optional[torch.Tensor] = None,
        num_views: int = 6,  # 例如 CubeMap 有 6 个视角
        *args,
        **kwargs,
    ) -> torch.Tensor:
            """
            实现 CubeDiff 论文中的 Inflated Attention:
            - 将输入 `B, N, C` 转换为 `B, F*N, C`
            - 在 `F*N` 维度上进行 Self-Attention
            """
            residual = hidden_states

            # 1️⃣ 预处理
            if attn.spatial_norm is not None:
                hidden_states = attn.spatial_norm(hidden_states, temb)

            input_ndim = hidden_states.ndim

            if input_ndim == 4:
                batch_size, channel, height, width = hidden_states.shape
                hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)


            BXF, N, C = hidden_states.shape  # 原始注意力输入
            # 2️⃣ **变换 `B, N, C → B, F*N, C`**
            F = num_views
            B=BXF//F
            

            # 3️⃣ **标准 Attention 计算**
            attention_mask = attn.prepare_attention_mask(attention_mask, hidden_states.shape[1], B)

            if attn.group_norm is not None:
                hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

            is_self_attn=False

            if encoder_hidden_states is None:
                hidden_states = hidden_states.view(B, F, N, C)
                hidden_states = hidden_states.reshape(B, F * N, C)
                encoder_hidden_states = hidden_states
                is_self_attn=True
            elif attn.norm_cross:
                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
                
            query = attn.to_q(hidden_states)
            key = attn.to_k(encoder_hidden_states)
            value = attn.to_v(encoder_hidden_states)

            query = attn.head_to_batch_dim(query,out_dim=4).permute(0,2,1,3)
            key = attn.head_to_batch_dim(key,out_dim=4).permute(0,2,1,3)
            value = attn.head_to_batch_dim(value,out_dim=4).permute(0,2,1,3)
            
            hidden_states = flash_attn_func(query, key, value, dropout_p=0.0, causal=False)
            B,L,H,D=hidden_states.shape
            hidden_states = hidden_states.view(B,L,H*D)
            
            # query = attn.head_to_batch_dim(query)
            # key = attn.head_to_batch_dim(key)
            # value = attn.head_to_batch_dim(value)

            # attention_probs = attn.get_attention_scores(query, key, attention_mask)
            # hidden_states = torch.bmm(attention_probs, value)
            # hidden_states = attn.batch_to_head_dim(hidden_states)
            
            # 4️⃣ **线性投影 & Dropout**
            hidden_states = attn.to_out[0](hidden_states)
            hidden_states = attn.to_out[1](hidden_states)

            if is_self_attn:
                # 5️⃣ **还原形状 `B, F*N, C → B, N, C`**
                hidden_states = hidden_states.view(B, F, N, C)
                hidden_states = hidden_states.reshape(BXF, N, C)

            if input_ndim == 4:
                hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

            if attn.residual_connection:
                hidden_states = hidden_states + residual

            hidden_states = hidden_states / attn.rescale_output_factor

            return hidden_states