tejani commited on
Commit
c744ef6
·
verified ·
1 Parent(s): 3712100

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -9
app.py CHANGED
@@ -4,25 +4,31 @@ import gradio as gr
4
 
5
  # Load the Stable Diffusion model
6
  model_id = "runwayml/stable-diffusion-v1-5" # Replace with your model if different
7
- pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32) # Use float32 for CPU
8
- pipe = pipe.to("cpu") # Explicitly set to CPU
9
-
10
- # Enable CPU offloading to save memory (optional but recommended)
11
- pipe.enable_attention_slicing() # Reduces memory usage by slicing attention computation
12
 
13
  # Define the generation function
14
  def generate_image(prompt, seed=None):
15
- # If no seed is provided, generate a random one
16
  if seed is None or seed == "":
 
17
  seed = torch.randint(0, 1000000, (1,)).item()
18
-
 
 
 
 
 
 
 
19
  # Set up the generator with the seed for CPU
20
  generator = torch.Generator(device="cpu").manual_seed(seed)
21
 
22
- # Generate the image with fewer steps for faster CPU execution
23
  image = pipe(prompt, generator=generator, num_inference_steps=20).images[0]
24
 
25
- return image, seed # Return the image and the seed used
26
 
27
  # Create Gradio interface
28
  interface = gr.Interface(
 
4
 
5
  # Load the Stable Diffusion model
6
  model_id = "runwayml/stable-diffusion-v1-5" # Replace with your model if different
7
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32)
8
+ pipe = pipe.to("cpu")
9
+ pipe.enable_attention_slicing() # Reduce memory usage on CPU
 
 
10
 
11
  # Define the generation function
12
  def generate_image(prompt, seed=None):
13
+ # Handle the seed input
14
  if seed is None or seed == "":
15
+ # Generate a random seed if none provided
16
  seed = torch.randint(0, 1000000, (1,)).item()
17
+ else:
18
+ # Convert the seed from string to integer
19
+ try:
20
+ seed = int(seed)
21
+ except ValueError:
22
+ # If conversion fails (e.g., user enters "abc"), use a random seed
23
+ seed = torch.randint(0, 1000000, (1,)).item()
24
+
25
  # Set up the generator with the seed for CPU
26
  generator = torch.Generator(device="cpu").manual_seed(seed)
27
 
28
+ # Generate the image
29
  image = pipe(prompt, generator=generator, num_inference_steps=20).images[0]
30
 
31
+ return image, str(seed) # Return seed as string for display
32
 
33
  # Create Gradio interface
34
  interface = gr.Interface(