tellurion's picture
Clean up dead code and add startup model loading
1928ea4
import gradio as gr
import argparse
from refnet.sampling import get_noise_schedulers, get_sampler_list
from functools import partial
from backend import *
links = {
"base": "https://arxiv.org/abs/2401.01456",
"v1": "https://openaccess.thecvf.com/content/WACV2025/html/Yan_ColorizeDiffusion_Improving_Reference-Based_Sketch_Colorization_with_Latent_Diffusion_Model_WACV_2025_paper.html",
"v1.5": "https://arxiv.org/abs/2502.19937v1",
"v2": "https://arxiv.org/abs/2504.06895",
"xl": "https://arxiv.org/abs/2601.04883",
"weights": "https://huggingface.co/tellurion/ColorizeDiffusionXL/tree/main",
"github": "https://github.com/tellurion-kanata/colorizeDiffusion",
}
def app_options():
parser = argparse.ArgumentParser()
parser.add_argument("--server_name", '-addr', type=str, default="0.0.0.0")
parser.add_argument("--server_port", '-port', type=int, default=7860)
parser.add_argument("--share", action="store_true")
parser.add_argument("--enable_text_manipulation", '-manipulate', action="store_true")
return parser.parse_args()
def init_interface(opt, *args, **kwargs) -> None:
sampler_list = get_sampler_list()
scheduler_list = get_noise_schedulers()
img_block = partial(gr.Image, type="pil", height=300, interactive=True, show_label=True, format="png")
with gr.Blocks(
title = "Colorize Diffusion",
css_paths = "backend/style.css",
theme = gr.themes.Ocean(),
elem_id = "main-interface",
analytics_enabled = False,
fill_width = True
) as block:
with gr.Row(elem_id="header-row", equal_height=True, variant="panel"):
gr.Markdown(f"""<div class="header-container">
<div class="app-header"><span class="emoji">🎨</span><span class="title-text">Colorize Diffusion</span></div>
<div class="paper-links-icons">
<a href="{links['base']}" target="_blank">
<img src="https://img.shields.io/badge/arXiv-2407.15886 (base)-B31B1B?style=flat&logo=arXiv" alt="arXiv Paper">
</a>
<a href="{links['v1']}" target="_blank">
<img src="https://img.shields.io/badge/WACV 2025-v1-0CA4A5?style=flat&logo=Semantic%20Web" alt="WACV 2025">
</a>
<a href="{links['v1.5']}" target="_blank">
<img src="https://img.shields.io/badge/CVPR 2025-v1.5-0CA4A5?style=flat&logo=Semantic%20Web" alt="CVPR 2025">
</a>
<a href="{links['v2']}" target="_blank">
<img src="https://img.shields.io/badge/arXiv-2504.06895 (v2)-B31B1B?style=flat&logo=arXiv" alt="arXiv v2 Paper">
</a>
<a href="{links['xl']}" target="_blank">
<img src="https://img.shields.io/badge/CVPR 2026-XL-0CA4A5?style=flat&logo=Semantic%20Web" alt="CVPR 2026">
</a>
<a href="{links['weights']}" target="_blank">
<img src="https://img.shields.io/badge/Hugging%20Face-Model%20Weights-FF9D00?style=flat&logo=Hugging%20Face" alt="Model Weights">
</a>
<a href="{links['github']}" target="_blank">
<img src="https://img.shields.io/badge/GitHub-Repo-blue?style=flat&logo=GitHub" alt="GitHub">
</a>
<a href="https://github.com/tellurion-kanata/colorizeDiffusion/blob/master/LICENSE" target="_blank">
<img src="https://img.shields.io/badge/License-CC--BY--NC--SA%204.0-4CAF50?style=flat&logo=Creative%20Commons" alt="License">
</a>
</div>
</div>""")
with gr.Row(elem_id="content-row", equal_height=False, variant="panel"):
with gr.Column():
with gr.Row(visible=opt.enable_text_manipulation):
target = gr.Textbox(label="Target prompt", value="", scale=2)
anchor = gr.Textbox(label="Anchor prompt", value="", scale=2)
control = gr.Textbox(label="Control prompt", value="", scale=2)
with gr.Row(visible=opt.enable_text_manipulation):
target_scale = gr.Slider(label="Target scale", value=0.0, minimum=0, maximum=15.0, step=0.25, scale=2)
ts0 = gr.Slider(label="Threshold 0", value=0.5, minimum=0, maximum=1.0, step=0.01)
ts1 = gr.Slider(label="Threshold 1", value=0.55, minimum=0, maximum=1.0, step=0.01)
ts2 = gr.Slider(label="Threshold 2", value=0.65, minimum=0, maximum=1.0, step=0.01)
ts3 = gr.Slider(label="Threshold 3", value=0.95, minimum=0, maximum=1.0, step=0.01)
with gr.Row(visible=opt.enable_text_manipulation):
enhance = gr.Checkbox(label="Enhance manipulation", value=False)
add_prompt = gr.Button(value="Add")
clear_prompt = gr.Button(value="Clear")
vis_button = gr.Button(value="Visualize")
text_prompt = gr.Textbox(label="Final prompt", value="", lines=3, visible=opt.enable_text_manipulation)
with gr.Row():
sketch_img = img_block(label="Sketch")
reference_img = img_block(label="Reference")
background_img = img_block(label="Background")
style_enhance = gr.State(False)
fg_enhance = gr.State(False)
with gr.Row():
bg_enhance = gr.Checkbox(label="Low-level injection", value=False)
injection = gr.Checkbox(label="Attention injection", value=False)
autofit_size = gr.Checkbox(label="Autofit size", value=False)
with gr.Row():
gs_r = gr.Slider(label="Reference guidance scale", minimum=1, maximum=15.0, value=4.0, step=0.5)
strength = gr.Slider(label="Reference strength", minimum=0, maximum=1, value=1, step=0.05)
fg_strength = gr.Slider(label="Foreground strength", minimum=0, maximum=1, value=1, step=0.05)
bg_strength = gr.Slider(label="Background strength", minimum=0, maximum=1, value=1, step=0.05)
with gr.Row():
gs_s = gr.Slider(label="Sketch guidance scale", minimum=1, maximum=5.0, value=1.0, step=0.1)
ctl_scale = gr.Slider(label="Sketch strength", minimum=0, maximum=3, value=1, step=0.05)
mask_scale = gr.Slider(label="Background factor", minimum=0, maximum=2, value=1, step=0.05)
merge_scale = gr.Slider(label="Merging scale", minimum=0, maximum=1, value=0, step=0.05)
with gr.Row():
bs = gr.Slider(label="Batch size", minimum=1, maximum=4, value=1, step=1, scale=1)
width = gr.Slider(label="Width", minimum=512, maximum=1536, value=1024, step=32, scale=2)
with gr.Row():
step = gr.Slider(label="Step", minimum=1, maximum=100, value=20, step=1, scale=1)
height = gr.Slider(label="Height", minimum=512, maximum=1536, value=1024, step=32, scale=2)
seed = gr.Slider(label="Seed", minimum=-1, maximum=MAXM_INT32, step=1, value=-1)
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
crop = gr.Checkbox(label="Crop result", value=False, scale=1)
remove_fg = gr.Checkbox(label="Remove foreground in background input", value=False, scale=2)
rmbg = gr.Checkbox(label="Remove background in result", value=False, scale=2)
latent_inpaint = gr.Checkbox(label="Latent copy BG input", value=False, scale=2)
with gr.Row():
injection_control_scale = gr.Slider(label="Injection fidelity (sketch)", minimum=0.0,
maximum=2.0, value=0, step=0.05)
injection_fidelity = gr.Slider(label="Injection fidelity (reference)", minimum=0.0,
maximum=1.0, value=0.5, step=0.05)
injection_start_step = gr.Slider(label="Injection start step", minimum=0.0, maximum=1.0,
value=0, step=0.05)
with gr.Row():
reuse_seed = gr.Button(value="Reuse Seed")
random_seed = gr.Button(value="Random Seed")
with gr.Column():
result_gallery = gr.Gallery(
label='Output', show_label=False, elem_id="gallery", preview=True, type="pil", format="png"
)
run_button = gr.Button("Generate", variant="primary", size="lg")
with gr.Row():
mask_ts = gr.Slider(label="Reference mask threshold", minimum=0., maximum=1., value=0.5, step=0.01)
mask_ss = gr.Slider(label="Sketch mask threshold", minimum=0., maximum=1., value=0.05, step=0.01)
pad_scale = gr.Slider(label="Reference padding scale", minimum=1, maximum=2, value=1, step=0.05)
with gr.Row():
available_models = get_available_models()
sd_model = gr.Dropdown(choices=available_models, label="Models",
value=available_models[0] if available_models else None)
extractor_model = gr.Dropdown(choices=line_extractor_list,
label="Line extractor", value=default_line_extractor)
mask_model = gr.Dropdown(choices=mask_extractor_list, label="Reference mask extractor",
value=default_mask_extractor)
with gr.Row():
sampler = gr.Dropdown(choices=sampler_list, value="DPM++ 3M SDE", label="Sampler")
scheduler = gr.Dropdown(choices=scheduler_list, value=scheduler_list[0], label="Noise scheduler")
preprocessor = gr.Dropdown(choices=["none", "extract", "invert", "invert-webui"],
label="Sketch preprocessor", value="invert")
with gr.Row():
deterministic = gr.Checkbox(label="Deterministic batch seed", value=False)
save_memory = gr.Checkbox(label="Save memory", value=True)
# Hidden states for unused advanced controls
fg_disentangle_scale = gr.State(1.0)
start_step = gr.State(0.0)
end_step = gr.State(1.0)
no_start_step = gr.State(-0.05)
no_end_step = gr.State(-0.05)
return_inter = gr.State(False)
accurate = gr.State(False)
enc_scale = gr.State(1.0)
middle_scale = gr.State(1.0)
low_scale = gr.State(1.0)
ctl_scale_1 = gr.State(1.0)
ctl_scale_2 = gr.State(1.0)
ctl_scale_3 = gr.State(1.0)
ctl_scale_4 = gr.State(1.0)
add_prompt.click(fn=apppend_prompt,
inputs=[target, anchor, control, target_scale, enhance, ts0, ts1, ts2, ts3, text_prompt],
outputs=[target, anchor, control, target_scale, enhance, ts0, ts1, ts2, ts3, text_prompt])
clear_prompt.click(fn=clear_prompts, outputs=[text_prompt])
reuse_seed.click(fn=get_last_seed, outputs=[seed])
random_seed.click(fn=reset_random_seed, outputs=[seed])
extractor_model.input(fn=switch_extractor, inputs=[extractor_model])
sd_model.input(fn=load_model, inputs=[sd_model])
mask_model.input(fn=switch_mask_extractor, inputs=[mask_model])
ips = [style_enhance, bg_enhance, fg_enhance, fg_disentangle_scale,
bs, sketch_img, reference_img, background_img, mask_ts, mask_ss, gs_r, gs_s, ctl_scale,
ctl_scale_1, ctl_scale_2, ctl_scale_3, ctl_scale_4, fg_strength, bg_strength, merge_scale,
mask_scale, height, width, seed, save_memory, step, injection, autofit_size,
remove_fg, rmbg, latent_inpaint, injection_control_scale, injection_fidelity, injection_start_step,
crop, pad_scale, start_step, end_step, no_start_step, no_end_step, return_inter, sampler, scheduler,
preprocessor, deterministic, text_prompt, target, anchor, control, target_scale, ts0, ts1, ts2, ts3,
enhance, accurate, enc_scale, middle_scale, low_scale, strength]
run_button.click(
fn = inference,
inputs = ips,
outputs = [result_gallery],
)
vis_button.click(
fn = visualize,
inputs = [reference_img, text_prompt, control, ts0, ts1, ts2, ts3],
outputs = [result_gallery],
)
block.launch(
server_name = opt.server_name,
share = opt.share,
server_port = opt.server_port,
)
if __name__ == '__main__':
opt = app_options()
switch_extractor(default_line_extractor)
switch_mask_extractor(default_mask_extractor)
available_models = get_available_models()
if available_models:
load_model(available_models[0])
init_interface(opt)