Yash Nagraj commited on
Commit
c2857b5
·
1 Parent(s): 95b2cf2

Add MidBlocks and change cross_attn in down blocks

Browse files
Files changed (1) hide show
  1. models/blocks.py +72 -5
models/blocks.py CHANGED
@@ -86,6 +86,11 @@ class DownBlock(nn.Module):
86
  out_channels, num_heads=num_heads, batch_first=True) for _ in range(num_layers)]
87
  )
88
 
 
 
 
 
 
89
  self.residual_input_conv = nn.ModuleList(
90
  [
91
  nn.Conv2d(in_channels=in_channels if i == 0 else out_channels,
@@ -121,15 +126,77 @@ class DownBlock(nn.Module):
121
 
122
  # Cross Attention
123
  if self.cross_attn:
124
- assert context not None, "Context must be given for cross_attn"
125
  batch_size, channels, h, w = out.shape
126
- in_attn = out.reshape(batch_size, channels, h*w)
127
- in_attn = self.attention_norms[i](in_attn)
128
  in_attn = in_attn.transpose(1, 2)
129
- out_attn, _ = self.attention[i](in_attn, in_attn, in_attn)
130
- out_attn = out.transpose(1, 2).reshape(
 
 
 
131
  batch_size, channels, h, w)
132
  out = out + out_attn
133
 
134
  out = self.resnet_down_conv(out)
135
  return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  out_channels, num_heads=num_heads, batch_first=True) for _ in range(num_layers)]
87
  )
88
 
89
+ self.context_proj = nn.ModuleList(
90
+ [nn.Linear(context_dim, out_channels)
91
+ for _ in range(num_layers)]
92
+ )
93
+
94
  self.residual_input_conv = nn.ModuleList(
95
  [
96
  nn.Conv2d(in_channels=in_channels if i == 0 else out_channels,
 
126
 
127
  # Cross Attention
128
  if self.cross_attn:
129
+ assert context is not None, "Context must be given for cross_attn"
130
  batch_size, channels, h, w = out.shape
131
+ in_attn = out.reshape(batch_size, channels, h * w)
132
+ in_attn = self.cross_attention_norms[i](in_attn)
133
  in_attn = in_attn.transpose(1, 2)
134
+ assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim
135
+ context_proj = self.context_proj[i](context)
136
+ out_attn, _ = self.cross_attentions[i](
137
+ in_attn, context_proj, context_proj)
138
+ out_attn = out_attn.transpose(1, 2).reshape(
139
  batch_size, channels, h, w)
140
  out = out + out_attn
141
 
142
  out = self.resnet_down_conv(out)
143
  return out
144
+
145
+
146
+ class MidBlock(nn.Module):
147
+ """
148
+ Mid Block that works with same dimensions, flows like this:
149
+ 1) Resnet block with time embedding
150
+ 2) Self Attention block
151
+ 3) Resnet block with time embedding
152
+ """
153
+
154
+ def __init__(self, in_channels, out_channels, t_emb_dim, num_heads, num_layers, norm_dim, cross_attn=None, context_dim=None):
155
+ self.in_channels = in_channels
156
+ self.out_channels = out_channels
157
+ self.t_emb_dim = t_emb_dim
158
+ self.cross_attn = cross_attn
159
+ self.context_dim = context_dim
160
+ self.resnet_conv_one = nn.ModuleList([
161
+ nn.Sequential(
162
+ nn.GroupNorm(norm_dim, in_channels if i ==
163
+ 0 else out_channels),
164
+ nn.SiLU(),
165
+ nn.Conv2d(in_channels if i == 0 else out_channels,
166
+ out_channels, 3, 1, 1)
167
+ )
168
+ for i in range(num_layers + 1)
169
+ ])
170
+
171
+ if self.t_emb_dim is not None:
172
+ self.time_emb_layers = nn.ModuleList([
173
+ nn.Sequential(
174
+ nn.SiLU(),
175
+ nn.Linear(t_emb_dim, out_channels)
176
+ )
177
+ for _ in range(num_layers + 1)
178
+ ])
179
+
180
+ self.resnet_conv_two = nn.ModuleList([
181
+ nn.Sequential(
182
+ nn.GroupNorm(norm_dim, out_channels),
183
+ nn.SiLU(),
184
+ nn.Conv2d(out_channels, out_channels, 3, 1, 1)
185
+ ) for _ in range(num_layers + 1)
186
+ ])
187
+
188
+ self.attention_norms = nn.ModuleList(
189
+ [nn.GroupNorm(norm_dim, out_channels) for _ in range(num_layers)]
190
+ )
191
+
192
+ self.attention_heads = nn.ModuleList(
193
+ [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
194
+ for _ in range(num_layers)]
195
+ )
196
+
197
+ if self.cross_attn:
198
+ assert context_dim is not None, "Context must be given for cross attn"
199
+ self.cross_attn_norms = nn.ModuleList(
200
+ [nn.GroupNorm(norm_dim, out_channels)
201
+ for _ in range(num_layers)]
202
+ )