chrisjcc's picture
Update to use ZeroGPU spaces feature
37a666c verified
import os
import io
from PIL import Image
import base64
import torch
from diffusers import EulerDiscreteScheduler
from diffusers import StableDiffusionXLPipeline
from diffusers import StableDiffusion3Pipeline
from diffusers import StableDiffusionPipeline
from diffusers import DiffusionPipeline
import spaces
#from transformers import pipeline
import gradio as gr
# Set Hugging Face API (needed for gated models)
hf_api_key = os.environ.get('HF_API_KEY')
# Load the Stable Diffusion pipeline
model_id = "sd-legacy/stable-diffusion-v1-5"
# Use the Euler scheduler here instead
scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, # Use float16 on GPU, float32 on CPU
scheduler=scheduler,
use_auth_token=hf_api_key # Required for gated model
)
# Load the Stable Diffusion pipeline
#model_id = "stabilityai/stable-diffusion-3.5-medium"
#pipe = SD3Transformer2DModel.from_pretrained(
# model_id,
# subfolder="transformer",
# torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, # Use float16 on GPU, float32 on CPU
# use_auth_token=hf_api_key # Required for gated model
#)
# Load the Stable Diffusion XL pipeline
#model_id = "stabilityai/stable-diffusion-xl-base-1.0"
#pipe = StableDiffusionXLPipeline.from_pretrained(
# model_id,
# torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, # Use float16 on GPU, float32 on CPU
# use_auth_token=hf_api_key # Required for gated model
#)
# Load the Stable Diffusion pipeline
#model_id = "stabilityai/stable-diffusion-3-medium"
#model_id = "stabilityai/stable-diffusion-3-medium-diffusers"
#pipe = StableDiffusion3Pipeline.from_pretrained(
# model_id,
# torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, # Use float16 on GPU, float32 on CPU,
# scheduler=scheduler,
# use_auth_token=hf_api_key # Required for gated model
#)
# Move pipeline to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = pipe.to(device)
# Text-to-image endpoint
#get_completion = pipeline("text-to-image", model="stabilityai/stable-diffusion-xl-base-1.0")
# A helper function to convert the PIL image to base64,
# so you can send it to the API
def base64_to_pil(img_base64):
base64_decoded = base64.b64decode(img_base64)
byte_stream = io.BytesIO(base64_decoded)
pil_image = Image.open(byte_stream)
return pil_image
#def generate(prompt):
# output = get_completion(prompt)
# result_image = base64_to_pil(output)
# return result_image
# Generate function
@spaces.GPU(duration=120) # Designed to be effect-free in non-ZeroGPU environments, ensuring compatibility across different setups.
def generate(prompt, negative_prompt, steps, guidance, width, height):
# Ensure width and height are multiples of 8 (required by Stable Diffusion)
width = int(width) - (int(width) % 8)
height = int(height) - (int(height) % 8)
# Generate image with Stable Diffusion
output = pipe(
prompt,
negative_prompt=negative_prompt or None, # Handle empty negative prompt
num_inference_steps=int(steps),
guidance_scale=float(guidance),
width=width,
height=height
)
return output.images[0] # Return the first generated image (PIL format)
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Image Generation with Stable Diffusion")
prompt = gr.Textbox(label="Your prompt")
with gr.Row():
with gr.Column():
negative_prompt = gr.Textbox(label="Negative prompt")
steps = gr.Slider(label="Inference Steps", minimum=1, maximum=100, value=25,
info="In many steps will the denoiser denoise the image?")
guidance = gr.Slider(label="Guidance Scale", minimum=1, maximum=20, value=7,
info="Controls how much the text prompt influences the result")
width = gr.Slider(label="Width", minimum=64, maximum=512, step=64, value=512)
height = gr.Slider(label="Height", minimum=64, maximum=512, step=64, value=512)
btn = gr.Button("Submit")
with gr.Column():
output = gr.Image(label="Result")
btn.click(fn=generate, inputs=[prompt, negative_prompt, steps, guidance, width, height], outputs=[output])
# Launch the app
demo.launch(
share=True,
#server_port=int(os.environ['PORT3'])
)