|
|
import os |
|
|
import random |
|
|
import torch |
|
|
import gradio as gr |
|
|
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler |
|
|
|
|
|
|
|
|
MODEL_ID = os.getenv("MODEL_ID", "runwayml/stable-diffusion-v1-5") |
|
|
|
|
|
|
|
|
DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
|
|
|
|
|
|
pipe = StableDiffusionPipeline.from_pretrained( |
|
|
MODEL_ID, |
|
|
torch_dtype=DTYPE, |
|
|
use_safetensors=True, |
|
|
) |
|
|
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
pipe = pipe.to("cuda") |
|
|
else: |
|
|
pipe.enable_attention_slicing() |
|
|
|
|
|
def generate(prompt, negative_prompt="", steps=25, guidance=7.5, width=512, height=512, seed=-1): |
|
|
if not prompt: |
|
|
raise gr.Error("Please enter a prompt.") |
|
|
|
|
|
width = max(256, int(width) // 8 * 8) |
|
|
height = max(256, int(height) // 8 * 8) |
|
|
|
|
|
if seed == -1: |
|
|
seed = random.randint(0, 2**32 - 1) |
|
|
|
|
|
generator = torch.Generator(device=pipe.device).manual_seed(int(seed)) |
|
|
|
|
|
result = pipe( |
|
|
prompt=prompt, |
|
|
negative_prompt=negative_prompt or None, |
|
|
num_inference_steps=int(steps), |
|
|
guidance_scale=float(guidance), |
|
|
width=int(width), |
|
|
height=int(height), |
|
|
generator=generator, |
|
|
) |
|
|
return result.images[0], seed |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Stable Diffusion 1.5") as demo: |
|
|
gr.Markdown("# 🖼️ Stable Diffusion 1.5 — Text to Image") |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
prompt = gr.Textbox(label="Prompt", placeholder="A cat wearing a spacesuit, 8k, cinematic") |
|
|
negative = gr.Textbox(label="Negative Prompt", placeholder="blurry, low quality") |
|
|
steps = gr.Slider(10, 50, value=25, label="Steps") |
|
|
guidance = gr.Slider(1.0, 12.0, value=7.5, step=0.1, label="Guidance Scale") |
|
|
width = gr.Dropdown([512, 640, 768], value=512, label="Width") |
|
|
height = gr.Dropdown([512, 640, 768], value=512, label="Height") |
|
|
seed = gr.Number(value=-1, label="Seed (-1 = random)") |
|
|
btn = gr.Button("Generate", variant="primary") |
|
|
with gr.Column(): |
|
|
output = gr.Image(label="Result", type="pil") |
|
|
used_seed = gr.Number(label="Used Seed", interactive=False) |
|
|
|
|
|
examples = gr.Examples( |
|
|
examples=[ |
|
|
["A cyberpunk city at night, neon lights", "", 25, 7.5, 512, 512, -1], |
|
|
["A cute dragon made of origami paper", "low quality", 30, 8, 512, 512, -1], |
|
|
], |
|
|
inputs=[prompt, negative, steps, guidance, width, height, seed], |
|
|
) |
|
|
|
|
|
btn.click( |
|
|
fn=generate, |
|
|
inputs=[prompt, negative, steps, guidance, width, height, seed], |
|
|
outputs=[output, used_seed], |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|