manbeast3b commited on
Commit
ffb28b2
·
verified ·
1 Parent(s): 58b1bdf

Update src/pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline.py +1 -0
src/pipeline.py CHANGED
@@ -45,6 +45,7 @@ def load_pipeline() -> Pipeline:
45
  torch.cuda.set_per_process_memory_fraction(0.95)
46
  pipeline.text_encoder.to(memory_format=torch.channels_last)
47
  pipeline.transformer.to(memory_format=torch.channels_last)
 
48
 
49
 
50
  pipeline.vae.to(memory_format=torch.channels_last)
 
45
  torch.cuda.set_per_process_memory_fraction(0.95)
46
  pipeline.text_encoder.to(memory_format=torch.channels_last)
47
  pipeline.transformer.to(memory_format=torch.channels_last)
48
+ torch.jit.enable_onednn_fusion(True)
49
 
50
 
51
  pipeline.vae.to(memory_format=torch.channels_last)