manbeast3b commited on
Commit
5230bad
·
verified ·
1 Parent(s): 6d3ee5b

Update src/pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline.py +7 -6
src/pipeline.py CHANGED
@@ -131,7 +131,8 @@ def load_pipeline() -> Pipeline:
131
  text_encoder_2 = T5EncoderModel.from_pretrained(
132
  "city96/t5-v1_1-xxl-encoder-bf16", torch_dtype=torch.bfloat16
133
  )
134
- vae=AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae", torch_dtype=dtype)
 
135
  # transformer = FluxTransformer2DModel.from_pretrined("manbeast3b/transfomer-flux-schnell-int8") # torch_dtype=dtype
136
  pipeline = DiffusionPipeline.from_pretrained(
137
  ckpt_id,
@@ -152,11 +153,11 @@ def load_pipeline() -> Pipeline:
152
  # pipeline.transformer.save_pretrained("/root/.cache/huggingface/hub/transformer-flux")
153
  # exit()
154
 
155
- pipeline.vae.to(memory_format=torch.channels_last)
156
- pipeline.vae = torch.compile(pipeline.vae)
157
- torch.save(pipeline.vae, '/root/.cache/huggingface/hub/compiled_vae.pth')
158
- exit()
159
- pipeline.vae = torch.load('/root/.cache/huggingface/hub/compiled_vae.pth')
160
 
161
  pipeline._exclude_from_cpu_offload = ["vae"]
162
  pipeline.enable_sequential_cpu_offload()
 
131
  text_encoder_2 = T5EncoderModel.from_pretrained(
132
  "city96/t5-v1_1-xxl-encoder-bf16", torch_dtype=torch.bfloat16
133
  )
134
+ # vae=AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae", torch_dtype=dtype)
135
+ vae = torch.load('/root/.cache/huggingface/hub/compiled_vae.pth')
136
  # transformer = FluxTransformer2DModel.from_pretrined("manbeast3b/transfomer-flux-schnell-int8") # torch_dtype=dtype
137
  pipeline = DiffusionPipeline.from_pretrained(
138
  ckpt_id,
 
153
  # pipeline.transformer.save_pretrained("/root/.cache/huggingface/hub/transformer-flux")
154
  # exit()
155
 
156
+ # pipeline.vae.to(memory_format=torch.channels_last)
157
+ # pipeline.vae = torch.compile(pipeline.vae)
158
+ # torch.save(pipeline.vae, '/root/.cache/huggingface/hub/compiled_vae.pth')
159
+ # exit()
160
+
161
 
162
  pipeline._exclude_from_cpu_offload = ["vae"]
163
  pipeline.enable_sequential_cpu_offload()