File size: 8,071 Bytes
c28dddb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
import torch
from torch import nn
from typing import Optional
from diffusers.models.embeddings import Timesteps, TimestepEmbedding, LabelEmbedding

class FinalLayer(nn.Module):
    """
    Final layer of the diffusion model that outputs the final logits.
    """
    def __init__(self, in_ch, out_ch=None, dropout=0.0):
        super().__init__()
        out_ch = in_ch if out_ch is None else out_ch
        self.linear = nn.Linear(in_ch, out_ch)
        self.norm = AdaLayerNormTC(in_ch, 2 * in_ch, dropout)

    def forward(self, x, t, cond=None):
        assert cond is not None
        x = self.norm(x, t, cond)
        x = self.linear(x)
        return x


class AdaLayerNormTC(nn.Module):
    """
    Norm layer modified to incorporate timestep and condition embeddings.
    """

    def __init__(self, embedding_dim, num_embeddings, dropout):
        super().__init__()
        self.emb = CombinedTimestepLabelEmbeddings(
            num_embeddings, embedding_dim, dropout
        )
        self.silu = nn.SiLU()
        self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
        self.norm = nn.LayerNorm(
            embedding_dim, elementwise_affine=False, eps=torch.finfo(torch.float16).eps
        )

    def forward(self, x, timestep, cond):
        emb = self.linear(self.silu(self.emb(timestep, cond, hidden_dtype=None)))
        scale, shift = torch.chunk(emb, 2, dim=1)
        x = self.norm(x) * (1 + scale[:, None]) + shift[:, None]
        return x


class PEmbeder(nn.Module):
    """
    Positional embedding layer.
    """
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self._init_embeddings()

    def _init_embeddings(self):
        nn.init.kaiming_normal_(self.embed.weight, mode="fan_in")

    def forward(self, x, idx=None):
        if idx is None:
            idx = torch.arange(x.shape[1], device=x.device, dtype=torch.long)
        return x + self.embed(idx)

class CombinedTimestepLabelEmbeddings(nn.Module):
    '''Modified from diffusers.models.embeddings.CombinedTimestepLabelEmbeddings'''
    def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
        super().__init__()

        self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
        self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
        self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob)

    def forward(self, timestep, class_labels, hidden_dtype=None, label_free=False):
        timesteps_proj = self.time_proj(timestep)
        timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype))  # (N, D)
        force_drop_ids = None # training mode
        if label_free: # inference mode, force_drop_ids is set to all ones to be dropped in class_embedder
            force_drop_ids = torch.ones_like(class_labels, dtype=torch.bool, device=class_labels.device)
        class_labels = self.class_embedder(class_labels, force_drop_ids)  # (N, D)
        conditioning = timesteps_emb + class_labels  # (N, D)
        return conditioning


class MyAdaLayerNormZero(nn.Module):
    """
    Adaptive layer norm zero (adaLN-Zero), borrowed from diffusers.models.attention.AdaLayerNormZero.
    Extended to incorporate scale parameters (gate_2, gate_3) for intermidate attention layers.
    """

    def __init__(self, embedding_dim, num_embeddings, class_dropout_prob):
        super().__init__()

        self.emb = CombinedTimestepLabelEmbeddings(
            num_embeddings, embedding_dim, class_dropout_prob
        )
        self.silu = nn.SiLU()
        self.linear = nn.Linear(embedding_dim, 8 * embedding_dim, bias=True)
        self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)

    def forward(self, x, timestep, class_labels, hidden_dtype=None, label_free=False):
        emb_t_cls = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype, label_free=label_free)
        emb = self.linear(self.silu(emb_t_cls))
        (
            shift_msa,
            scale_msa,
            gate_msa,
            shift_mlp,
            scale_mlp,
            gate_mlp,
            gate_2,
            gate_3,
        ) = emb.chunk(8, dim=1)
        x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
        return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_2, gate_3


class VisAttnProcessor:
    r"""
    This code is adapted from diffusers.models.attention_processor.AttnProcessor.
    Used for visualizing the attention maps when testing, NOT for training.
    """

    def __call__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        temb: Optional[torch.FloatTensor] = None,
        *args,
        **kwargs,
    ) -> torch.Tensor:
        # Removed
        # if len(args) > 0 or kwargs.get("scale", None) is not None:
        #     deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
        #     deprecate("scale", "1.0.0", deprecation_message)

        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        query = attn.head_to_batch_dim(query) # (40, 160, 16)
        key = attn.head_to_batch_dim(key) # (40, 256, 16)
        value = attn.head_to_batch_dim(value)  # (40, 256, 16)
        
        if attention_mask is not None:
            if attention_mask.dtype == torch.bool:
                attn_mask = torch.zeros_like(attention_mask, dtype=query.dtype, device=query.device)
                attn_mask = attn_mask.masked_fill_(attention_mask.logical_not(), float("-inf"))
            else:
                attn_mask = attention_mask
                assert attn_mask.dtype == query.dtype, f"query and attention_mask must have the same dtype, but got {query.dtype} and {attention_mask.dtype}."
        else:
            attn_mask = None
        attention_probs = attn.get_attention_scores(query, key, attn_mask) # (40, 160, 256)
        hidden_states = torch.bmm(attention_probs, value) # (40, 160, 16)
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        attention_probs = attention_probs.reshape(batch_size, attn.heads, query.shape[1], sequence_length)

        return hidden_states, attention_probs