tellurion Claude Sonnet 4.6 commited on
Commit
1928ea4
·
1 Parent(s): f8a6cb1

Clean up dead code and add startup model loading

Browse files

- Remove unused `new` parameter from preprocess_sketch, preprocessing_inputs, postprocess
- Simplify load_model globals check (model is always defined)
- Load first available model at startup

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Files changed (3) hide show
  1. app.py +7 -1
  2. backend/appfunc.py +3 -3
  3. backend/functool.py +6 -11
app.py CHANGED
@@ -11,7 +11,7 @@ links = {
11
  "v1.5": "https://arxiv.org/abs/2502.19937v1",
12
  "v2": "https://arxiv.org/abs/2504.06895",
13
  "xl": "https://arxiv.org/abs/2601.04883",
14
- "weights": "https://huggingface.co/tellurion/colorizer/tree/main",
15
  "github": "https://github.com/tellurion-kanata/colorizeDiffusion",
16
  }
17
 
@@ -53,6 +53,9 @@ def init_interface(opt, *args, **kwargs) -> None:
53
  <a href="{links['v2']}" target="_blank">
54
  <img src="https://img.shields.io/badge/arXiv-2504.06895 (v2)-B31B1B?style=flat&logo=arXiv" alt="arXiv v2 Paper">
55
  </a>
 
 
 
56
  <a href="{links['weights']}" target="_blank">
57
  <img src="https://img.shields.io/badge/Hugging%20Face-Model%20Weights-FF9D00?style=flat&logo=Hugging%20Face" alt="Model Weights">
58
  </a>
@@ -219,4 +222,7 @@ if __name__ == '__main__':
219
  opt = app_options()
220
  switch_extractor(default_line_extractor)
221
  switch_mask_extractor(default_mask_extractor)
 
 
 
222
  init_interface(opt)
 
11
  "v1.5": "https://arxiv.org/abs/2502.19937v1",
12
  "v2": "https://arxiv.org/abs/2504.06895",
13
  "xl": "https://arxiv.org/abs/2601.04883",
14
+ "weights": "https://huggingface.co/tellurion/ColorizeDiffusionXL/tree/main",
15
  "github": "https://github.com/tellurion-kanata/colorizeDiffusion",
16
  }
17
 
 
53
  <a href="{links['v2']}" target="_blank">
54
  <img src="https://img.shields.io/badge/arXiv-2504.06895 (v2)-B31B1B?style=flat&logo=arXiv" alt="arXiv v2 Paper">
55
  </a>
56
+ <a href="{links['xl']}" target="_blank">
57
+ <img src="https://img.shields.io/badge/CVPR 2026-XL-0CA4A5?style=flat&logo=Semantic%20Web" alt="CVPR 2026">
58
+ </a>
59
  <a href="{links['weights']}" target="_blank">
60
  <img src="https://img.shields.io/badge/Hugging%20Face-Model%20Weights-FF9D00?style=flat&logo=Hugging%20Face" alt="Model Weights">
61
  </a>
 
222
  opt = app_options()
223
  switch_extractor(default_line_extractor)
224
  switch_mask_extractor(default_mask_extractor)
225
+ available_models = get_available_models()
226
+ if available_models:
227
+ load_model(available_models[0])
228
  init_interface(opt)
backend/appfunc.py CHANGED
@@ -26,7 +26,7 @@ smask_extractor = create_model("ISNet-sketch").cpu()
26
  MAXM_INT32 = 429496729
27
 
28
  # HuggingFace model repository
29
- HF_REPO_ID = "tellurion/sdxl"
30
  MODEL_CACHE_DIR = "models"
31
 
32
  model_types = ["sdxl", "xlv2"]
@@ -115,8 +115,8 @@ def load_model(ckpt_name):
115
  new_model_type = key
116
  break
117
 
118
- if model_type != new_model_type or not "model" in globals():
119
- if "model" in globals() and exists(model):
120
  del model
121
  config_path = osp.join(config_root, f"{new_model_type}.yaml")
122
  new_model = instantiate_from_config(OmegaConf.load(config_path).model).cpu().eval()
 
26
  MAXM_INT32 = 429496729
27
 
28
  # HuggingFace model repository
29
+ HF_REPO_ID = "tellurion/ColorizeDiffusionXL"
30
  MODEL_CACHE_DIR = "models"
31
 
32
  model_types = ["sdxl", "xlv2"]
 
115
  new_model_type = key
116
  break
117
 
118
+ if model_type != new_model_type:
119
+ if exists(model):
120
  del model
121
  config_path = osp.join(config_root, f"{new_model_type}.yaml")
122
  new_model = instantiate_from_config(OmegaConf.load(config_path).model).cpu().eval()
backend/functool.py CHANGED
@@ -67,7 +67,7 @@ def lineart_standard(x: Image.Image):
67
  result = to_tensor(intensity.clip(0, 255).astype(np.uint8))
68
  return result
69
 
70
- def preprocess_sketch(sketch, resolution, preprocess="none", extractor=None, new=False):
71
  w, h = sketch.size
72
  th, tw = resolution
73
  r = min(th/h, tw/w)
@@ -82,11 +82,7 @@ def preprocess_sketch(sketch, resolution, preprocess="none", extractor=None, new
82
  sketch = extractor.proceed(resize((768, 768))(sketch)).repeat(1, 3, 1, 1)
83
 
84
  sketch, original_shape = pad_image(resize((int(h*r), int(w*r)))(sketch), th, tw)
85
- if new:
86
- sketch = ((sketch + 1) / 2.).clamp(0, 1)
87
- white_sketch = 1 - sketch
88
- else:
89
- white_sketch = -sketch
90
  return sketch, original_shape, white_sketch
91
 
92
 
@@ -100,14 +96,13 @@ def preprocessing_inputs(
100
  resolution: tuple[int, int],
101
  extractor: nn.Module,
102
  pad_scale: float = 1.,
103
- new = False
104
  ):
105
  extractor = extractor.cuda()
106
  h, w = resolution
107
  if exists(sketch):
108
- sketch, original_shape, white_sketch = preprocess_sketch(sketch, resolution, preprocess, extractor, new)
109
  else:
110
- sketch = torch.zeros([1, 3, h, w], device="cuda") if new else -torch.ones([1, 3, h, w], device="cuda")
111
  white_sketch = None
112
  original_shape = (0, 0, h, w)
113
 
@@ -134,9 +129,9 @@ def preprocessing_inputs(
134
  return sketch, reference, background, original_shape, inject_xr, inject_xs, white_sketch
135
 
136
  def postprocess(results, sketch, reference, background, crop, original_shape,
137
- mask_guided, smask, rmask, bgmask, mask_ts, mask_ss, new=False):
138
  results = to_numpy(results)
139
- sketch = to_numpy(sketch, not new)[0]
140
 
141
  results_list = []
142
  for result in results:
 
67
  result = to_tensor(intensity.clip(0, 255).astype(np.uint8))
68
  return result
69
 
70
+ def preprocess_sketch(sketch, resolution, preprocess="none", extractor=None):
71
  w, h = sketch.size
72
  th, tw = resolution
73
  r = min(th/h, tw/w)
 
82
  sketch = extractor.proceed(resize((768, 768))(sketch)).repeat(1, 3, 1, 1)
83
 
84
  sketch, original_shape = pad_image(resize((int(h*r), int(w*r)))(sketch), th, tw)
85
+ white_sketch = -sketch
 
 
 
 
86
  return sketch, original_shape, white_sketch
87
 
88
 
 
96
  resolution: tuple[int, int],
97
  extractor: nn.Module,
98
  pad_scale: float = 1.,
 
99
  ):
100
  extractor = extractor.cuda()
101
  h, w = resolution
102
  if exists(sketch):
103
+ sketch, original_shape, white_sketch = preprocess_sketch(sketch, resolution, preprocess, extractor)
104
  else:
105
+ sketch = -torch.ones([1, 3, h, w], device="cuda")
106
  white_sketch = None
107
  original_shape = (0, 0, h, w)
108
 
 
129
  return sketch, reference, background, original_shape, inject_xr, inject_xs, white_sketch
130
 
131
  def postprocess(results, sketch, reference, background, crop, original_shape,
132
+ mask_guided, smask, rmask, bgmask, mask_ts, mask_ss):
133
  results = to_numpy(results)
134
+ sketch = to_numpy(sketch, True)[0]
135
 
136
  results_list = []
137
  for result in results: