Spaces:
Running on Zero
Running on Zero
| import os | |
| import random | |
| import traceback | |
| import gradio as gr | |
| import spaces | |
| import os.path as osp | |
| from huggingface_hub import hf_hub_download, list_repo_files | |
| HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
| print(f"HF_TOKEN present: {HF_TOKEN is not None}") | |
| from omegaconf import OmegaConf | |
| from refnet.util import instantiate_from_config | |
| from preprocessor import create_model | |
| from .functool import * | |
| model = None | |
| model_type = "" | |
| current_checkpoint = "" | |
| global_seed = None | |
| smask_extractor = create_model("ISNet-sketch").cpu() | |
| MAXM_INT32 = 429496729 | |
| # HuggingFace model repository | |
| HF_REPO_ID = "tellurion/ColorizeDiffusionXL" | |
| MODEL_CACHE_DIR = "models" | |
| model_types = ["sdxl", "xlv2"] | |
| ''' | |
| Gradio UI functions | |
| ''' | |
| def get_available_models(): | |
| """Fetch available .safetensors files from HuggingFace Hub.""" | |
| try: | |
| files = list_repo_files(HF_REPO_ID, token=HF_TOKEN) | |
| return [f for f in files if f.endswith(".safetensors")] | |
| except Exception as e: | |
| print(f"Failed to list models from {HF_REPO_ID}: {e}") | |
| return [] | |
| def download_model(filename): | |
| """Download a model from HuggingFace Hub if not already cached.""" | |
| os.makedirs(MODEL_CACHE_DIR, exist_ok=True) | |
| local_path = osp.join(MODEL_CACHE_DIR, filename) | |
| if osp.exists(local_path): | |
| return local_path | |
| print(f"Downloading {filename} from {HF_REPO_ID}...") | |
| gr.Info(f"Downloading {filename}...") | |
| path = hf_hub_download( | |
| repo_id=HF_REPO_ID, | |
| filename=filename, | |
| local_dir=MODEL_CACHE_DIR, | |
| token=HF_TOKEN, | |
| ) | |
| print(f"Downloaded to {path}") | |
| return path | |
| def switch_extractor(type): | |
| global line_extractor | |
| try: | |
| line_extractor = create_model(type) | |
| gr.Info(f"Switched to {type} extractor") | |
| except Exception as e: | |
| print(f"Error info: {e}") | |
| print(traceback.print_exc()) | |
| gr.Info(f"Failed in loading {type} extractor") | |
| def switch_mask_extractor(type): | |
| global mask_extractor | |
| try: | |
| mask_extractor = create_model(type) | |
| gr.Info(f"Switched to {type} extractor") | |
| except Exception as e: | |
| print(f"Error info: {e}") | |
| print(traceback.print_exc()) | |
| gr.Info(f"Failed in loading {type} extractor") | |
| def apppend_prompt(target, anchor, control, scale, enhance, ts0, ts1, ts2, ts3, prompt): | |
| target = target.strip() | |
| anchor = anchor.strip() | |
| control = control.strip() | |
| if target == "": target = "none" | |
| if anchor == "": anchor = "none" | |
| if control == "": control = "none" | |
| new_p = (f"\n[target] {target}; [anchor] {anchor}; [control] {control}; [scale] {str(scale)}; " | |
| f"[enhanced] {str(enhance)}; [ts0] {str(ts0)}; [ts1] {str(ts1)}; [ts2] {str(ts2)}; [ts3] {str(ts3)}") | |
| return "", "", "", 0.0, False, 0.5, 0.55, 0.65, 0.95, (prompt + new_p).strip() | |
| def clear_prompts(): | |
| return "" | |
| def load_model(ckpt_name): | |
| global model, model_type, current_checkpoint | |
| config_root = "configs/inference" | |
| try: | |
| # Determine model type from filename prefix | |
| new_model_type = "" | |
| for key in model_types: | |
| if ckpt_name.startswith(key): | |
| new_model_type = key | |
| break | |
| if model_type != new_model_type: | |
| if exists(model): | |
| del model | |
| config_path = osp.join(config_root, f"{new_model_type}.yaml") | |
| new_model = instantiate_from_config(OmegaConf.load(config_path).model).cpu().eval() | |
| print(f"Switched to {new_model_type} model, loading weights from [{ckpt_name}]...") | |
| model = new_model | |
| # Download model from HF Hub | |
| local_path = download_model(ckpt_name) | |
| model.parameterization = "eps" if ckpt_name.find("eps") > -1 else "v" | |
| model.init_from_ckpt(local_path, logging=True) | |
| model.switch_to_fp16() | |
| model_type = new_model_type | |
| current_checkpoint = ckpt_name | |
| print(f"Loaded model from [{ckpt_name}], model_type [{model_type}].") | |
| gr.Info("Loaded model successfully.") | |
| except Exception as e: | |
| print(f"Error type: {e}") | |
| print(traceback.print_exc()) | |
| gr.Info("Failed in loading model.") | |
| def get_last_seed(): | |
| return global_seed or -1 | |
| def reset_random_seed(): | |
| return -1 | |
| def visualize(reference, text, *args): | |
| return visualize_heatmaps(model, reference, parse_prompts(text), *args) | |
| def set_cas_scales(accurate, cas_args): | |
| enc_scale, middle_scale, low_scale, strength = cas_args[:4] | |
| if not accurate: | |
| scale_strength = { | |
| "level_control": True, | |
| "scales": { | |
| "encoder": enc_scale * strength, | |
| "middle": middle_scale * strength, | |
| "low": low_scale * strength, | |
| } | |
| } | |
| else: | |
| scale_strength = { | |
| "level_control": False, | |
| "scales": list(cas_args[4:]) | |
| } | |
| return scale_strength | |
| def inference( | |
| style_enhance, bg_enhance, fg_enhance, fg_disentangle_scale, | |
| bs, input_s, input_r, input_bg, mask_ts, mask_ss, gs_r, gs_s, ctl_scale, | |
| ctl_scale_1, ctl_scale_2, ctl_scale_3, ctl_scale_4, | |
| fg_strength, bg_strength, merge_scale, mask_scale, height, width, seed, low_vram, step, | |
| injection, autofit_size, remove_fg, rmbg, latent_inpaint, infid_x, infid_r, injstep, crop, pad_scale, | |
| start_step, end_step, no_start_step, no_end_step, return_inter, sampler, scheduler, preprocess, | |
| deterministic, text, target, anchor, control, target_scale, ts0, ts1, ts2, ts3, enhance, accurate, | |
| *args | |
| ): | |
| global global_seed, line_extractor, mask_extractor | |
| global_seed = seed if seed > -1 else random.randint(0, MAXM_INT32) | |
| torch.manual_seed(global_seed) | |
| # Auto-fit size based on sketch dimensions | |
| if autofit_size and exists(input_s): | |
| sketch_w, sketch_h = input_s.size | |
| aspect_ratio = sketch_w / sketch_h | |
| target_area = 1024 * 1024 | |
| new_h = int((target_area / aspect_ratio) ** 0.5) | |
| new_w = int(new_h * aspect_ratio) | |
| height = ((new_h + 16) // 32) * 32 | |
| width = ((new_w + 16) // 32) * 32 | |
| height = max(768, min(1536, height)) | |
| width = max(768, min(1536, width)) | |
| gr.Info(f"Auto-fitted size: {width}x{height}") | |
| smask, rmask, bgmask = None, None, None | |
| manipulation_params = parse_prompts(text, target, anchor, control, target_scale, ts0, ts1, ts2, ts3, enhance) | |
| inputs = preprocessing_inputs( | |
| sketch = input_s, | |
| reference = input_r, | |
| background = input_bg, | |
| preprocess = preprocess, | |
| hook = injection, | |
| resolution = (height, width), | |
| extractor = line_extractor, | |
| pad_scale = pad_scale, | |
| ) | |
| sketch, reference, background, original_shape, inject_xr, inject_xs, white_sketch = inputs | |
| cond = {"reference": reference, "sketch": sketch, "background": background} | |
| mask_guided = bg_enhance or fg_enhance | |
| if exists(white_sketch) and exists(reference) and mask_guided: | |
| mask_extractor.cuda() | |
| smask_extractor.cuda() | |
| smask = smask_extractor.proceed( | |
| x=white_sketch, pil_x=input_s, th=height, tw=width, threshold=mask_ss, crop=False | |
| ) | |
| if exists(background) and remove_fg: | |
| bgmask = mask_extractor.proceed(x=background, pil_x=input_bg, threshold=mask_ts, dilate=True) | |
| filtered_background = torch.where(bgmask < mask_ts, background, torch.ones_like(background)) | |
| cond.update({"background": filtered_background, "rmask": bgmask}) | |
| else: | |
| rmask = mask_extractor.proceed(x=reference, pil_x=input_r, threshold=mask_ts, dilate=True) | |
| cond.update({"rmask": rmask}) | |
| rmask = torch.where(rmask > 0.5, torch.ones_like(rmask), torch.zeros_like(rmask)) | |
| cond.update({"smask": smask}) | |
| smask_extractor.cpu() | |
| mask_extractor.cpu() | |
| scale_strength = set_cas_scales(accurate, args) | |
| ctl_scales = [ctl_scale_1, ctl_scale_2, ctl_scale_3, ctl_scale_4] | |
| ctl_scales = [t * ctl_scale for t in ctl_scales] | |
| results = model.generate( | |
| # Colorization mode | |
| style_enhance = style_enhance, | |
| bg_enhance = bg_enhance, | |
| fg_enhance = fg_enhance, | |
| fg_disentangle_scale = fg_disentangle_scale, | |
| latent_inpaint = latent_inpaint, | |
| # Conditional inputs | |
| cond = cond, | |
| ctl_scale = ctl_scales, | |
| merge_scale = merge_scale, | |
| mask_scale = mask_scale, | |
| mask_thresh = mask_ts, | |
| mask_thresh_sketch = mask_ss, | |
| # Sampling settings | |
| bs = bs, | |
| gs = [gs_r, gs_s], | |
| sampler = sampler, | |
| scheduler = scheduler, | |
| start_step = start_step, | |
| end_step = end_step, | |
| no_start_step = no_start_step, | |
| no_end_step = no_end_step, | |
| strength = scale_strength, | |
| fg_strength = fg_strength, | |
| bg_strength = bg_strength, | |
| seed = global_seed, | |
| deterministic = deterministic, | |
| height = height, | |
| width = width, | |
| step = step, | |
| # Injection settings | |
| injection = injection, | |
| injection_cfg = infid_r, | |
| injection_control = infid_x, | |
| injection_start_step = injstep, | |
| hook_xr = inject_xr, | |
| hook_xs = inject_xs, | |
| # Additional settings | |
| low_vram = low_vram, | |
| return_intermediate = return_inter, | |
| manipulation_params = manipulation_params, | |
| ) | |
| if rmbg: | |
| mask_extractor.cuda() | |
| mask = smask_extractor.proceed(x=-sketch, threshold=mask_ss).repeat(results.shape[0], 1, 1, 1) | |
| results = torch.where(mask >= mask_ss, results, torch.ones_like(results)) | |
| mask_extractor.cpu() | |
| results = postprocess(results, sketch, reference, background, crop, original_shape, | |
| mask_guided, smask, rmask, bgmask, mask_ts, mask_ss) | |
| torch.cuda.empty_cache() | |
| gr.Info("Generation completed.") | |
| return results | |