Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,071 Bytes
c28dddb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
import torch
from torch import nn
from typing import Optional
from diffusers.models.embeddings import Timesteps, TimestepEmbedding, LabelEmbedding
class FinalLayer(nn.Module):
"""
Final layer of the diffusion model that outputs the final logits.
"""
def __init__(self, in_ch, out_ch=None, dropout=0.0):
super().__init__()
out_ch = in_ch if out_ch is None else out_ch
self.linear = nn.Linear(in_ch, out_ch)
self.norm = AdaLayerNormTC(in_ch, 2 * in_ch, dropout)
def forward(self, x, t, cond=None):
assert cond is not None
x = self.norm(x, t, cond)
x = self.linear(x)
return x
class AdaLayerNormTC(nn.Module):
"""
Norm layer modified to incorporate timestep and condition embeddings.
"""
def __init__(self, embedding_dim, num_embeddings, dropout):
super().__init__()
self.emb = CombinedTimestepLabelEmbeddings(
num_embeddings, embedding_dim, dropout
)
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
self.norm = nn.LayerNorm(
embedding_dim, elementwise_affine=False, eps=torch.finfo(torch.float16).eps
)
def forward(self, x, timestep, cond):
emb = self.linear(self.silu(self.emb(timestep, cond, hidden_dtype=None)))
scale, shift = torch.chunk(emb, 2, dim=1)
x = self.norm(x) * (1 + scale[:, None]) + shift[:, None]
return x
class PEmbeder(nn.Module):
"""
Positional embedding layer.
"""
def __init__(self, vocab_size, d_model):
super().__init__()
self.embed = nn.Embedding(vocab_size, d_model)
self._init_embeddings()
def _init_embeddings(self):
nn.init.kaiming_normal_(self.embed.weight, mode="fan_in")
def forward(self, x, idx=None):
if idx is None:
idx = torch.arange(x.shape[1], device=x.device, dtype=torch.long)
return x + self.embed(idx)
class CombinedTimestepLabelEmbeddings(nn.Module):
'''Modified from diffusers.models.embeddings.CombinedTimestepLabelEmbeddings'''
def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob)
def forward(self, timestep, class_labels, hidden_dtype=None, label_free=False):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
force_drop_ids = None # training mode
if label_free: # inference mode, force_drop_ids is set to all ones to be dropped in class_embedder
force_drop_ids = torch.ones_like(class_labels, dtype=torch.bool, device=class_labels.device)
class_labels = self.class_embedder(class_labels, force_drop_ids) # (N, D)
conditioning = timesteps_emb + class_labels # (N, D)
return conditioning
class MyAdaLayerNormZero(nn.Module):
"""
Adaptive layer norm zero (adaLN-Zero), borrowed from diffusers.models.attention.AdaLayerNormZero.
Extended to incorporate scale parameters (gate_2, gate_3) for intermidate attention layers.
"""
def __init__(self, embedding_dim, num_embeddings, class_dropout_prob):
super().__init__()
self.emb = CombinedTimestepLabelEmbeddings(
num_embeddings, embedding_dim, class_dropout_prob
)
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, 8 * embedding_dim, bias=True)
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, timestep, class_labels, hidden_dtype=None, label_free=False):
emb_t_cls = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype, label_free=label_free)
emb = self.linear(self.silu(emb_t_cls))
(
shift_msa,
scale_msa,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
gate_2,
gate_3,
) = emb.chunk(8, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_2, gate_3
class VisAttnProcessor:
r"""
This code is adapted from diffusers.models.attention_processor.AttnProcessor.
Used for visualizing the attention maps when testing, NOT for training.
"""
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
) -> torch.Tensor:
# Removed
# if len(args) > 0 or kwargs.get("scale", None) is not None:
# deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
# deprecate("scale", "1.0.0", deprecation_message)
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query) # (40, 160, 16)
key = attn.head_to_batch_dim(key) # (40, 256, 16)
value = attn.head_to_batch_dim(value) # (40, 256, 16)
if attention_mask is not None:
if attention_mask.dtype == torch.bool:
attn_mask = torch.zeros_like(attention_mask, dtype=query.dtype, device=query.device)
attn_mask = attn_mask.masked_fill_(attention_mask.logical_not(), float("-inf"))
else:
attn_mask = attention_mask
assert attn_mask.dtype == query.dtype, f"query and attention_mask must have the same dtype, but got {query.dtype} and {attention_mask.dtype}."
else:
attn_mask = None
attention_probs = attn.get_attention_scores(query, key, attn_mask) # (40, 160, 256)
hidden_states = torch.bmm(attention_probs, value) # (40, 160, 16)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
attention_probs = attention_probs.reshape(batch_size, attn.heads, query.shape[1], sequence_length)
return hidden_states, attention_probs
|