Spaces:
Running on Zero
Running on Zero
| import gradio as gr | |
| import argparse | |
| from refnet.sampling import get_noise_schedulers, get_sampler_list | |
| from functools import partial | |
| from backend import * | |
| links = { | |
| "base": "https://arxiv.org/abs/2401.01456", | |
| "v1": "https://openaccess.thecvf.com/content/WACV2025/html/Yan_ColorizeDiffusion_Improving_Reference-Based_Sketch_Colorization_with_Latent_Diffusion_Model_WACV_2025_paper.html", | |
| "v1.5": "https://arxiv.org/abs/2502.19937v1", | |
| "v2": "https://arxiv.org/abs/2504.06895", | |
| "xl": "https://arxiv.org/abs/2601.04883", | |
| "weights": "https://huggingface.co/tellurion/ColorizeDiffusionXL/tree/main", | |
| "github": "https://github.com/tellurion-kanata/colorizeDiffusion", | |
| } | |
| def app_options(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--server_name", '-addr', type=str, default="0.0.0.0") | |
| parser.add_argument("--server_port", '-port', type=int, default=7860) | |
| parser.add_argument("--share", action="store_true") | |
| parser.add_argument("--enable_text_manipulation", '-manipulate', action="store_true") | |
| return parser.parse_args() | |
| def init_interface(opt, *args, **kwargs) -> None: | |
| sampler_list = get_sampler_list() | |
| scheduler_list = get_noise_schedulers() | |
| img_block = partial(gr.Image, type="pil", height=300, interactive=True, show_label=True, format="png") | |
| with gr.Blocks( | |
| title = "Colorize Diffusion", | |
| css_paths = "backend/style.css", | |
| theme = gr.themes.Ocean(), | |
| elem_id = "main-interface", | |
| analytics_enabled = False, | |
| fill_width = True | |
| ) as block: | |
| with gr.Row(elem_id="header-row", equal_height=True, variant="panel"): | |
| gr.Markdown(f"""<div class="header-container"> | |
| <div class="app-header"><span class="emoji">🎨</span><span class="title-text">Colorize Diffusion</span></div> | |
| <div class="paper-links-icons"> | |
| <a href="{links['base']}" target="_blank"> | |
| <img src="https://img.shields.io/badge/arXiv-2407.15886 (base)-B31B1B?style=flat&logo=arXiv" alt="arXiv Paper"> | |
| </a> | |
| <a href="{links['v1']}" target="_blank"> | |
| <img src="https://img.shields.io/badge/WACV 2025-v1-0CA4A5?style=flat&logo=Semantic%20Web" alt="WACV 2025"> | |
| </a> | |
| <a href="{links['v1.5']}" target="_blank"> | |
| <img src="https://img.shields.io/badge/CVPR 2025-v1.5-0CA4A5?style=flat&logo=Semantic%20Web" alt="CVPR 2025"> | |
| </a> | |
| <a href="{links['v2']}" target="_blank"> | |
| <img src="https://img.shields.io/badge/arXiv-2504.06895 (v2)-B31B1B?style=flat&logo=arXiv" alt="arXiv v2 Paper"> | |
| </a> | |
| <a href="{links['xl']}" target="_blank"> | |
| <img src="https://img.shields.io/badge/CVPR 2026-XL-0CA4A5?style=flat&logo=Semantic%20Web" alt="CVPR 2026"> | |
| </a> | |
| <a href="{links['weights']}" target="_blank"> | |
| <img src="https://img.shields.io/badge/Hugging%20Face-Model%20Weights-FF9D00?style=flat&logo=Hugging%20Face" alt="Model Weights"> | |
| </a> | |
| <a href="{links['github']}" target="_blank"> | |
| <img src="https://img.shields.io/badge/GitHub-Repo-blue?style=flat&logo=GitHub" alt="GitHub"> | |
| </a> | |
| <a href="https://github.com/tellurion-kanata/colorizeDiffusion/blob/master/LICENSE" target="_blank"> | |
| <img src="https://img.shields.io/badge/License-CC--BY--NC--SA%204.0-4CAF50?style=flat&logo=Creative%20Commons" alt="License"> | |
| </a> | |
| </div> | |
| </div>""") | |
| with gr.Row(elem_id="content-row", equal_height=False, variant="panel"): | |
| with gr.Column(): | |
| with gr.Row(visible=opt.enable_text_manipulation): | |
| target = gr.Textbox(label="Target prompt", value="", scale=2) | |
| anchor = gr.Textbox(label="Anchor prompt", value="", scale=2) | |
| control = gr.Textbox(label="Control prompt", value="", scale=2) | |
| with gr.Row(visible=opt.enable_text_manipulation): | |
| target_scale = gr.Slider(label="Target scale", value=0.0, minimum=0, maximum=15.0, step=0.25, scale=2) | |
| ts0 = gr.Slider(label="Threshold 0", value=0.5, minimum=0, maximum=1.0, step=0.01) | |
| ts1 = gr.Slider(label="Threshold 1", value=0.55, minimum=0, maximum=1.0, step=0.01) | |
| ts2 = gr.Slider(label="Threshold 2", value=0.65, minimum=0, maximum=1.0, step=0.01) | |
| ts3 = gr.Slider(label="Threshold 3", value=0.95, minimum=0, maximum=1.0, step=0.01) | |
| with gr.Row(visible=opt.enable_text_manipulation): | |
| enhance = gr.Checkbox(label="Enhance manipulation", value=False) | |
| add_prompt = gr.Button(value="Add") | |
| clear_prompt = gr.Button(value="Clear") | |
| vis_button = gr.Button(value="Visualize") | |
| text_prompt = gr.Textbox(label="Final prompt", value="", lines=3, visible=opt.enable_text_manipulation) | |
| with gr.Row(): | |
| sketch_img = img_block(label="Sketch") | |
| reference_img = img_block(label="Reference") | |
| background_img = img_block(label="Background") | |
| style_enhance = gr.State(False) | |
| fg_enhance = gr.State(False) | |
| with gr.Row(): | |
| bg_enhance = gr.Checkbox(label="Low-level injection", value=False) | |
| injection = gr.Checkbox(label="Attention injection", value=False) | |
| autofit_size = gr.Checkbox(label="Autofit size", value=False) | |
| with gr.Row(): | |
| gs_r = gr.Slider(label="Reference guidance scale", minimum=1, maximum=15.0, value=4.0, step=0.5) | |
| strength = gr.Slider(label="Reference strength", minimum=0, maximum=1, value=1, step=0.05) | |
| fg_strength = gr.Slider(label="Foreground strength", minimum=0, maximum=1, value=1, step=0.05) | |
| bg_strength = gr.Slider(label="Background strength", minimum=0, maximum=1, value=1, step=0.05) | |
| with gr.Row(): | |
| gs_s = gr.Slider(label="Sketch guidance scale", minimum=1, maximum=5.0, value=1.0, step=0.1) | |
| ctl_scale = gr.Slider(label="Sketch strength", minimum=0, maximum=3, value=1, step=0.05) | |
| mask_scale = gr.Slider(label="Background factor", minimum=0, maximum=2, value=1, step=0.05) | |
| merge_scale = gr.Slider(label="Merging scale", minimum=0, maximum=1, value=0, step=0.05) | |
| with gr.Row(): | |
| bs = gr.Slider(label="Batch size", minimum=1, maximum=4, value=1, step=1, scale=1) | |
| width = gr.Slider(label="Width", minimum=512, maximum=1536, value=1024, step=32, scale=2) | |
| with gr.Row(): | |
| step = gr.Slider(label="Step", minimum=1, maximum=100, value=20, step=1, scale=1) | |
| height = gr.Slider(label="Height", minimum=512, maximum=1536, value=1024, step=32, scale=2) | |
| seed = gr.Slider(label="Seed", minimum=-1, maximum=MAXM_INT32, step=1, value=-1) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| with gr.Row(): | |
| crop = gr.Checkbox(label="Crop result", value=False, scale=1) | |
| remove_fg = gr.Checkbox(label="Remove foreground in background input", value=False, scale=2) | |
| rmbg = gr.Checkbox(label="Remove background in result", value=False, scale=2) | |
| latent_inpaint = gr.Checkbox(label="Latent copy BG input", value=False, scale=2) | |
| with gr.Row(): | |
| injection_control_scale = gr.Slider(label="Injection fidelity (sketch)", minimum=0.0, | |
| maximum=2.0, value=0, step=0.05) | |
| injection_fidelity = gr.Slider(label="Injection fidelity (reference)", minimum=0.0, | |
| maximum=1.0, value=0.5, step=0.05) | |
| injection_start_step = gr.Slider(label="Injection start step", minimum=0.0, maximum=1.0, | |
| value=0, step=0.05) | |
| with gr.Row(): | |
| reuse_seed = gr.Button(value="Reuse Seed") | |
| random_seed = gr.Button(value="Random Seed") | |
| with gr.Column(): | |
| result_gallery = gr.Gallery( | |
| label='Output', show_label=False, elem_id="gallery", preview=True, type="pil", format="png" | |
| ) | |
| run_button = gr.Button("Generate", variant="primary", size="lg") | |
| with gr.Row(): | |
| mask_ts = gr.Slider(label="Reference mask threshold", minimum=0., maximum=1., value=0.5, step=0.01) | |
| mask_ss = gr.Slider(label="Sketch mask threshold", minimum=0., maximum=1., value=0.05, step=0.01) | |
| pad_scale = gr.Slider(label="Reference padding scale", minimum=1, maximum=2, value=1, step=0.05) | |
| with gr.Row(): | |
| available_models = get_available_models() | |
| sd_model = gr.Dropdown(choices=available_models, label="Models", | |
| value=available_models[0] if available_models else None) | |
| extractor_model = gr.Dropdown(choices=line_extractor_list, | |
| label="Line extractor", value=default_line_extractor) | |
| mask_model = gr.Dropdown(choices=mask_extractor_list, label="Reference mask extractor", | |
| value=default_mask_extractor) | |
| with gr.Row(): | |
| sampler = gr.Dropdown(choices=sampler_list, value="DPM++ 3M SDE", label="Sampler") | |
| scheduler = gr.Dropdown(choices=scheduler_list, value=scheduler_list[0], label="Noise scheduler") | |
| preprocessor = gr.Dropdown(choices=["none", "extract", "invert", "invert-webui"], | |
| label="Sketch preprocessor", value="invert") | |
| with gr.Row(): | |
| deterministic = gr.Checkbox(label="Deterministic batch seed", value=False) | |
| save_memory = gr.Checkbox(label="Save memory", value=True) | |
| # Hidden states for unused advanced controls | |
| fg_disentangle_scale = gr.State(1.0) | |
| start_step = gr.State(0.0) | |
| end_step = gr.State(1.0) | |
| no_start_step = gr.State(-0.05) | |
| no_end_step = gr.State(-0.05) | |
| return_inter = gr.State(False) | |
| accurate = gr.State(False) | |
| enc_scale = gr.State(1.0) | |
| middle_scale = gr.State(1.0) | |
| low_scale = gr.State(1.0) | |
| ctl_scale_1 = gr.State(1.0) | |
| ctl_scale_2 = gr.State(1.0) | |
| ctl_scale_3 = gr.State(1.0) | |
| ctl_scale_4 = gr.State(1.0) | |
| add_prompt.click(fn=apppend_prompt, | |
| inputs=[target, anchor, control, target_scale, enhance, ts0, ts1, ts2, ts3, text_prompt], | |
| outputs=[target, anchor, control, target_scale, enhance, ts0, ts1, ts2, ts3, text_prompt]) | |
| clear_prompt.click(fn=clear_prompts, outputs=[text_prompt]) | |
| reuse_seed.click(fn=get_last_seed, outputs=[seed]) | |
| random_seed.click(fn=reset_random_seed, outputs=[seed]) | |
| extractor_model.input(fn=switch_extractor, inputs=[extractor_model]) | |
| sd_model.input(fn=load_model, inputs=[sd_model]) | |
| mask_model.input(fn=switch_mask_extractor, inputs=[mask_model]) | |
| ips = [style_enhance, bg_enhance, fg_enhance, fg_disentangle_scale, | |
| bs, sketch_img, reference_img, background_img, mask_ts, mask_ss, gs_r, gs_s, ctl_scale, | |
| ctl_scale_1, ctl_scale_2, ctl_scale_3, ctl_scale_4, fg_strength, bg_strength, merge_scale, | |
| mask_scale, height, width, seed, save_memory, step, injection, autofit_size, | |
| remove_fg, rmbg, latent_inpaint, injection_control_scale, injection_fidelity, injection_start_step, | |
| crop, pad_scale, start_step, end_step, no_start_step, no_end_step, return_inter, sampler, scheduler, | |
| preprocessor, deterministic, text_prompt, target, anchor, control, target_scale, ts0, ts1, ts2, ts3, | |
| enhance, accurate, enc_scale, middle_scale, low_scale, strength] | |
| run_button.click( | |
| fn = inference, | |
| inputs = ips, | |
| outputs = [result_gallery], | |
| ) | |
| vis_button.click( | |
| fn = visualize, | |
| inputs = [reference_img, text_prompt, control, ts0, ts1, ts2, ts3], | |
| outputs = [result_gallery], | |
| ) | |
| block.launch( | |
| server_name = opt.server_name, | |
| share = opt.share, | |
| server_port = opt.server_port, | |
| ) | |
| if __name__ == '__main__': | |
| opt = app_options() | |
| switch_extractor(default_line_extractor) | |
| switch_mask_extractor(default_mask_extractor) | |
| available_models = get_available_models() | |
| if available_models: | |
| load_model(available_models[0]) | |
| init_interface(opt) | |