Yash Nagraj commited on
Commit
95b2cf2
·
1 Parent(s): aee1300

Add cross attn if needed for conditional latent diffusion

Browse files
Files changed (1) hide show
  1. models/blocks.py +24 -7
models/blocks.py CHANGED
@@ -109,10 +109,27 @@ class DownBlock(nn.Module):
109
  out = out + self.residual_input_conv[i](resnet_input)
110
 
111
  # Self Attention
112
- batch_size, channels, h, w = out.shape
113
- in_attn = out.reshape(batch_size, channels, h*w)
114
- in_attn = self.attention_norms[i](in_attn)
115
- in_attn = in_attn.transpose(1, 2)
116
- out_attn, _ = self.attention[i](in_attn, in_attn, in_attn)
117
- out_attn = out.transpose(1, 2).reshape(batch_size, channels, h, w)
118
- out = out + out_attn
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  out = out + self.residual_input_conv[i](resnet_input)
110
 
111
  # Self Attention
112
+ if self.attn:
113
+ batch_size, channels, h, w = out.shape
114
+ in_attn = out.reshape(batch_size, channels, h*w)
115
+ in_attn = self.attention_norms[i](in_attn)
116
+ in_attn = in_attn.transpose(1, 2)
117
+ out_attn, _ = self.attention[i](in_attn, in_attn, in_attn)
118
+ out_attn = out.transpose(1, 2).reshape(
119
+ batch_size, channels, h, w)
120
+ out = out + out_attn
121
+
122
+ # Cross Attention
123
+ if self.cross_attn:
124
+ assert context not None, "Context must be given for cross_attn"
125
+ batch_size, channels, h, w = out.shape
126
+ in_attn = out.reshape(batch_size, channels, h*w)
127
+ in_attn = self.attention_norms[i](in_attn)
128
+ in_attn = in_attn.transpose(1, 2)
129
+ out_attn, _ = self.attention[i](in_attn, in_attn, in_attn)
130
+ out_attn = out.transpose(1, 2).reshape(
131
+ batch_size, channels, h, w)
132
+ out = out + out_attn
133
+
134
+ out = self.resnet_down_conv(out)
135
+ return out