Aatricks's picture
Deploy ZeroGPU Gradio Space snapshot
b701455
"""Transformer components for diffusion models."""
from einops import rearrange
import torch
import torch.nn as nn
from src.Utilities import util
from src.Attention import Attention
from src.Device import Device
from src.cond import Activation, cast
from src.sample import sampling_util
ops = cast.disable_weight_init
class FeedForward(nn.Module):
"""FeedForward network."""
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0, dtype=None, device=None, operations=ops):
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out or dim
project_in = Activation.GEGLU(dim, inner_dim) if glu else nn.Sequential(
operations.Linear(dim, inner_dim, dtype=dtype, device=device), nn.GELU())
self.net = nn.Sequential(project_in, nn.Dropout(dropout),
operations.Linear(inner_dim, dim_out, dtype=dtype, device=device))
def forward(self, x):
return self.net(x)
class BasicTransformerBlock(nn.Module):
"""Basic Transformer block with self/cross attention."""
def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True,
checkpoint=True, ff_in=False, inner_dim=None, disable_self_attn=False,
disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False,
dtype=None, device=None, operations=ops):
super().__init__()
self.ff_in = ff_in or inner_dim is not None
inner_dim = inner_dim or dim
self.is_res = inner_dim == dim
self.disable_self_attn = disable_self_attn
self.checkpoint = checkpoint
self.n_heads, self.d_head = n_heads, d_head
self.attn1 = Attention.CrossAttention(
query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout,
context_dim=context_dim if disable_self_attn else None,
dtype=dtype, device=device, operations=operations)
self.attn2 = Attention.CrossAttention(
query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout,
context_dim=None if switch_temporal_ca_to_sa else context_dim,
dtype=dtype, device=device, operations=operations)
self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff,
dtype=dtype, device=device, operations=operations)
self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm3 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
def forward(self, x, context=None, transformer_options={}):
return sampling_util.checkpoint(self._forward, (x, context, transformer_options),
self.parameters(), self.checkpoint)
def _forward(self, x, context=None, transformer_options={}):
n = self.norm1(x)
n = self.attn1(n, context=None, value=None)
x = x + n
if self.attn2:
n = self.norm2(x)
n = self.attn2(n, context=context, value=None)
x = x + n
x_skip = x if self.is_res else None
x = self.ff(self.norm3(x))
return x + x_skip if x_skip is not None else x
class SpatialTransformer(nn.Module):
"""Spatial Transformer module."""
def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None,
disable_self_attn=False, use_linear=False, use_checkpoint=True,
dtype=None, device=None, operations=ops):
super().__init__()
inner_dim = n_heads * d_head
context_dim = [context_dim] * depth if context_dim and not isinstance(context_dim, list) else context_dim
self.norm = operations.GroupNorm(32, in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
if use_linear:
self.proj_in = operations.Linear(in_channels, inner_dim, dtype=dtype, device=device)
self.proj_out = operations.Linear(in_channels, inner_dim, dtype=dtype, device=device)
else:
self.proj_in = operations.Conv2d(in_channels, inner_dim, 1, dtype=dtype, device=device)
self.proj_out = operations.Conv2d(inner_dim, in_channels, 1, dtype=dtype, device=device)
self.transformer_blocks = nn.ModuleList([
BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout,
context_dim=context_dim[d] if context_dim else None,
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint,
dtype=dtype, device=device, operations=operations)
for d in range(depth)])
self.use_linear = use_linear
def forward(self, x, context=None, transformer_options={}):
context = [context] * len(self.transformer_blocks) if not isinstance(context, list) else context
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
if not self.use_linear:
x = self.proj_in(x)
x = rearrange(x, "b c h w -> b (h w) c").contiguous()
if self.use_linear:
x = self.proj_in(x)
for i, block in enumerate(self.transformer_blocks):
transformer_options["block_index"] = i
x = block(x, context=context[i], transformer_options=transformer_options)
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
if not self.use_linear:
x = self.proj_out(x)
return x + x_in
def count_blocks(state_dict_keys, prefix_string):
"""Count blocks matching prefix."""
count = 0
while any(k.startswith(prefix_string.format(count)) for k in state_dict_keys):
count += 1
return count
def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
"""Calculate transformer depth from state dict."""
transformer_prefix = prefix + "1.transformer_blocks."
transformer_keys = [k for k in state_dict_keys if k.startswith(transformer_prefix)]
if not transformer_keys:
return None
depth = count_blocks(state_dict_keys, transformer_prefix + "{}")
context_dim = state_dict[f"{transformer_prefix}0.attn2.to_k.weight"].shape[1]
use_linear = len(state_dict[f"{prefix}1.proj_in.weight"].shape) == 2
time_stack = (f"{prefix}1.time_stack.0.attn1.to_q.weight" in state_dict or
f"{prefix}1.time_mix_blocks.0.attn1.to_q.weight" in state_dict)
return depth, context_dim, use_linear, time_stack