manbeast3b commited on
Commit
1d65bbb
·
verified ·
1 Parent(s): 096ddbe

Update src/pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline.py +33 -33
src/pipeline.py CHANGED
@@ -726,16 +726,16 @@ class FluxPipeline(
726
  self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
727
  # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
728
  timestep = t.expand(latents.shape[0]).to(latents.dtype)
729
- print("=============== printing all the shapes right now ======================")
730
- print(latents.shape)
731
- print(timestep)
732
- print(guidance)
733
- print(pooled_prompt_embeds.shape)
734
- print(prompt_embeds.shape)
735
- print(text_ids.shape)
736
- print(latent_image_ids.shape)
737
- print("=================== thats all folks for now ============================")
738
- exit()
739
  noise_pred = self.transformer(
740
  hidden_states=latents,
741
  timestep=timestep / 1000,
@@ -858,45 +858,45 @@ def load_pipeline() -> Pipeline:
858
  model_name = "manbeast3b/Flux.1.Schnell-full-quant1"
859
  revision = "e7ddf488a4ea8a3cba05db5b8d06e7e0feb826a2"
860
 
861
- hub_model_dir = os.path.join(
862
- HF_HUB_CACHE,
863
- f"models--{model_name.replace('/', '--')}",
864
- "snapshots",
865
- revision,
866
- "transformer"
867
- )
868
- transformer = FluxTransformer2DModel.from_pretrained(
869
- hub_model_dir,
870
- torch_dtype=torch.bfloat16,
871
- use_safetensors=False
872
- ).to(memory_format=torch.channels_last)
873
 
874
  pipeline = FluxPipeline.from_pretrained(
875
  ckpt_id,
876
  revision=ckpt_revision,
877
  # text_encoder_2=text_encoder_2,
878
- transformer=transformer,
879
  # vae=vae,
880
  torch_dtype=torch.bfloat16
881
  )
882
  # pipeline.vae = torch.compile(vae)
883
  pipeline.to("cuda")
884
 
885
- # path = os.path.join(HF_HUB_CACHE, "models--manbeast3b--Flux.1.la_schnella_transformer/snapshots/8bfd89bd9e2099e70a1155403bf8aabb0a3177df/flux_la_schnell_aten.so.pt2")
886
- # inputs1 = get_example_inputs()
887
- # print(f"AoT pre compiled path is {path}")
888
 
889
- # # transformer = torch._inductor.aoti_load_package(path)
890
  # transformer = torch._inductor.aoti_load_package(path)
891
- # print(f"{transformer(**inputs1)[0].shape=}")
 
892
 
893
- # for _ in range(3):
894
- # _ = transformer(**inputs1)[0]
895
 
896
- # time = benchmark_fn(f, transformer, **inputs1)
897
- # print(f"{time=} seconds.")
898
 
899
- # pipeline.transformer = transformer
900
 
901
  warmup_ = "controllable varied focus thai warriors entertainment blue golden pink soft tough padthai"
902
  for _ in range(1):
 
726
  self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
727
  # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
728
  timestep = t.expand(latents.shape[0]).to(latents.dtype)
729
+ # print("=============== printing all the shapes right now ======================")
730
+ # print(latents.shape)
731
+ # print(timestep)
732
+ # print(guidance)
733
+ # print(pooled_prompt_embeds.shape)
734
+ # print(prompt_embeds.shape)
735
+ # print(text_ids.shape)
736
+ # print(latent_image_ids.shape)
737
+ # print("=================== thats all folks for now ============================")
738
+ # exit()
739
  noise_pred = self.transformer(
740
  hidden_states=latents,
741
  timestep=timestep / 1000,
 
858
  model_name = "manbeast3b/Flux.1.Schnell-full-quant1"
859
  revision = "e7ddf488a4ea8a3cba05db5b8d06e7e0feb826a2"
860
 
861
+ # hub_model_dir = os.path.join(
862
+ # HF_HUB_CACHE,
863
+ # f"models--{model_name.replace('/', '--')}",
864
+ # "snapshots",
865
+ # revision,
866
+ # "transformer"
867
+ # )
868
+ # transformer = FluxTransformer2DModel.from_pretrained(
869
+ # hub_model_dir,
870
+ # torch_dtype=torch.bfloat16,
871
+ # use_safetensors=False
872
+ # ).to(memory_format=torch.channels_last)
873
 
874
  pipeline = FluxPipeline.from_pretrained(
875
  ckpt_id,
876
  revision=ckpt_revision,
877
  # text_encoder_2=text_encoder_2,
878
+ transformer=None, #transformer,
879
  # vae=vae,
880
  torch_dtype=torch.bfloat16
881
  )
882
  # pipeline.vae = torch.compile(vae)
883
  pipeline.to("cuda")
884
 
885
+ path = os.path.join(HF_HUB_CACHE, "models--manbeast3b--Flux.1.la_schnella_transformer/snapshots/8e704256516ef1cbd6730fedd019f5b1a71d38d3/flux_la_schnell_ez.so.pt2")
886
+ inputs1 = get_example_inputs()
887
+ print(f"AoT pre compiled path is {path}")
888
 
 
889
  # transformer = torch._inductor.aoti_load_package(path)
890
+ transformer = torch._inductor.aoti_load_package(path)
891
+ print(f"{transformer(**inputs1)[0].shape=}")
892
 
893
+ for _ in range(3):
894
+ _ = transformer(**inputs1)[0]
895
 
896
+ time = benchmark_fn(f, transformer, **inputs1)
897
+ print(f"{time=} seconds.")
898
 
899
+ pipeline.transformer = transformer
900
 
901
  warmup_ = "controllable varied focus thai warriors entertainment blue golden pink soft tough padthai"
902
  for _ in range(1):