heramb04 commited on
Commit
2d3511c
·
verified ·
1 Parent(s): 00ef95d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -2
app.py CHANGED
@@ -4,8 +4,17 @@ import gradio as gr
4
 
5
  def load_pipeline():
6
 
7
- device = "cuda" if torch.cuda.is_available() else "cpu"
8
- dtype = torch.float16 if device == "cuda" else torch.float32
 
 
 
 
 
 
 
 
 
9
 
10
  pipe = StableDiffusionPipeline.from_pretrained(
11
  "runwayml/stable-diffusion-v1-5",
 
4
 
5
  def load_pipeline():
6
 
7
+ if torch.cuda.is_available():
8
+ device = "cuda"
9
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
10
+ device = "mps"
11
+ else:
12
+ device = "cpu"
13
+
14
+ # Use float16 on any accelerator, float32 on CPU
15
+ dtype = torch.float16 if device != "cpu" else torch.float32
16
+ print(f"Using device: {device}, dtype: {dtype}")
17
+
18
 
19
  pipe = StableDiffusionPipeline.from_pretrained(
20
  "runwayml/stable-diffusion-v1-5",