import torch from diffusers import StableDiffusionXLImg2ImgPipeline, StableDiffusionXLPipeline, StableDiffusionXLInpaintPipeline, \ AutoencoderKL from utils import tiling def create_stable_diffusion_xl_pipeline(device, enable_tiling=True): vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) pipe = StableDiffusionXLPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", vae=vae, torch_dtype=torch.float16, variant="fp16", use_safetensors=True ) if enable_tiling: pipe.vae.disable_tiling() tiling.enable_circular_tiling([pipe.vae, pipe.text_encoder, pipe.unet]) pipe = pipe.to(device) return pipe def create_stable_diffusion_xl_img2img_pipe(device): vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-refiner-1.0", vae=vae, torch_dtype=torch.float16, variant="fp16", use_safetensors=True ) pipe.vae.disable_tiling() tiling.enable_circular_tiling([pipe.vae, pipe.text_encoder_2, pipe.unet]) pipe = pipe.to(device) return pipe def create_stable_diffusion_xl_inpainting_pipe(): pipe = StableDiffusionXLInpaintPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True ) pipe = pipe.to("mps") pipe.vae.disable_tiling() tiling.enable_circular_tiling([pipe.vae, pipe.text_encoder, pipe.unet]) return pipe