File size: 6,723 Bytes
b701455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
"""Transformer components for diffusion models."""
from einops import rearrange
import torch
import torch.nn as nn
from src.Utilities import util
from src.Attention import Attention
from src.Device import Device
from src.cond import Activation, cast
from src.sample import sampling_util

ops = cast.disable_weight_init


class FeedForward(nn.Module):
    """FeedForward network."""
    def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0, dtype=None, device=None, operations=ops):
        super().__init__()
        inner_dim = int(dim * mult)
        dim_out = dim_out or dim
        project_in = Activation.GEGLU(dim, inner_dim) if glu else nn.Sequential(
            operations.Linear(dim, inner_dim, dtype=dtype, device=device), nn.GELU())
        self.net = nn.Sequential(project_in, nn.Dropout(dropout),
                                 operations.Linear(inner_dim, dim_out, dtype=dtype, device=device))

    def forward(self, x):
        return self.net(x)


class BasicTransformerBlock(nn.Module):
    """Basic Transformer block with self/cross attention."""
    def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True,
                 checkpoint=True, ff_in=False, inner_dim=None, disable_self_attn=False,
                 disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False,
                 dtype=None, device=None, operations=ops):
        super().__init__()
        self.ff_in = ff_in or inner_dim is not None
        inner_dim = inner_dim or dim
        self.is_res = inner_dim == dim
        self.disable_self_attn = disable_self_attn
        self.checkpoint = checkpoint
        self.n_heads, self.d_head = n_heads, d_head

        self.attn1 = Attention.CrossAttention(
            query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout,
            context_dim=context_dim if disable_self_attn else None,
            dtype=dtype, device=device, operations=operations)
        
        self.attn2 = Attention.CrossAttention(
            query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout,
            context_dim=None if switch_temporal_ca_to_sa else context_dim,
            dtype=dtype, device=device, operations=operations)
        
        self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff,
                              dtype=dtype, device=device, operations=operations)
        self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
        self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
        self.norm3 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)

    def forward(self, x, context=None, transformer_options={}):
        return sampling_util.checkpoint(self._forward, (x, context, transformer_options),
                                        self.parameters(), self.checkpoint)

    def _forward(self, x, context=None, transformer_options={}):
        n = self.norm1(x)
        n = self.attn1(n, context=None, value=None)
        x = x + n
        
        if self.attn2:
            n = self.norm2(x)
            n = self.attn2(n, context=context, value=None)
            x = x + n
        
        x_skip = x if self.is_res else None
        x = self.ff(self.norm3(x))
        return x + x_skip if x_skip is not None else x


class SpatialTransformer(nn.Module):
    """Spatial Transformer module."""
    def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None,
                 disable_self_attn=False, use_linear=False, use_checkpoint=True,
                 dtype=None, device=None, operations=ops):
        super().__init__()
        inner_dim = n_heads * d_head
        context_dim = [context_dim] * depth if context_dim and not isinstance(context_dim, list) else context_dim
        
        self.norm = operations.GroupNorm(32, in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
        
        if use_linear:
            self.proj_in = operations.Linear(in_channels, inner_dim, dtype=dtype, device=device)
            self.proj_out = operations.Linear(in_channels, inner_dim, dtype=dtype, device=device)
        else:
            self.proj_in = operations.Conv2d(in_channels, inner_dim, 1, dtype=dtype, device=device)
            self.proj_out = operations.Conv2d(inner_dim, in_channels, 1, dtype=dtype, device=device)
        
        self.transformer_blocks = nn.ModuleList([
            BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout,
                                  context_dim=context_dim[d] if context_dim else None,
                                  disable_self_attn=disable_self_attn, checkpoint=use_checkpoint,
                                  dtype=dtype, device=device, operations=operations)
            for d in range(depth)])
        self.use_linear = use_linear

    def forward(self, x, context=None, transformer_options={}):
        context = [context] * len(self.transformer_blocks) if not isinstance(context, list) else context
        b, c, h, w = x.shape
        x_in = x
        x = self.norm(x)
        if not self.use_linear:
            x = self.proj_in(x)
        x = rearrange(x, "b c h w -> b (h w) c").contiguous()
        if self.use_linear:
            x = self.proj_in(x)
        
        for i, block in enumerate(self.transformer_blocks):
            transformer_options["block_index"] = i
            x = block(x, context=context[i], transformer_options=transformer_options)
        
        if self.use_linear:
            x = self.proj_out(x)
        x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
        if not self.use_linear:
            x = self.proj_out(x)
        return x + x_in


def count_blocks(state_dict_keys, prefix_string):
    """Count blocks matching prefix."""
    count = 0
    while any(k.startswith(prefix_string.format(count)) for k in state_dict_keys):
        count += 1
    return count


def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
    """Calculate transformer depth from state dict."""
    transformer_prefix = prefix + "1.transformer_blocks."
    transformer_keys = [k for k in state_dict_keys if k.startswith(transformer_prefix)]
    if not transformer_keys:
        return None
    
    depth = count_blocks(state_dict_keys, transformer_prefix + "{}")
    context_dim = state_dict[f"{transformer_prefix}0.attn2.to_k.weight"].shape[1]
    use_linear = len(state_dict[f"{prefix}1.proj_in.weight"].shape) == 2
    time_stack = (f"{prefix}1.time_stack.0.attn1.to_q.weight" in state_dict or
                  f"{prefix}1.time_mix_blocks.0.attn1.to_q.weight" in state_dict)
    return depth, context_dim, use_linear, time_stack