manbeast3b commited on
Commit
97b5a44
·
verified ·
1 Parent(s): e7abe80

Update src/pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline.py +11 -10
src/pipeline.py CHANGED
@@ -70,18 +70,19 @@ def decode_latents_to_image(latents, height: int, width: int, vae):
70
  vae_scale_factor = 1
71
  image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
72
 
73
- # Try to load the traced model; trace and save if not found
74
- if os.path.exists(traced_vae_decode_path):
75
- try:
76
- traced_vae_decode = torch.jit.load(traced_vae_decode_path)
77
- # print("Loaded traced VAE decoder from file.")
78
- except Exception as e:
79
- # print(f"Error loading traced VAE decoder: {e}. Retracing...")
80
- traced_vae_decode = trace_and_save_vae_decoder(vae, latents)
81
 
82
- else:
83
- traced_vae_decode = trace_and_save_vae_decoder(vae, latents)
84
 
 
85
  with torch.no_grad():
86
  latents = FluxPipeline._unpack_latents(latents.unsqueeze(0), height, width, vae_scale_factor)
87
  latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
 
70
  vae_scale_factor = 1
71
  image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
72
 
73
+ # # Try to load the traced model; trace and save if not found
74
+ # if os.path.exists(traced_vae_decode_path):
75
+ # try:
76
+ # traced_vae_decode = torch.jit.load(traced_vae_decode_path)
77
+ # # print("Loaded traced VAE decoder from file.")
78
+ # except Exception as e:
79
+ # # print(f"Error loading traced VAE decoder: {e}. Retracing...")
80
+ # traced_vae_decode = trace_and_save_vae_decoder(vae, latents)
81
 
82
+ # else:
83
+ # traced_vae_decode = trace_and_save_vae_decoder(vae, latents)
84
 
85
+ traced_vae_decode = vae.decode
86
  with torch.no_grad():
87
  latents = FluxPipeline._unpack_latents(latents.unsqueeze(0), height, width, vae_scale_factor)
88
  latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor