Yash Nagraj commited on
Commit ·
95b2cf2
1
Parent(s): aee1300
Add cross attn if needed for conditional latent diffusion
Browse files- models/blocks.py +24 -7
models/blocks.py
CHANGED
|
@@ -109,10 +109,27 @@ class DownBlock(nn.Module):
|
|
| 109 |
out = out + self.residual_input_conv[i](resnet_input)
|
| 110 |
|
| 111 |
# Self Attention
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
out = out + self.residual_input_conv[i](resnet_input)
|
| 110 |
|
| 111 |
# Self Attention
|
| 112 |
+
if self.attn:
|
| 113 |
+
batch_size, channels, h, w = out.shape
|
| 114 |
+
in_attn = out.reshape(batch_size, channels, h*w)
|
| 115 |
+
in_attn = self.attention_norms[i](in_attn)
|
| 116 |
+
in_attn = in_attn.transpose(1, 2)
|
| 117 |
+
out_attn, _ = self.attention[i](in_attn, in_attn, in_attn)
|
| 118 |
+
out_attn = out.transpose(1, 2).reshape(
|
| 119 |
+
batch_size, channels, h, w)
|
| 120 |
+
out = out + out_attn
|
| 121 |
+
|
| 122 |
+
# Cross Attention
|
| 123 |
+
if self.cross_attn:
|
| 124 |
+
assert context not None, "Context must be given for cross_attn"
|
| 125 |
+
batch_size, channels, h, w = out.shape
|
| 126 |
+
in_attn = out.reshape(batch_size, channels, h*w)
|
| 127 |
+
in_attn = self.attention_norms[i](in_attn)
|
| 128 |
+
in_attn = in_attn.transpose(1, 2)
|
| 129 |
+
out_attn, _ = self.attention[i](in_attn, in_attn, in_attn)
|
| 130 |
+
out_attn = out.transpose(1, 2).reshape(
|
| 131 |
+
batch_size, channels, h, w)
|
| 132 |
+
out = out + out_attn
|
| 133 |
+
|
| 134 |
+
out = self.resnet_down_conv(out)
|
| 135 |
+
return out
|