manbeast3b commited on
Commit
25cf65a
·
verified ·
1 Parent(s): 6ae76cc

Update src/pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline.py +2 -4
src/pipeline.py CHANGED
@@ -208,9 +208,7 @@ def empty_cache():
208
  print(f"Flush took: {time.time() - start}")
209
 
210
  def load_pipeline() -> Pipeline:
211
- buffer = torch.empty((1024, 1024), device="cuda")
212
  empty_cache()
213
- buffer = torch.empty((1024, 1024), device="cuda")
214
  dtype, device = torch.bfloat16, "cuda"
215
 
216
  text_encoder_2 = T5EncoderModel.from_pretrained(
@@ -228,11 +226,11 @@ def load_pipeline() -> Pipeline:
228
  torch.cuda.set_per_process_memory_fraction(0.99)
229
  pipeline.text_encoder.to(memory_format=torch.channels_last)
230
  pipeline.transformer.to(memory_format=torch.channels_last)
231
-
232
-
233
  pipeline.vae.to(memory_format=torch.channels_last)
 
234
  pipeline.vae = torch.compile(pipeline.vae)
235
 
 
236
  pipeline._exclude_from_cpu_offload = ["vae"]
237
  pipeline.enable_sequential_cpu_offload()
238
  for _ in range(2):
 
208
  print(f"Flush took: {time.time() - start}")
209
 
210
  def load_pipeline() -> Pipeline:
 
211
  empty_cache()
 
212
  dtype, device = torch.bfloat16, "cuda"
213
 
214
  text_encoder_2 = T5EncoderModel.from_pretrained(
 
226
  torch.cuda.set_per_process_memory_fraction(0.99)
227
  pipeline.text_encoder.to(memory_format=torch.channels_last)
228
  pipeline.transformer.to(memory_format=torch.channels_last)
 
 
229
  pipeline.vae.to(memory_format=torch.channels_last)
230
+ pipeline.vae.enable_tiling()
231
  pipeline.vae = torch.compile(pipeline.vae)
232
 
233
+
234
  pipeline._exclude_from_cpu_offload = ["vae"]
235
  pipeline.enable_sequential_cpu_offload()
236
  for _ in range(2):