05185634
Browse files- context_unet_backup.py +9 -1
context_unet_backup.py
CHANGED
|
@@ -127,6 +127,10 @@ class TimestepBlock(ABC, nn.Module):
|
|
| 127 |
"""
|
| 128 |
|
| 129 |
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
def forward(self, x, emb, encoder_out=None):
|
| 131 |
for layer in self:
|
| 132 |
if isinstance(layer, TimestepBlock):
|
|
@@ -134,7 +138,11 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
|
| 134 |
elif isinstance(layer, AttentionBlock):
|
| 135 |
x = layer(x, encoder_out)
|
| 136 |
else:
|
| 137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
return x
|
| 139 |
|
| 140 |
class ResBlock(TimestepBlock):
|
|
|
|
| 127 |
"""
|
| 128 |
|
| 129 |
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
| 130 |
+
def __init__(self, *arg, use_checkpoint=False):
|
| 131 |
+
super().__init__(*args)
|
| 132 |
+
self.use_checkpoint = use_checkpoint
|
| 133 |
+
|
| 134 |
def forward(self, x, emb, encoder_out=None):
|
| 135 |
for layer in self:
|
| 136 |
if isinstance(layer, TimestepBlock):
|
|
|
|
| 138 |
elif isinstance(layer, AttentionBlock):
|
| 139 |
x = layer(x, encoder_out)
|
| 140 |
else:
|
| 141 |
+
if self.use_checkpoint:
|
| 142 |
+
print(f"TimestepEmbedSequential checkpoint working")
|
| 143 |
+
x = checkpoint.checkpoint(layer, x)
|
| 144 |
+
else:
|
| 145 |
+
x = layer(x)
|
| 146 |
return x
|
| 147 |
|
| 148 |
class ResBlock(TimestepBlock):
|