File size: 10,833 Bytes
5000658
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional

from ..._common import precision
from ...functional import geglu, matmul, softmax, split
from ...layers import Conv2d, GroupNorm, LayerNorm, Linear
from ...module import Module, ModuleList


class AttentionBlock(Module):

    def __init__(self,
                 channels: int,
                 num_head_channels: Optional[int] = None,
                 num_groups: int = 32,
                 rescale_output_factor: float = 1.0,
                 eps: float = 1e-5):
        super().__init__()
        self.channels = channels

        self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
        self.num_head_size = num_head_channels
        self.group_norm = GroupNorm(num_channels=channels,
                                    num_groups=num_groups,
                                    eps=eps,
                                    affine=True)

        self.qkv = Linear(channels, channels * 3)
        self.rescale_output_factor = rescale_output_factor
        self.proj_attn = Linear(channels, channels, 1)

    def transpose_for_scores(self, projection):
        new_projection_shape = projection.size()[:-1] + (self.num_heads,
                                                         self.num_head_size)
        # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
        new_projection = projection.view(new_projection_shape).permute(
            [0, 2, 1, 3])
        return new_projection

    def forward(self, hidden_states):
        assert not hidden_states.is_dynamic()

        residual = hidden_states
        batch, channel, height, width = hidden_states.size()

        # norm
        hidden_states = self.group_norm(hidden_states)
        hidden_states = hidden_states.view([batch, channel,
                                            height * width]).transpose(1, 2)

        # proj to q, k, v
        qkv_proj = self.qkv(hidden_states)

        query_proj, key_proj, value_proj = split(qkv_proj, channel, dim=2)

        # transpose
        query_states = self.transpose_for_scores(query_proj)
        key_states = self.transpose_for_scores(key_proj)
        value_states = self.transpose_for_scores(value_proj)

        # get scores
        with precision('float32'):
            attention_scores = matmul(query_states,
                                      (key_states).transpose(-1, -2))
            attention_scores = attention_scores / math.sqrt(
                self.channels / self.num_heads)
            attention_probs = softmax(attention_scores, dim=-1)

        # compute attention output
        hidden_states = matmul(attention_probs, value_states)
        hidden_states = hidden_states.permute([0, 2, 1, 3])

        new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels, )
        hidden_states = hidden_states.view(new_hidden_states_shape)

        # compute next hidden_states
        hidden_states = self.proj_attn(hidden_states)
        hidden_states = hidden_states.transpose(-1, -2).view(
            [batch, channel, height, width])

        # res connect and rescale
        hidden_states = (hidden_states + residual) / self.rescale_output_factor
        return hidden_states


def _transpose_for_scores(tensor, heads):
    batch_size, seq_len, dim = tensor.size()
    tensor = tensor.view([batch_size, seq_len, heads, dim // heads])
    tensor = tensor.permute([0, 2, 1, 3])
    return tensor


def _attention(query, key, value, scale):
    attention_scores = matmul(query, key.transpose(-1, -2))
    attention_scores = attention_scores * scale
    attention_probs = softmax(attention_scores, dim=-1)
    hidden_states = matmul(attention_probs, value)
    hidden_states = hidden_states.permute([0, 2, 1, 3])
    return hidden_states


class SelfAttention(Module):

    def __init__(self, query_dim: int, heads: int = 8, dim_head: int = 64):
        super().__init__()
        self.inner_dim = dim_head * heads
        self.scale = dim_head**-0.5
        self.heads = heads
        self._slice_size = None

        self.to_qkv = Linear(query_dim, 3 * self.inner_dim, bias=False)
        self.to_out = Linear(self.inner_dim, query_dim)

    def forward(self, hidden_states, mask=None):
        assert not hidden_states.is_dynamic()

        qkv = self.to_qkv(hidden_states)

        query, key, value = split(qkv, self.inner_dim, dim=2)
        query = _transpose_for_scores(query, self.heads)
        key = _transpose_for_scores(key, self.heads)
        value = _transpose_for_scores(value, self.heads)
        hidden_states = _attention(query, key, value, self.scale)

        batch_size, seq_len, head_size, head_dim = hidden_states.size()
        hidden_states = hidden_states.view(
            [batch_size, seq_len, head_size * head_dim])
        return self.to_out(hidden_states)


class CrossAttention(Module):

    def __init__(self,
                 query_dim: int,
                 context_dim: Optional[int] = None,
                 heads: int = 8,
                 dim_head: int = 64):
        super().__init__()
        self.inner_dim = dim_head * heads
        context_dim = context_dim if context_dim is not None else query_dim
        self.scale = dim_head**-0.5
        self.heads = heads
        self._slice_size = None

        self.to_q = Linear(query_dim, self.inner_dim, bias=False)
        self.to_kv = Linear(context_dim, 2 * self.inner_dim, bias=False)
        self.to_out = Linear(self.inner_dim, query_dim)

    def forward(self, hidden_states, context=None, mask=None):
        assert not hidden_states.is_dynamic()
        query = self.to_q(hidden_states)
        is_cross_attn = context is not None
        context = context if is_cross_attn else hidden_states
        assert not context.is_dynamic()
        kv = self.to_kv(context)

        query = _transpose_for_scores(query, self.heads)
        key, value = split(kv, self.inner_dim, dim=2)
        key = _transpose_for_scores(key, self.heads)
        value = _transpose_for_scores(value, self.heads)
        hidden_states = _attention(query, key, value, self.scale)

        batch_size, seq_len, head_size, head_dim = hidden_states.size()
        hidden_states = hidden_states.view(
            [batch_size, seq_len, head_size * head_dim])
        return self.to_out(hidden_states)


class FeedForward(Module):

    def __init__(self, dim: int, dim_out: Optional[int] = None, mult: int = 4):
        super().__init__()
        inner_dim = int(dim * mult)
        dim_out = dim_out if dim_out is not None else dim
        self.proj_in = Linear(dim, inner_dim * 2)
        self.proj_out = Linear(inner_dim, dim_out)

    def forward(self, hidden_states):
        x = self.proj_in(hidden_states)
        x = geglu(x)
        return self.proj_out(x)


class BasicTransformerBlock(Module):

    def __init__(
        self,
        dim: int,
        n_heads: int,
        d_head: int,
        context_dim: Optional[int] = None,
    ):
        super().__init__()
        self.attn1 = SelfAttention(query_dim=dim,
                                   heads=n_heads,
                                   dim_head=d_head)  # is a self-attention
        self.ff = FeedForward(dim)
        self.attn2 = CrossAttention(
            query_dim=dim,
            context_dim=context_dim,
            heads=n_heads,
            dim_head=d_head)  # is self-attn if context is none
        self.norm1 = LayerNorm(dim)
        self.norm2 = LayerNorm(dim)
        self.norm3 = LayerNorm(dim)

    def forward(self, hidden_states, context=None):
        hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states
        hidden_states = self.attn2(self.norm2(hidden_states),
                                   context=context) + hidden_states
        hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
        return hidden_states


class Transformer2DModel(Module):

    def __init__(
        self,
        num_attention_heads: int = 16,
        attention_head_dim: int = 88,
        in_channels: Optional[int] = None,
        num_layers: int = 1,
        norm_num_groups: int = 32,
        cross_attention_dim: Optional[int] = None,
    ):
        super().__init__()
        self.num_attention_heads = num_attention_heads
        self.attention_head_dim = attention_head_dim
        inner_dim = num_attention_heads * attention_head_dim

        self.norm = GroupNorm(num_groups=norm_num_groups,
                              num_channels=in_channels,
                              eps=1e-6,
                              affine=True)

        self.proj_in = Conv2d(in_channels,
                              inner_dim,
                              kernel_size=(1, 1),
                              stride=(1, 1),
                              padding=(0, 0))

        self.transformer_blocks = ModuleList([
            BasicTransformerBlock(inner_dim,
                                  num_attention_heads,
                                  attention_head_dim,
                                  context_dim=cross_attention_dim)
            for d in range(num_layers)
        ])
        self.proj_out = Conv2d(inner_dim,
                               in_channels,
                               kernel_size=(1, 1),
                               stride=(1, 1),
                               padding=(0, 0))

    def forward(self, hidden_states, context=None):
        assert not hidden_states.is_dynamic()
        batch, _, height, weight = hidden_states.size()
        residual = hidden_states
        hidden_states = self.norm(hidden_states)
        hidden_states = self.proj_in(hidden_states)
        inner_dim = hidden_states.size()[1]
        hidden_states = hidden_states.permute([0, 2, 3, 1]).view(
            [batch, height * weight, inner_dim])
        for block in self.transformer_blocks:
            hidden_states = block(hidden_states, context=context)
        hidden_states = hidden_states.view([batch, height, weight,
                                            inner_dim]).permute([0, 3, 1, 2])
        hidden_states = self.proj_out(hidden_states)
        return hidden_states + residual