Update src/pipeline.py
Browse files- src/pipeline.py +11 -10
src/pipeline.py
CHANGED
|
@@ -70,18 +70,19 @@ def decode_latents_to_image(latents, height: int, width: int, vae):
|
|
| 70 |
vae_scale_factor = 1
|
| 71 |
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
|
| 72 |
|
| 73 |
-
|
| 74 |
-
if os.path.exists(traced_vae_decode_path):
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
|
| 82 |
-
else:
|
| 83 |
-
|
| 84 |
|
|
|
|
| 85 |
with torch.no_grad():
|
| 86 |
latents = FluxPipeline._unpack_latents(latents.unsqueeze(0), height, width, vae_scale_factor)
|
| 87 |
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
|
|
|
|
| 70 |
vae_scale_factor = 1
|
| 71 |
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
|
| 72 |
|
| 73 |
+
# # Try to load the traced model; trace and save if not found
|
| 74 |
+
# if os.path.exists(traced_vae_decode_path):
|
| 75 |
+
# try:
|
| 76 |
+
# traced_vae_decode = torch.jit.load(traced_vae_decode_path)
|
| 77 |
+
# # print("Loaded traced VAE decoder from file.")
|
| 78 |
+
# except Exception as e:
|
| 79 |
+
# # print(f"Error loading traced VAE decoder: {e}. Retracing...")
|
| 80 |
+
# traced_vae_decode = trace_and_save_vae_decoder(vae, latents)
|
| 81 |
|
| 82 |
+
# else:
|
| 83 |
+
# traced_vae_decode = trace_and_save_vae_decoder(vae, latents)
|
| 84 |
|
| 85 |
+
traced_vae_decode = vae.decode
|
| 86 |
with torch.no_grad():
|
| 87 |
latents = FluxPipeline._unpack_latents(latents.unsqueeze(0), height, width, vae_scale_factor)
|
| 88 |
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
|