tellurion's picture
Clean up dead code and add startup model loading
1928ea4
import os
import random
import traceback
import gradio as gr
import spaces
import os.path as osp
from huggingface_hub import hf_hub_download, list_repo_files
HF_TOKEN = os.environ.get("HF_TOKEN", None)
print(f"HF_TOKEN present: {HF_TOKEN is not None}")
from omegaconf import OmegaConf
from refnet.util import instantiate_from_config
from preprocessor import create_model
from .functool import *
model = None
model_type = ""
current_checkpoint = ""
global_seed = None
smask_extractor = create_model("ISNet-sketch").cpu()
MAXM_INT32 = 429496729
# HuggingFace model repository
HF_REPO_ID = "tellurion/ColorizeDiffusionXL"
MODEL_CACHE_DIR = "models"
model_types = ["sdxl", "xlv2"]
'''
Gradio UI functions
'''
def get_available_models():
"""Fetch available .safetensors files from HuggingFace Hub."""
try:
files = list_repo_files(HF_REPO_ID, token=HF_TOKEN)
return [f for f in files if f.endswith(".safetensors")]
except Exception as e:
print(f"Failed to list models from {HF_REPO_ID}: {e}")
return []
def download_model(filename):
"""Download a model from HuggingFace Hub if not already cached."""
os.makedirs(MODEL_CACHE_DIR, exist_ok=True)
local_path = osp.join(MODEL_CACHE_DIR, filename)
if osp.exists(local_path):
return local_path
print(f"Downloading {filename} from {HF_REPO_ID}...")
gr.Info(f"Downloading {filename}...")
path = hf_hub_download(
repo_id=HF_REPO_ID,
filename=filename,
local_dir=MODEL_CACHE_DIR,
token=HF_TOKEN,
)
print(f"Downloaded to {path}")
return path
def switch_extractor(type):
global line_extractor
try:
line_extractor = create_model(type)
gr.Info(f"Switched to {type} extractor")
except Exception as e:
print(f"Error info: {e}")
print(traceback.print_exc())
gr.Info(f"Failed in loading {type} extractor")
def switch_mask_extractor(type):
global mask_extractor
try:
mask_extractor = create_model(type)
gr.Info(f"Switched to {type} extractor")
except Exception as e:
print(f"Error info: {e}")
print(traceback.print_exc())
gr.Info(f"Failed in loading {type} extractor")
def apppend_prompt(target, anchor, control, scale, enhance, ts0, ts1, ts2, ts3, prompt):
target = target.strip()
anchor = anchor.strip()
control = control.strip()
if target == "": target = "none"
if anchor == "": anchor = "none"
if control == "": control = "none"
new_p = (f"\n[target] {target}; [anchor] {anchor}; [control] {control}; [scale] {str(scale)}; "
f"[enhanced] {str(enhance)}; [ts0] {str(ts0)}; [ts1] {str(ts1)}; [ts2] {str(ts2)}; [ts3] {str(ts3)}")
return "", "", "", 0.0, False, 0.5, 0.55, 0.65, 0.95, (prompt + new_p).strip()
def clear_prompts():
return ""
def load_model(ckpt_name):
global model, model_type, current_checkpoint
config_root = "configs/inference"
try:
# Determine model type from filename prefix
new_model_type = ""
for key in model_types:
if ckpt_name.startswith(key):
new_model_type = key
break
if model_type != new_model_type:
if exists(model):
del model
config_path = osp.join(config_root, f"{new_model_type}.yaml")
new_model = instantiate_from_config(OmegaConf.load(config_path).model).cpu().eval()
print(f"Switched to {new_model_type} model, loading weights from [{ckpt_name}]...")
model = new_model
# Download model from HF Hub
local_path = download_model(ckpt_name)
model.parameterization = "eps" if ckpt_name.find("eps") > -1 else "v"
model.init_from_ckpt(local_path, logging=True)
model.switch_to_fp16()
model_type = new_model_type
current_checkpoint = ckpt_name
print(f"Loaded model from [{ckpt_name}], model_type [{model_type}].")
gr.Info("Loaded model successfully.")
except Exception as e:
print(f"Error type: {e}")
print(traceback.print_exc())
gr.Info("Failed in loading model.")
def get_last_seed():
return global_seed or -1
def reset_random_seed():
return -1
def visualize(reference, text, *args):
return visualize_heatmaps(model, reference, parse_prompts(text), *args)
def set_cas_scales(accurate, cas_args):
enc_scale, middle_scale, low_scale, strength = cas_args[:4]
if not accurate:
scale_strength = {
"level_control": True,
"scales": {
"encoder": enc_scale * strength,
"middle": middle_scale * strength,
"low": low_scale * strength,
}
}
else:
scale_strength = {
"level_control": False,
"scales": list(cas_args[4:])
}
return scale_strength
@spaces.GPU(duration=120)
@torch.no_grad()
def inference(
style_enhance, bg_enhance, fg_enhance, fg_disentangle_scale,
bs, input_s, input_r, input_bg, mask_ts, mask_ss, gs_r, gs_s, ctl_scale,
ctl_scale_1, ctl_scale_2, ctl_scale_3, ctl_scale_4,
fg_strength, bg_strength, merge_scale, mask_scale, height, width, seed, low_vram, step,
injection, autofit_size, remove_fg, rmbg, latent_inpaint, infid_x, infid_r, injstep, crop, pad_scale,
start_step, end_step, no_start_step, no_end_step, return_inter, sampler, scheduler, preprocess,
deterministic, text, target, anchor, control, target_scale, ts0, ts1, ts2, ts3, enhance, accurate,
*args
):
global global_seed, line_extractor, mask_extractor
global_seed = seed if seed > -1 else random.randint(0, MAXM_INT32)
torch.manual_seed(global_seed)
# Auto-fit size based on sketch dimensions
if autofit_size and exists(input_s):
sketch_w, sketch_h = input_s.size
aspect_ratio = sketch_w / sketch_h
target_area = 1024 * 1024
new_h = int((target_area / aspect_ratio) ** 0.5)
new_w = int(new_h * aspect_ratio)
height = ((new_h + 16) // 32) * 32
width = ((new_w + 16) // 32) * 32
height = max(768, min(1536, height))
width = max(768, min(1536, width))
gr.Info(f"Auto-fitted size: {width}x{height}")
smask, rmask, bgmask = None, None, None
manipulation_params = parse_prompts(text, target, anchor, control, target_scale, ts0, ts1, ts2, ts3, enhance)
inputs = preprocessing_inputs(
sketch = input_s,
reference = input_r,
background = input_bg,
preprocess = preprocess,
hook = injection,
resolution = (height, width),
extractor = line_extractor,
pad_scale = pad_scale,
)
sketch, reference, background, original_shape, inject_xr, inject_xs, white_sketch = inputs
cond = {"reference": reference, "sketch": sketch, "background": background}
mask_guided = bg_enhance or fg_enhance
if exists(white_sketch) and exists(reference) and mask_guided:
mask_extractor.cuda()
smask_extractor.cuda()
smask = smask_extractor.proceed(
x=white_sketch, pil_x=input_s, th=height, tw=width, threshold=mask_ss, crop=False
)
if exists(background) and remove_fg:
bgmask = mask_extractor.proceed(x=background, pil_x=input_bg, threshold=mask_ts, dilate=True)
filtered_background = torch.where(bgmask < mask_ts, background, torch.ones_like(background))
cond.update({"background": filtered_background, "rmask": bgmask})
else:
rmask = mask_extractor.proceed(x=reference, pil_x=input_r, threshold=mask_ts, dilate=True)
cond.update({"rmask": rmask})
rmask = torch.where(rmask > 0.5, torch.ones_like(rmask), torch.zeros_like(rmask))
cond.update({"smask": smask})
smask_extractor.cpu()
mask_extractor.cpu()
scale_strength = set_cas_scales(accurate, args)
ctl_scales = [ctl_scale_1, ctl_scale_2, ctl_scale_3, ctl_scale_4]
ctl_scales = [t * ctl_scale for t in ctl_scales]
results = model.generate(
# Colorization mode
style_enhance = style_enhance,
bg_enhance = bg_enhance,
fg_enhance = fg_enhance,
fg_disentangle_scale = fg_disentangle_scale,
latent_inpaint = latent_inpaint,
# Conditional inputs
cond = cond,
ctl_scale = ctl_scales,
merge_scale = merge_scale,
mask_scale = mask_scale,
mask_thresh = mask_ts,
mask_thresh_sketch = mask_ss,
# Sampling settings
bs = bs,
gs = [gs_r, gs_s],
sampler = sampler,
scheduler = scheduler,
start_step = start_step,
end_step = end_step,
no_start_step = no_start_step,
no_end_step = no_end_step,
strength = scale_strength,
fg_strength = fg_strength,
bg_strength = bg_strength,
seed = global_seed,
deterministic = deterministic,
height = height,
width = width,
step = step,
# Injection settings
injection = injection,
injection_cfg = infid_r,
injection_control = infid_x,
injection_start_step = injstep,
hook_xr = inject_xr,
hook_xs = inject_xs,
# Additional settings
low_vram = low_vram,
return_intermediate = return_inter,
manipulation_params = manipulation_params,
)
if rmbg:
mask_extractor.cuda()
mask = smask_extractor.proceed(x=-sketch, threshold=mask_ss).repeat(results.shape[0], 1, 1, 1)
results = torch.where(mask >= mask_ss, results, torch.ones_like(results))
mask_extractor.cpu()
results = postprocess(results, sketch, reference, background, crop, original_shape,
mask_guided, smask, rmask, bgmask, mask_ts, mask_ss)
torch.cuda.empty_cache()
gr.Info("Generation completed.")
return results