Spaces:
Running
Running
Sean Powell commited on
Commit ·
a5222e6
1
Parent(s): 22facc4
Select between cuda or mps based on environment.
Browse files- utils/pipes.py +6 -2
utils/pipes.py
CHANGED
|
@@ -5,20 +5,24 @@ from utils import tiling
|
|
| 5 |
|
| 6 |
|
| 7 |
def create_stable_diffusion_xl_pipeline():
|
|
|
|
|
|
|
| 8 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
| 9 |
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
| 10 |
)
|
| 11 |
-
pipe = pipe.to(
|
| 12 |
pipe.vae.disable_tiling()
|
| 13 |
tiling.enable_circular_tiling([pipe.vae, pipe.text_encoder, pipe.unet])
|
| 14 |
return pipe
|
| 15 |
|
| 16 |
|
| 17 |
def create_stable_diffusion_xl_img2img_pipe():
|
|
|
|
|
|
|
| 18 |
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
|
| 19 |
"stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
| 20 |
)
|
| 21 |
-
pipe = pipe.to(
|
| 22 |
pipe.vae.disable_tiling()
|
| 23 |
tiling.enable_circular_tiling([pipe.vae, pipe.text_encoder_2, pipe.unet])
|
| 24 |
return pipe
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
def create_stable_diffusion_xl_pipeline():
|
| 8 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "mps")
|
| 9 |
+
|
| 10 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
| 11 |
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
| 12 |
)
|
| 13 |
+
pipe = pipe.to(device)
|
| 14 |
pipe.vae.disable_tiling()
|
| 15 |
tiling.enable_circular_tiling([pipe.vae, pipe.text_encoder, pipe.unet])
|
| 16 |
return pipe
|
| 17 |
|
| 18 |
|
| 19 |
def create_stable_diffusion_xl_img2img_pipe():
|
| 20 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "mps")
|
| 21 |
+
|
| 22 |
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
|
| 23 |
"stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
| 24 |
)
|
| 25 |
+
pipe = pipe.to(device)
|
| 26 |
pipe.vae.disable_tiling()
|
| 27 |
tiling.enable_circular_tiling([pipe.vae, pipe.text_encoder_2, pipe.unet])
|
| 28 |
return pipe
|