Yash Nagraj commited on
Commit ·
c2857b5
1
Parent(s): 95b2cf2
Add MidBlocks and change cross_attn in down blocks
Browse files- models/blocks.py +72 -5
models/blocks.py
CHANGED
|
@@ -86,6 +86,11 @@ class DownBlock(nn.Module):
|
|
| 86 |
out_channels, num_heads=num_heads, batch_first=True) for _ in range(num_layers)]
|
| 87 |
)
|
| 88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
self.residual_input_conv = nn.ModuleList(
|
| 90 |
[
|
| 91 |
nn.Conv2d(in_channels=in_channels if i == 0 else out_channels,
|
|
@@ -121,15 +126,77 @@ class DownBlock(nn.Module):
|
|
| 121 |
|
| 122 |
# Cross Attention
|
| 123 |
if self.cross_attn:
|
| 124 |
-
assert context not None, "Context must be given for cross_attn"
|
| 125 |
batch_size, channels, h, w = out.shape
|
| 126 |
-
in_attn = out.reshape(batch_size, channels, h*w)
|
| 127 |
-
in_attn = self.
|
| 128 |
in_attn = in_attn.transpose(1, 2)
|
| 129 |
-
|
| 130 |
-
|
|
|
|
|
|
|
|
|
|
| 131 |
batch_size, channels, h, w)
|
| 132 |
out = out + out_attn
|
| 133 |
|
| 134 |
out = self.resnet_down_conv(out)
|
| 135 |
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
out_channels, num_heads=num_heads, batch_first=True) for _ in range(num_layers)]
|
| 87 |
)
|
| 88 |
|
| 89 |
+
self.context_proj = nn.ModuleList(
|
| 90 |
+
[nn.Linear(context_dim, out_channels)
|
| 91 |
+
for _ in range(num_layers)]
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
self.residual_input_conv = nn.ModuleList(
|
| 95 |
[
|
| 96 |
nn.Conv2d(in_channels=in_channels if i == 0 else out_channels,
|
|
|
|
| 126 |
|
| 127 |
# Cross Attention
|
| 128 |
if self.cross_attn:
|
| 129 |
+
assert context is not None, "Context must be given for cross_attn"
|
| 130 |
batch_size, channels, h, w = out.shape
|
| 131 |
+
in_attn = out.reshape(batch_size, channels, h * w)
|
| 132 |
+
in_attn = self.cross_attention_norms[i](in_attn)
|
| 133 |
in_attn = in_attn.transpose(1, 2)
|
| 134 |
+
assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim
|
| 135 |
+
context_proj = self.context_proj[i](context)
|
| 136 |
+
out_attn, _ = self.cross_attentions[i](
|
| 137 |
+
in_attn, context_proj, context_proj)
|
| 138 |
+
out_attn = out_attn.transpose(1, 2).reshape(
|
| 139 |
batch_size, channels, h, w)
|
| 140 |
out = out + out_attn
|
| 141 |
|
| 142 |
out = self.resnet_down_conv(out)
|
| 143 |
return out
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class MidBlock(nn.Module):
|
| 147 |
+
"""
|
| 148 |
+
Mid Block that works with same dimensions, flows like this:
|
| 149 |
+
1) Resnet block with time embedding
|
| 150 |
+
2) Self Attention block
|
| 151 |
+
3) Resnet block with time embedding
|
| 152 |
+
"""
|
| 153 |
+
|
| 154 |
+
def __init__(self, in_channels, out_channels, t_emb_dim, num_heads, num_layers, norm_dim, cross_attn=None, context_dim=None):
|
| 155 |
+
self.in_channels = in_channels
|
| 156 |
+
self.out_channels = out_channels
|
| 157 |
+
self.t_emb_dim = t_emb_dim
|
| 158 |
+
self.cross_attn = cross_attn
|
| 159 |
+
self.context_dim = context_dim
|
| 160 |
+
self.resnet_conv_one = nn.ModuleList([
|
| 161 |
+
nn.Sequential(
|
| 162 |
+
nn.GroupNorm(norm_dim, in_channels if i ==
|
| 163 |
+
0 else out_channels),
|
| 164 |
+
nn.SiLU(),
|
| 165 |
+
nn.Conv2d(in_channels if i == 0 else out_channels,
|
| 166 |
+
out_channels, 3, 1, 1)
|
| 167 |
+
)
|
| 168 |
+
for i in range(num_layers + 1)
|
| 169 |
+
])
|
| 170 |
+
|
| 171 |
+
if self.t_emb_dim is not None:
|
| 172 |
+
self.time_emb_layers = nn.ModuleList([
|
| 173 |
+
nn.Sequential(
|
| 174 |
+
nn.SiLU(),
|
| 175 |
+
nn.Linear(t_emb_dim, out_channels)
|
| 176 |
+
)
|
| 177 |
+
for _ in range(num_layers + 1)
|
| 178 |
+
])
|
| 179 |
+
|
| 180 |
+
self.resnet_conv_two = nn.ModuleList([
|
| 181 |
+
nn.Sequential(
|
| 182 |
+
nn.GroupNorm(norm_dim, out_channels),
|
| 183 |
+
nn.SiLU(),
|
| 184 |
+
nn.Conv2d(out_channels, out_channels, 3, 1, 1)
|
| 185 |
+
) for _ in range(num_layers + 1)
|
| 186 |
+
])
|
| 187 |
+
|
| 188 |
+
self.attention_norms = nn.ModuleList(
|
| 189 |
+
[nn.GroupNorm(norm_dim, out_channels) for _ in range(num_layers)]
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
self.attention_heads = nn.ModuleList(
|
| 193 |
+
[nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
|
| 194 |
+
for _ in range(num_layers)]
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
if self.cross_attn:
|
| 198 |
+
assert context_dim is not None, "Context must be given for cross attn"
|
| 199 |
+
self.cross_attn_norms = nn.ModuleList(
|
| 200 |
+
[nn.GroupNorm(norm_dim, out_channels)
|
| 201 |
+
for _ in range(num_layers)]
|
| 202 |
+
)
|