Update src/pipeline.py
Browse files- src/pipeline.py +12 -2
src/pipeline.py
CHANGED
|
@@ -86,8 +86,18 @@ def load_pipeline() -> Pipeline:
|
|
| 86 |
pipeline.vae.to(memory_format=torch.channels_last)
|
| 87 |
pipeline.vae = torch.compile(pipeline.vae)
|
| 88 |
|
| 89 |
-
pipeline._exclude_from_cpu_offload = ["vae"
|
| 90 |
-
pipeline.enable_sequential_cpu_offload()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
for _ in range(2):
|
| 92 |
pipeline(prompt="onomancy, aftergo, spirantic, Platyhelmia, modificator, drupaceous, jobbernowl, hereness", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
|
| 93 |
|
|
|
|
| 86 |
pipeline.vae.to(memory_format=torch.channels_last)
|
| 87 |
pipeline.vae = torch.compile(pipeline.vae)
|
| 88 |
|
| 89 |
+
pipeline._exclude_from_cpu_offload = ["vae"]
|
| 90 |
+
# pipeline.enable_sequential_cpu_offload()
|
| 91 |
+
def custom_cpu_offload(model, device, offload_buffers=True):
|
| 92 |
+
state_dict = model.state_dict()
|
| 93 |
+
filtered_state_dict = {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)}
|
| 94 |
+
for name, param in filtered_state_dict.items():
|
| 95 |
+
param.data = param.to(device)
|
| 96 |
+
|
| 97 |
+
custom_cpu_offload(pipeline.text_encoder, "cpu")
|
| 98 |
+
custom_cpu_offload(pipeline.text_encoder_2, "cpu")
|
| 99 |
+
custom_cpu_offload(pipeline.transformer, "cpu")
|
| 100 |
+
|
| 101 |
for _ in range(2):
|
| 102 |
pipeline(prompt="onomancy, aftergo, spirantic, Platyhelmia, modificator, drupaceous, jobbernowl, hereness", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
|
| 103 |
|