Xsmos commited on
Commit
20f8cb6
·
verified ·
1 Parent(s): 3df929e
Files changed (1) hide show
  1. context_unet_backup.py +9 -1
context_unet_backup.py CHANGED
@@ -127,6 +127,10 @@ class TimestepBlock(ABC, nn.Module):
127
  """
128
 
129
  class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
 
 
 
 
130
  def forward(self, x, emb, encoder_out=None):
131
  for layer in self:
132
  if isinstance(layer, TimestepBlock):
@@ -134,7 +138,11 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
134
  elif isinstance(layer, AttentionBlock):
135
  x = layer(x, encoder_out)
136
  else:
137
- x = layer(x)
 
 
 
 
138
  return x
139
 
140
  class ResBlock(TimestepBlock):
 
127
  """
128
 
129
  class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
130
+ def __init__(self, *arg, use_checkpoint=False):
131
+ super().__init__(*args)
132
+ self.use_checkpoint = use_checkpoint
133
+
134
  def forward(self, x, emb, encoder_out=None):
135
  for layer in self:
136
  if isinstance(layer, TimestepBlock):
 
138
  elif isinstance(layer, AttentionBlock):
139
  x = layer(x, encoder_out)
140
  else:
141
+ if self.use_checkpoint:
142
+ print(f"TimestepEmbedSequential checkpoint working")
143
+ x = checkpoint.checkpoint(layer, x)
144
+ else:
145
+ x = layer(x)
146
  return x
147
 
148
  class ResBlock(TimestepBlock):