NSFW2 / app.py
Saravutw's picture
Update app.py
d1368fa verified
import gradio as gr
import numpy as np
import random
from PIL import Image
from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel
from controlnet_aux import CannyDetector
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
model_repo_id = "stabilityai/sdxl-turbo"
controlnet_model_id = "diffusers/controlnet-canny-sdxl-1.0"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
controlnet = ControlNetModel.from_pretrained(
controlnet_model_id,
torch_dtype=torch_dtype,
variant="fp16" if torch.cuda.is_available() else None,
use_safetensors=True
)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
model_repo_id,
controlnet=controlnet,
torch_dtype=torch_dtype,
variant="fp16" if torch.cuda.is_available() else None,
use_safetensors=True
).to(device)
canny_detector = CannyDetector()
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 768 # safer for free tier
def get_canny_image(image):
image = np.array(image)
image = canny_detector(image)
return Image.fromarray(image)
def infer(
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
control_image,
controlnet_conditioning_scale,
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device=device).manual_seed(seed)
if control_image is not None:
processed_control_image = get_canny_image(control_image)
actual_controlnet_conditioning_scale = controlnet_conditioning_scale
else:
processed_control_image = Image.new("RGB", (width, height), (0, 0, 0))
actual_controlnet_conditioning_scale = 0.0
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
image=processed_control_image,
controlnet_conditioning_scale=actual_controlnet_conditioning_scale,
).images[0]
return image, seed, processed_control_image
examples = [
["Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", None, None],
["An astronaut riding a green horse", None, None],
["A delicious ceviche cheesecake slice", None, None],
]
with gr.Blocks() as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("## SDXL Turbo + ControlNet (Canny)")
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0)
result = gr.Image(label="Result", show_label=False)
processed_control_image_output = gr.Image(label="Processed Control Image", show_label=False)
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Text(label="Negative prompt", max_lines=1, placeholder="Enter a negative prompt")
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=512)
height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=512)
with gr.Row():
guidance_scale = gr.Slider(label="Guidance scale", minimum=0.0, maximum=10.0, step=0.1, value=4.0)
num_inference_steps = gr.Slider(label="Inference steps", minimum=1, maximum=20, step=1, value=2)
with gr.Row():
control_image = gr.Image(label="Control Image", type="pil", value=None)
controlnet_conditioning_scale = gr.Slider(label="ControlNet Conditioning Scale", minimum=0.0, maximum=2.0, step=0.05, value=1.0)
gr.Examples(examples=examples, inputs=[prompt, control_image, negative_prompt])
run_button.click(
fn=infer,
inputs=[
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
control_image,
controlnet_conditioning_scale,
],
outputs=[result, seed, processed_control_image_output],
)
demo.launch()