chrisjcc commited on
Commit
34ffcd0
·
verified ·
1 Parent(s): b0967f1

Update generate function

Browse files
Files changed (1) hide show
  1. app.py +9 -6
app.py CHANGED
@@ -6,7 +6,7 @@ import base64
6
  import torch
7
  from diffusers import StableDiffusionPipeline
8
 
9
- from transformers import pipeline
10
  import gradio as gr
11
 
12
  # Set Hugging Face API (needed for gated models)
@@ -16,7 +16,7 @@ hf_api_key = os.environ.get('HF_API_KEY')
16
  model_id = "runwayml/stable-diffusion-v1-5"
17
  pipe = StableDiffusionPipeline.from_pretrained(
18
  model_id,
19
- torch_dtype=torch.float16, # Use float16 for better performance on GPU
20
  use_auth_token=hf_api_key # Required for gated model
21
  )
22
 
@@ -41,18 +41,21 @@ pipe = pipe.to(device)
41
  # return result_image
42
 
43
  def generate(prompt, negative_prompt, steps, guidance, width, height):
 
 
 
 
44
  # Generate image with Stable Diffusion
45
  output = pipe(
46
  prompt,
47
- negative_prompt=negative_prompt,
48
  num_inference_steps=int(steps),
49
  guidance_scale=float(guidance),
50
- width=int(width),
51
- height=int(height)
52
  )
53
  return output.images[0] # Return the first generated image (PIL format)
54
 
55
-
56
  # Create Gradio interface
57
  with gr.Blocks() as demo:
58
  gr.Markdown("# Image Generation with Stable Diffusion")
 
6
  import torch
7
  from diffusers import StableDiffusionPipeline
8
 
9
+ #from transformers import pipeline
10
  import gradio as gr
11
 
12
  # Set Hugging Face API (needed for gated models)
 
16
  model_id = "runwayml/stable-diffusion-v1-5"
17
  pipe = StableDiffusionPipeline.from_pretrained(
18
  model_id,
19
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, # Use float16 on GPU, float32 on CPU
20
  use_auth_token=hf_api_key # Required for gated model
21
  )
22
 
 
41
  # return result_image
42
 
43
  def generate(prompt, negative_prompt, steps, guidance, width, height):
44
+ # Ensure width and height are multiples of 8 (required by Stable Diffusion)
45
+ width = int(width) - (int(width) % 8)
46
+ height = int(height) - (int(height) % 8)
47
+
48
  # Generate image with Stable Diffusion
49
  output = pipe(
50
  prompt,
51
+ negative_prompt=negative_prompt or None, # Handle empty negative prompt
52
  num_inference_steps=int(steps),
53
  guidance_scale=float(guidance),
54
+ width=width,
55
+ height=height
56
  )
57
  return output.images[0] # Return the first generated image (PIL format)
58
 
 
59
  # Create Gradio interface
60
  with gr.Blocks() as demo:
61
  gr.Markdown("# Image Generation with Stable Diffusion")