Spaces:
Running on Zero
Running on Zero
File size: 13,710 Bytes
d066167 1928ea4 d066167 1928ea4 d066167 115b3c7 d066167 bc388ad 1928ea4 bc388ad | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 | import gradio as gr
import argparse
from refnet.sampling import get_noise_schedulers, get_sampler_list
from functools import partial
from backend import *
links = {
"base": "https://arxiv.org/abs/2401.01456",
"v1": "https://openaccess.thecvf.com/content/WACV2025/html/Yan_ColorizeDiffusion_Improving_Reference-Based_Sketch_Colorization_with_Latent_Diffusion_Model_WACV_2025_paper.html",
"v1.5": "https://arxiv.org/abs/2502.19937v1",
"v2": "https://arxiv.org/abs/2504.06895",
"xl": "https://arxiv.org/abs/2601.04883",
"weights": "https://huggingface.co/tellurion/ColorizeDiffusionXL/tree/main",
"github": "https://github.com/tellurion-kanata/colorizeDiffusion",
}
def app_options():
parser = argparse.ArgumentParser()
parser.add_argument("--server_name", '-addr', type=str, default="0.0.0.0")
parser.add_argument("--server_port", '-port', type=int, default=7860)
parser.add_argument("--share", action="store_true")
parser.add_argument("--enable_text_manipulation", '-manipulate', action="store_true")
return parser.parse_args()
def init_interface(opt, *args, **kwargs) -> None:
sampler_list = get_sampler_list()
scheduler_list = get_noise_schedulers()
img_block = partial(gr.Image, type="pil", height=300, interactive=True, show_label=True, format="png")
with gr.Blocks(
title = "Colorize Diffusion",
css_paths = "backend/style.css",
theme = gr.themes.Ocean(),
elem_id = "main-interface",
analytics_enabled = False,
fill_width = True
) as block:
with gr.Row(elem_id="header-row", equal_height=True, variant="panel"):
gr.Markdown(f"""<div class="header-container">
<div class="app-header"><span class="emoji">🎨</span><span class="title-text">Colorize Diffusion</span></div>
<div class="paper-links-icons">
<a href="{links['base']}" target="_blank">
<img src="https://img.shields.io/badge/arXiv-2407.15886 (base)-B31B1B?style=flat&logo=arXiv" alt="arXiv Paper">
</a>
<a href="{links['v1']}" target="_blank">
<img src="https://img.shields.io/badge/WACV 2025-v1-0CA4A5?style=flat&logo=Semantic%20Web" alt="WACV 2025">
</a>
<a href="{links['v1.5']}" target="_blank">
<img src="https://img.shields.io/badge/CVPR 2025-v1.5-0CA4A5?style=flat&logo=Semantic%20Web" alt="CVPR 2025">
</a>
<a href="{links['v2']}" target="_blank">
<img src="https://img.shields.io/badge/arXiv-2504.06895 (v2)-B31B1B?style=flat&logo=arXiv" alt="arXiv v2 Paper">
</a>
<a href="{links['xl']}" target="_blank">
<img src="https://img.shields.io/badge/CVPR 2026-XL-0CA4A5?style=flat&logo=Semantic%20Web" alt="CVPR 2026">
</a>
<a href="{links['weights']}" target="_blank">
<img src="https://img.shields.io/badge/Hugging%20Face-Model%20Weights-FF9D00?style=flat&logo=Hugging%20Face" alt="Model Weights">
</a>
<a href="{links['github']}" target="_blank">
<img src="https://img.shields.io/badge/GitHub-Repo-blue?style=flat&logo=GitHub" alt="GitHub">
</a>
<a href="https://github.com/tellurion-kanata/colorizeDiffusion/blob/master/LICENSE" target="_blank">
<img src="https://img.shields.io/badge/License-CC--BY--NC--SA%204.0-4CAF50?style=flat&logo=Creative%20Commons" alt="License">
</a>
</div>
</div>""")
with gr.Row(elem_id="content-row", equal_height=False, variant="panel"):
with gr.Column():
with gr.Row(visible=opt.enable_text_manipulation):
target = gr.Textbox(label="Target prompt", value="", scale=2)
anchor = gr.Textbox(label="Anchor prompt", value="", scale=2)
control = gr.Textbox(label="Control prompt", value="", scale=2)
with gr.Row(visible=opt.enable_text_manipulation):
target_scale = gr.Slider(label="Target scale", value=0.0, minimum=0, maximum=15.0, step=0.25, scale=2)
ts0 = gr.Slider(label="Threshold 0", value=0.5, minimum=0, maximum=1.0, step=0.01)
ts1 = gr.Slider(label="Threshold 1", value=0.55, minimum=0, maximum=1.0, step=0.01)
ts2 = gr.Slider(label="Threshold 2", value=0.65, minimum=0, maximum=1.0, step=0.01)
ts3 = gr.Slider(label="Threshold 3", value=0.95, minimum=0, maximum=1.0, step=0.01)
with gr.Row(visible=opt.enable_text_manipulation):
enhance = gr.Checkbox(label="Enhance manipulation", value=False)
add_prompt = gr.Button(value="Add")
clear_prompt = gr.Button(value="Clear")
vis_button = gr.Button(value="Visualize")
text_prompt = gr.Textbox(label="Final prompt", value="", lines=3, visible=opt.enable_text_manipulation)
with gr.Row():
sketch_img = img_block(label="Sketch")
reference_img = img_block(label="Reference")
background_img = img_block(label="Background")
style_enhance = gr.State(False)
fg_enhance = gr.State(False)
with gr.Row():
bg_enhance = gr.Checkbox(label="Low-level injection", value=False)
injection = gr.Checkbox(label="Attention injection", value=False)
autofit_size = gr.Checkbox(label="Autofit size", value=False)
with gr.Row():
gs_r = gr.Slider(label="Reference guidance scale", minimum=1, maximum=15.0, value=4.0, step=0.5)
strength = gr.Slider(label="Reference strength", minimum=0, maximum=1, value=1, step=0.05)
fg_strength = gr.Slider(label="Foreground strength", minimum=0, maximum=1, value=1, step=0.05)
bg_strength = gr.Slider(label="Background strength", minimum=0, maximum=1, value=1, step=0.05)
with gr.Row():
gs_s = gr.Slider(label="Sketch guidance scale", minimum=1, maximum=5.0, value=1.0, step=0.1)
ctl_scale = gr.Slider(label="Sketch strength", minimum=0, maximum=3, value=1, step=0.05)
mask_scale = gr.Slider(label="Background factor", minimum=0, maximum=2, value=1, step=0.05)
merge_scale = gr.Slider(label="Merging scale", minimum=0, maximum=1, value=0, step=0.05)
with gr.Row():
bs = gr.Slider(label="Batch size", minimum=1, maximum=4, value=1, step=1, scale=1)
width = gr.Slider(label="Width", minimum=512, maximum=1536, value=1024, step=32, scale=2)
with gr.Row():
step = gr.Slider(label="Step", minimum=1, maximum=100, value=20, step=1, scale=1)
height = gr.Slider(label="Height", minimum=512, maximum=1536, value=1024, step=32, scale=2)
seed = gr.Slider(label="Seed", minimum=-1, maximum=MAXM_INT32, step=1, value=-1)
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
crop = gr.Checkbox(label="Crop result", value=False, scale=1)
remove_fg = gr.Checkbox(label="Remove foreground in background input", value=False, scale=2)
rmbg = gr.Checkbox(label="Remove background in result", value=False, scale=2)
latent_inpaint = gr.Checkbox(label="Latent copy BG input", value=False, scale=2)
with gr.Row():
injection_control_scale = gr.Slider(label="Injection fidelity (sketch)", minimum=0.0,
maximum=2.0, value=0, step=0.05)
injection_fidelity = gr.Slider(label="Injection fidelity (reference)", minimum=0.0,
maximum=1.0, value=0.5, step=0.05)
injection_start_step = gr.Slider(label="Injection start step", minimum=0.0, maximum=1.0,
value=0, step=0.05)
with gr.Row():
reuse_seed = gr.Button(value="Reuse Seed")
random_seed = gr.Button(value="Random Seed")
with gr.Column():
result_gallery = gr.Gallery(
label='Output', show_label=False, elem_id="gallery", preview=True, type="pil", format="png"
)
run_button = gr.Button("Generate", variant="primary", size="lg")
with gr.Row():
mask_ts = gr.Slider(label="Reference mask threshold", minimum=0., maximum=1., value=0.5, step=0.01)
mask_ss = gr.Slider(label="Sketch mask threshold", minimum=0., maximum=1., value=0.05, step=0.01)
pad_scale = gr.Slider(label="Reference padding scale", minimum=1, maximum=2, value=1, step=0.05)
with gr.Row():
available_models = get_available_models()
sd_model = gr.Dropdown(choices=available_models, label="Models",
value=available_models[0] if available_models else None)
extractor_model = gr.Dropdown(choices=line_extractor_list,
label="Line extractor", value=default_line_extractor)
mask_model = gr.Dropdown(choices=mask_extractor_list, label="Reference mask extractor",
value=default_mask_extractor)
with gr.Row():
sampler = gr.Dropdown(choices=sampler_list, value="DPM++ 3M SDE", label="Sampler")
scheduler = gr.Dropdown(choices=scheduler_list, value=scheduler_list[0], label="Noise scheduler")
preprocessor = gr.Dropdown(choices=["none", "extract", "invert", "invert-webui"],
label="Sketch preprocessor", value="invert")
with gr.Row():
deterministic = gr.Checkbox(label="Deterministic batch seed", value=False)
save_memory = gr.Checkbox(label="Save memory", value=True)
# Hidden states for unused advanced controls
fg_disentangle_scale = gr.State(1.0)
start_step = gr.State(0.0)
end_step = gr.State(1.0)
no_start_step = gr.State(-0.05)
no_end_step = gr.State(-0.05)
return_inter = gr.State(False)
accurate = gr.State(False)
enc_scale = gr.State(1.0)
middle_scale = gr.State(1.0)
low_scale = gr.State(1.0)
ctl_scale_1 = gr.State(1.0)
ctl_scale_2 = gr.State(1.0)
ctl_scale_3 = gr.State(1.0)
ctl_scale_4 = gr.State(1.0)
add_prompt.click(fn=apppend_prompt,
inputs=[target, anchor, control, target_scale, enhance, ts0, ts1, ts2, ts3, text_prompt],
outputs=[target, anchor, control, target_scale, enhance, ts0, ts1, ts2, ts3, text_prompt])
clear_prompt.click(fn=clear_prompts, outputs=[text_prompt])
reuse_seed.click(fn=get_last_seed, outputs=[seed])
random_seed.click(fn=reset_random_seed, outputs=[seed])
extractor_model.input(fn=switch_extractor, inputs=[extractor_model])
sd_model.input(fn=load_model, inputs=[sd_model])
mask_model.input(fn=switch_mask_extractor, inputs=[mask_model])
ips = [style_enhance, bg_enhance, fg_enhance, fg_disentangle_scale,
bs, sketch_img, reference_img, background_img, mask_ts, mask_ss, gs_r, gs_s, ctl_scale,
ctl_scale_1, ctl_scale_2, ctl_scale_3, ctl_scale_4, fg_strength, bg_strength, merge_scale,
mask_scale, height, width, seed, save_memory, step, injection, autofit_size,
remove_fg, rmbg, latent_inpaint, injection_control_scale, injection_fidelity, injection_start_step,
crop, pad_scale, start_step, end_step, no_start_step, no_end_step, return_inter, sampler, scheduler,
preprocessor, deterministic, text_prompt, target, anchor, control, target_scale, ts0, ts1, ts2, ts3,
enhance, accurate, enc_scale, middle_scale, low_scale, strength]
run_button.click(
fn = inference,
inputs = ips,
outputs = [result_gallery],
)
vis_button.click(
fn = visualize,
inputs = [reference_img, text_prompt, control, ts0, ts1, ts2, ts3],
outputs = [result_gallery],
)
block.launch(
server_name = opt.server_name,
share = opt.share,
server_port = opt.server_port,
)
if __name__ == '__main__':
opt = app_options()
switch_extractor(default_line_extractor)
switch_mask_extractor(default_mask_extractor)
available_models = get_available_models()
if available_models:
load_model(available_models[0])
init_interface(opt)
|