"""Transformer components for diffusion models.""" from einops import rearrange import torch import torch.nn as nn from src.Utilities import util from src.Attention import Attention from src.Device import Device from src.cond import Activation, cast from src.sample import sampling_util ops = cast.disable_weight_init class FeedForward(nn.Module): """FeedForward network.""" def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0, dtype=None, device=None, operations=ops): super().__init__() inner_dim = int(dim * mult) dim_out = dim_out or dim project_in = Activation.GEGLU(dim, inner_dim) if glu else nn.Sequential( operations.Linear(dim, inner_dim, dtype=dtype, device=device), nn.GELU()) self.net = nn.Sequential(project_in, nn.Dropout(dropout), operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)) def forward(self, x): return self.net(x) class BasicTransformerBlock(nn.Module): """Basic Transformer block with self/cross attention.""" def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True, ff_in=False, inner_dim=None, disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, dtype=None, device=None, operations=ops): super().__init__() self.ff_in = ff_in or inner_dim is not None inner_dim = inner_dim or dim self.is_res = inner_dim == dim self.disable_self_attn = disable_self_attn self.checkpoint = checkpoint self.n_heads, self.d_head = n_heads, d_head self.attn1 = Attention.CrossAttention( query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout, context_dim=context_dim if disable_self_attn else None, dtype=dtype, device=device, operations=operations) self.attn2 = Attention.CrossAttention( query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout, context_dim=None if switch_temporal_ca_to_sa else context_dim, dtype=dtype, device=device, operations=operations) self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations) self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device) self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device) self.norm3 = operations.LayerNorm(inner_dim, dtype=dtype, device=device) def forward(self, x, context=None, transformer_options={}): return sampling_util.checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint) def _forward(self, x, context=None, transformer_options={}): n = self.norm1(x) n = self.attn1(n, context=None, value=None) x = x + n if self.attn2: n = self.norm2(x) n = self.attn2(n, context=context, value=None) x = x + n x_skip = x if self.is_res else None x = self.ff(self.norm3(x)) return x + x_skip if x_skip is not None else x class SpatialTransformer(nn.Module): """Spatial Transformer module.""" def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None, disable_self_attn=False, use_linear=False, use_checkpoint=True, dtype=None, device=None, operations=ops): super().__init__() inner_dim = n_heads * d_head context_dim = [context_dim] * depth if context_dim and not isinstance(context_dim, list) else context_dim self.norm = operations.GroupNorm(32, in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) if use_linear: self.proj_in = operations.Linear(in_channels, inner_dim, dtype=dtype, device=device) self.proj_out = operations.Linear(in_channels, inner_dim, dtype=dtype, device=device) else: self.proj_in = operations.Conv2d(in_channels, inner_dim, 1, dtype=dtype, device=device) self.proj_out = operations.Conv2d(inner_dim, in_channels, 1, dtype=dtype, device=device) self.transformer_blocks = nn.ModuleList([ BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d] if context_dim else None, disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, dtype=dtype, device=device, operations=operations) for d in range(depth)]) self.use_linear = use_linear def forward(self, x, context=None, transformer_options={}): context = [context] * len(self.transformer_blocks) if not isinstance(context, list) else context b, c, h, w = x.shape x_in = x x = self.norm(x) if not self.use_linear: x = self.proj_in(x) x = rearrange(x, "b c h w -> b (h w) c").contiguous() if self.use_linear: x = self.proj_in(x) for i, block in enumerate(self.transformer_blocks): transformer_options["block_index"] = i x = block(x, context=context[i], transformer_options=transformer_options) if self.use_linear: x = self.proj_out(x) x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous() if not self.use_linear: x = self.proj_out(x) return x + x_in def count_blocks(state_dict_keys, prefix_string): """Count blocks matching prefix.""" count = 0 while any(k.startswith(prefix_string.format(count)) for k in state_dict_keys): count += 1 return count def calculate_transformer_depth(prefix, state_dict_keys, state_dict): """Calculate transformer depth from state dict.""" transformer_prefix = prefix + "1.transformer_blocks." transformer_keys = [k for k in state_dict_keys if k.startswith(transformer_prefix)] if not transformer_keys: return None depth = count_blocks(state_dict_keys, transformer_prefix + "{}") context_dim = state_dict[f"{transformer_prefix}0.attn2.to_k.weight"].shape[1] use_linear = len(state_dict[f"{prefix}1.proj_in.weight"].shape) == 2 time_stack = (f"{prefix}1.time_stack.0.attn1.to_q.weight" in state_dict or f"{prefix}1.time_mix_blocks.0.attn1.to_q.weight" in state_dict) return depth, context_dim, use_linear, time_stack