Spaces:
Running on Zero
Running on Zero
Clean up dead code and add startup model loading
Browse files- Remove unused `new` parameter from preprocess_sketch, preprocessing_inputs, postprocess
- Simplify load_model globals check (model is always defined)
- Load first available model at startup
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- app.py +7 -1
- backend/appfunc.py +3 -3
- backend/functool.py +6 -11
app.py
CHANGED
|
@@ -11,7 +11,7 @@ links = {
|
|
| 11 |
"v1.5": "https://arxiv.org/abs/2502.19937v1",
|
| 12 |
"v2": "https://arxiv.org/abs/2504.06895",
|
| 13 |
"xl": "https://arxiv.org/abs/2601.04883",
|
| 14 |
-
"weights": "https://huggingface.co/tellurion/
|
| 15 |
"github": "https://github.com/tellurion-kanata/colorizeDiffusion",
|
| 16 |
}
|
| 17 |
|
|
@@ -53,6 +53,9 @@ def init_interface(opt, *args, **kwargs) -> None:
|
|
| 53 |
<a href="{links['v2']}" target="_blank">
|
| 54 |
<img src="https://img.shields.io/badge/arXiv-2504.06895 (v2)-B31B1B?style=flat&logo=arXiv" alt="arXiv v2 Paper">
|
| 55 |
</a>
|
|
|
|
|
|
|
|
|
|
| 56 |
<a href="{links['weights']}" target="_blank">
|
| 57 |
<img src="https://img.shields.io/badge/Hugging%20Face-Model%20Weights-FF9D00?style=flat&logo=Hugging%20Face" alt="Model Weights">
|
| 58 |
</a>
|
|
@@ -219,4 +222,7 @@ if __name__ == '__main__':
|
|
| 219 |
opt = app_options()
|
| 220 |
switch_extractor(default_line_extractor)
|
| 221 |
switch_mask_extractor(default_mask_extractor)
|
|
|
|
|
|
|
|
|
|
| 222 |
init_interface(opt)
|
|
|
|
| 11 |
"v1.5": "https://arxiv.org/abs/2502.19937v1",
|
| 12 |
"v2": "https://arxiv.org/abs/2504.06895",
|
| 13 |
"xl": "https://arxiv.org/abs/2601.04883",
|
| 14 |
+
"weights": "https://huggingface.co/tellurion/ColorizeDiffusionXL/tree/main",
|
| 15 |
"github": "https://github.com/tellurion-kanata/colorizeDiffusion",
|
| 16 |
}
|
| 17 |
|
|
|
|
| 53 |
<a href="{links['v2']}" target="_blank">
|
| 54 |
<img src="https://img.shields.io/badge/arXiv-2504.06895 (v2)-B31B1B?style=flat&logo=arXiv" alt="arXiv v2 Paper">
|
| 55 |
</a>
|
| 56 |
+
<a href="{links['xl']}" target="_blank">
|
| 57 |
+
<img src="https://img.shields.io/badge/CVPR 2026-XL-0CA4A5?style=flat&logo=Semantic%20Web" alt="CVPR 2026">
|
| 58 |
+
</a>
|
| 59 |
<a href="{links['weights']}" target="_blank">
|
| 60 |
<img src="https://img.shields.io/badge/Hugging%20Face-Model%20Weights-FF9D00?style=flat&logo=Hugging%20Face" alt="Model Weights">
|
| 61 |
</a>
|
|
|
|
| 222 |
opt = app_options()
|
| 223 |
switch_extractor(default_line_extractor)
|
| 224 |
switch_mask_extractor(default_mask_extractor)
|
| 225 |
+
available_models = get_available_models()
|
| 226 |
+
if available_models:
|
| 227 |
+
load_model(available_models[0])
|
| 228 |
init_interface(opt)
|
backend/appfunc.py
CHANGED
|
@@ -26,7 +26,7 @@ smask_extractor = create_model("ISNet-sketch").cpu()
|
|
| 26 |
MAXM_INT32 = 429496729
|
| 27 |
|
| 28 |
# HuggingFace model repository
|
| 29 |
-
HF_REPO_ID = "tellurion/
|
| 30 |
MODEL_CACHE_DIR = "models"
|
| 31 |
|
| 32 |
model_types = ["sdxl", "xlv2"]
|
|
@@ -115,8 +115,8 @@ def load_model(ckpt_name):
|
|
| 115 |
new_model_type = key
|
| 116 |
break
|
| 117 |
|
| 118 |
-
if model_type != new_model_type
|
| 119 |
-
if
|
| 120 |
del model
|
| 121 |
config_path = osp.join(config_root, f"{new_model_type}.yaml")
|
| 122 |
new_model = instantiate_from_config(OmegaConf.load(config_path).model).cpu().eval()
|
|
|
|
| 26 |
MAXM_INT32 = 429496729
|
| 27 |
|
| 28 |
# HuggingFace model repository
|
| 29 |
+
HF_REPO_ID = "tellurion/ColorizeDiffusionXL"
|
| 30 |
MODEL_CACHE_DIR = "models"
|
| 31 |
|
| 32 |
model_types = ["sdxl", "xlv2"]
|
|
|
|
| 115 |
new_model_type = key
|
| 116 |
break
|
| 117 |
|
| 118 |
+
if model_type != new_model_type:
|
| 119 |
+
if exists(model):
|
| 120 |
del model
|
| 121 |
config_path = osp.join(config_root, f"{new_model_type}.yaml")
|
| 122 |
new_model = instantiate_from_config(OmegaConf.load(config_path).model).cpu().eval()
|
backend/functool.py
CHANGED
|
@@ -67,7 +67,7 @@ def lineart_standard(x: Image.Image):
|
|
| 67 |
result = to_tensor(intensity.clip(0, 255).astype(np.uint8))
|
| 68 |
return result
|
| 69 |
|
| 70 |
-
def preprocess_sketch(sketch, resolution, preprocess="none", extractor=None
|
| 71 |
w, h = sketch.size
|
| 72 |
th, tw = resolution
|
| 73 |
r = min(th/h, tw/w)
|
|
@@ -82,11 +82,7 @@ def preprocess_sketch(sketch, resolution, preprocess="none", extractor=None, new
|
|
| 82 |
sketch = extractor.proceed(resize((768, 768))(sketch)).repeat(1, 3, 1, 1)
|
| 83 |
|
| 84 |
sketch, original_shape = pad_image(resize((int(h*r), int(w*r)))(sketch), th, tw)
|
| 85 |
-
|
| 86 |
-
sketch = ((sketch + 1) / 2.).clamp(0, 1)
|
| 87 |
-
white_sketch = 1 - sketch
|
| 88 |
-
else:
|
| 89 |
-
white_sketch = -sketch
|
| 90 |
return sketch, original_shape, white_sketch
|
| 91 |
|
| 92 |
|
|
@@ -100,14 +96,13 @@ def preprocessing_inputs(
|
|
| 100 |
resolution: tuple[int, int],
|
| 101 |
extractor: nn.Module,
|
| 102 |
pad_scale: float = 1.,
|
| 103 |
-
new = False
|
| 104 |
):
|
| 105 |
extractor = extractor.cuda()
|
| 106 |
h, w = resolution
|
| 107 |
if exists(sketch):
|
| 108 |
-
sketch, original_shape, white_sketch = preprocess_sketch(sketch, resolution, preprocess, extractor
|
| 109 |
else:
|
| 110 |
-
sketch =
|
| 111 |
white_sketch = None
|
| 112 |
original_shape = (0, 0, h, w)
|
| 113 |
|
|
@@ -134,9 +129,9 @@ def preprocessing_inputs(
|
|
| 134 |
return sketch, reference, background, original_shape, inject_xr, inject_xs, white_sketch
|
| 135 |
|
| 136 |
def postprocess(results, sketch, reference, background, crop, original_shape,
|
| 137 |
-
mask_guided, smask, rmask, bgmask, mask_ts, mask_ss
|
| 138 |
results = to_numpy(results)
|
| 139 |
-
sketch = to_numpy(sketch,
|
| 140 |
|
| 141 |
results_list = []
|
| 142 |
for result in results:
|
|
|
|
| 67 |
result = to_tensor(intensity.clip(0, 255).astype(np.uint8))
|
| 68 |
return result
|
| 69 |
|
| 70 |
+
def preprocess_sketch(sketch, resolution, preprocess="none", extractor=None):
|
| 71 |
w, h = sketch.size
|
| 72 |
th, tw = resolution
|
| 73 |
r = min(th/h, tw/w)
|
|
|
|
| 82 |
sketch = extractor.proceed(resize((768, 768))(sketch)).repeat(1, 3, 1, 1)
|
| 83 |
|
| 84 |
sketch, original_shape = pad_image(resize((int(h*r), int(w*r)))(sketch), th, tw)
|
| 85 |
+
white_sketch = -sketch
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
return sketch, original_shape, white_sketch
|
| 87 |
|
| 88 |
|
|
|
|
| 96 |
resolution: tuple[int, int],
|
| 97 |
extractor: nn.Module,
|
| 98 |
pad_scale: float = 1.,
|
|
|
|
| 99 |
):
|
| 100 |
extractor = extractor.cuda()
|
| 101 |
h, w = resolution
|
| 102 |
if exists(sketch):
|
| 103 |
+
sketch, original_shape, white_sketch = preprocess_sketch(sketch, resolution, preprocess, extractor)
|
| 104 |
else:
|
| 105 |
+
sketch = -torch.ones([1, 3, h, w], device="cuda")
|
| 106 |
white_sketch = None
|
| 107 |
original_shape = (0, 0, h, w)
|
| 108 |
|
|
|
|
| 129 |
return sketch, reference, background, original_shape, inject_xr, inject_xs, white_sketch
|
| 130 |
|
| 131 |
def postprocess(results, sketch, reference, background, crop, original_shape,
|
| 132 |
+
mask_guided, smask, rmask, bgmask, mask_ts, mask_ss):
|
| 133 |
results = to_numpy(results)
|
| 134 |
+
sketch = to_numpy(sketch, True)[0]
|
| 135 |
|
| 136 |
results_list = []
|
| 137 |
for result in results:
|