multimodalart commited on
Commit
07fcbd1
·
1 Parent(s): 3b5c738

Fix VAE decode dtype mismatch (keep all VAE in fp32 after upcast)

Browse files
Files changed (1) hide show
  1. SDXL/diff_pipe.py +0 -16
SDXL/diff_pipe.py CHANGED
@@ -648,23 +648,7 @@ class StableDiffusionXLDiffImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixi
648
 
649
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
650
  def upcast_vae(self):
651
- dtype = self.vae.dtype
652
  self.vae.to(dtype=torch.float32)
653
- use_torch_2_0_or_xformers = isinstance(
654
- self.vae.decoder.mid_block.attentions[0].processor,
655
- (
656
- AttnProcessor2_0,
657
- XFormersAttnProcessor,
658
- LoRAXFormersAttnProcessor,
659
- LoRAAttnProcessor2_0,
660
- ),
661
- )
662
- # if xformers or torch_2_0 is used attention block does not need
663
- # to be in float32 which can save lots of memory
664
- if use_torch_2_0_or_xformers:
665
- self.vae.post_quant_conv.to(dtype)
666
- self.vae.decoder.conv_in.to(dtype)
667
- self.vae.decoder.mid_block.to(dtype)
668
 
669
  @torch.no_grad()
670
  @replace_example_docstring(EXAMPLE_DOC_STRING)
 
648
 
649
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
650
  def upcast_vae(self):
 
651
  self.vae.to(dtype=torch.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
652
 
653
  @torch.no_grad()
654
  @replace_example_docstring(EXAMPLE_DOC_STRING)