Sean Powell commited on
Commit
a5222e6
·
1 Parent(s): 22facc4

Select between cuda or mps based on environment.

Browse files
Files changed (1) hide show
  1. utils/pipes.py +6 -2
utils/pipes.py CHANGED
@@ -5,20 +5,24 @@ from utils import tiling
5
 
6
 
7
  def create_stable_diffusion_xl_pipeline():
 
 
8
  pipe = StableDiffusionXLPipeline.from_pretrained(
9
  "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
10
  )
11
- pipe = pipe.to("mps")
12
  pipe.vae.disable_tiling()
13
  tiling.enable_circular_tiling([pipe.vae, pipe.text_encoder, pipe.unet])
14
  return pipe
15
 
16
 
17
  def create_stable_diffusion_xl_img2img_pipe():
 
 
18
  pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
19
  "stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
20
  )
21
- pipe = pipe.to("mps")
22
  pipe.vae.disable_tiling()
23
  tiling.enable_circular_tiling([pipe.vae, pipe.text_encoder_2, pipe.unet])
24
  return pipe
 
5
 
6
 
7
  def create_stable_diffusion_xl_pipeline():
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "mps")
9
+
10
  pipe = StableDiffusionXLPipeline.from_pretrained(
11
  "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
12
  )
13
+ pipe = pipe.to(device)
14
  pipe.vae.disable_tiling()
15
  tiling.enable_circular_tiling([pipe.vae, pipe.text_encoder, pipe.unet])
16
  return pipe
17
 
18
 
19
  def create_stable_diffusion_xl_img2img_pipe():
20
+ device = torch.device("cuda" if torch.cuda.is_available() else "mps")
21
+
22
  pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
23
  "stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
24
  )
25
+ pipe = pipe.to(device)
26
  pipe.vae.disable_tiling()
27
  tiling.enable_circular_tiling([pipe.vae, pipe.text_encoder_2, pipe.unet])
28
  return pipe