File size: 10,025 Bytes
d066167
 
 
 
f206837
d066167
 
4f6bab9
f35094d
 
4f6bab9
d066167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1928ea4
d066167
 
 
 
 
 
 
 
 
 
63d1fdd
 
4f6bab9
63d1fdd
115b3c7
 
63d1fdd
d066167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f6bab9
d066167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63d1fdd
 
 
 
 
 
d066167
1928ea4
 
d066167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f206837
d066167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
import os
import random
import traceback
import gradio as gr
import spaces
import os.path as osp

from huggingface_hub import hf_hub_download, list_repo_files

HF_TOKEN = os.environ.get("HF_TOKEN", None)
print(f"HF_TOKEN present: {HF_TOKEN is not None}")

from omegaconf import OmegaConf
from refnet.util import instantiate_from_config
from preprocessor import create_model
from .functool import *

model = None

model_type = ""
current_checkpoint = ""
global_seed = None

smask_extractor = create_model("ISNet-sketch").cpu()

MAXM_INT32 = 429496729

# HuggingFace model repository
HF_REPO_ID = "tellurion/ColorizeDiffusionXL"
MODEL_CACHE_DIR = "models"

model_types = ["sdxl", "xlv2"]

'''
    Gradio UI functions
'''


def get_available_models():
    """Fetch available .safetensors files from HuggingFace Hub."""
    try:
        files = list_repo_files(HF_REPO_ID, token=HF_TOKEN)
        return [f for f in files if f.endswith(".safetensors")]
    except Exception as e:
        print(f"Failed to list models from {HF_REPO_ID}: {e}")
        return []


def download_model(filename):
    """Download a model from HuggingFace Hub if not already cached."""
    os.makedirs(MODEL_CACHE_DIR, exist_ok=True)
    local_path = osp.join(MODEL_CACHE_DIR, filename)
    if osp.exists(local_path):
        return local_path

    print(f"Downloading {filename} from {HF_REPO_ID}...")
    gr.Info(f"Downloading {filename}...")
    path = hf_hub_download(
        repo_id=HF_REPO_ID,
        filename=filename,
        local_dir=MODEL_CACHE_DIR,
        token=HF_TOKEN,
    )
    print(f"Downloaded to {path}")
    return path


def switch_extractor(type):
    global line_extractor
    try:
        line_extractor = create_model(type)
        gr.Info(f"Switched to {type} extractor")
    except Exception as e:
        print(f"Error info: {e}")
        print(traceback.print_exc())
        gr.Info(f"Failed in loading {type} extractor")


def switch_mask_extractor(type):
    global mask_extractor
    try:
        mask_extractor = create_model(type)
        gr.Info(f"Switched to {type} extractor")
    except Exception as e:
        print(f"Error info: {e}")
        print(traceback.print_exc())
        gr.Info(f"Failed in loading {type} extractor")


def apppend_prompt(target, anchor, control, scale, enhance, ts0, ts1, ts2, ts3, prompt):
    target = target.strip()
    anchor = anchor.strip()
    control = control.strip()
    if target == "": target = "none"
    if anchor == "": anchor = "none"
    if control == "": control = "none"
    new_p = (f"\n[target] {target}; [anchor] {anchor}; [control] {control}; [scale] {str(scale)}; "
             f"[enhanced] {str(enhance)}; [ts0] {str(ts0)}; [ts1] {str(ts1)}; [ts2] {str(ts2)}; [ts3] {str(ts3)}")
    return "", "", "", 0.0, False, 0.5, 0.55, 0.65, 0.95, (prompt + new_p).strip()


def clear_prompts():
    return ""


def load_model(ckpt_name):
    global model, model_type, current_checkpoint
    config_root = "configs/inference"

    try:
        # Determine model type from filename prefix
        new_model_type = ""
        for key in model_types:
            if ckpt_name.startswith(key):
                new_model_type = key
                break

        if model_type != new_model_type:
            if exists(model):
                del model
            config_path = osp.join(config_root, f"{new_model_type}.yaml")
            new_model = instantiate_from_config(OmegaConf.load(config_path).model).cpu().eval()
            print(f"Switched to {new_model_type} model, loading weights from [{ckpt_name}]...")
            model = new_model

        # Download model from HF Hub
        local_path = download_model(ckpt_name)

        model.parameterization = "eps" if ckpt_name.find("eps") > -1 else "v"
        model.init_from_ckpt(local_path, logging=True)
        model.switch_to_fp16()

        model_type = new_model_type
        current_checkpoint = ckpt_name
        print(f"Loaded model from [{ckpt_name}], model_type [{model_type}].")
        gr.Info("Loaded model successfully.")

    except Exception as e:
        print(f"Error type: {e}")
        print(traceback.print_exc())
        gr.Info("Failed in loading model.")


def get_last_seed():
    return global_seed or -1


def reset_random_seed():
    return -1


def visualize(reference, text, *args):
    return visualize_heatmaps(model, reference, parse_prompts(text), *args)


def set_cas_scales(accurate, cas_args):
    enc_scale, middle_scale, low_scale, strength = cas_args[:4]
    if not accurate:
        scale_strength = {
            "level_control": True,
            "scales": {
                "encoder": enc_scale * strength,
                "middle": middle_scale * strength,
                "low": low_scale * strength,
            }
        }
    else:
        scale_strength = {
            "level_control": False,
            "scales": list(cas_args[4:])
        }
    return scale_strength


@spaces.GPU(duration=120)
@torch.no_grad()
def inference(
        style_enhance, bg_enhance, fg_enhance, fg_disentangle_scale,
        bs, input_s, input_r, input_bg, mask_ts, mask_ss, gs_r, gs_s, ctl_scale,
        ctl_scale_1, ctl_scale_2, ctl_scale_3, ctl_scale_4,
        fg_strength, bg_strength, merge_scale, mask_scale, height, width, seed, low_vram, step,
        injection, autofit_size, remove_fg, rmbg, latent_inpaint, infid_x, infid_r, injstep, crop, pad_scale,
        start_step, end_step, no_start_step, no_end_step, return_inter, sampler, scheduler, preprocess,
        deterministic, text, target, anchor, control, target_scale, ts0, ts1, ts2, ts3, enhance, accurate,
        *args
):
    global global_seed, line_extractor, mask_extractor
    global_seed = seed if seed > -1 else random.randint(0, MAXM_INT32)
    torch.manual_seed(global_seed)

    # Auto-fit size based on sketch dimensions
    if autofit_size and exists(input_s):
        sketch_w, sketch_h = input_s.size
        aspect_ratio = sketch_w / sketch_h
        target_area = 1024 * 1024
        new_h = int((target_area / aspect_ratio) ** 0.5)
        new_w = int(new_h * aspect_ratio)
        height = ((new_h + 16) // 32) * 32
        width = ((new_w + 16) // 32) * 32
        height = max(768, min(1536, height))
        width = max(768, min(1536, width))
        gr.Info(f"Auto-fitted size: {width}x{height}")

    smask, rmask, bgmask = None, None, None
    manipulation_params = parse_prompts(text, target, anchor, control, target_scale, ts0, ts1, ts2, ts3, enhance)
    inputs = preprocessing_inputs(
        sketch = input_s,
        reference = input_r,
        background = input_bg,
        preprocess = preprocess,
        hook = injection,
        resolution = (height, width),
        extractor = line_extractor,
        pad_scale = pad_scale,
    )
    sketch, reference, background, original_shape, inject_xr, inject_xs, white_sketch = inputs

    cond = {"reference": reference, "sketch": sketch, "background": background}
    mask_guided = bg_enhance or fg_enhance

    if exists(white_sketch) and exists(reference) and mask_guided:
        mask_extractor.cuda()
        smask_extractor.cuda()
        smask = smask_extractor.proceed(
            x=white_sketch, pil_x=input_s, th=height, tw=width, threshold=mask_ss, crop=False
        )

        if exists(background) and remove_fg:
            bgmask = mask_extractor.proceed(x=background, pil_x=input_bg, threshold=mask_ts, dilate=True)
            filtered_background = torch.where(bgmask < mask_ts, background, torch.ones_like(background))
            cond.update({"background": filtered_background, "rmask": bgmask})
        else:
            rmask = mask_extractor.proceed(x=reference, pil_x=input_r, threshold=mask_ts, dilate=True)
            cond.update({"rmask": rmask})
        rmask = torch.where(rmask > 0.5, torch.ones_like(rmask), torch.zeros_like(rmask))
        cond.update({"smask": smask})
        smask_extractor.cpu()
        mask_extractor.cpu()

    scale_strength = set_cas_scales(accurate, args)
    ctl_scales = [ctl_scale_1, ctl_scale_2, ctl_scale_3, ctl_scale_4]
    ctl_scales = [t * ctl_scale for t in ctl_scales]

    results = model.generate(
        # Colorization mode
        style_enhance = style_enhance,
        bg_enhance = bg_enhance,
        fg_enhance = fg_enhance,
        fg_disentangle_scale = fg_disentangle_scale,
        latent_inpaint = latent_inpaint,

        # Conditional inputs
        cond = cond,
        ctl_scale = ctl_scales,
        merge_scale = merge_scale,
        mask_scale = mask_scale,
        mask_thresh = mask_ts,
        mask_thresh_sketch = mask_ss,

        # Sampling settings
        bs = bs,
        gs = [gs_r, gs_s],
        sampler = sampler,
        scheduler = scheduler,
        start_step = start_step,
        end_step = end_step,
        no_start_step = no_start_step,
        no_end_step = no_end_step,
        strength = scale_strength,
        fg_strength = fg_strength,
        bg_strength = bg_strength,
        seed = global_seed,
        deterministic = deterministic,
        height = height,
        width = width,
        step = step,

        # Injection settings
        injection = injection,
        injection_cfg = infid_r,
        injection_control = infid_x,
        injection_start_step = injstep,
        hook_xr = inject_xr,
        hook_xs = inject_xs,

        # Additional settings
        low_vram = low_vram,
        return_intermediate = return_inter,
        manipulation_params = manipulation_params,
    )

    if rmbg:
        mask_extractor.cuda()
        mask = smask_extractor.proceed(x=-sketch, threshold=mask_ss).repeat(results.shape[0], 1, 1, 1)
        results = torch.where(mask >= mask_ss, results, torch.ones_like(results))
        mask_extractor.cpu()

    results = postprocess(results, sketch, reference, background, crop, original_shape,
                          mask_guided, smask, rmask, bgmask, mask_ts, mask_ss)
    torch.cuda.empty_cache()
    gr.Info("Generation completed.")
    return results