Spaces:
Running on Zero
Running on Zero
multimodalart commited on
Commit ·
07fcbd1
1
Parent(s): 3b5c738
Fix VAE decode dtype mismatch (keep all VAE in fp32 after upcast)
Browse files- SDXL/diff_pipe.py +0 -16
SDXL/diff_pipe.py
CHANGED
|
@@ -648,23 +648,7 @@ class StableDiffusionXLDiffImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixi
|
|
| 648 |
|
| 649 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
|
| 650 |
def upcast_vae(self):
|
| 651 |
-
dtype = self.vae.dtype
|
| 652 |
self.vae.to(dtype=torch.float32)
|
| 653 |
-
use_torch_2_0_or_xformers = isinstance(
|
| 654 |
-
self.vae.decoder.mid_block.attentions[0].processor,
|
| 655 |
-
(
|
| 656 |
-
AttnProcessor2_0,
|
| 657 |
-
XFormersAttnProcessor,
|
| 658 |
-
LoRAXFormersAttnProcessor,
|
| 659 |
-
LoRAAttnProcessor2_0,
|
| 660 |
-
),
|
| 661 |
-
)
|
| 662 |
-
# if xformers or torch_2_0 is used attention block does not need
|
| 663 |
-
# to be in float32 which can save lots of memory
|
| 664 |
-
if use_torch_2_0_or_xformers:
|
| 665 |
-
self.vae.post_quant_conv.to(dtype)
|
| 666 |
-
self.vae.decoder.conv_in.to(dtype)
|
| 667 |
-
self.vae.decoder.mid_block.to(dtype)
|
| 668 |
|
| 669 |
@torch.no_grad()
|
| 670 |
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
|
|
|
| 648 |
|
| 649 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
|
| 650 |
def upcast_vae(self):
|
|
|
|
| 651 |
self.vae.to(dtype=torch.float32)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 652 |
|
| 653 |
@torch.no_grad()
|
| 654 |
@replace_example_docstring(EXAMPLE_DOC_STRING)
|