import os import random import traceback import gradio as gr import spaces import os.path as osp from huggingface_hub import hf_hub_download, list_repo_files HF_TOKEN = os.environ.get("HF_TOKEN", None) print(f"HF_TOKEN present: {HF_TOKEN is not None}") from omegaconf import OmegaConf from refnet.util import instantiate_from_config from preprocessor import create_model from .functool import * model = None model_type = "" current_checkpoint = "" global_seed = None smask_extractor = create_model("ISNet-sketch").cpu() MAXM_INT32 = 429496729 # HuggingFace model repository HF_REPO_ID = "tellurion/ColorizeDiffusionXL" MODEL_CACHE_DIR = "models" model_types = ["sdxl", "xlv2"] ''' Gradio UI functions ''' def get_available_models(): """Fetch available .safetensors files from HuggingFace Hub.""" try: files = list_repo_files(HF_REPO_ID, token=HF_TOKEN) return [f for f in files if f.endswith(".safetensors")] except Exception as e: print(f"Failed to list models from {HF_REPO_ID}: {e}") return [] def download_model(filename): """Download a model from HuggingFace Hub if not already cached.""" os.makedirs(MODEL_CACHE_DIR, exist_ok=True) local_path = osp.join(MODEL_CACHE_DIR, filename) if osp.exists(local_path): return local_path print(f"Downloading {filename} from {HF_REPO_ID}...") gr.Info(f"Downloading {filename}...") path = hf_hub_download( repo_id=HF_REPO_ID, filename=filename, local_dir=MODEL_CACHE_DIR, token=HF_TOKEN, ) print(f"Downloaded to {path}") return path def switch_extractor(type): global line_extractor try: line_extractor = create_model(type) gr.Info(f"Switched to {type} extractor") except Exception as e: print(f"Error info: {e}") print(traceback.print_exc()) gr.Info(f"Failed in loading {type} extractor") def switch_mask_extractor(type): global mask_extractor try: mask_extractor = create_model(type) gr.Info(f"Switched to {type} extractor") except Exception as e: print(f"Error info: {e}") print(traceback.print_exc()) gr.Info(f"Failed in loading {type} extractor") def apppend_prompt(target, anchor, control, scale, enhance, ts0, ts1, ts2, ts3, prompt): target = target.strip() anchor = anchor.strip() control = control.strip() if target == "": target = "none" if anchor == "": anchor = "none" if control == "": control = "none" new_p = (f"\n[target] {target}; [anchor] {anchor}; [control] {control}; [scale] {str(scale)}; " f"[enhanced] {str(enhance)}; [ts0] {str(ts0)}; [ts1] {str(ts1)}; [ts2] {str(ts2)}; [ts3] {str(ts3)}") return "", "", "", 0.0, False, 0.5, 0.55, 0.65, 0.95, (prompt + new_p).strip() def clear_prompts(): return "" def load_model(ckpt_name): global model, model_type, current_checkpoint config_root = "configs/inference" try: # Determine model type from filename prefix new_model_type = "" for key in model_types: if ckpt_name.startswith(key): new_model_type = key break if model_type != new_model_type: if exists(model): del model config_path = osp.join(config_root, f"{new_model_type}.yaml") new_model = instantiate_from_config(OmegaConf.load(config_path).model).cpu().eval() print(f"Switched to {new_model_type} model, loading weights from [{ckpt_name}]...") model = new_model # Download model from HF Hub local_path = download_model(ckpt_name) model.parameterization = "eps" if ckpt_name.find("eps") > -1 else "v" model.init_from_ckpt(local_path, logging=True) model.switch_to_fp16() model_type = new_model_type current_checkpoint = ckpt_name print(f"Loaded model from [{ckpt_name}], model_type [{model_type}].") gr.Info("Loaded model successfully.") except Exception as e: print(f"Error type: {e}") print(traceback.print_exc()) gr.Info("Failed in loading model.") def get_last_seed(): return global_seed or -1 def reset_random_seed(): return -1 def visualize(reference, text, *args): return visualize_heatmaps(model, reference, parse_prompts(text), *args) def set_cas_scales(accurate, cas_args): enc_scale, middle_scale, low_scale, strength = cas_args[:4] if not accurate: scale_strength = { "level_control": True, "scales": { "encoder": enc_scale * strength, "middle": middle_scale * strength, "low": low_scale * strength, } } else: scale_strength = { "level_control": False, "scales": list(cas_args[4:]) } return scale_strength @spaces.GPU(duration=120) @torch.no_grad() def inference( style_enhance, bg_enhance, fg_enhance, fg_disentangle_scale, bs, input_s, input_r, input_bg, mask_ts, mask_ss, gs_r, gs_s, ctl_scale, ctl_scale_1, ctl_scale_2, ctl_scale_3, ctl_scale_4, fg_strength, bg_strength, merge_scale, mask_scale, height, width, seed, low_vram, step, injection, autofit_size, remove_fg, rmbg, latent_inpaint, infid_x, infid_r, injstep, crop, pad_scale, start_step, end_step, no_start_step, no_end_step, return_inter, sampler, scheduler, preprocess, deterministic, text, target, anchor, control, target_scale, ts0, ts1, ts2, ts3, enhance, accurate, *args ): global global_seed, line_extractor, mask_extractor global_seed = seed if seed > -1 else random.randint(0, MAXM_INT32) torch.manual_seed(global_seed) # Auto-fit size based on sketch dimensions if autofit_size and exists(input_s): sketch_w, sketch_h = input_s.size aspect_ratio = sketch_w / sketch_h target_area = 1024 * 1024 new_h = int((target_area / aspect_ratio) ** 0.5) new_w = int(new_h * aspect_ratio) height = ((new_h + 16) // 32) * 32 width = ((new_w + 16) // 32) * 32 height = max(768, min(1536, height)) width = max(768, min(1536, width)) gr.Info(f"Auto-fitted size: {width}x{height}") smask, rmask, bgmask = None, None, None manipulation_params = parse_prompts(text, target, anchor, control, target_scale, ts0, ts1, ts2, ts3, enhance) inputs = preprocessing_inputs( sketch = input_s, reference = input_r, background = input_bg, preprocess = preprocess, hook = injection, resolution = (height, width), extractor = line_extractor, pad_scale = pad_scale, ) sketch, reference, background, original_shape, inject_xr, inject_xs, white_sketch = inputs cond = {"reference": reference, "sketch": sketch, "background": background} mask_guided = bg_enhance or fg_enhance if exists(white_sketch) and exists(reference) and mask_guided: mask_extractor.cuda() smask_extractor.cuda() smask = smask_extractor.proceed( x=white_sketch, pil_x=input_s, th=height, tw=width, threshold=mask_ss, crop=False ) if exists(background) and remove_fg: bgmask = mask_extractor.proceed(x=background, pil_x=input_bg, threshold=mask_ts, dilate=True) filtered_background = torch.where(bgmask < mask_ts, background, torch.ones_like(background)) cond.update({"background": filtered_background, "rmask": bgmask}) else: rmask = mask_extractor.proceed(x=reference, pil_x=input_r, threshold=mask_ts, dilate=True) cond.update({"rmask": rmask}) rmask = torch.where(rmask > 0.5, torch.ones_like(rmask), torch.zeros_like(rmask)) cond.update({"smask": smask}) smask_extractor.cpu() mask_extractor.cpu() scale_strength = set_cas_scales(accurate, args) ctl_scales = [ctl_scale_1, ctl_scale_2, ctl_scale_3, ctl_scale_4] ctl_scales = [t * ctl_scale for t in ctl_scales] results = model.generate( # Colorization mode style_enhance = style_enhance, bg_enhance = bg_enhance, fg_enhance = fg_enhance, fg_disentangle_scale = fg_disentangle_scale, latent_inpaint = latent_inpaint, # Conditional inputs cond = cond, ctl_scale = ctl_scales, merge_scale = merge_scale, mask_scale = mask_scale, mask_thresh = mask_ts, mask_thresh_sketch = mask_ss, # Sampling settings bs = bs, gs = [gs_r, gs_s], sampler = sampler, scheduler = scheduler, start_step = start_step, end_step = end_step, no_start_step = no_start_step, no_end_step = no_end_step, strength = scale_strength, fg_strength = fg_strength, bg_strength = bg_strength, seed = global_seed, deterministic = deterministic, height = height, width = width, step = step, # Injection settings injection = injection, injection_cfg = infid_r, injection_control = infid_x, injection_start_step = injstep, hook_xr = inject_xr, hook_xs = inject_xs, # Additional settings low_vram = low_vram, return_intermediate = return_inter, manipulation_params = manipulation_params, ) if rmbg: mask_extractor.cuda() mask = smask_extractor.proceed(x=-sketch, threshold=mask_ss).repeat(results.shape[0], 1, 1, 1) results = torch.where(mask >= mask_ss, results, torch.ones_like(results)) mask_extractor.cpu() results = postprocess(results, sketch, reference, background, crop, original_shape, mask_guided, smask, rmask, bgmask, mask_ts, mask_ss) torch.cuda.empty_cache() gr.Info("Generation completed.") return results