GLITCH-BITE404's picture
Update app.py
ccdcc64 verified
import torch
import gradio as gr
from PIL import Image
import random
from diffusers import (
DiffusionPipeline,
AutoencoderKL,
StableDiffusionControlNetPipeline,
ControlNetModel,
StableDiffusionControlNetImg2ImgPipeline,
DPMSolverMultistepScheduler,
EulerDiscreteScheduler
)
import tempfile
import time
from share_btn import community_icon_html, loading_icon_html, share_js
import user_history
from illusion_style import css
import os
BASE_MODEL = "SG161222/Realistic_Vision_V5.1_noVAE"
# Initialize both pipelines on CPU with float32
# Using float32 because CPU doesn't support half-precision (float16) well
device = "cpu"
torch_dtype = torch.float32
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch_dtype)
controlnet = ControlNetModel.from_pretrained("monster-labs/control_v1p_sd15_qrcode_monster", torch_dtype=torch_dtype)
# Safety checker disabled by default to save memory/CPU cycles
SAFETY_CHECKER_ENABLED = False
main_pipe = StableDiffusionControlNetPipeline.from_pretrained(
BASE_MODEL,
controlnet=controlnet,
vae=vae,
safety_checker=None,
feature_extractor=None,
torch_dtype=torch_dtype,
).to(device)
image_pipe = StableDiffusionControlNetImg2ImgPipeline(**main_pipe.components)
# Sampler map
SAMPLER_MAP = {
"DPM++ Karras SDE": lambda config: DPMSolverMultistepScheduler.from_config(config, use_karras=True, algorithm_type="sde-dpmsolver++"),
"Euler": lambda config: EulerDiscreteScheduler.from_config(config),
}
def center_crop_resize(img, output_size=(512, 512)):
width, height = img.size
new_dimension = min(width, height)
left = (width - new_dimension)/2
top = (height - new_dimension)/2
right = (width + new_dimension)/2
bottom = (height + new_dimension)/2
img = img.crop((left, top, right, bottom))
img = img.resize(output_size)
return img
def common_upscale(samples, width, height, upscale_method, crop=False):
return torch.nn.functional.interpolate(samples, size=(height, width), mode=upscale_method)
def upscale(samples, upscale_method, scale_by):
width = round(samples.shape[3] * scale_by)
height = round(samples.shape[2] * scale_by)
s = common_upscale(samples, width, height, upscale_method)
return s
def check_inputs(prompt: str, control_image: Image.Image):
if control_image is None:
raise gr.Error("Please select or upload an Input Illusion")
if prompt is None or prompt == "":
raise gr.Error("Prompt is required")
# Inference function
def inference(
control_image: Image.Image,
prompt: str,
negative_prompt: str,
guidance_scale: float = 8.0,
controlnet_conditioning_scale: float = 1,
control_guidance_start: float = 1,
control_guidance_end: float = 1,
upscaler_strength: float = 0.5,
seed: int = -1,
sampler = "DPM++ Karras SDE",
progress = gr.Progress(track_tqdm=True),
profile: gr.OAuthProfile | None = None,
):
start_time = time.time()
control_image_small = center_crop_resize(control_image)
control_image_large = center_crop_resize(control_image, (1024, 1024))
main_pipe.scheduler = SAMPLER_MAP[sampler](main_pipe.scheduler.config)
my_seed = random.randint(0, 2**32 - 1) if seed == -1 else seed
generator = torch.Generator(device=device).manual_seed(my_seed)
# Reducing steps for CPU performance
out = main_pipe(
prompt=prompt,
negative_prompt=negative_prompt,
image=control_image_small,
guidance_scale=float(guidance_scale),
controlnet_conditioning_scale=float(controlnet_conditioning_scale),
generator=generator,
control_guidance_start=float(control_guidance_start),
control_guidance_end=float(control_guidance_end),
num_inference_steps=12, # Dropped steps for speed
output_type="latent"
)
upscaled_latents = upscale(out.images if hasattr(out, 'images') else out[0], "nearest-exact", 2)
out_image = image_pipe(
prompt=prompt,
negative_prompt=negative_prompt,
control_image=control_image_large,
image=upscaled_latents,
guidance_scale=float(guidance_scale),
generator=generator,
num_inference_steps=15, # Dropped steps for speed
strength=upscaler_strength,
control_guidance_start=float(control_guidance_start),
control_guidance_end=float(control_guidance_end),
controlnet_conditioning_scale=float(controlnet_conditioning_scale)
)
end_time = time.time()
print(f"Inference took {end_time-start_time}s")
user_history.save_image(
label=prompt,
image=out_image["images"][0],
profile=profile,
metadata={"prompt": prompt, "seed": my_seed},
)
return out_image["images"][0], gr.update(visible=True), gr.update(visible=True), my_seed
with gr.Blocks(css=css) as app:
gr.Markdown("<div style='text-align: center;'><h1>Illusion Diffusion CPU 🌀</h1></div>")
with gr.Row():
with gr.Column():
control_image = gr.Image(label="Input Illusion", type="pil")
controlnet_conditioning_scale = gr.Slider(minimum=0.0, maximum=5.0, step=0.01, value=0.8, label="Illusion strength")
prompt = gr.Textbox(label="Prompt", placeholder="Medieval village scene")
negative_prompt = gr.Textbox(label="Negative Prompt", value="low quality")
with gr.Accordion(label="Advanced Options", open=False):
guidance_scale = gr.Slider(minimum=0.0, maximum=50.0, step=0.25, value=7.5, label="Guidance Scale")
sampler = gr.Dropdown(choices=list(SAMPLER_MAP.keys()), value="Euler")
control_start = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0, label="Start of ControlNet")
control_end = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1, label="End of ControlNet")
strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1, label="Strength of the upscaler")
seed = gr.Slider(minimum=-1, maximum=9999999999, step=1, value=-1, label="Seed")
used_seed = gr.Number(label="Last seed used", interactive=False)
run_btn = gr.Button("Run")
with gr.Column():
result_image = gr.Image(label="Output", interactive=False)
with gr.Group(visible=False) as share_group:
share_button = gr.Button("Share to community")
run_btn.click(
check_inputs,
inputs=[prompt, control_image],
queue=False
).success(
inference,
inputs=[control_image, prompt, negative_prompt, guidance_scale, controlnet_conditioning_scale, control_start, control_end, strength, seed, sampler],
outputs=[result_image, result_image, share_group, used_seed])
app.queue(max_size=10).launch()