Update src/pipeline.py
Browse files- src/pipeline.py +1 -0
src/pipeline.py
CHANGED
|
@@ -45,6 +45,7 @@ def load_pipeline() -> Pipeline:
|
|
| 45 |
torch.cuda.set_per_process_memory_fraction(0.95)
|
| 46 |
pipeline.text_encoder.to(memory_format=torch.channels_last)
|
| 47 |
pipeline.transformer.to(memory_format=torch.channels_last)
|
|
|
|
| 48 |
|
| 49 |
|
| 50 |
pipeline.vae.to(memory_format=torch.channels_last)
|
|
|
|
| 45 |
torch.cuda.set_per_process_memory_fraction(0.95)
|
| 46 |
pipeline.text_encoder.to(memory_format=torch.channels_last)
|
| 47 |
pipeline.transformer.to(memory_format=torch.channels_last)
|
| 48 |
+
torch.jit.enable_onednn_fusion(True)
|
| 49 |
|
| 50 |
|
| 51 |
pipeline.vae.to(memory_format=torch.channels_last)
|