Update src/pipeline.py
Browse files- src/pipeline.py +2 -4
src/pipeline.py
CHANGED
|
@@ -208,9 +208,7 @@ def empty_cache():
|
|
| 208 |
print(f"Flush took: {time.time() - start}")
|
| 209 |
|
| 210 |
def load_pipeline() -> Pipeline:
|
| 211 |
-
buffer = torch.empty((1024, 1024), device="cuda")
|
| 212 |
empty_cache()
|
| 213 |
-
buffer = torch.empty((1024, 1024), device="cuda")
|
| 214 |
dtype, device = torch.bfloat16, "cuda"
|
| 215 |
|
| 216 |
text_encoder_2 = T5EncoderModel.from_pretrained(
|
|
@@ -228,11 +226,11 @@ def load_pipeline() -> Pipeline:
|
|
| 228 |
torch.cuda.set_per_process_memory_fraction(0.99)
|
| 229 |
pipeline.text_encoder.to(memory_format=torch.channels_last)
|
| 230 |
pipeline.transformer.to(memory_format=torch.channels_last)
|
| 231 |
-
|
| 232 |
-
|
| 233 |
pipeline.vae.to(memory_format=torch.channels_last)
|
|
|
|
| 234 |
pipeline.vae = torch.compile(pipeline.vae)
|
| 235 |
|
|
|
|
| 236 |
pipeline._exclude_from_cpu_offload = ["vae"]
|
| 237 |
pipeline.enable_sequential_cpu_offload()
|
| 238 |
for _ in range(2):
|
|
|
|
| 208 |
print(f"Flush took: {time.time() - start}")
|
| 209 |
|
| 210 |
def load_pipeline() -> Pipeline:
|
|
|
|
| 211 |
empty_cache()
|
|
|
|
| 212 |
dtype, device = torch.bfloat16, "cuda"
|
| 213 |
|
| 214 |
text_encoder_2 = T5EncoderModel.from_pretrained(
|
|
|
|
| 226 |
torch.cuda.set_per_process_memory_fraction(0.99)
|
| 227 |
pipeline.text_encoder.to(memory_format=torch.channels_last)
|
| 228 |
pipeline.transformer.to(memory_format=torch.channels_last)
|
|
|
|
|
|
|
| 229 |
pipeline.vae.to(memory_format=torch.channels_last)
|
| 230 |
+
pipeline.vae.enable_tiling()
|
| 231 |
pipeline.vae = torch.compile(pipeline.vae)
|
| 232 |
|
| 233 |
+
|
| 234 |
pipeline._exclude_from_cpu_offload = ["vae"]
|
| 235 |
pipeline.enable_sequential_cpu_offload()
|
| 236 |
for _ in range(2):
|