Update src/pipeline.py
Browse files- src/pipeline.py +33 -33
src/pipeline.py
CHANGED
|
@@ -726,16 +726,16 @@ class FluxPipeline(
|
|
| 726 |
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
|
| 727 |
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 728 |
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
| 729 |
-
print("=============== printing all the shapes right now ======================")
|
| 730 |
-
print(latents.shape)
|
| 731 |
-
print(timestep)
|
| 732 |
-
print(guidance)
|
| 733 |
-
print(pooled_prompt_embeds.shape)
|
| 734 |
-
print(prompt_embeds.shape)
|
| 735 |
-
print(text_ids.shape)
|
| 736 |
-
print(latent_image_ids.shape)
|
| 737 |
-
print("=================== thats all folks for now ============================")
|
| 738 |
-
exit()
|
| 739 |
noise_pred = self.transformer(
|
| 740 |
hidden_states=latents,
|
| 741 |
timestep=timestep / 1000,
|
|
@@ -858,45 +858,45 @@ def load_pipeline() -> Pipeline:
|
|
| 858 |
model_name = "manbeast3b/Flux.1.Schnell-full-quant1"
|
| 859 |
revision = "e7ddf488a4ea8a3cba05db5b8d06e7e0feb826a2"
|
| 860 |
|
| 861 |
-
hub_model_dir = os.path.join(
|
| 862 |
-
|
| 863 |
-
|
| 864 |
-
|
| 865 |
-
|
| 866 |
-
|
| 867 |
-
)
|
| 868 |
-
transformer = FluxTransformer2DModel.from_pretrained(
|
| 869 |
-
|
| 870 |
-
|
| 871 |
-
|
| 872 |
-
).to(memory_format=torch.channels_last)
|
| 873 |
|
| 874 |
pipeline = FluxPipeline.from_pretrained(
|
| 875 |
ckpt_id,
|
| 876 |
revision=ckpt_revision,
|
| 877 |
# text_encoder_2=text_encoder_2,
|
| 878 |
-
transformer=transformer,
|
| 879 |
# vae=vae,
|
| 880 |
torch_dtype=torch.bfloat16
|
| 881 |
)
|
| 882 |
# pipeline.vae = torch.compile(vae)
|
| 883 |
pipeline.to("cuda")
|
| 884 |
|
| 885 |
-
|
| 886 |
-
|
| 887 |
-
|
| 888 |
|
| 889 |
-
# # transformer = torch._inductor.aoti_load_package(path)
|
| 890 |
# transformer = torch._inductor.aoti_load_package(path)
|
| 891 |
-
|
|
|
|
| 892 |
|
| 893 |
-
|
| 894 |
-
|
| 895 |
|
| 896 |
-
|
| 897 |
-
|
| 898 |
|
| 899 |
-
|
| 900 |
|
| 901 |
warmup_ = "controllable varied focus thai warriors entertainment blue golden pink soft tough padthai"
|
| 902 |
for _ in range(1):
|
|
|
|
| 726 |
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
|
| 727 |
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 728 |
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
| 729 |
+
# print("=============== printing all the shapes right now ======================")
|
| 730 |
+
# print(latents.shape)
|
| 731 |
+
# print(timestep)
|
| 732 |
+
# print(guidance)
|
| 733 |
+
# print(pooled_prompt_embeds.shape)
|
| 734 |
+
# print(prompt_embeds.shape)
|
| 735 |
+
# print(text_ids.shape)
|
| 736 |
+
# print(latent_image_ids.shape)
|
| 737 |
+
# print("=================== thats all folks for now ============================")
|
| 738 |
+
# exit()
|
| 739 |
noise_pred = self.transformer(
|
| 740 |
hidden_states=latents,
|
| 741 |
timestep=timestep / 1000,
|
|
|
|
| 858 |
model_name = "manbeast3b/Flux.1.Schnell-full-quant1"
|
| 859 |
revision = "e7ddf488a4ea8a3cba05db5b8d06e7e0feb826a2"
|
| 860 |
|
| 861 |
+
# hub_model_dir = os.path.join(
|
| 862 |
+
# HF_HUB_CACHE,
|
| 863 |
+
# f"models--{model_name.replace('/', '--')}",
|
| 864 |
+
# "snapshots",
|
| 865 |
+
# revision,
|
| 866 |
+
# "transformer"
|
| 867 |
+
# )
|
| 868 |
+
# transformer = FluxTransformer2DModel.from_pretrained(
|
| 869 |
+
# hub_model_dir,
|
| 870 |
+
# torch_dtype=torch.bfloat16,
|
| 871 |
+
# use_safetensors=False
|
| 872 |
+
# ).to(memory_format=torch.channels_last)
|
| 873 |
|
| 874 |
pipeline = FluxPipeline.from_pretrained(
|
| 875 |
ckpt_id,
|
| 876 |
revision=ckpt_revision,
|
| 877 |
# text_encoder_2=text_encoder_2,
|
| 878 |
+
transformer=None, #transformer,
|
| 879 |
# vae=vae,
|
| 880 |
torch_dtype=torch.bfloat16
|
| 881 |
)
|
| 882 |
# pipeline.vae = torch.compile(vae)
|
| 883 |
pipeline.to("cuda")
|
| 884 |
|
| 885 |
+
path = os.path.join(HF_HUB_CACHE, "models--manbeast3b--Flux.1.la_schnella_transformer/snapshots/8e704256516ef1cbd6730fedd019f5b1a71d38d3/flux_la_schnell_ez.so.pt2")
|
| 886 |
+
inputs1 = get_example_inputs()
|
| 887 |
+
print(f"AoT pre compiled path is {path}")
|
| 888 |
|
|
|
|
| 889 |
# transformer = torch._inductor.aoti_load_package(path)
|
| 890 |
+
transformer = torch._inductor.aoti_load_package(path)
|
| 891 |
+
print(f"{transformer(**inputs1)[0].shape=}")
|
| 892 |
|
| 893 |
+
for _ in range(3):
|
| 894 |
+
_ = transformer(**inputs1)[0]
|
| 895 |
|
| 896 |
+
time = benchmark_fn(f, transformer, **inputs1)
|
| 897 |
+
print(f"{time=} seconds.")
|
| 898 |
|
| 899 |
+
pipeline.transformer = transformer
|
| 900 |
|
| 901 |
warmup_ = "controllable varied focus thai warriors entertainment blue golden pink soft tough padthai"
|
| 902 |
for _ in range(1):
|