import gradio as gr import argparse from refnet.sampling import get_noise_schedulers, get_sampler_list from functools import partial from backend import * links = { "base": "https://arxiv.org/abs/2401.01456", "v1": "https://openaccess.thecvf.com/content/WACV2025/html/Yan_ColorizeDiffusion_Improving_Reference-Based_Sketch_Colorization_with_Latent_Diffusion_Model_WACV_2025_paper.html", "v1.5": "https://arxiv.org/abs/2502.19937v1", "v2": "https://arxiv.org/abs/2504.06895", "xl": "https://arxiv.org/abs/2601.04883", "weights": "https://huggingface.co/tellurion/ColorizeDiffusionXL/tree/main", "github": "https://github.com/tellurion-kanata/colorizeDiffusion", } def app_options(): parser = argparse.ArgumentParser() parser.add_argument("--server_name", '-addr', type=str, default="0.0.0.0") parser.add_argument("--server_port", '-port', type=int, default=7860) parser.add_argument("--share", action="store_true") parser.add_argument("--enable_text_manipulation", '-manipulate', action="store_true") return parser.parse_args() def init_interface(opt, *args, **kwargs) -> None: sampler_list = get_sampler_list() scheduler_list = get_noise_schedulers() img_block = partial(gr.Image, type="pil", height=300, interactive=True, show_label=True, format="png") with gr.Blocks( title = "Colorize Diffusion", css_paths = "backend/style.css", theme = gr.themes.Ocean(), elem_id = "main-interface", analytics_enabled = False, fill_width = True ) as block: with gr.Row(elem_id="header-row", equal_height=True, variant="panel"): gr.Markdown(f"""
🎨Colorize Diffusion
""") with gr.Row(elem_id="content-row", equal_height=False, variant="panel"): with gr.Column(): with gr.Row(visible=opt.enable_text_manipulation): target = gr.Textbox(label="Target prompt", value="", scale=2) anchor = gr.Textbox(label="Anchor prompt", value="", scale=2) control = gr.Textbox(label="Control prompt", value="", scale=2) with gr.Row(visible=opt.enable_text_manipulation): target_scale = gr.Slider(label="Target scale", value=0.0, minimum=0, maximum=15.0, step=0.25, scale=2) ts0 = gr.Slider(label="Threshold 0", value=0.5, minimum=0, maximum=1.0, step=0.01) ts1 = gr.Slider(label="Threshold 1", value=0.55, minimum=0, maximum=1.0, step=0.01) ts2 = gr.Slider(label="Threshold 2", value=0.65, minimum=0, maximum=1.0, step=0.01) ts3 = gr.Slider(label="Threshold 3", value=0.95, minimum=0, maximum=1.0, step=0.01) with gr.Row(visible=opt.enable_text_manipulation): enhance = gr.Checkbox(label="Enhance manipulation", value=False) add_prompt = gr.Button(value="Add") clear_prompt = gr.Button(value="Clear") vis_button = gr.Button(value="Visualize") text_prompt = gr.Textbox(label="Final prompt", value="", lines=3, visible=opt.enable_text_manipulation) with gr.Row(): sketch_img = img_block(label="Sketch") reference_img = img_block(label="Reference") background_img = img_block(label="Background") style_enhance = gr.State(False) fg_enhance = gr.State(False) with gr.Row(): bg_enhance = gr.Checkbox(label="Low-level injection", value=False) injection = gr.Checkbox(label="Attention injection", value=False) autofit_size = gr.Checkbox(label="Autofit size", value=False) with gr.Row(): gs_r = gr.Slider(label="Reference guidance scale", minimum=1, maximum=15.0, value=4.0, step=0.5) strength = gr.Slider(label="Reference strength", minimum=0, maximum=1, value=1, step=0.05) fg_strength = gr.Slider(label="Foreground strength", minimum=0, maximum=1, value=1, step=0.05) bg_strength = gr.Slider(label="Background strength", minimum=0, maximum=1, value=1, step=0.05) with gr.Row(): gs_s = gr.Slider(label="Sketch guidance scale", minimum=1, maximum=5.0, value=1.0, step=0.1) ctl_scale = gr.Slider(label="Sketch strength", minimum=0, maximum=3, value=1, step=0.05) mask_scale = gr.Slider(label="Background factor", minimum=0, maximum=2, value=1, step=0.05) merge_scale = gr.Slider(label="Merging scale", minimum=0, maximum=1, value=0, step=0.05) with gr.Row(): bs = gr.Slider(label="Batch size", minimum=1, maximum=4, value=1, step=1, scale=1) width = gr.Slider(label="Width", minimum=512, maximum=1536, value=1024, step=32, scale=2) with gr.Row(): step = gr.Slider(label="Step", minimum=1, maximum=100, value=20, step=1, scale=1) height = gr.Slider(label="Height", minimum=512, maximum=1536, value=1024, step=32, scale=2) seed = gr.Slider(label="Seed", minimum=-1, maximum=MAXM_INT32, step=1, value=-1) with gr.Accordion("Advanced Settings", open=False): with gr.Row(): crop = gr.Checkbox(label="Crop result", value=False, scale=1) remove_fg = gr.Checkbox(label="Remove foreground in background input", value=False, scale=2) rmbg = gr.Checkbox(label="Remove background in result", value=False, scale=2) latent_inpaint = gr.Checkbox(label="Latent copy BG input", value=False, scale=2) with gr.Row(): injection_control_scale = gr.Slider(label="Injection fidelity (sketch)", minimum=0.0, maximum=2.0, value=0, step=0.05) injection_fidelity = gr.Slider(label="Injection fidelity (reference)", minimum=0.0, maximum=1.0, value=0.5, step=0.05) injection_start_step = gr.Slider(label="Injection start step", minimum=0.0, maximum=1.0, value=0, step=0.05) with gr.Row(): reuse_seed = gr.Button(value="Reuse Seed") random_seed = gr.Button(value="Random Seed") with gr.Column(): result_gallery = gr.Gallery( label='Output', show_label=False, elem_id="gallery", preview=True, type="pil", format="png" ) run_button = gr.Button("Generate", variant="primary", size="lg") with gr.Row(): mask_ts = gr.Slider(label="Reference mask threshold", minimum=0., maximum=1., value=0.5, step=0.01) mask_ss = gr.Slider(label="Sketch mask threshold", minimum=0., maximum=1., value=0.05, step=0.01) pad_scale = gr.Slider(label="Reference padding scale", minimum=1, maximum=2, value=1, step=0.05) with gr.Row(): available_models = get_available_models() sd_model = gr.Dropdown(choices=available_models, label="Models", value=available_models[0] if available_models else None) extractor_model = gr.Dropdown(choices=line_extractor_list, label="Line extractor", value=default_line_extractor) mask_model = gr.Dropdown(choices=mask_extractor_list, label="Reference mask extractor", value=default_mask_extractor) with gr.Row(): sampler = gr.Dropdown(choices=sampler_list, value="DPM++ 3M SDE", label="Sampler") scheduler = gr.Dropdown(choices=scheduler_list, value=scheduler_list[0], label="Noise scheduler") preprocessor = gr.Dropdown(choices=["none", "extract", "invert", "invert-webui"], label="Sketch preprocessor", value="invert") with gr.Row(): deterministic = gr.Checkbox(label="Deterministic batch seed", value=False) save_memory = gr.Checkbox(label="Save memory", value=True) # Hidden states for unused advanced controls fg_disentangle_scale = gr.State(1.0) start_step = gr.State(0.0) end_step = gr.State(1.0) no_start_step = gr.State(-0.05) no_end_step = gr.State(-0.05) return_inter = gr.State(False) accurate = gr.State(False) enc_scale = gr.State(1.0) middle_scale = gr.State(1.0) low_scale = gr.State(1.0) ctl_scale_1 = gr.State(1.0) ctl_scale_2 = gr.State(1.0) ctl_scale_3 = gr.State(1.0) ctl_scale_4 = gr.State(1.0) add_prompt.click(fn=apppend_prompt, inputs=[target, anchor, control, target_scale, enhance, ts0, ts1, ts2, ts3, text_prompt], outputs=[target, anchor, control, target_scale, enhance, ts0, ts1, ts2, ts3, text_prompt]) clear_prompt.click(fn=clear_prompts, outputs=[text_prompt]) reuse_seed.click(fn=get_last_seed, outputs=[seed]) random_seed.click(fn=reset_random_seed, outputs=[seed]) extractor_model.input(fn=switch_extractor, inputs=[extractor_model]) sd_model.input(fn=load_model, inputs=[sd_model]) mask_model.input(fn=switch_mask_extractor, inputs=[mask_model]) ips = [style_enhance, bg_enhance, fg_enhance, fg_disentangle_scale, bs, sketch_img, reference_img, background_img, mask_ts, mask_ss, gs_r, gs_s, ctl_scale, ctl_scale_1, ctl_scale_2, ctl_scale_3, ctl_scale_4, fg_strength, bg_strength, merge_scale, mask_scale, height, width, seed, save_memory, step, injection, autofit_size, remove_fg, rmbg, latent_inpaint, injection_control_scale, injection_fidelity, injection_start_step, crop, pad_scale, start_step, end_step, no_start_step, no_end_step, return_inter, sampler, scheduler, preprocessor, deterministic, text_prompt, target, anchor, control, target_scale, ts0, ts1, ts2, ts3, enhance, accurate, enc_scale, middle_scale, low_scale, strength] run_button.click( fn = inference, inputs = ips, outputs = [result_gallery], ) vis_button.click( fn = visualize, inputs = [reference_img, text_prompt, control, ts0, ts1, ts2, ts3], outputs = [result_gallery], ) block.launch( server_name = opt.server_name, share = opt.share, server_port = opt.server_port, ) if __name__ == '__main__': opt = app_options() switch_extractor(default_line_extractor) switch_mask_extractor(default_mask_extractor) available_models = get_available_models() if available_models: load_model(available_models[0]) init_interface(opt)