heramb04 commited on
Commit
1f42979
·
verified ·
1 Parent(s): 2d3511c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -14
app.py CHANGED
@@ -3,25 +3,27 @@ from diffusers import StableDiffusionPipeline
3
  import gradio as gr
4
 
5
  def load_pipeline():
6
-
7
- if torch.cuda.is_available():
8
- device = "cuda"
9
  elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
10
- device = "mps"
11
  else:
12
- device = "cpu"
13
 
14
- # Use float16 on any accelerator, float32 on CPU
15
- dtype = torch.float16 if device != "cpu" else torch.float32
16
  print(f"Using device: {device}, dtype: {dtype}")
17
 
18
-
19
  pipe = StableDiffusionPipeline.from_pretrained(
20
  "runwayml/stable-diffusion-v1-5",
21
  torch_dtype=dtype
22
- )
23
- return pipe.to(device)
 
24
 
 
25
  pipe = load_pipeline()
26
 
27
  def generate(prompt: str, steps: int, scale: float):
@@ -29,17 +31,18 @@ def generate(prompt: str, steps: int, scale: float):
29
  out = pipe(prompt, num_inference_steps=steps, guidance_scale=scale)
30
  return out.images[0]
31
 
 
32
  demo = gr.Interface(
33
  fn=generate,
34
  inputs=[
35
- gr.Textbox(lines=1, placeholder="a steampunk robot in a lush jungle", label="Prompt"),
36
  gr.Slider(1, 100, value=50, step=1, label="Inference Steps"),
37
  gr.Slider(1.0, 15.0, value=7.5, step=0.1, label="Guidance Scale"),
38
  ],
39
  outputs=gr.Image(type="pil"),
40
- title="Stable Diffusion image generator",
41
- description="Generates images using Stable Diffusion."
42
  )
43
 
44
  if __name__ == "__main__":
45
- demo.launch(share=True)
 
3
  import gradio as gr
4
 
5
  def load_pipeline():
6
+ # Auto-detect any available GPU backend or fallback to CPU
7
+ if torch.cuda.is_available():
8
+ device = torch.device("cuda")
9
  elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
10
+ device = torch.device("mps")
11
  else:
12
+ device = torch.device("cpu")
13
 
14
+ # Use float16 precision on any GPU/MPS, float32 on CPU
15
+ dtype = torch.float16 if device.type != "cpu" else torch.float32
16
  print(f"Using device: {device}, dtype: {dtype}")
17
 
18
+ # Load weights & configs from HF at runtime
19
  pipe = StableDiffusionPipeline.from_pretrained(
20
  "runwayml/stable-diffusion-v1-5",
21
  torch_dtype=dtype
22
+ ).to(device)
23
+
24
+ return pipe
25
 
26
+ # Initialize pipeline once
27
  pipe = load_pipeline()
28
 
29
  def generate(prompt: str, steps: int, scale: float):
 
31
  out = pipe(prompt, num_inference_steps=steps, guidance_scale=scale)
32
  return out.images[0]
33
 
34
+ # Build and launch Gradio UI
35
  demo = gr.Interface(
36
  fn=generate,
37
  inputs=[
38
+ gr.Textbox(lines=1, placeholder="Enter prompt…", label="Prompt"),
39
  gr.Slider(1, 100, value=50, step=1, label="Inference Steps"),
40
  gr.Slider(1.0, 15.0, value=7.5, step=0.1, label="Guidance Scale"),
41
  ],
42
  outputs=gr.Image(type="pil"),
43
+ title="Stable Diffusion Image Generator",
44
+ description="Generates images based on your prompt!."
45
  )
46
 
47
  if __name__ == "__main__":
48
+ demo.launch()