heramb04's picture
Update app.py
7f44418 verified
raw
history blame
1.11 kB
import torch
from diffusers import StableDiffusionPipeline
import gradio as gr
def load_pipeline():
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=dtype
)
return pipe.to(device)
pipe = load_pipeline()
def generate(prompt: str, steps: int, scale: float):
"""Run the pipeline and return a PIL image."""
out = pipe(prompt, num_inference_steps=steps, guidance_scale=scale)
return out.images[0]
demo = gr.Interface(
fn=generate,
inputs=[
gr.Textbox(lines=1, placeholder="a steampunk robot in a lush jungle", label="Prompt"),
gr.Slider(1, 100, value=50, step=1, label="Inference Steps"),
gr.Slider(1.0, 15.0, value=7.5, step=0.1, label="Guidance Scale"),
],
outputs=gr.Image(type="pil"),
title="Stable Diffusion image generator",
description="Generates images using Stable Diffusion."
)
if __name__ == "__main__":
demo.launch(share=True)