manbeast3b commited on
Commit
71265b3
·
verified ·
1 Parent(s): 6a0ebda

Update src/pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline.py +12 -2
src/pipeline.py CHANGED
@@ -86,8 +86,18 @@ def load_pipeline() -> Pipeline:
86
  pipeline.vae.to(memory_format=torch.channels_last)
87
  pipeline.vae = torch.compile(pipeline.vae)
88
 
89
- pipeline._exclude_from_cpu_offload = ["vae", "transformer"]
90
- pipeline.enable_sequential_cpu_offload()
 
 
 
 
 
 
 
 
 
 
91
  for _ in range(2):
92
  pipeline(prompt="onomancy, aftergo, spirantic, Platyhelmia, modificator, drupaceous, jobbernowl, hereness", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
93
 
 
86
  pipeline.vae.to(memory_format=torch.channels_last)
87
  pipeline.vae = torch.compile(pipeline.vae)
88
 
89
+ pipeline._exclude_from_cpu_offload = ["vae"]
90
+ # pipeline.enable_sequential_cpu_offload()
91
+ def custom_cpu_offload(model, device, offload_buffers=True):
92
+ state_dict = model.state_dict()
93
+ filtered_state_dict = {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)}
94
+ for name, param in filtered_state_dict.items():
95
+ param.data = param.to(device)
96
+
97
+ custom_cpu_offload(pipeline.text_encoder, "cpu")
98
+ custom_cpu_offload(pipeline.text_encoder_2, "cpu")
99
+ custom_cpu_offload(pipeline.transformer, "cpu")
100
+
101
  for _ in range(2):
102
  pipeline(prompt="onomancy, aftergo, spirantic, Platyhelmia, modificator, drupaceous, jobbernowl, hereness", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
103