Spaces:
Running on Zero
Running on Zero
File size: 6,723 Bytes
b701455 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 | """Transformer components for diffusion models."""
from einops import rearrange
import torch
import torch.nn as nn
from src.Utilities import util
from src.Attention import Attention
from src.Device import Device
from src.cond import Activation, cast
from src.sample import sampling_util
ops = cast.disable_weight_init
class FeedForward(nn.Module):
"""FeedForward network."""
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0, dtype=None, device=None, operations=ops):
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out or dim
project_in = Activation.GEGLU(dim, inner_dim) if glu else nn.Sequential(
operations.Linear(dim, inner_dim, dtype=dtype, device=device), nn.GELU())
self.net = nn.Sequential(project_in, nn.Dropout(dropout),
operations.Linear(inner_dim, dim_out, dtype=dtype, device=device))
def forward(self, x):
return self.net(x)
class BasicTransformerBlock(nn.Module):
"""Basic Transformer block with self/cross attention."""
def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True,
checkpoint=True, ff_in=False, inner_dim=None, disable_self_attn=False,
disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False,
dtype=None, device=None, operations=ops):
super().__init__()
self.ff_in = ff_in or inner_dim is not None
inner_dim = inner_dim or dim
self.is_res = inner_dim == dim
self.disable_self_attn = disable_self_attn
self.checkpoint = checkpoint
self.n_heads, self.d_head = n_heads, d_head
self.attn1 = Attention.CrossAttention(
query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout,
context_dim=context_dim if disable_self_attn else None,
dtype=dtype, device=device, operations=operations)
self.attn2 = Attention.CrossAttention(
query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout,
context_dim=None if switch_temporal_ca_to_sa else context_dim,
dtype=dtype, device=device, operations=operations)
self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff,
dtype=dtype, device=device, operations=operations)
self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm3 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
def forward(self, x, context=None, transformer_options={}):
return sampling_util.checkpoint(self._forward, (x, context, transformer_options),
self.parameters(), self.checkpoint)
def _forward(self, x, context=None, transformer_options={}):
n = self.norm1(x)
n = self.attn1(n, context=None, value=None)
x = x + n
if self.attn2:
n = self.norm2(x)
n = self.attn2(n, context=context, value=None)
x = x + n
x_skip = x if self.is_res else None
x = self.ff(self.norm3(x))
return x + x_skip if x_skip is not None else x
class SpatialTransformer(nn.Module):
"""Spatial Transformer module."""
def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None,
disable_self_attn=False, use_linear=False, use_checkpoint=True,
dtype=None, device=None, operations=ops):
super().__init__()
inner_dim = n_heads * d_head
context_dim = [context_dim] * depth if context_dim and not isinstance(context_dim, list) else context_dim
self.norm = operations.GroupNorm(32, in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
if use_linear:
self.proj_in = operations.Linear(in_channels, inner_dim, dtype=dtype, device=device)
self.proj_out = operations.Linear(in_channels, inner_dim, dtype=dtype, device=device)
else:
self.proj_in = operations.Conv2d(in_channels, inner_dim, 1, dtype=dtype, device=device)
self.proj_out = operations.Conv2d(inner_dim, in_channels, 1, dtype=dtype, device=device)
self.transformer_blocks = nn.ModuleList([
BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout,
context_dim=context_dim[d] if context_dim else None,
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint,
dtype=dtype, device=device, operations=operations)
for d in range(depth)])
self.use_linear = use_linear
def forward(self, x, context=None, transformer_options={}):
context = [context] * len(self.transformer_blocks) if not isinstance(context, list) else context
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
if not self.use_linear:
x = self.proj_in(x)
x = rearrange(x, "b c h w -> b (h w) c").contiguous()
if self.use_linear:
x = self.proj_in(x)
for i, block in enumerate(self.transformer_blocks):
transformer_options["block_index"] = i
x = block(x, context=context[i], transformer_options=transformer_options)
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
if not self.use_linear:
x = self.proj_out(x)
return x + x_in
def count_blocks(state_dict_keys, prefix_string):
"""Count blocks matching prefix."""
count = 0
while any(k.startswith(prefix_string.format(count)) for k in state_dict_keys):
count += 1
return count
def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
"""Calculate transformer depth from state dict."""
transformer_prefix = prefix + "1.transformer_blocks."
transformer_keys = [k for k in state_dict_keys if k.startswith(transformer_prefix)]
if not transformer_keys:
return None
depth = count_blocks(state_dict_keys, transformer_prefix + "{}")
context_dim = state_dict[f"{transformer_prefix}0.attn2.to_k.weight"].shape[1]
use_linear = len(state_dict[f"{prefix}1.proj_in.weight"].shape) == 2
time_stack = (f"{prefix}1.time_stack.0.attn1.to_q.weight" in state_dict or
f"{prefix}1.time_mix_blocks.0.attn1.to_q.weight" in state_dict)
return depth, context_dim, use_linear, time_stack
|